@@ -17,6 +17,10 @@ class MeasurementError(Component):
17
17
Name of the measurement error component. Default is "MeasurementError".
18
18
observed_state_names : list[str] | None, optional
19
19
Names of the observed variables. If None, defaults to ["data"].
20
+ share_states: bool, default False
21
+ Whether latent states are shared across the observed states. If True, there will be only one set of latent
22
+ states, which are observed by all observed states. If False, each observed state has its own set of
23
+ latent states. This argument has no effect if `k_endog` is 1.
20
24
21
25
Notes
22
26
-----
@@ -93,11 +97,16 @@ class MeasurementError(Component):
93
97
"""
94
98
95
99
def __init__ (
96
- self , name : str = "MeasurementError" , observed_state_names : list [str ] | None = None
100
+ self ,
101
+ name : str = "MeasurementError" ,
102
+ observed_state_names : list [str ] | None = None ,
103
+ share_states : bool = False ,
97
104
):
98
105
if observed_state_names is None :
99
106
observed_state_names = ["data" ]
100
107
108
+ self .share_states = share_states
109
+
101
110
k_endog = len (observed_state_names )
102
111
k_states = 0
103
112
k_posdef = 0
@@ -113,25 +122,32 @@ def __init__(
113
122
)
114
123
115
124
def populate_component_properties (self ):
125
+ k_endog = self .k_endog
126
+ k_endog_effective = 1 if self .share_states else k_endog
127
+
116
128
self .param_names = [f"sigma_{ self .name } " ]
117
129
self .param_dims = {}
118
130
self .coords = {}
119
131
120
- if self . k_endog > 1 :
132
+ if k_endog_effective > 1 :
121
133
self .param_dims [f"sigma_{ self .name } " ] = (f"endog_{ self .name } " ,)
122
134
self .coords [f"endog_{ self .name } " ] = self .observed_state_names
123
135
124
136
self .param_info = {
125
137
f"sigma_{ self .name } " : {
126
- "shape" : (self . k_endog ,) if self . k_endog > 1 else (),
138
+ "shape" : (k_endog_effective ,) if k_endog_effective > 1 else (),
127
139
"constraints" : "Positive" ,
128
- "dims" : (f"endog_{ self .name } " ,) if self . k_endog > 1 else None ,
140
+ "dims" : (f"endog_{ self .name } " ,) if k_endog_effective > 1 else None ,
129
141
}
130
142
}
131
143
132
144
def make_symbolic_graph (self ) -> None :
133
- sigma_shape = () if self .k_endog == 1 else (self .k_endog ,)
145
+ k_endog = self .k_endog
146
+ k_endog_effective = 1 if self .share_states else k_endog
147
+
148
+ sigma_shape = () if k_endog_effective == 1 else (k_endog_effective ,)
134
149
error_sigma = self .make_and_register_variable (f"sigma_{ self .name } " , shape = sigma_shape )
150
+
135
151
diag_idx = np .diag_indices (self .k_endog )
136
152
idx = np .s_ ["obs_cov" , diag_idx [0 ], diag_idx [1 ]]
137
153
self .ssm [idx ] = error_sigma ** 2
0 commit comments