Skip to content

Commit 8e2e68f

Browse files
Eugen Hotajfacebook-github-bot
authored andcommitted
Move TorchXRunner into Ax. (#855)
Summary: Pull Request resolved: #855 X-link: pytorch/torchx#427 As title. Reviewed By: lena-kashtelyan Differential Revision: D34928063 fbshipit-source-id: 41fa86d6cc789dc5e84228f037df82eb3a7847d9
1 parent e01c700 commit 8e2e68f

File tree

7 files changed

+470
-3
lines changed

7 files changed

+470
-3
lines changed

.github/workflows/build-and-test.yml

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,8 @@ jobs:
2929
pip install git+https://github.com/cornellius-gp/gpytorch.git
3030
pip install git+https://github.com/pytorch/botorch.git
3131
pip install -e .[dev,mysql,notebook]
32-
pip install tensorboard # For tensorboard unit tests
32+
pip install tensorboard # For tensorboard unit tests.
33+
pip install torchx # For torchx unit tests.
3334
- name: Tests and coverage
3435
run: |
3536
pytest -ra --cov=ax
@@ -83,7 +84,8 @@ jobs:
8384
pip install git+https://github.com/cornellius-gp/gpytorch.git
8485
pip install git+https://github.com/pytorch/botorch.git
8586
pip install -e .[dev,mysql,notebook]
86-
pip install tensorboard # For generating Sphinx docs for TensorboardCurveMetric
87+
pip install tensorboard # For generating Sphinx docs for TensorboardCurveMetric.
88+
pip install torchx # For generating Sphinx docs for TorchXMetric.
8789
- name: Validate Sphinx
8890
run: |
8991
python scripts/validate_sphinx.py -p "${pwd}"

ax/metrics/torchx.py

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
#
3+
# This source code is licensed under the MIT license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
from typing import Any, cast
7+
8+
import pandas as pd
9+
from ax.core import Trial
10+
from ax.core.base_trial import BaseTrial
11+
from ax.core.data import Data
12+
from ax.core.metric import Metric
13+
from ax.runners.torchx import TORCHX_TRACKER_BASE
14+
from ax.utils.common.logger import get_logger
15+
from ax.utils.common.typeutils import not_none
16+
17+
logger = get_logger(__name__)
18+
19+
try:
20+
from torchx.runtime.tracking import FsspecResultTracker
21+
22+
class TorchXMetric(Metric):
23+
"""
24+
Fetches AppMetric (the observation returned by the trial job/app) via the
25+
``torchx.tracking`` module. Assumes that the app used the tracker in the
26+
following manner:
27+
28+
.. code-block:: python
29+
30+
tracker = torchx.runtime.tracking.FsspecResultTracker(tracker_base)
31+
tracker[str(trial_index)] = {metric_name: value}
32+
33+
# -- or --
34+
tracker[str(trial_index)] = {"metric_name/mean": mean_value,
35+
"metric_name/sem": sem_value}
36+
37+
"""
38+
39+
def fetch_trial_data(self, trial: BaseTrial, **kwargs: Any) -> Data:
40+
41+
tracker_base = trial.run_metadata[TORCHX_TRACKER_BASE]
42+
tracker = FsspecResultTracker(tracker_base)
43+
res = tracker[trial.index]
44+
45+
if self.name in res:
46+
mean = res[self.name]
47+
sem = None
48+
else:
49+
mean = res.get(f"{self.name}/mean")
50+
sem = res.get(f"{self.name}/sem")
51+
52+
if mean is None and sem is None:
53+
raise KeyError(
54+
f"Observation for `{self.name}` not found in tracker at base "
55+
f"`{tracker_base}`. Ensure that the trial job is writing the "
56+
"results at the same tracker base."
57+
)
58+
59+
df_dict = {
60+
"arm_name": not_none(cast(Trial, trial).arm).name,
61+
"trial_index": trial.index,
62+
"metric_name": self.name,
63+
"mean": mean,
64+
"sem": sem,
65+
}
66+
return Data(df=pd.DataFrame.from_records([df_dict]))
67+
68+
69+
except ImportError:
70+
logger.warning(
71+
"torchx package not found. If you would like to use TorchXMetric, please "
72+
"install torchx."
73+
)
74+
pass

ax/runners/synthetic.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ class SyntheticRunner(Runner):
1616
Currently acts as a shell runner, only creating a name.
1717
"""
1818

19-
def __init__(self, dummy_metadata: Optional[str] = None):
19+
def __init__(self, dummy_metadata: Optional[str] = None) -> None:
2020
self.dummy_metadata = dummy_metadata
2121

2222
def run(self, trial: BaseTrial) -> Dict[str, Any]:

ax/runners/tests/test_torchx.py

Lines changed: 185 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,185 @@
1+
#!/usr/bin/env python3
2+
# Copyright (c) Meta Platforms, Inc. and affiliates.
3+
#
4+
# This source code is licensed under the MIT license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import os
8+
import shutil
9+
import tempfile
10+
from typing import List
11+
12+
from ax.core import (
13+
BatchTrial,
14+
Experiment,
15+
Objective,
16+
OptimizationConfig,
17+
Parameter,
18+
ParameterType,
19+
RangeParameter,
20+
SearchSpace,
21+
)
22+
from ax.metrics.torchx import TorchXMetric
23+
from ax.modelbridge.dispatch_utils import choose_generation_strategy
24+
from ax.runners.torchx import TorchXRunner
25+
from ax.service.scheduler import SchedulerOptions, Scheduler, FailureRateExceededError
26+
from ax.utils.common.constants import Keys
27+
from ax.utils.common.testutils import TestCase
28+
from torchx.components import utils
29+
30+
31+
class TorchXRunnerTest(TestCase):
32+
def setUp(self) -> None:
33+
self.test_dir = tempfile.mkdtemp("torchx_runtime_hpo_ax_test")
34+
35+
self.old_cwd = os.getcwd()
36+
os.chdir(os.path.dirname(os.path.dirname(__file__)))
37+
38+
self._parameters: List[Parameter] = [
39+
RangeParameter(
40+
name="x1",
41+
lower=-10.0,
42+
upper=10.0,
43+
parameter_type=ParameterType.FLOAT,
44+
),
45+
RangeParameter(
46+
name="x2",
47+
lower=-10.0,
48+
upper=10.0,
49+
parameter_type=ParameterType.FLOAT,
50+
),
51+
]
52+
53+
self._minimize = True
54+
self._objective = Objective(
55+
metric=TorchXMetric(
56+
name="booth_eval",
57+
),
58+
minimize=self._minimize,
59+
)
60+
61+
self._runner = TorchXRunner(
62+
tracker_base=self.test_dir,
63+
component=utils.booth,
64+
scheduler="local_cwd",
65+
cfg={"prepend_cwd": True},
66+
)
67+
68+
def tearDown(self) -> None:
69+
shutil.rmtree(self.test_dir)
70+
os.chdir(self.old_cwd)
71+
72+
def test_run_experiment_locally(self) -> None:
73+
"""Runs optimization over n rounds of k sequential trials."""
74+
75+
experiment = Experiment(
76+
name="torchx_booth_sequential_demo",
77+
search_space=SearchSpace(parameters=self._parameters),
78+
optimization_config=OptimizationConfig(objective=self._objective),
79+
runner=self._runner,
80+
is_test=True,
81+
properties={Keys.IMMUTABLE_SEARCH_SPACE_AND_OPT_CONF: True},
82+
)
83+
84+
scheduler = Scheduler(
85+
experiment=experiment,
86+
generation_strategy=(
87+
choose_generation_strategy(
88+
search_space=experiment.search_space,
89+
)
90+
),
91+
options=SchedulerOptions(),
92+
)
93+
94+
try:
95+
for _ in range(3):
96+
scheduler.run_n_trials(max_trials=2)
97+
98+
# TorchXMetric always returns trial index; hence the best experiment for min
99+
# objective will be the params for trial 0.
100+
scheduler.report_results()
101+
except FailureRateExceededError:
102+
pass # TODO(ehotaj): Figure out why this test fails in OSS.
103+
# Nothing to assert, just make sure experiment runs.
104+
105+
def test_stop_trials(self) -> None:
106+
experiment = Experiment(
107+
name="torchx_booth_sequential_demo",
108+
search_space=SearchSpace(parameters=self._parameters),
109+
optimization_config=OptimizationConfig(objective=self._objective),
110+
runner=self._runner,
111+
is_test=True,
112+
properties={Keys.IMMUTABLE_SEARCH_SPACE_AND_OPT_CONF: True},
113+
)
114+
scheduler = Scheduler(
115+
experiment=experiment,
116+
generation_strategy=(
117+
choose_generation_strategy(
118+
search_space=experiment.search_space,
119+
)
120+
),
121+
options=SchedulerOptions(),
122+
)
123+
scheduler.run(max_new_trials=3)
124+
trial = scheduler.running_trials[0]
125+
reason = self._runner.stop(trial, reason="some_reason")
126+
self.assertEqual(reason, {"reason": "some_reason"})
127+
128+
def test_run_experiment_locally_in_batches(self) -> None:
129+
"""Runs optimization over k x n rounds of k parallel trials.
130+
131+
This asks Ax to run up to max_parallelism_cap trials in parallel by submitting
132+
them to the scheduler at the same time.
133+
134+
NOTE:
135+
* setting max_parallelism_cap in generation_strategy
136+
* setting run_trials_in_batches in scheduler options
137+
* setting total_trials = parallelism * rounds
138+
139+
"""
140+
parallelism = 2
141+
rounds = 3
142+
143+
experiment = Experiment(
144+
name="torchx_booth_parallel_demo",
145+
search_space=SearchSpace(parameters=self._parameters),
146+
optimization_config=OptimizationConfig(objective=self._objective),
147+
runner=self._runner,
148+
is_test=True,
149+
properties={Keys.IMMUTABLE_SEARCH_SPACE_AND_OPT_CONF: True},
150+
)
151+
152+
scheduler = Scheduler(
153+
experiment=experiment,
154+
generation_strategy=(
155+
choose_generation_strategy(
156+
search_space=experiment.search_space,
157+
max_parallelism_cap=parallelism,
158+
)
159+
),
160+
options=SchedulerOptions(
161+
run_trials_in_batches=True, total_trials=(parallelism * rounds)
162+
),
163+
)
164+
165+
try:
166+
scheduler.run_all_trials()
167+
168+
# TorchXMetric always returns trial index; hence the best experiment for min
169+
# objective will be the params for trial 0.
170+
scheduler.report_results()
171+
except FailureRateExceededError:
172+
pass # TODO(ehotaj): Figure out why this test fails in OSS.
173+
# Nothing to assert, just make sure experiment runs.
174+
175+
def test_runner_no_batch_trials(self) -> None:
176+
experiment = Experiment(
177+
name="runner_test",
178+
search_space=SearchSpace(parameters=self._parameters),
179+
optimization_config=OptimizationConfig(objective=self._objective),
180+
runner=self._runner,
181+
is_test=True,
182+
)
183+
184+
with self.assertRaises(ValueError):
185+
self._runner.run(trial=BatchTrial(experiment))

0 commit comments

Comments
 (0)