Skip to content

Commit d06b519

Browse files
Version based GPU configuration and QoS addition (#1092)
Summary: Slurm 24.11.0rc1 and beyond do not suport GRES per task. So we need to call `gpus-per-node` in sbatch to ensure failure free allocation. https://github.com/SchedMD/slurm/blob/master/CHANGELOG/slurm-24.11.md # Changes here 1. Introduced Slurm Version based GPU request configuration 2. Introduced an option QoS parameter which can be used to control priority of jobs. Differential Revision: D78778304
1 parent 4adf7f6 commit d06b519

File tree

2 files changed

+409
-11
lines changed

2 files changed

+409
-11
lines changed

torchx/schedulers/slurm_scheduler.py

Lines changed: 123 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,8 @@
4949

5050
SLURM_JOB_DIRS = ".torchxslurmjobdirs"
5151

52+
DEFAULT_SLURM_VERSION: str = "1.0"
53+
5254
SLURM_STATES: Mapping[str, AppState] = {
5355
"BOOT_FAIL": AppState.FAILED,
5456
"CANCELLED": AppState.CANCELLED,
@@ -72,6 +74,45 @@ def appstate_from_slurm_state(slurm_state: str) -> AppState:
7274
return SLURM_STATES.get(slurm_state, AppState.UNKNOWN)
7375

7476

77+
def _parse_slurm_version(version_str: str) -> Tuple[int, int]:
78+
"""
79+
Parse Slurm version string (e.g., '24.05', '25.11.2') into (major, minor) tuple.
80+
Raises ValueError if parsing fails.
81+
"""
82+
parts = version_str.split(".")
83+
if len(parts) < 2:
84+
raise ValueError(
85+
f"Invalid Slurm version string: {version_str}. Expected format: '24.05' or '25.11.2'"
86+
)
87+
88+
try:
89+
major = int(parts[0])
90+
minor = int(parts[1])
91+
except (ValueError, IndexError) as err:
92+
raise ValueError(
93+
f"Invalid Slurm version string: {version_str}. Expected format: '24.05' or '25.11.2'"
94+
) from err
95+
96+
return (major, minor)
97+
98+
99+
def _should_use_gpus_per_node_from_version(version_str: Optional[str]) -> bool:
100+
"""
101+
Determine whether to use gpus-per-node based on version string.
102+
Returns True if version > 24.11, False otherwise or if version cannot be parsed.
103+
"""
104+
if not version_str:
105+
return False
106+
107+
try:
108+
major, minor = _parse_slurm_version(version_str)
109+
except ValueError:
110+
return False
111+
112+
# Use gpus-per-node if version > 24.11
113+
return major > 24 or (major == 24 and minor > 11)
114+
115+
75116
SBATCH_JOB_OPTIONS = {
76117
"comment",
77118
"mail-user",
@@ -81,6 +122,7 @@ def appstate_from_slurm_state(slurm_state: str) -> AppState:
81122
"partition",
82123
"time",
83124
"constraint",
125+
"qos",
84126
}
85127

86128
log: logging.Logger = logging.getLogger(__name__)
@@ -106,6 +148,8 @@ def _apply_app_id_env(s: str) -> str:
106148
"mail-user": Optional[str],
107149
"mail-type": Optional[str],
108150
"job_dir": Optional[str],
151+
"qos": Optional[str],
152+
"slurm_version": Optional[str],
109153
},
110154
total=False,
111155
)
@@ -126,7 +170,11 @@ class SlurmReplicaRequest:
126170

