Skip to content

Commit 0c0669d

Browse files
authored
Make replay consistent with Pyro (#1345)
* make replay consistent with numpyro * Add docs for apply_stack to clarify its functionality
1 parent f9c756c commit 0c0669d

File tree

4 files changed

+47
-16
lines changed

4 files changed

+47
-16
lines changed

numpyro/distributions/distribution.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -762,8 +762,8 @@ def has_enumerate_support(self):
762762
return self.base_dist.has_enumerate_support
763763

764764
@property
765-
def reparameterized_params(self):
766-
return self.base_dist.reparameterized_params
765+
def reparametrized_params(self):
766+
return self.base_dist.reparametrized_params
767767

768768
@property
769769
def mean(self):

numpyro/handlers.py

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -202,14 +202,27 @@ class replay(Messenger):
202202
>>> assert replayed_trace['a']['value'] == exec_trace['a']['value']
203203
"""
204204

205-
def __init__(self, fn=None, trace=None, guide_trace=None):
205+
def __init__(self, fn=None, trace=None):
206206
assert trace is not None
207207
self.trace = trace
208208
super(replay, self).__init__(fn)
209209

210210
def process_message(self, msg):
211211
if msg["type"] in ("sample", "plate") and msg["name"] in self.trace:
212-
msg["value"] = self.trace[msg["name"]]["value"]
212+
name = msg["name"]
213+
if msg["type"] in ("sample", "plate") and name in self.trace:
214+
guide_msg = self.trace[name]
215+
if msg["type"] == "plate":
216+
if guide_msg["type"] != "plate":
217+
raise RuntimeError(f"Site {name} must be a plate in trace.")
218+
msg["value"] = guide_msg["value"]
219+
return None
220+
if msg["is_observed"]:
221+
return None
222+
if guide_msg["type"] != "sample" or guide_msg["is_observed"]:
223+
raise RuntimeError(f"Site {name} must be sampled in trace.")
224+
msg["value"] = guide_msg["value"]
225+
msg["infer"] = guide_msg["infer"]
213226

214227

215228
class block(Messenger):

numpyro/primitives.py

Lines changed: 25 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -18,21 +18,39 @@
1818
CondIndepStackFrame = namedtuple("CondIndepStackFrame", ["name", "dim", "size"])
1919

2020

21+
def default_process_message(msg):
22+
if msg["value"] is None:
23+
if msg["type"] == "sample":
24+
msg["value"], msg["intermediates"] = msg["fn"](
25+
*msg["args"], sample_intermediates=True, **msg["kwargs"]
26+
)
27+
else:
28+
msg["value"] = msg["fn"](*msg["args"], **msg["kwargs"])
29+
30+
2131
def apply_stack(msg):
32+
"""
33+
Execute the effect stack at a single site according to the following scheme:
34+
35+
1. For each ``Messenger`` in the stack from bottom to top,
36+
execute ``Messenger.process_message`` with the message;
37+
if the message field "stop" is True, stop;
38+
otherwise, continue
39+
2. Apply default behavior (``default_process_message``) to finish remaining
40+
site execution
41+
3. For each ``Messenger`` in the stack from top to bottom,
42+
execute ``Messenger.postprocess_message`` to update the message
43+
and internal messenger state with the site results
44+
"""
2245
pointer = 0
2346
for pointer, handler in enumerate(reversed(_PYRO_STACK)):
2447
handler.process_message(msg)
2548
# When a Messenger sets the "stop" field of a message,
2649
# it prevents any Messengers above it on the stack from being applied.
2750
if msg.get("stop"):
2851
break
29-
if msg["value"] is None:
30-
if msg["type"] == "sample":
31-
msg["value"], msg["intermediates"] = msg["fn"](
32-
*msg["args"], sample_intermediates=True, **msg["kwargs"]
33-
)
34-
else:
35-
msg["value"] = msg["fn"](*msg["args"], **msg["kwargs"])
52+
53+
default_process_message(msg)
3654

3755
# A Messenger that sets msg["stop"] == True also prevents application
3856
# of postprocess_message by Messengers above it on the stack

test/test_distributions.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -894,23 +894,23 @@ def test_sample_gradient(jax_dist, sp_dist, params):
894894
params_dict = dict(zip(dist_args[: len(params)], params))
895895

896896
jax_class = type(jax_dist(**params_dict))
897-
reparameterized_params = [
897+
reparametrized_params = [
898898
p for p in jax_class.reparametrized_params if p not in gamma_derived_params
899899
]
900-
if not reparameterized_params:
900+
if not reparametrized_params:
901901
pytest.skip("{} not reparametrized.".format(jax_class.__name__))
902902

903903
nonrepara_params_dict = {
904-
k: v for k, v in params_dict.items() if k not in reparameterized_params
904+
k: v for k, v in params_dict.items() if k not in reparametrized_params
905905
}
906906
repara_params = tuple(
907-
v for k, v in params_dict.items() if k in reparameterized_params
907+
v for k, v in params_dict.items() if k in reparametrized_params
908908
)
909909

910910
rng_key = random.PRNGKey(0)
911911

912912
def fn(args):
913-
args_dict = dict(zip(reparameterized_params, args))
913+
args_dict = dict(zip(reparametrized_params, args))
914914
return jnp.sum(
915915
jax_dist(**args_dict, **nonrepara_params_dict).sample(key=rng_key)
916916
)

0 commit comments

Comments
 (0)