Skip to content

Commit 1341c52

Browse files
Add shared_state argument to MeasurementError
1 parent 7a8bdf1 commit 1341c52

File tree

2 files changed

+63
-5
lines changed

2 files changed

+63
-5
lines changed

pymc_extras/statespace/models/structural/components/measurement_error.py

Lines changed: 21 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,10 @@ class MeasurementError(Component):
1717
Name of the measurement error component. Default is "MeasurementError".
1818
observed_state_names : list[str] | None, optional
1919
Names of the observed variables. If None, defaults to ["data"].
20+
share_states: bool, default False
21+
Whether latent states are shared across the observed states. If True, there will be only one set of latent
22+
states, which are observed by all observed states. If False, each observed state has its own set of
23+
latent states. This argument has no effect if `k_endog` is 1.
2024
2125
Notes
2226
-----
@@ -93,11 +97,16 @@ class MeasurementError(Component):
9397
"""
9498

9599
def __init__(
96-
self, name: str = "MeasurementError", observed_state_names: list[str] | None = None
100+
self,
101+
name: str = "MeasurementError",
102+
observed_state_names: list[str] | None = None,
103+
share_states: bool = False,
97104
):
98105
if observed_state_names is None:
99106
observed_state_names = ["data"]
100107

108+
self.share_states = share_states
109+
101110
k_endog = len(observed_state_names)
102111
k_states = 0
103112
k_posdef = 0
@@ -113,25 +122,32 @@ def __init__(
113122
)
114123

115124
def populate_component_properties(self):
125+
k_endog = self.k_endog
126+
k_endog_effective = 1 if self.share_states else k_endog
127+
116128
self.param_names = [f"sigma_{self.name}"]
117129
self.param_dims = {}
118130
self.coords = {}
119131

120-
if self.k_endog > 1:
132+
if k_endog_effective > 1:
121133
self.param_dims[f"sigma_{self.name}"] = (f"endog_{self.name}",)
122134
self.coords[f"endog_{self.name}"] = self.observed_state_names
123135

124136
self.param_info = {
125137
f"sigma_{self.name}": {
126-
"shape": (self.k_endog,) if self.k_endog > 1 else (),
138+
"shape": (k_endog_effective,) if k_endog_effective > 1 else (),
127139
"constraints": "Positive",
128-
"dims": (f"endog_{self.name}",) if self.k_endog > 1 else None,
140+
"dims": (f"endog_{self.name}",) if k_endog_effective > 1 else None,
129141
}
130142
}
131143

132144
def make_symbolic_graph(self) -> None:
133-
sigma_shape = () if self.k_endog == 1 else (self.k_endog,)
145+
k_endog = self.k_endog
146+
k_endog_effective = 1 if self.share_states else k_endog
147+
148+
sigma_shape = () if k_endog_effective == 1 else (k_endog_effective,)
134149
error_sigma = self.make_and_register_variable(f"sigma_{self.name}", shape=sigma_shape)
150+
135151
diag_idx = np.diag_indices(self.k_endog)
136152
idx = np.s_["obs_cov", diag_idx[0], diag_idx[1]]
137153
self.ssm[idx] = error_sigma**2

tests/statespace/models/structural/components/test_measurement_error.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,7 @@
11
import numpy as np
2+
import pytensor
3+
4+
from pytensor.graph.basic import explicit_graph_inputs
25

36
from pymc_extras.statespace.models import structural as st
47
from tests.statespace.models.structural.conftest import _assert_basic_coords_correct
@@ -19,6 +22,45 @@ def test_measurement_error_multiple_observed():
1922
assert mod.param_dims["sigma_obs"] == ("endog_obs",)
2023

2124

25+
def test_measurement_error_share_states():
26+
mod = st.MeasurementError("obs", observed_state_names=["data_1", "data_2"], share_states=True)
27+
mod.build(verbose=False)
28+
29+
assert mod.k_endog == 2
30+
assert mod.param_names == ["sigma_obs", "P0"]
31+
assert "endog_obs" not in mod.coords
32+
33+
# Check that the parameter is shared across the observed states
34+
assert mod.param_info["sigma_obs"]["shape"] == ()
35+
36+
outputs = mod.ssm["obs_cov"]
37+
38+
H = pytensor.function(list(explicit_graph_inputs([outputs])), outputs)(sigma_obs=np.array(0.5))
39+
np.testing.assert_allclose(H, np.diag([0.5, 0.5]) ** 2)
40+
41+
42+
def test_measurement_error_shared_and_not_shared():
43+
shared = st.MeasurementError(
44+
"error_shared", observed_state_names=["data_1", "data_2"], share_states=True
45+
)
46+
individual = st.MeasurementError("error_individual", observed_state_names=["data_1", "data_2"])
47+
mod = (shared + individual).build(verbose=False)
48+
49+
assert mod.k_endog == 2
50+
assert mod.param_names == ["sigma_error_shared", "sigma_error_individual", "P0"]
51+
assert mod.coords["endog_error_individual"] == ["data_1", "data_2"]
52+
53+
assert mod.param_info["sigma_error_shared"]["shape"] == ()
54+
assert mod.param_info["sigma_error_individual"]["shape"] == (2,)
55+
56+
outputs = mod.ssm["obs_cov"]
57+
58+
H = pytensor.function(list(explicit_graph_inputs([outputs])), outputs)(
59+
sigma_error_shared=np.array(0.5), sigma_error_individual=np.array([0.1, 0.9])
60+
)
61+
np.testing.assert_allclose(H, np.diag([0.5, 0.5]) ** 2 + np.diag([0.1, 0.9]) ** 2)
62+
63+
2264
def test_build_with_measurement_error_subset():
2365
ll = st.LevelTrendComponent(order=2, observed_state_names=["data_1", "data_2", "data_3"])
2466
me = st.MeasurementError("obs", observed_state_names=["data_1", "data_3"])

0 commit comments

Comments
 (0)