18
18
import shlex
19
19
import subprocess
20
20
import tempfile
21
+ import warnings
21
22
from dataclasses import dataclass
22
23
from datetime import datetime
23
24
from subprocess import CalledProcessError , PIPE
@@ -72,6 +73,55 @@ def appstate_from_slurm_state(slurm_state: str) -> AppState:
72
73
return SLURM_STATES .get (slurm_state , AppState .UNKNOWN )
73
74
74
75
76
+ def version () -> Tuple [int , int ]:
77
+ """
78
+ Uses ``sinfo --version`` to get the slurm version. If the command fails, it
79
+ assumes the version is ``slurm 24.05.8``.
80
+
81
+ Returns:
82
+ -------
83
+ Tuple[int, int] slurm version as a tuple of ints (major, minor).
84
+ """
85
+
86
+ cmd = ["sinfo" , "--version" ]
87
+ try :
88
+ out = subprocess .check_output (cmd , stderr = PIPE , encoding = "utf-8" )
89
+ except (CalledProcessError , FileNotFoundError ):
90
+ out = "slurm 24.05.8"
91
+ warnings .warn (
92
+ "Error running: `{sinfo_cmd}` to get SLURM version. Are you running outside the "
93
+ "cluster's login or head node? This typically happens when running in `--dryrun`"
94
+ " mode. Assuming version is `slurm 24.05.8`." ,
95
+ RuntimeWarning ,
96
+ stacklevel = 2 ,
97
+ )
98
+
99
+ # sinfo --version returns in the form "slurm 24.1.0"
100
+ _ , version_literal = out .split (" " , maxsplit = 2 )
101
+ major , minor = [int (v ) for v in version_literal .split ("." )][:2 ]
102
+
103
+ return (major , minor )
104
+
105
+
106
+ def _should_use_gpus_per_node_from_version () -> bool :
107
+ """
108
+ Determine whether to use gpus-per-node based on automatically detected slurm version.
109
+
110
+ Change Reference: https://fburl.com/sqwqzxn6
111
+ > select/linear - Reject jobs asking for GRES per job|socket|task or cpus|mem per GRES.
112
+
113
+ Returns:
114
+ ``True`` in slurm ``version>=24.11.0``, ``False`` otherwise.
115
+ """
116
+
117
+ slurm_24_11_0 = (24 , 11 )
118
+ slurm_version = version ()
119
+
120
+ return slurm_version [0 ] > slurm_24_11_0 [0 ] or ( # Major version is greater
121
+ slurm_version [0 ] == slurm_24_11_0 [0 ] and slurm_version [1 ] >= slurm_24_11_0 [1 ]
122
+ ) # Major version is equal and minor version is greater or equal
123
+
124
+
75
125
SBATCH_JOB_OPTIONS = {
76
126
"comment" ,
77
127
"mail-user" ,
@@ -81,6 +131,7 @@ def appstate_from_slurm_state(slurm_state: str) -> AppState:
81
131
"partition" ,
82
132
"time" ,
83
133
"constraint" ,
134
+ "qos" ,
84
135
}
85
136
86
137
log : logging .Logger = logging .getLogger (__name__ )
@@ -106,6 +157,7 @@ def _apply_app_id_env(s: str) -> str:
106
157
"mail-user" : Optional [str ],
107
158
"mail-type" : Optional [str ],
108
159
"job_dir" : Optional [str ],
160
+ "qos" : Optional [str ],
109
161
},
110
162
total = False ,
111
163
)
@@ -126,7 +178,11 @@ class SlurmReplicaRequest:
126
178
127
179
@classmethod
128
180
def from_role (
129
- cls , name : str , role : Role , cfg : SlurmOpts , nomem : bool
181
+ cls ,
182
+ name : str ,
183
+ role : Role ,
184
+ cfg : SlurmOpts ,
185
+ nomem : bool ,
130
186
) -> "SlurmReplicaRequest" :
131
187
"""
132
188
``from_role`` creates a SlurmReplicaRequest for the specific role and
@@ -149,7 +205,11 @@ def from_role(
149
205
if not nomem and resource .memMB > 0 :
150
206
sbatch_opts .setdefault ("mem" , str (resource .memMB ))
151
207
if resource .gpu > 0 :
152
- sbatch_opts .setdefault ("gpus-per-task" , str (resource .gpu ))
208
+ # Use smart GPU allocation based on automatically detected Slurm version
209
+ if _should_use_gpus_per_node_from_version ():
210
+ sbatch_opts .setdefault ("gpus-per-node" , str (resource .gpu ))
211
+ else :
212
+ sbatch_opts .setdefault ("gpus-per-task" , str (resource .gpu ))
153
213
154
214
srun_opts = {
155
215
"output" : f"slurm-{ macros .app_id } -{ name } .out" ,
@@ -378,6 +438,11 @@ def _run_opts(self) -> runopts:
378
438
iteration, jobs will be tracked in ``.torchxslurmjobdirs``.
379
439
""" ,
380
440
)
441
+ opts .add (
442
+ "qos" ,
443
+ type_ = str ,
444
+ help = "Quality of Service (QoS) to assign to the job." ,
445
+ )
381
446
return opts
382
447
383
448
def schedule (self , dryrun_info : AppDryRunInfo [SlurmBatchRequest ]) -> str :
0 commit comments