We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent 589b352 commit 0e63cfaCopy full SHA for 0e63cfa
numpyro/ops/pytree.py
@@ -32,6 +32,12 @@ def tree_flatten(self):
32
# set to None to avoid leaks during tracing by JAX
33
kwargs["rng_key"] = None
34
aux_trace[name][key] = kwargs
35
+ elif key == "infer":
36
+ kwargs = site["infer"].copy()
37
+ if "_scan_current_index" in kwargs:
38
+ # set to None to avoid leaks during tracing by JAX
39
+ kwargs["_scan_current_index"] = None
40
+ aux_trace[name][key] = kwargs
41
else:
42
aux_trace[name][key] = site[key]
43
# keep the site order information because in JAX, flatten and unflatten do not preserve
0 commit comments