Skip to content

Commit 0e63cfa

Browse files
authored
Fix memory leak when using scan (#1469)
1 parent 589b352 commit 0e63cfa

File tree

1 file changed

+6
-0
lines changed

1 file changed

+6
-0
lines changed

numpyro/ops/pytree.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,12 @@ def tree_flatten(self):
3232
# set to None to avoid leaks during tracing by JAX
3333
kwargs["rng_key"] = None
3434
aux_trace[name][key] = kwargs
35+
elif key == "infer":
36+
kwargs = site["infer"].copy()
37+
if "_scan_current_index" in kwargs:
38+
# set to None to avoid leaks during tracing by JAX
39+
kwargs["_scan_current_index"] = None
40+
aux_trace[name][key] = kwargs
3541
else:
3642
aux_trace[name][key] = site[key]
3743
# keep the site order information because in JAX, flatten and unflatten do not preserve

0 commit comments

Comments
 (0)