Skip to content

Commit ae55901

Browse files
Version based GPU configuration and QoS addition
Differential Revision: D78778304 Pull Request resolved: #1092
1 parent dc70d90 commit ae55901

File tree

2 files changed

+359
-16
lines changed

2 files changed

+359
-16
lines changed

torchx/schedulers/slurm_scheduler.py

Lines changed: 67 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
import shlex
1919
import subprocess
2020
import tempfile
21+
import warnings
2122
from dataclasses import dataclass
2223
from datetime import datetime
2324
from subprocess import CalledProcessError, PIPE
@@ -72,6 +73,55 @@ def appstate_from_slurm_state(slurm_state: str) -> AppState:
7273
return SLURM_STATES.get(slurm_state, AppState.UNKNOWN)
7374

7475

76+
def version() -> Tuple[int, int]:
77+
"""
78+
Uses ``sinfo --version`` to get the slurm version. If the command fails, it
79+
assumes the version is ``slurm 24.05.8``.
80+
81+
Returns:
82+
-------
83+
Tuple[int, int] slurm version as a tuple of ints (major, minor).
84+
"""
85+
86+
cmd = ["sinfo", "--version"]
87+
try:
88+
out = subprocess.check_output(cmd, stderr=PIPE, encoding="utf-8")
89+
except (CalledProcessError, FileNotFoundError):
90+
out = "slurm 24.05.8"
91+
warnings.warn(
92+
"Error running: `{sinfo_cmd}` to get SLURM version. Are you running outside the "
93+
"cluster's login or head node? This typically happens when running in `--dryrun`"
94+
" mode. Assuming version is `slurm 24.05.8`.",
95+
RuntimeWarning,
96+
stacklevel=2,
97+
)
98+
99+
# sinfo --version returns in the form "slurm 24.1.0"
100+
_, version_literal = out.split(" ", maxsplit=2)
101+
major, minor = [int(v) for v in version_literal.split(".")][:2]
102+
103+
return (major, minor)
104+
105+
106+
def _should_use_gpus_per_node_from_version() -> bool:
107+
"""
108+
Determine whether to use gpus-per-node based on automatically detected slurm version.
109+
110+
Change Reference: https://fburl.com/sqwqzxn6
111+
> select/linear - Reject jobs asking for GRES per job|socket|task or cpus|mem per GRES.
112+
113+
Returns:
114+
``True`` in slurm ``version>=24.11.0``, ``False`` otherwise.
115+
"""
116+
117+
slurm_24_11_0 = (24, 11)
118+
slurm_version = version()
119+
120+
return slurm_version[0] > slurm_24_11_0[0] or ( # Major version is greater
121+
slurm_version[0] == slurm_24_11_0[0] and slurm_version[1] >= slurm_24_11_0[1]
122+
) # Major version is equal and minor version is greater or equal
123+
124+
75125
SBATCH_JOB_OPTIONS = {
76126
"comment",
77127
"mail-user",
@@ -81,6 +131,7 @@ def appstate_from_slurm_state(slurm_state: str) -> AppState:
81131
"partition",
82132
"time",
83133
"constraint",
134+
"qos",
84135
}
85136

86137
log: logging.Logger = logging.getLogger(__name__)
@@ -106,6 +157,7 @@ def _apply_app_id_env(s: str) -> str:
106157
"mail-user": Optional[str],
107158
"mail-type": Optional[str],
108159
"job_dir": Optional[str],
160+
"qos": Optional[str],
109161
},
110162
total=False,
111163
)
@@ -126,7 +178,11 @@ class SlurmReplicaRequest:
126178

127179
@classmethod
128180
def from_role(
129-
cls, name: str, role: Role, cfg: SlurmOpts, nomem: bool
181+
cls,
182+
name: str,
183+
role: Role,
184+
cfg: SlurmOpts,
185+
nomem: bool,
130186
) -> "SlurmReplicaRequest":
131187
"""
132188
``from_role`` creates a SlurmReplicaRequest for the specific role and
@@ -149,7 +205,11 @@ def from_role(
149205
if not nomem and resource.memMB > 0:
150206
sbatch_opts.setdefault("mem", str(resource.memMB))
151207
if resource.gpu > 0:
152-
sbatch_opts.setdefault("gpus-per-task", str(resource.gpu))
208+
# Use smart GPU allocation based on automatically detected Slurm version
209+
if _should_use_gpus_per_node_from_version():
210+
sbatch_opts.setdefault("gpus-per-node", str(resource.gpu))
211+
else:
212+
sbatch_opts.setdefault("gpus-per-task", str(resource.gpu))
153213

154214
srun_opts = {
155215
"output": f"slurm-{macros.app_id}-{name}.out",
@@ -378,6 +438,11 @@ def _run_opts(self) -> runopts:
378438
iteration, jobs will be tracked in ``.torchxslurmjobdirs``.
379439
""",
380440
)
441+
opts.add(
442+
"qos",
443+
type_=str,
444+
help="Quality of Service (QoS) to assign to the job.",
445+
)
381446
return opts
382447

383448
def schedule(self, dryrun_info: AppDryRunInfo[SlurmBatchRequest]) -> str:

0 commit comments

Comments
 (0)