Skip to content

Commit baabb59

Browse files
authored
Merge branch 'main' into patch-1
2 parents f9ac29b + 6641ab3 commit baabb59

34 files changed

+2168
-169
lines changed

.github/workflows/slurm-local-integration-tests.yaml

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,11 @@ on:
66
- main
77
pull_request:
88

9+
910
env:
10-
SLURM_VERSION: 21.08.6
11+
# slurm tag should be one of https://github.com/SchedMD/slurm/tags
12+
SLURM_TAG: slurm-23-11-11-1
13+
SLURM_VERSION: 23.11.11
1114

1215
jobs:
1316
slurm:

.pyre_configuration

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,5 +18,6 @@
1818
"stubs"
1919
],
2020
"strict": true,
21-
"version": "0.0.101732536891"
21+
"enable_strict_any_check": true,
22+
"version": "0.0.101749035478"
2223
}

dev-requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ torchmetrics==1.6.3
3030
torchserve>=0.10.0
3131
torchtext==0.18.0
3232
torchvision==0.22.0
33+
typing-extensions
3334
ts==0.5.1
3435
ray[default]
3536
wheel

requirements.txt

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,4 @@
1-
pyre-extensions
21
docstring-parser>=0.8.1
3-
importlib-metadata
42
pyyaml
53
docker
64
filelock

torchx/cli/cmd_list.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121

2222
HANDLE_HEADER = "APP HANDLE"
2323
STATUS_HEADER = "APP STATUS"
24+
NAME_HEADER = "APP NAME"
2425

2526

2627
class CmdList(SubCommand):
@@ -39,5 +40,7 @@ def add_arguments(self, subparser: argparse.ArgumentParser) -> None:
3940
def run(self, args: argparse.Namespace) -> None:
4041
with get_runner() as runner:
4142
apps = runner.list(args.scheduler)
42-
apps_data = [[app.app_handle, str(app.state)] for app in apps]
43-
print(tabulate(apps_data, headers=[HANDLE_HEADER, STATUS_HEADER]))
43+
apps_data = [[app.app_handle, app.name, str(app.state)] for app in apps]
44+
print(
45+
tabulate(apps_data, headers=[HANDLE_HEADER, NAME_HEADER, STATUS_HEADER])
46+
)

