49
49
50
50
SLURM_JOB_DIRS = ".torchxslurmjobdirs"
51
51
52
+ DEFAULT_SLURM_VERSION : str = "1.0"
53
+
52
54
SLURM_STATES : Mapping [str , AppState ] = {
53
55
"BOOT_FAIL" : AppState .FAILED ,
54
56
"CANCELLED" : AppState .CANCELLED ,
@@ -72,6 +74,45 @@ def appstate_from_slurm_state(slurm_state: str) -> AppState:
72
74
return SLURM_STATES .get (slurm_state , AppState .UNKNOWN )
73
75
74
76
77
+ def _parse_slurm_version (version_str : str ) -> Tuple [int , int ]:
78
+ """
79
+ Parse Slurm version string (e.g., '24.05', '25.11.2') into (major, minor) tuple.
80
+ Raises ValueError if parsing fails.
81
+ """
82
+ parts = version_str .split ("." )
83
+ if len (parts ) < 2 :
84
+ raise ValueError (
85
+ f"Invalid Slurm version string: { version_str } . Expected format: '24.05' or '25.11.2'"
86
+ )
87
+
88
+ try :
89
+ major = int (parts [0 ])
90
+ minor = int (parts [1 ])
91
+ except (ValueError , IndexError ) as err :
92
+ raise ValueError (
93
+ f"Invalid Slurm version string: { version_str } . Expected format: '24.05' or '25.11.2'"
94
+ ) from err
95
+
96
+ return (major , minor )
97
+
98
+
99
+ def _should_use_gpus_per_node_from_version (version_str : Optional [str ]) -> bool :
100
+ """
101
+ Determine whether to use gpus-per-node based on version string.
102
+ Returns True if version > 24.11, False otherwise or if version cannot be parsed.
103
+ """
104
+ if not version_str :
105
+ return False
106
+
107
+ try :
108
+ major , minor = _parse_slurm_version (version_str )
109
+ except ValueError :
110
+ return False
111
+
112
+ # Use gpus-per-node if version > 24.11
113
+ return major > 24 or (major == 24 and minor > 11 )
114
+
115
+
75
116
SBATCH_JOB_OPTIONS = {
76
117
"comment" ,
77
118
"mail-user" ,
@@ -81,6 +122,7 @@ def appstate_from_slurm_state(slurm_state: str) -> AppState:
81
122
"partition" ,
82
123
"time" ,
83
124
"constraint" ,
125
+ "qos" ,
84
126
}
85
127
86
128
log : logging .Logger = logging .getLogger (__name__ )
@@ -106,6 +148,8 @@ def _apply_app_id_env(s: str) -> str:
106
148
"mail-user" : Optional [str ],
107
149
"mail-type" : Optional [str ],
108
150
"job_dir" : Optional [str ],
151
+ "qos" : Optional [str ],
152
+ "slurm_version" : Optional [str ],
109
153
},
110
154
total = False ,
111
155
)
@@ -126,7 +170,11 @@ class SlurmReplicaRequest:
126
170
127
171
@classmethod
128
172
def from_role (
129
- cls , name : str , role : Role , cfg : SlurmOpts , nomem : bool
173
+ cls ,
174
+ name : str ,
175
+ role : Role ,
176
+ cfg : SlurmOpts ,
177
+ nomem : bool ,
130
178
) -> "SlurmReplicaRequest" :
131
179
"""
132
180
``from_role`` creates a SlurmReplicaRequest for the specific role and
@@ -149,7 +197,12 @@ def from_role(
149
197
if not nomem and resource .memMB > 0 :
150
198
sbatch_opts .setdefault ("mem" , str (resource .memMB ))
151
199
if resource .gpu > 0 :
152
- sbatch_opts .setdefault ("gpus-per-task" , str (resource .gpu ))
200
+ # Use smart GPU allocation based on Slurm version from config
201
+ slurm_version = cfg .get ("slurm_version" )
202
+ if _should_use_gpus_per_node_from_version (slurm_version ):
203
+ sbatch_opts .setdefault ("gpus-per-node" , str (resource .gpu ))
204
+ else :
205
+ sbatch_opts .setdefault ("gpus-per-task" , str (resource .gpu ))
153
206
154
207
srun_opts = {
155
208
"output" : f"slurm-{ macros .app_id } -{ name } .out" ,
@@ -378,6 +431,18 @@ def _run_opts(self) -> runopts:
378
431
iteration, jobs will be tracked in ``.torchxslurmjobdirs``.
379
432
""" ,
380
433
)
434
+ opts .add (
435
+ "qos" ,
436
+ type_ = str ,
437
+ help = "Quality of Service (QoS) to assign to the job." ,
438
+ )
439
+ opts .add (
440
+ "slurm_version" ,
441
+ type_ = str ,
442
+ help = """Slurm version (e.g., '24.05', '25.11'). If version > 24.11,
443
+ uses gpus-per-node instead of gpus-per-task for GPU allocation.
444
+ """ ,
445
+ )
381
446
return opts
382
447
383
448
def schedule (self , dryrun_info : AppDryRunInfo [SlurmBatchRequest ]) -> str :
@@ -401,6 +466,55 @@ def schedule(self, dryrun_info: AppDryRunInfo[SlurmBatchRequest]) -> str:
401
466
402
467
return job_id
403
468
469
+ def _get_slurm_version (self ) -> str :
470
+ """
471
+ _get_slurm_version returns the Slurm version string (e.g., "24.05").
472
+ Raises ValueError if version cannot be determined.
473
+ """
474
+ try :
475
+ p = subprocess .run (
476
+ ["sinfo" , "--version" ],
477
+ stdout = subprocess .PIPE ,
478
+ stderr = subprocess .PIPE ,
479
+ )
480
+ except FileNotFoundError :
481
+ log .error (
482
+ "Slurm is not available (sinfo command not found). "
483
+ "Returning default 1.0 instead."
484
+ )
485
+
486
+ return DEFAULT_SLURM_VERSION
487
+
488
+ if p .returncode != 0 :
489
+ log .error (
490
+ f"Failed to get Slurm version: { p .stderr .decode ('utf-8' ).strip ()} . "
491
+ "Returning default 1.0 instead."
492
+ )
493
+
494
+ return DEFAULT_SLURM_VERSION
495
+
496
+ output = p .stdout .decode ("utf-8" ).strip ().lower ()
497
+ if not output .startswith ("slurm " ):
498
+ log .error (
499
+ f"Unexpected sinfo --version output format: { output } . "
500
+ "Returning default 1.0 instead."
501
+ )
502
+
503
+ return DEFAULT_SLURM_VERSION
504
+
505
+ # Remove "slurm " prefix and extract version (e.g., "24.05.4" -> "24.05")
506
+ version_full = output .replace ("slurm" , "" ).strip ()
507
+ version_parts = version_full .split ("." )
508
+ if len (version_parts ) < 2 :
509
+ log .error (
510
+ f"Invalid Slurm version format: `{ version_full } `; "
511
+ "Returning default 1.0 instead."
512
+ )
513
+
514
+ return DEFAULT_SLURM_VERSION
515
+
516
+ return f"{ version_parts [0 ]} .{ version_parts [1 ]} "
517
+
404
518
def _partition_memmb (self , partition : Optional [str ]) -> Optional [int ]:
405
519
"""
406
520
_partition_memmb returns the memory allocation for the given partition
@@ -441,6 +555,12 @@ def _submit_dryrun(
441
555
partition = cfg .get ("partition" )
442
556
assert partition is None or isinstance (partition , str ), "partition must be str"
443
557
558
+ # Create a new config with the resolved slurm version
559
+ resolved_cfg = cfg .copy ()
560
+ resolved_cfg ["slurm_version" ] = cfg .get (
561
+ "slurm_version" , self ._get_slurm_version ()
562
+ )
563
+
444
564
# check if the partition has at least 1GB memory, if we're not sure,
445
565
# default to using memory allocations
446
566
memmb = self ._partition_memmb (partition )
@@ -460,7 +580,7 @@ def _submit_dryrun(
460
580
replicas [name ] = SlurmReplicaRequest .from_role (
461
581
name ,
462
582
replica_role ,
463
- cfg ,
583
+ resolved_cfg ,
464
584
nomem = nomem ,
465
585
)
466
586
cmd = ["sbatch" , "--parsable" ]
0 commit comments