Skip to content

Commit 6e9777c

Browse files
authored
Merge branch 'main' into patch-1
2 parents baabb59 + cc41de6 commit 6e9777c

File tree

10 files changed

+565
-62
lines changed

10 files changed

+565
-62
lines changed

dev-requirements.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ torch>=2.7.0
2929
torchmetrics==1.6.3
3030
torchserve>=0.10.0
3131
torchtext==0.18.0
32-
torchvision==0.22.0
32+
torchvision==0.23.0
3333
typing-extensions
3434
ts==0.5.1
3535
ray[default]

torchx/cli/cmd_run.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -207,8 +207,7 @@ def _run(self, runner: Runner, args: argparse.Namespace) -> None:
207207
" (e.g. `local_cwd`)"
208208
)
209209

210-
scheduler_opts = runner.scheduler_run_opts(args.scheduler)
211-
cfg = scheduler_opts.cfg_from_str(args.scheduler_args)
210+
cfg = dict(runner.cfg_from_str(args.scheduler, args.scheduler_args))
212211
config.apply(scheduler=args.scheduler, cfg=cfg)
213212

214213
component, component_args = _parse_component_name_and_args(
@@ -263,12 +262,14 @@ def _run(self, runner: Runner, args: argparse.Namespace) -> None:
263262
sys.exit(1)
264263
except specs.InvalidRunConfigException as e:
265264
error_msg = (
266-
f"Scheduler arg is incorrect or missing required option: `{e.cfg_key}`\n"
267-
f"Run `torchx runopts` to check configuration for `{args.scheduler}` scheduler\n"
268-
f"Use `-cfg` to specify run cfg as `key1=value1,key2=value2` pair\n"
269-
"of setup `.torchxconfig` file, see: https://pytorch.org/torchx/main/experimental/runner.config.html"
265+
"Invalid scheduler configuration: %s\n"
266+
"To configure scheduler options, either:\n"
267+
" 1. Use the `-cfg` command-line argument, e.g., `-cfg key1=value1,key2=value2`\n"
268+
" 2. Set up a `.torchxconfig` file. For more details, visit: https://pytorch.org/torchx/main/runner.config.html\n"
269+
"Run `torchx runopts %s` to check all available configuration options for the "
270+
"`%s` scheduler."
270271
)
271-
logger.error(error_msg)
272+
print(error_msg % (e, args.scheduler, args.scheduler), file=sys.stderr)
272273
sys.exit(1)
273274

274275
def run(self, args: argparse.Namespace) -> None:

torchx/runner/api.py

Lines changed: 24 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -129,9 +129,9 @@ def __init__(
129129
def _get_scheduler_params_from_env(self) -> Dict[str, str]:
130130
scheduler_params = {}
131131
for key, value in os.environ.items():
132-
lower_case_key = key.lower()
133-
if lower_case_key.startswith("torchx_"):
134-
scheduler_params[lower_case_key.strip("torchx_")] = value
132+
key = key.lower()
133+
if key.startswith("torchx_"):
134+
scheduler_params[key.removeprefix("torchx_")] = value
135135
return scheduler_params
136136

137137
def __enter__(self) -> "Self":
@@ -486,6 +486,27 @@ def scheduler_run_opts(self, scheduler: str) -> runopts:
486486
"""
487487
return self._scheduler(scheduler).run_opts()
488488

489+
def cfg_from_str(self, scheduler: str, *cfg_literal: str) -> Mapping[str, CfgVal]:
490+
"""
491+
Convenience function around the scheduler's ``runopts.cfg_from_str()`` method.
492+
493+
Usage:
494+
495+
.. doctest::
496+
497+
from torchx.runner import get_runner
498+
499+
runner = get_runner()
500+
cfg = runner.cfg_from_str("local_cwd", "log_dir=/tmp/foobar", "prepend_cwd=True")
501+
assert cfg == {"log_dir": "/tmp/foobar", "prepend_cwd": True, "auto_set_cuda_visible_devices": False}
502+
"""
503+
504+
opts = self._scheduler(scheduler).run_opts()
505+
cfg = {}
506+
for cfg_str in cfg_literal:
507+
cfg.update(opts.cfg_from_str(cfg_str))
508+
return cfg
509+
489510
def scheduler_backends(self) -> List[str]:
490511
"""
491512
Returns a list of all supported scheduler backends.

torchx/runner/config.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -278,14 +278,14 @@ def dump(
278278
continue
279279

280280
# serialize list elements with `;` delimiter (consistent with torchx cli)
281-
if opt.opt_type == List[str]:
281+
if opt.is_type_list_of_str:
282282
# deal with empty or None default lists
283283
if opt.default:
284284
# pyre-ignore[6] opt.default type checked already as List[str]
285285
val = ";".join(opt.default)
286286
else:
287287
val = _NONE
288-
elif opt.opt_type == Dict[str, str]:
288+
elif opt.is_type_dict_of_str:
289289
# deal with empty or None default lists
290290
if opt.default:
291291
# pyre-ignore[16] opt.default type checked already as Dict[str, str]
@@ -536,26 +536,26 @@ def load(scheduler: str, f: TextIO, cfg: Dict[str, CfgVal]) -> None:
536536
# this also handles empty or None lists
537537
cfg[name] = None
538538
else:
539-
runopt = runopts.get(name)
539+
opt = runopts.get(name)
540540

541-
if runopt is None:
541+
if opt is None:
542542
log.warning(
543543
f"`{name} = {value}` was declared in the [{section}] section "
544544
f" of the config file but is not a runopt of `{scheduler}` scheduler."
545545
f" Remove the entry from the config file to no longer see this warning"
546546
)
547547
else:
548-
if runopt.opt_type is bool:
548+
if opt.opt_type is bool:
549549
# need to handle bool specially since str -> bool is based on
550550
# str emptiness not value (e.g. bool("False") == True)
551551
cfg[name] = config.getboolean(section, name)
552-
elif runopt.opt_type is List[str]:
552+
elif opt.is_type_list_of_str:
553553
cfg[name] = value.split(";")
554-
elif runopt.opt_type is Dict[str, str]:
554+
elif opt.is_type_dict_of_str:
555555
cfg[name] = {
556556
s.split(":", 1)[0]: s.split(":", 1)[1]
557557
for s in value.replace(",", ";").split(";")
558558
}
559559
else:
560560
# pyre-ignore[29]
561-
cfg[name] = runopt.opt_type(value)
561+
cfg[name] = opt.opt_type(value)

torchx/runner/test/api_test.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
parse_app_handle,
2929
Resource,
3030
Role,
31+
runopts,
3132
UnknownAppException,
3233
)
3334
from torchx.specs.finder import ComponentNotFoundException
@@ -701,3 +702,36 @@ def test_runner_manual_close(self, _) -> None:
701702
def test_get_default_runner(self, _) -> None:
702703
runner = get_runner()
703704
self.assertEqual("torchx", runner._name)
705+
706+
def test_cfg_from_str(self, _) -> None:
707+
scheduler_mock = MagicMock()
708+
opts = runopts()
709+
opts.add("foo", type_=str, default="", help="")
710+
opts.add("test_key", type_=str, default="", help="")
711+
opts.add("default_time", type_=int, default=0, help="")
712+
opts.add("enable", type_=bool, default=True, help="")
713+
opts.add("disable", type_=bool, default=True, help="")
714+
opts.add("complex_list", type_=List[str], default=[], help="")
715+
scheduler_mock.run_opts.return_value = opts
716+
717+
with Runner(
718+
name=SESSION_NAME,
719+
scheduler_factories={"local_dir": lambda name, **kwargs: scheduler_mock},
720+
) as runner:
721+
self.assertDictEqual(
722+
{
723+
"foo": "bar",
724+
"test_key": "test_value",
725+
"default_time": 42,
726+
"enable": True,
727+
"disable": False,
728+
"complex_list": ["v1", "v2", "v3"],
729+
},
730+
runner.cfg_from_str(
731+
"local_dir",
732+
"foo=bar",
733+
"test_key=test_value",
734+
"default_time=42",
735+
"enable=True,disable=False,complex_list=v1;v2;v3",
736+
),
737+
)

torchx/runner/test/config_test.py

Lines changed: 24 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -95,22 +95,34 @@ def _run_opts(self) -> runopts:
9595
)
9696
opts.add(
9797
"l",
98-
type_=List[str],
98+
type_=list[str],
9999
default=["a", "b", "c"],
100100
help="a list option",
101101
)
102102
opts.add(
103-
"l_none",
103+
"l_typing",
104104
type_=List[str],
105+
default=["a", "b", "c"],
106+
help="a typing.List option",
107+
)
108+
opts.add(
109+
"l_none",
110+
type_=list[str],
105111
default=None,
106112
help="a None list option",
107113
)
108114
opts.add(
109115
"d",
110-
type_=Dict[str, str],
116+
type_=dict[str, str],
111117
default={"foo": "bar"},
112118
help="a dict option",
113119
)
120+
opts.add(
121+
"d_typing",
122+
type_=Dict[str, str],
123+
default={"foo": "bar"},
124+
help="a typing.Dict option",
125+
)
114126
opts.add(
115127
"d_none",
116128
type_=Dict[str, str],
@@ -151,6 +163,10 @@ def _run_opts(self) -> runopts:
151163
[test]
152164
s = my_default
153165
i = 100
166+
l = abc;def
167+
l_typing = ghi;jkl
168+
d = a:b,c:d
169+
d_typing = e:f,g:h
154170
"""
155171

156172
_MY_CONFIG2 = """#
@@ -387,6 +403,10 @@ def test_apply_dirs(self, _) -> None:
387403
self.assertEqual("runtime_value", cfg.get("s"))
388404
self.assertEqual(100, cfg.get("i"))
389405
self.assertEqual(1.2, cfg.get("f"))
406+
self.assertEqual({"a": "b", "c": "d"}, cfg.get("d"))
407+
self.assertEqual({"e": "f", "g": "h"}, cfg.get("d_typing"))
408+
self.assertEqual(["abc", "def"], cfg.get("l"))
409+
self.assertEqual(["ghi", "jkl"], cfg.get("l_typing"))
390410

391411
def test_dump_invalid_scheduler(self) -> None:
392412
with self.assertRaises(ValueError):
@@ -460,7 +480,7 @@ def test_dump_and_load_all_runopt_types(self, _) -> None:
460480

461481
# all runopts in the TestScheduler have defaults, just check against those
462482
for opt_name, opt in TestScheduler("test").run_opts():
463-
self.assertEqual(cfg.get(opt_name), opt.default)
483+
self.assertEqual(opt.default, cfg.get(opt_name))
464484

465485
def test_dump_and_load_all_registered_schedulers(self) -> None:
466486
# dump all the runopts for all registered schedulers

torchx/schedulers/slurm_scheduler.py

Lines changed: 67 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
import shlex
1919
import subprocess
2020
import tempfile
21+
import warnings
2122
from dataclasses import dataclass
2223
from datetime import datetime
2324
from subprocess import CalledProcessError, PIPE
@@ -72,6 +73,55 @@ def appstate_from_slurm_state(slurm_state: str) -> AppState:
7273
return SLURM_STATES.get(slurm_state, AppState.UNKNOWN)
7374

7475

76+
def version() -> Tuple[int, int]:
77+
"""
78+
Uses ``sinfo --version`` to get the slurm version. If the command fails, it
79+
assumes the version is ``slurm 24.05.8``.
80+
81+
Returns:
82+
-------
83+
Tuple[int, int] slurm version as a tuple of ints (major, minor).
84+
"""
85+
86+
cmd = ["sinfo", "--version"]
87+
try:
88+
out = subprocess.check_output(cmd, stderr=PIPE, encoding="utf-8")
89+
except (CalledProcessError, FileNotFoundError):
90+
out = "slurm 24.05.8"
91+
warnings.warn(
92+
"Error running: `{sinfo_cmd}` to get SLURM version. Are you running outside the "
93+
"cluster's login or head node? This typically happens when running in `--dryrun`"
94+
" mode. Assuming version is `slurm 24.05.8`.",
95+
RuntimeWarning,
96+
stacklevel=2,
97+
)
98+
99+
# sinfo --version returns in the form "slurm 24.1.0"
100+
_, version_literal = out.split(" ", maxsplit=2)
101+
major, minor = [int(v) for v in version_literal.split(".")][:2]
102+
103+
return (major, minor)
104+
105+
106+
def _should_use_gpus_per_node_from_version() -> bool:
107+
"""
108+
Determine whether to use gpus-per-node based on automatically detected slurm version.
109+
110+
Change Reference: https://fburl.com/sqwqzxn6
111+
> select/linear - Reject jobs asking for GRES per job|socket|task or cpus|mem per GRES.
112+
113+
Returns:
114+
``True`` in slurm ``version>=24.11.0``, ``False`` otherwise.
115+
"""
116+
117+
slurm_24_11_0 = (24, 11)
118+
slurm_version = version()
119+
120+
return slurm_version[0] > slurm_24_11_0[0] or ( # Major version is greater
121+
slurm_version[0] == slurm_24_11_0[0] and slurm_version[1] >= slurm_24_11_0[1]
122+
) # Major version is equal and minor version is greater or equal
123+
124+
75125
SBATCH_JOB_OPTIONS = {
76126
"comment",
77127
"mail-user",
@@ -81,6 +131,7 @@ def appstate_from_slurm_state(slurm_state: str) -> AppState:
81131
"partition",
82132
"time",
83133
"constraint",
134+
"qos",
84135
}
85136

86137
log: logging.Logger = logging.getLogger(__name__)
@@ -106,6 +157,7 @@ def _apply_app_id_env(s: str) -> str:
106157
"mail-user": Optional[str],
107158
"mail-type": Optional[str],
108159
"job_dir": Optional[str],
160+
"qos": Optional[str],
109161
},
110162
total=False,
111163
)
@@ -126,7 +178,11 @@ class SlurmReplicaRequest:
126178