torchx/components/dist.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -132,9 +132,6 @@ def spmd(
132132
j: {nnodes}x{nproc_per_node}. For GPU hosts omitting nproc_per_node will infer it from the GPU count on the host
133133
env: environment variables to be passed to the run (e.g. ENV1=v1,ENV2=v2,ENV3=v3)
134134
max_retries: the number of scheduler retries allowed
135-
rdzv_port: the port on rank0's host to use for hosting the c10d store used for rendezvous.
136-
Only takes effect when running multi-node. When running single node, this parameter
137-
is ignored and a random free port is chosen.
138135
mounts: (for docker based runs only) mounts to mount into the worker environment/container
139136
(ex. type=<bind/volume>,src=/host,dst=/job[,readonly]).
140137
debug: whether to run with preset debug flags enabled
@@ -174,6 +171,7 @@ def ddp(
174171
max_retries: int = 0,
175172
rdzv_port: int = 29500,
176173
rdzv_backend: str = "c10d",
174+
rdzv_conf: Optional[str] = None,
177175
mounts: Optional[List[str]] = None,
178176
debug: bool = False,
179177
tee: int = 3,
@@ -208,6 +206,7 @@ def ddp(
208206
Only takes effect when running multi-node. When running single node, this parameter
209207
is ignored and a random free port is chosen.
210208
rdzv_backend: the rendezvous backend to use. Only takes effect when running multi-node.
209+
rdzv_conf: the additional rendezvous configuration to use (ex. join_timeout=600,close_timeout=600,timeout=600).
211210
mounts: mounts to mount into the worker environment/container (ex. type=<bind/volume>,src=/host,dst=/job[,readonly]).
212211
See scheduler documentation for more info.
213212
debug: whether to run with preset debug flags enabled
@@ -258,6 +257,7 @@ def ddp(
258257
"torchrun",
259258
"--rdzv_backend",
260259
rdzv_backend,
260+
*(["--rdzv_conf", rdzv_conf] if rdzv_conf is not None else []),
261261
"--rdzv_endpoint",
262262
rdzv_endpoint,
263263
"--rdzv_id",

torchx/components/structured_arg.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -30,8 +30,6 @@
3030
from pathlib import Path
3131
from typing import Optional
3232

33-
from pyre_extensions import none_throws
34-
3533
from torchx import specs
3634

3735

@@ -148,7 +146,8 @@ def parse_from(
148146
if m: # use the last module name
149147
run_name = m.rpartition(".")[2]
150148
else: # use script name w/ no extension
151-
run_name = Path(none_throws(script)).stem
149+
assert script, "`script` can't be `None` here due checks above"
150+
run_name = Path(script).stem
152151
return StructuredNameArgument(
153152
experiment_name or default_experiment_name, run_name
154153
)

torchx/components/test/dist_test.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,8 +41,10 @@ def test_ddp_debug(self) -> None:
4141
self.assertEqual(env[k], v)
4242

4343
def test_ddp_rdzv_backend_static(self) -> None:
44-
app = ddp(script="foo.py", rdzv_backend="static")
44+
rdzv_conf = "join_timeout=600,close_timeout=600,timeout=600"
45+
app = ddp(script="foo.py", rdzv_backend="static", rdzv_conf=rdzv_conf)
4546
cmd = app.roles[0].args[1]
47+
self.assertTrue(f"--rdzv_conf {rdzv_conf}" in cmd)
4648
self.assertTrue("--rdzv_backend static" in cmd)
4749
self.assertTrue("--node_rank" in cmd)
4850

torchx/runner/api.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,18 @@
1414
import warnings
1515
from datetime import datetime
1616
from types import TracebackType
17-
from typing import Any, Dict, Iterable, List, Mapping, Optional, Tuple, Type, TypeVar
17+
from typing import (
18+
Any,
19+
Dict,
20+
Iterable,
21+
List,
22+
Mapping,
23+
Optional,
24+
Tuple,
25+
Type,
26+
TYPE_CHECKING,
27+
TypeVar,
28+
)
1829

1930
from torchx.runner.events import log_event
2031
from torchx.schedulers import get_scheduler_factories, SchedulerFactory
@@ -44,6 +55,9 @@
4455
from torchx.util.types import none_throws
4556
from torchx.workspace.api import PkgInfo, WorkspaceBuilder, WorkspaceMixin
4657

58+
if TYPE_CHECKING:
59+
from typing_extensions import Self
60+
4761
from .config import get_config, get_configs
4862

4963
logger: logging.Logger = logging.getLogger(__name__)
@@ -120,7 +134,7 @@ def _get_scheduler_params_from_env(self) -> Dict[str, str]:
120134
scheduler_params[lower_case_key.strip("torchx_")] = value
121135
return scheduler_params
122136

123-
def __enter__(self) -> "Runner":
137+
def __enter__(self) -> "Self":
124138
return self
125139

126140
def __exit__(

torchx/runner/test/config_test.py

Lines changed: 7 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -470,20 +470,17 @@ def test_dump_and_load_all_registered_schedulers(self) -> None:
470470
sfile = StringIO()
471471
dump(sfile)
472472

473-
scheduler_factories = {
474-
**get_scheduler_factories(),
475-
**(
476-
get_scheduler_factories(
477-
group="torchx.schedulers.orchestrator", skip_defaults=True
478-
)
479-
or {}
480-
),
481-
}
473+
scheduler_factories = get_scheduler_factories()
482474

483475
for sched_name, sched in scheduler_factories.items():
484476
sfile.seek(0) # reset the file pos
485477
cfg = {}
486-
load(scheduler=sched_name, f=sfile, cfg=cfg)
478+
try:
479+
load(scheduler=sched_name, f=sfile, cfg=cfg)
480+
except ModuleNotFoundError:
481+
# just test the ones that have been installed
482+
continue
483+
487484
for opt_name, _ in sched("test").run_opts():
488485
self.assertTrue(
489486
opt_name in cfg,

0 commit comments

Comments
 (0)