Skip to content

Commit ddbd0b8

Browse files
authored
numpyro.contrib.module: only collect mutables in nnx_module() when th… (#2061)
* numpyro.contrib.module: only collect mutables in nnx_module() when they exist Testing Done: avoided registration of empty mutables in nnx ConvNet Signed-off-by: Eli Sennesh <[email protected]> * test.contrib.test_module: expect mutables only when batchnorm or dropout active Signed-off-by: Eli Sennesh <[email protected]> --------- Signed-off-by: Eli Sennesh <[email protected]>
1 parent d4411dc commit ddbd0b8

File tree

2 files changed

+6
-4
lines changed

2 files changed

+6
-4
lines changed

numpyro/contrib/module.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -484,8 +484,6 @@ def nnx_module(name, nn_module):
484484

485485
mutable_holder = None
486486
if eager_other_state_dict:
487-
mutable_holder = numpyro_mutable(name + "$state")
488-
if mutable_holder is None:
489487
mutable_holder = numpyro_mutable(
490488
name + "$state", {"state": eager_other_state_dict}
491489
)

test/contrib/test_module.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -424,8 +424,12 @@ def model():
424424
with handlers.trace(model) as tr, handlers.seed(rng_seed=0):
425425
model()
426426

427-
assert set(tr.keys()) == {"nn$params", "nn$state", "x", "y"}
428-
assert tr["nn$state"]["type"] == "mutable"
427+
key_set = {"nn$params", "x", "y"}
428+
if batchnorm or dropout:
429+
key_set.add("nn$state")
430+
assert set(tr.keys()) == key_set
431+
if batchnorm or dropout:
432+
assert tr["nn$state"]["type"] == "mutable"
429433

430434
# test svi
431435
guide = AutoDelta(model)

0 commit comments

Comments
 (0)