Skip to content

Commit e01c700

Browse files
serialization/deserialization of botorch modular components required for combinatorial generation strategy (#861)
Summary: Pull Request resolved: #861 This provides Ax with the ability to serialize/deserialize the botorch modular components that are required for dme65's new combinatorial generation strategy (created via `get_combinerator_gs`). Reviewed By: danielrjiang Differential Revision: D34721287 fbshipit-source-id: eb120f55dd3edfd32fc770e567f7a599b7ce5d72
1 parent 4cd47f1 commit e01c700

File tree

8 files changed

+158
-23
lines changed

8 files changed

+158
-23
lines changed

ax/storage/botorch_modular_registry.py

Lines changed: 30 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,11 +40,18 @@
4040
from botorch.models.model_list_gp_regression import ModelListGP
4141
from botorch.models.multitask import FixedNoiseMultiTaskGP, MultiTaskGP
4242

43+
# Miscellaneous BoTorch imports
44+
from gpytorch.constraints import Interval
45+
from gpytorch.likelihoods.gaussian_likelihood import GaussianLikelihood
46+
from gpytorch.likelihoods.likelihood import Likelihood
47+
4348
# BoTorch `MarginalLogLikelihood` imports
4449
from gpytorch.mlls.exact_marginal_log_likelihood import ExactMarginalLogLikelihood
50+
from gpytorch.mlls.leave_one_out_pseudo_likelihood import LeaveOneOutPseudoLikelihood
4551
from gpytorch.mlls.marginal_log_likelihood import MarginalLogLikelihood
4652
from gpytorch.mlls.sum_marginal_log_likelihood import SumMarginalLogLikelihood
47-
53+
from gpytorch.priors.torch_priors import GammaPrior
54+
from torch.nn import Module
4855

4956
# NOTE: When adding a new registry for a class, make sure to make changes
5057
# to `CLASS_TO_REGISTRY` and `CLASS_TO_REVERSE_REGISTRY` in this file.
@@ -93,18 +100,30 @@
93100
"""
94101
MLL_REGISTRY: Dict[Type[MarginalLogLikelihood], str] = {
95102
ExactMarginalLogLikelihood: "ExactMarginalLogLikelihood",
103+
LeaveOneOutPseudoLikelihood: "LeaveOneOutPseudoLikelihood",
96104
SumMarginalLogLikelihood: "SumMarginalLogLikelihood",
97105
}
98106

107+
LIKELIHOOD_REGISTRY: Dict[Type[GaussianLikelihood], str] = {
108+
GaussianLikelihood: "GaussianLikelihood"
109+
}
110+
111+
GPYTORCH_COMPONENT_REGISTRY: Dict[Type[Module], str] = {
112+
Interval: "Interval",
113+
GammaPrior: "GammaPrior",
114+
}
99115

100116
"""
101117
Overarching mapping from encoded classes to registry map.
102118
"""
103119
CLASS_TO_REGISTRY: Dict[Any, Dict[Type[Any], str]] = {
104120
Acquisition: ACQUISITION_REGISTRY,
105121
AcquisitionFunction: ACQUISITION_FUNCTION_REGISTRY,
122+
Likelihood: LIKELIHOOD_REGISTRY,
106123
MarginalLogLikelihood: MLL_REGISTRY,
107124
Model: MODEL_REGISTRY,
125+
Interval: GPYTORCH_COMPONENT_REGISTRY,
126+
GammaPrior: GPYTORCH_COMPONENT_REGISTRY,
108127
}
109128

110129

@@ -130,15 +149,25 @@
130149
v: k for k, v in MLL_REGISTRY.items()
131150
}
132151

152+
REVERSE_LIKELIHOOD_REGISTRY: Dict[str, Type[Likelihood]] = {
153+
v: k for k, v in LIKELIHOOD_REGISTRY.items()
154+
}
155+
156+
REVERSE_GPYTORCH_COMPONENT_REGISTRY: Dict[str, Type[Module]] = {
157+
v: k for k, v in GPYTORCH_COMPONENT_REGISTRY.items()
158+
}
133159

134160
"""
135161
Overarching mapping from encoded classes to reverse registry map.
136162
"""
137163
CLASS_TO_REVERSE_REGISTRY: Dict[Any, Dict[str, Type[Any]]] = {
138164
Acquisition: REVERSE_ACQUISITION_REGISTRY,
139165
AcquisitionFunction: REVERSE_ACQUISITION_FUNCTION_REGISTRY,
166+
Likelihood: REVERSE_LIKELIHOOD_REGISTRY,
140167
MarginalLogLikelihood: REVERSE_MLL_REGISTRY,
141168
Model: REVERSE_MODEL_REGISTRY,
169+
Interval: REVERSE_GPYTORCH_COMPONENT_REGISTRY,
170+
GammaPrior: REVERSE_GPYTORCH_COMPONENT_REGISTRY,
142171
}
143172

144173

ax/storage/json_store/decoder.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,11 @@
3434
ModelRegistryBase,
3535
_decode_callables_from_references,
3636
)
37-
from ax.storage.json_store.decoders import batch_trial_from_json, trial_from_json
37+
from ax.storage.json_store.decoders import (
38+
batch_trial_from_json,
39+
trial_from_json,
40+
botorch_component_from_json,
41+
)
3842
from ax.storage.json_store.registry import (
3943
CORE_CLASS_DECODER_REGISTRY,
4044
CORE_DECODER_REGISTRY,
@@ -43,6 +47,7 @@
4347
from ax.utils.measurement import synthetic_functions
4448
from ax.utils.measurement.synthetic_functions import from_botorch
4549
from botorch.test_functions import synthetic as botorch_synthetic
50+
from torch.nn import Module
4651

4752

4853
def object_from_json(
@@ -156,6 +161,8 @@ def object_from_json(
156161
if isclass(_class) and issubclass(_class, Enum):
157162
# to access enum members by name, use item access
158163
return _class[object_json["name"]]
164+
elif isclass(_class) and issubclass(_class, Module):
165+
return botorch_component_from_json(botorch_class=_class, json=object_json)
159166
elif _class == GeneratorRun:
160167
return generator_run_from_json(
161168
object_json=object_json,

ax/storage/json_store/decoders.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66

77
from __future__ import annotations
88

9+
import inspect
10+
import logging
911
from datetime import datetime
1012
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Type
1113

@@ -19,6 +21,9 @@
1921
from ax.storage.botorch_modular_registry import CLASS_TO_REVERSE_REGISTRY
2022
from ax.storage.transform_registry import REVERSE_TRANSFORM_REGISTRY
2123
from ax.utils.common.kwargs import warn_on_kwargs
24+
from ax.utils.common.logger import get_logger
25+
26+
logger: logging.Logger = get_logger(__name__)
2227

2328

2429
if TYPE_CHECKING:
@@ -161,3 +166,32 @@ def class_from_json(json: Dict[str, Any]) -> Type[Any]:
161166
f"{class_path} does not have a corresponding entry in "
162167
"CLASS_TO_REVERSE_REGISTRY."
163168
)
169+
170+
171+
def botorch_component_from_json(botorch_class: Any, json: Dict[str, Any]) -> Type[Any]:
172+
"""Load any instance of `gpytorch.Module` or descendent registered in
173+
`CLASS_DECODER_REGISTRY` from state dict."""
174+
class_path = json.pop("class")
175+
state_dict = json.pop("state_dict")
176+
init_args = inspect.signature(botorch_class).parameters
177+
required_args = {p for p, v in init_args.items() if v.default is inspect._empty}
178+
allowable_args = set(init_args)
179+
received_args = set(state_dict)
180+
missing_args = required_args - received_args
181+
if missing_args:
182+
raise ValueError(
183+
f"Missing required initialization args {missing_args} for class "
184+
f"{class_path}. For gpytorch objects, this is likely because the "
185+
"object's `state_dict` method does not return the args required "
186+
"for initialization."
187+
)
188+
extra_args = received_args - allowable_args
189+
if extra_args:
190+
raise ValueError(
191+
f"Received unused args {extra_args} for class {class_path}. "
192+
"For gpytorch objects, this is likely because the object's "
193+
"`state_dict` method returns these extra args, which could "
194+
"indicate that the object's state will not be fully recreated "
195+
"by this serialization/deserialization method."
196+
)
197+
return botorch_class(**state_dict)

ax/storage/json_store/encoder.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -181,7 +181,8 @@ def object_to_json(
181181
return {"__type": f"torch_{_type.__name__}", "value": torch_type_to_str(obj)}
182182

183183
err = (
184-
f"Object {obj} passed to `object_to_json` (of type {_type}) is "
185-
f"not registered with a corresponding encoder in ENCODER_REGISTRY."
184+
f"Object {obj} passed to `object_to_json` (of type {_type}, module: "
185+
f"{_type.__module__}) is not registered with a corresponding encoder "
186+
"in ENCODER_REGISTRY."
186187
)
187188
raise JSONEncodeError(err)

ax/storage/json_store/encoders.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -487,7 +487,7 @@ def botorch_modular_to_dict(class_type: Type[Any]) -> Dict[str, Any]:
487487
"located in `ax/storage/botorch_modular_registry.py`."
488488
)
489489
return {
490-
"__type": f"Type[{_class.__name__}]",
490+
"__type": "Type[Module]",
491491
"index": registry[class_type],
492492
"class": f"{_class}",
493493
}
@@ -497,6 +497,19 @@ def botorch_modular_to_dict(class_type: Type[Any]) -> Dict[str, Any]:
497497
)
498498

499499

500+
def botorch_component_to_dict(input_obj: Type[Any]) -> Dict[str, Any]:
501+
class_type = input_obj.__class__
502+
state_dict = input_obj.state_dict()
503+
# Cast dict values to float to avoid errors with Tensors.
504+
state_dict = {k: float(v) for k, v in state_dict.items()}
505+
return {
506+
"__type": f"{class_type.__name__}",
507+
"index": class_type,
508+
"class": f"{class_type}",
509+
"state_dict": state_dict,
510+
}
511+
512+
500513
def percentile_early_stopping_strategy_to_dict(
501514
strategy: PercentileEarlyStoppingStrategy,
502515
) -> Dict[str, Any]:

ax/storage/json_store/registry.py

Lines changed: 22 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -60,11 +60,15 @@
6060
from ax.models.torch.botorch_modular.model import BoTorchModel
6161
from ax.models.torch.botorch_modular.surrogate import Surrogate
6262
from ax.runners.synthetic import SyntheticRunner
63-
from ax.storage.json_store.decoders import class_from_json, transform_type_from_json
63+
from ax.storage.json_store.decoders import (
64+
class_from_json,
65+
transform_type_from_json,
66+
)
6467
from ax.storage.json_store.encoders import (
6568
arm_to_dict,
6669
batch_to_dict,
6770
benchmark_problem_to_dict,
71+
botorch_component_to_dict,
6872
botorch_model_to_dict,
6973
botorch_modular_to_dict,
7074
choice_parameter_to_dict,
@@ -101,7 +105,11 @@
101105
from ax.storage.utils import DomainType, ParameterConstraintType
102106
from botorch.acquisition.acquisition import AcquisitionFunction
103107
from botorch.models.model import Model
108+
from gpytorch.constraints import Interval
109+
from gpytorch.likelihoods.likelihood import Likelihood
104110
from gpytorch.mlls.marginal_log_likelihood import MarginalLogLikelihood
111+
from gpytorch.priors.torch_priors import GammaPrior
112+
from torch.nn import Module
105113

106114

107115
CORE_ENCODER_REGISTRY: Dict[Type, Callable[[Any], Dict[str, Any]]] = {
@@ -118,10 +126,12 @@
118126
Experiment: experiment_to_dict,
119127
FactorialMetric: metric_to_dict,
120128
FixedParameter: fixed_parameter_to_dict,
129+
GammaPrior: botorch_component_to_dict,
121130
GenerationStep: generation_step_to_dict,
122131
GenerationStrategy: generation_strategy_to_dict,
123132
GeneratorRun: generator_run_to_dict,
124133
Hartmann6Metric: metric_to_dict,
134+
Interval: botorch_component_to_dict,
125135
ListSurrogate: surrogate_to_dict,
126136
L2NormMetric: metric_to_dict,
127137
MapData: map_data_to_dict,
@@ -161,11 +171,13 @@
161171
# The encoder iterates through this dictionary and uses the first superclass that
162172
# it finds, which might not be the intended superclass.
163173
CORE_CLASS_ENCODER_REGISTRY: Dict[Type, Callable[[Any], Dict[str, Any]]] = {
164-
Acquisition: botorch_modular_to_dict,
165-
AcquisitionFunction: botorch_modular_to_dict,
166-
MarginalLogLikelihood: botorch_modular_to_dict,
167-
Model: botorch_modular_to_dict,
168-
Transform: transform_type_to_dict,
174+
Acquisition: botorch_modular_to_dict, # Ax MBM component
175+
AcquisitionFunction: botorch_modular_to_dict, # BoTorch component
176+
Likelihood: botorch_modular_to_dict, # BoTorch component
177+
Module: botorch_modular_to_dict, # BoTorch module
178+
MarginalLogLikelihood: botorch_modular_to_dict, # BoTorch component
179+
Model: botorch_modular_to_dict, # BoTorch component
180+
Transform: transform_type_to_dict, # Ax general (not just MBM) component
169181
}
170182

171183
CORE_DECODER_REGISTRY: Dict[str, Type] = {
@@ -189,12 +201,14 @@
189201
"Experiment": Experiment,
190202
"FactorialMetric": FactorialMetric,
191203
"FixedParameter": FixedParameter,
204+
"GammaPrior": GammaPrior,
192205
"GenerationStrategy": GenerationStrategy,
193206
"GenerationStep": GenerationStep,
194207
"GeneratorRun": GeneratorRun,
195208
"GeneratorRunStruct": GeneratorRunStruct,
196209
"Hartmann6Metric": Hartmann6Metric,
197210
"HierarchicalSearchSpace": HierarchicalSearchSpace,
211+
"Interval": Interval,
198212
"ListSurrogate": ListSurrogate,
199213
"L2NormMetric": L2NormMetric,
200214
"MapData": MapData,
@@ -238,6 +252,8 @@
238252
CORE_CLASS_DECODER_REGISTRY: Dict[str, Callable[[Dict[str, Any]], Any]] = {
239253
"Type[Acquisition]": class_from_json,
240254
"Type[AcquisitionFunction]": class_from_json,
255+
"Type[Likelihood]": class_from_json,
256+
"Type[Module]": class_from_json,
241257
"Type[MarginalLogLikelihood]": class_from_json,
242258
"Type[Model]": class_from_json,
243259
"Type[Transform]": transform_type_from_json,

0 commit comments

Comments
 (0)