|
| 1 | +#!/usr/bin/env python3 |
| 2 | +# Copyright (c) Meta Platforms, Inc. and affiliates. |
| 3 | +# |
| 4 | +# This source code is licensed under the MIT license found in the |
| 5 | +# LICENSE file in the root directory of this source tree. |
| 6 | + |
| 7 | +import os |
| 8 | +import shutil |
| 9 | +import tempfile |
| 10 | +from typing import List |
| 11 | + |
| 12 | +from ax.core import ( |
| 13 | + BatchTrial, |
| 14 | + Experiment, |
| 15 | + Objective, |
| 16 | + OptimizationConfig, |
| 17 | + Parameter, |
| 18 | + ParameterType, |
| 19 | + RangeParameter, |
| 20 | + SearchSpace, |
| 21 | +) |
| 22 | +from ax.metrics.torchx import TorchXMetric |
| 23 | +from ax.modelbridge.dispatch_utils import choose_generation_strategy |
| 24 | +from ax.runners.torchx import TorchXRunner |
| 25 | +from ax.service.scheduler import SchedulerOptions, Scheduler, FailureRateExceededError |
| 26 | +from ax.utils.common.constants import Keys |
| 27 | +from ax.utils.common.testutils import TestCase |
| 28 | +from torchx.components import utils |
| 29 | + |
| 30 | + |
| 31 | +class TorchXRunnerTest(TestCase): |
| 32 | + def setUp(self) -> None: |
| 33 | + self.test_dir = tempfile.mkdtemp("torchx_runtime_hpo_ax_test") |
| 34 | + |
| 35 | + self.old_cwd = os.getcwd() |
| 36 | + os.chdir(os.path.dirname(os.path.dirname(__file__))) |
| 37 | + |
| 38 | + self._parameters: List[Parameter] = [ |
| 39 | + RangeParameter( |
| 40 | + name="x1", |
| 41 | + lower=-10.0, |
| 42 | + upper=10.0, |
| 43 | + parameter_type=ParameterType.FLOAT, |
| 44 | + ), |
| 45 | + RangeParameter( |
| 46 | + name="x2", |
| 47 | + lower=-10.0, |
| 48 | + upper=10.0, |
| 49 | + parameter_type=ParameterType.FLOAT, |
| 50 | + ), |
| 51 | + ] |
| 52 | + |
| 53 | + self._minimize = True |
| 54 | + self._objective = Objective( |
| 55 | + metric=TorchXMetric( |
| 56 | + name="booth_eval", |
| 57 | + ), |
| 58 | + minimize=self._minimize, |
| 59 | + ) |
| 60 | + |
| 61 | + self._runner = TorchXRunner( |
| 62 | + tracker_base=self.test_dir, |
| 63 | + component=utils.booth, |
| 64 | + scheduler="local_cwd", |
| 65 | + cfg={"prepend_cwd": True}, |
| 66 | + ) |
| 67 | + |
| 68 | + def tearDown(self) -> None: |
| 69 | + shutil.rmtree(self.test_dir) |
| 70 | + os.chdir(self.old_cwd) |
| 71 | + |
| 72 | + def test_run_experiment_locally(self) -> None: |
| 73 | + """Runs optimization over n rounds of k sequential trials.""" |
| 74 | + |
| 75 | + experiment = Experiment( |
| 76 | + name="torchx_booth_sequential_demo", |
| 77 | + search_space=SearchSpace(parameters=self._parameters), |
| 78 | + optimization_config=OptimizationConfig(objective=self._objective), |
| 79 | + runner=self._runner, |
| 80 | + is_test=True, |
| 81 | + properties={Keys.IMMUTABLE_SEARCH_SPACE_AND_OPT_CONF: True}, |
| 82 | + ) |
| 83 | + |
| 84 | + scheduler = Scheduler( |
| 85 | + experiment=experiment, |
| 86 | + generation_strategy=( |
| 87 | + choose_generation_strategy( |
| 88 | + search_space=experiment.search_space, |
| 89 | + ) |
| 90 | + ), |
| 91 | + options=SchedulerOptions(), |
| 92 | + ) |
| 93 | + |
| 94 | + try: |
| 95 | + for _ in range(3): |
| 96 | + scheduler.run_n_trials(max_trials=2) |
| 97 | + |
| 98 | + # TorchXMetric always returns trial index; hence the best experiment for min |
| 99 | + # objective will be the params for trial 0. |
| 100 | + scheduler.report_results() |
| 101 | + except FailureRateExceededError: |
| 102 | + pass # TODO(ehotaj): Figure out why this test fails in OSS. |
| 103 | + # Nothing to assert, just make sure experiment runs. |
| 104 | + |
| 105 | + def test_stop_trials(self) -> None: |
| 106 | + experiment = Experiment( |
| 107 | + name="torchx_booth_sequential_demo", |
| 108 | + search_space=SearchSpace(parameters=self._parameters), |
| 109 | + optimization_config=OptimizationConfig(objective=self._objective), |
| 110 | + runner=self._runner, |
| 111 | + is_test=True, |
| 112 | + properties={Keys.IMMUTABLE_SEARCH_SPACE_AND_OPT_CONF: True}, |
| 113 | + ) |
| 114 | + scheduler = Scheduler( |
| 115 | + experiment=experiment, |
| 116 | + generation_strategy=( |
| 117 | + choose_generation_strategy( |
| 118 | + search_space=experiment.search_space, |
| 119 | + ) |
| 120 | + ), |
| 121 | + options=SchedulerOptions(), |
| 122 | + ) |
| 123 | + scheduler.run(max_new_trials=3) |
| 124 | + trial = scheduler.running_trials[0] |
| 125 | + reason = self._runner.stop(trial, reason="some_reason") |
| 126 | + self.assertEqual(reason, {"reason": "some_reason"}) |
| 127 | + |
| 128 | + def test_run_experiment_locally_in_batches(self) -> None: |
| 129 | + """Runs optimization over k x n rounds of k parallel trials. |
| 130 | +
|
| 131 | + This asks Ax to run up to max_parallelism_cap trials in parallel by submitting |
| 132 | + them to the scheduler at the same time. |
| 133 | +
|
| 134 | + NOTE: |
| 135 | + * setting max_parallelism_cap in generation_strategy |
| 136 | + * setting run_trials_in_batches in scheduler options |
| 137 | + * setting total_trials = parallelism * rounds |
| 138 | +
|
| 139 | + """ |
| 140 | + parallelism = 2 |
| 141 | + rounds = 3 |
| 142 | + |
| 143 | + experiment = Experiment( |
| 144 | + name="torchx_booth_parallel_demo", |
| 145 | + search_space=SearchSpace(parameters=self._parameters), |
| 146 | + optimization_config=OptimizationConfig(objective=self._objective), |
| 147 | + runner=self._runner, |
| 148 | + is_test=True, |
| 149 | + properties={Keys.IMMUTABLE_SEARCH_SPACE_AND_OPT_CONF: True}, |
| 150 | + ) |
| 151 | + |
| 152 | + scheduler = Scheduler( |
| 153 | + experiment=experiment, |
| 154 | + generation_strategy=( |
| 155 | + choose_generation_strategy( |
| 156 | + search_space=experiment.search_space, |
| 157 | + max_parallelism_cap=parallelism, |
| 158 | + ) |
| 159 | + ), |
| 160 | + options=SchedulerOptions( |
| 161 | + run_trials_in_batches=True, total_trials=(parallelism * rounds) |
| 162 | + ), |
| 163 | + ) |
| 164 | + |
| 165 | + try: |
| 166 | + scheduler.run_all_trials() |
| 167 | + |
| 168 | + # TorchXMetric always returns trial index; hence the best experiment for min |
| 169 | + # objective will be the params for trial 0. |
| 170 | + scheduler.report_results() |
| 171 | + except FailureRateExceededError: |
| 172 | + pass # TODO(ehotaj): Figure out why this test fails in OSS. |
| 173 | + # Nothing to assert, just make sure experiment runs. |
| 174 | + |
| 175 | + def test_runner_no_batch_trials(self) -> None: |
| 176 | + experiment = Experiment( |
| 177 | + name="runner_test", |
| 178 | + search_space=SearchSpace(parameters=self._parameters), |
| 179 | + optimization_config=OptimizationConfig(objective=self._objective), |
| 180 | + runner=self._runner, |
| 181 | + is_test=True, |
| 182 | + ) |
| 183 | + |
| 184 | + with self.assertRaises(ValueError): |
| 185 | + self._runner.run(trial=BatchTrial(experiment)) |
0 commit comments