127179
@classmethod
128180
def from_role(
129-
cls, name: str, role: Role, cfg: SlurmOpts, nomem: bool
181+
cls,
182+
name: str,
183+
role: Role,
184+
cfg: SlurmOpts,
185+
nomem: bool,
130186
) -> "SlurmReplicaRequest":
131187
"""
132188
``from_role`` creates a SlurmReplicaRequest for the specific role and
@@ -149,7 +205,11 @@ def from_role(
149205
if not nomem and resource.memMB > 0:
150206
sbatch_opts.setdefault("mem", str(resource.memMB))
151207
if resource.gpu > 0:
152-
sbatch_opts.setdefault("gpus-per-task", str(resource.gpu))
208+
# Use smart GPU allocation based on automatically detected Slurm version
209+
if _should_use_gpus_per_node_from_version():
210+
sbatch_opts.setdefault("gpus-per-node", str(resource.gpu))
211+
else:
212+
sbatch_opts.setdefault("gpus-per-task", str(resource.gpu))
153213

154214
srun_opts = {
155215
"output": f"slurm-{macros.app_id}-{name}.out",
@@ -378,6 +438,11 @@ def _run_opts(self) -> runopts:
378438
iteration, jobs will be tracked in ``.torchxslurmjobdirs``.
379439
""",
380440
)
441+
opts.add(
442+
"qos",
443+
type_=str,
444+
help="Quality of Service (QoS) to assign to the job.",
445+
)
381446
return opts
382447

383448
def schedule(self, dryrun_info: AppDryRunInfo[SlurmBatchRequest]) -> str:

0 commit comments

Comments
 (0)