Skip to content

Commit 703dc27

Browse files
jessegrabowskiricardoV94
authored andcommitted
Remove shape and dim from MeasurementError sigma (it's a scalar)
1 parent 381745e commit 703dc27

File tree

1 file changed

+4
-5
lines changed

1 file changed

+4
-5
lines changed

pymc_experimental/statespace/models/structural.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@
2020
ALL_STATE_DIM,
2121
AR_PARAM_DIM,
2222
LONG_MATRIX_NAMES,
23-
OBS_STATE_DIM,
2423
POSITION_DERIVATIVE_NAMES,
2524
TIME_DIM,
2625
)
@@ -910,17 +909,17 @@ def __init__(self, name: str = "MeasurementError"):
910909

911910
def populate_component_properties(self):
912911
self.param_names = [f"sigma_{self.name}"]
913-
self.param_dims = {f"sigma_{self.name}": (OBS_STATE_DIM,)}
912+
self.param_dims = {}
914913
self.param_info = {
915914
f"sigma_{self.name}": {
916-
"shape": (self.k_endog,),
915+
"shape": (),
917916
"constraints": "Positive",
918-
"dims": (OBS_STATE_DIM,),
917+
"dims": None,
919918
}
920919
}
921920

922921
def make_symbolic_graph(self) -> None:
923-
sigma_shape = () if self.k_endog == 1 else (self.k_endog,)
922+
sigma_shape = ()
924923
error_sigma = self.make_and_register_variable(f"sigma_{self.name}", shape=sigma_shape)
925924
diag_idx = np.diag_indices(self.k_endog)
926925
idx = np.s_["obs_cov", diag_idx[0], diag_idx[1]]

0 commit comments

Comments
 (0)