127171
@classmethod
128172
def from_role(
129-
cls, name: str, role: Role, cfg: SlurmOpts, nomem: bool
173+
cls,
174+
name: str,
175+
role: Role,
176+
cfg: SlurmOpts,
177+
nomem: bool,
130178
) -> "SlurmReplicaRequest":
131179
"""
132180
``from_role`` creates a SlurmReplicaRequest for the specific role and
@@ -149,7 +197,12 @@ def from_role(
149197
if not nomem and resource.memMB > 0:
150198
sbatch_opts.setdefault("mem", str(resource.memMB))
151199
if resource.gpu > 0:
152-
sbatch_opts.setdefault("gpus-per-task", str(resource.gpu))
200+
# Use smart GPU allocation based on Slurm version from config
201+
slurm_version = cfg.get("slurm_version")
202+
if _should_use_gpus_per_node_from_version(slurm_version):
203+
sbatch_opts.setdefault("gpus-per-node", str(resource.gpu))
204+
else:
205+
sbatch_opts.setdefault("gpus-per-task", str(resource.gpu))
153206

154207
srun_opts = {
155208
"output": f"slurm-{macros.app_id}-{name}.out",
@@ -378,6 +431,18 @@ def _run_opts(self) -> runopts:
378431
iteration, jobs will be tracked in ``.torchxslurmjobdirs``.
379432
""",
380433
)
434+
opts.add(
435+
"qos",
436+
type_=str,
437+
help="Quality of Service (QoS) to assign to the job.",
438+
)
439+
opts.add(
440+
"slurm_version",
441+
type_=str,
442+
help="""Slurm version (e.g., '24.05', '25.11'). If version > 24.11,
443+
uses gpus-per-node instead of gpus-per-task for GPU allocation.
444+
""",
445+
)
381446
return opts
382447

383448
def schedule(self, dryrun_info: AppDryRunInfo[SlurmBatchRequest]) -> str:
@@ -401,6 +466,55 @@ def schedule(self, dryrun_info: AppDryRunInfo[SlurmBatchRequest]) -> str:
401466

402467
return job_id
403468

469+
def _get_slurm_version(self) -> str:
470+
"""
471+
_get_slurm_version returns the Slurm version string (e.g., "24.05").
472+
Raises ValueError if version cannot be determined.
473+
"""
474+
try:
475+
p = subprocess.run(
476+
["sinfo", "--version"],
477+
stdout=subprocess.PIPE,
478+
stderr=subprocess.PIPE,
479+
)
480+
except FileNotFoundError:
481+
log.error(
482+
"Slurm is not available (sinfo command not found). "
483+
"Returning default 1.0 instead."
484+
)
485+
486+
return DEFAULT_SLURM_VERSION
487+
488+
if p.returncode != 0:
489+
log.error(
490+
f"Failed to get Slurm version: {p.stderr.decode('utf-8').strip()}. "
491+
"Returning default 1.0 instead."
492+
)
493+
494+
return DEFAULT_SLURM_VERSION
495+
496+
output = p.stdout.decode("utf-8").strip().lower()
497+
if not output.startswith("slurm "):
498+
log.error(
499+
f"Unexpected sinfo --version output format: {output}. "
500+
"Returning default 1.0 instead."
501+
)
502+
503+
return DEFAULT_SLURM_VERSION
504+
505+
# Remove "slurm " prefix and extract version (e.g., "24.05.4" -> "24.05")
506+
version_full = output.replace("slurm", "").strip()
507+
version_parts = version_full.split(".")
508+
if len(version_parts) < 2:
509+
log.error(
510+
f"Invalid Slurm version format: `{version_full}`; "
511+
"Returning default 1.0 instead."
512+
)
513+
514+
return DEFAULT_SLURM_VERSION
515+
516+
return f"{version_parts[0]}.{version_parts[1]}"
517+
404518
def _partition_memmb(self, partition: Optional[str]) -> Optional[int]:
405519
"""
406520
_partition_memmb returns the memory allocation for the given partition
@@ -441,6 +555,12 @@ def _submit_dryrun(
441555
partition = cfg.get("partition")
442556
assert partition is None or isinstance(partition, str), "partition must be str"
443557

558+
# Create a new config with the resolved slurm version
559+
resolved_cfg = cfg.copy()
560+
resolved_cfg["slurm_version"] = cfg.get(
561+
"slurm_version", self._get_slurm_version()
562+
)
563+
444564
# check if the partition has at least 1GB memory, if we're not sure,
445565
# default to using memory allocations
446566
memmb = self._partition_memmb(partition)
@@ -460,7 +580,7 @@ def _submit_dryrun(
460580
replicas[name] = SlurmReplicaRequest.from_role(
461581
name,
462582
replica_role,
463-
cfg,
583+
resolved_cfg,
464584
nomem=nomem,
465585
)
466586
cmd = ["sbatch", "--parsable"]

0 commit comments

Comments
 (0)