Skip to content

Commit 00d825b

Browse files
Be more consistent with the names of filter outputs
1 parent 7bee2d8 commit 00d825b

File tree

3 files changed

+18
-11
lines changed

3 files changed

+18
-11
lines changed

pymc_experimental/statespace/core/statespace.py

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -758,22 +758,28 @@ def _register_matrices_with_pymc_model(self) -> list[pt.TensorVariable]:
758758
@staticmethod
759759
def _register_kalman_filter_outputs_with_pymc_model(outputs: tuple[pt.TensorVariable]) -> None:
760760
mod = modelcontext(None)
761+
coords = mod.coords
762+
761763
states, covs = outputs[:4], outputs[4:]
762764

763-
state_names = ["filtered_state", "predicted_state", "observed_state", "smoothed_state"]
765+
state_names = [
766+
"filtered_state",
767+
"predicted_state",
768+
"predicted_observed_state",
769+
"smoothed_state",
770+
]
764771
cov_names = [
765772
"filtered_covariance",
766773
"predicted_covariance",
767-
"observed_covariance",
774+
"predicted_observed_covariance",
768775
"smoothed_covariance",
769776
]
770777

771778
with mod:
772-
for state, name in zip(states, state_names):
773-
pm.Deterministic(name, state, dims=FILTER_OUTPUT_DIMS.get(name, None))
774-
775-
for cov, name in zip(covs, cov_names):
776-
pm.Deterministic(name, cov, dims=FILTER_OUTPUT_DIMS.get(name, None))
779+
for var, name in zip(states + covs, state_names + cov_names):
780+
dim_names = FILTER_OUTPUT_DIMS.get(name, None)
781+
dims = tuple([dim if dim in coords.keys() else None for dim in dim_names])
782+
pm.Deterministic(name, var, dims=dims)
777783

778784
def build_statespace_graph(
779785
self,
@@ -866,7 +872,7 @@ def build_statespace_graph(
866872
all_kf_outputs = states + [smooth_states] + covs + [smooth_covariances]
867873
self._register_kalman_filter_outputs_with_pymc_model(all_kf_outputs)
868874

869-
obs_dims = FILTER_OUTPUT_DIMS["obs"]
875+
obs_dims = FILTER_OUTPUT_DIMS["predicted_observed_state"]
870876
obs_dims = obs_dims if all([dim in pm_mod.coords.keys() for dim in obs_dims]) else None
871877

872878
SequenceMvNormal(

pymc_experimental/statespace/utils/constants.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@
4444
]
4545

4646
SMOOTHER_OUTPUT_NAMES = ["smoothed_state", "smoothed_covariance"]
47-
OBSERVED_OUTPUT_NAMES = ["observed_state", "observed_covariance"]
47+
OBSERVED_OUTPUT_NAMES = ["predicted_observed_state", "predicted_observed_covariance"]
4848

4949
MATRIX_DIMS = {
5050
"x0": (ALL_STATE_DIM,),
@@ -65,7 +65,8 @@
6565
"filtered_covariance": (TIME_DIM, ALL_STATE_DIM, ALL_STATE_AUX_DIM),
6666
"smoothed_covariance": (TIME_DIM, ALL_STATE_DIM, ALL_STATE_AUX_DIM),
6767
"predicted_covariance": (TIME_DIM, ALL_STATE_DIM, ALL_STATE_AUX_DIM),
68-
"obs": (TIME_DIM, OBS_STATE_DIM),
68+
"predicted_observed_state": (TIME_DIM, OBS_STATE_DIM),
69+
"predicted_observed_covariance": (TIME_DIM, OBS_STATE_DIM, OBS_STATE_AUX_DIM),
6970
}
7071

7172
POSITION_DERIVATIVE_NAMES = ["level", "trend", "acceleration", "jerk", "snap", "crackle", "pop"]

pymc_experimental/tests/statespace/test_coord_assignment.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,7 @@ def test_filter_output_coord_assignment(f, warning, create_model):
9494
with warning:
9595
pymc_model = create_model(f)
9696

97-
for output in FILTER_OUTPUT_NAMES + SMOOTHER_OUTPUT_NAMES + ["obs"]:
97+
for output in FILTER_OUTPUT_NAMES + SMOOTHER_OUTPUT_NAMES + ["predicted_observed_state"]:
9898
assert pymc_model.named_vars_to_dims[output] == FILTER_OUTPUT_DIMS[output]
9999

100100

0 commit comments

Comments
 (0)