Skip to content

Commit c3b05c9

Browse files
Adjust test suite to reflect API changes
Modify structural tests to accommodate deterministic models Save kalman filter outputs to idata for statespace tests Remove test related to `add_exogenous` Adjust structural module tests
1 parent 65b113e commit c3b05c9

File tree

3 files changed

+44
-47
lines changed

3 files changed

+44
-47
lines changed

pymc_experimental/tests/statespace/test_statespace.py

Lines changed: 1 addition & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -126,7 +126,7 @@ def exog_pymc_mod(exog_ss_mod, rng):
126126
beta_exog = pm.Normal("beta_exog", dims=["exog_state"])
127127

128128
sigma_trend = pm.Exponential("sigma_trend", 1, dims=["trend_shock"])
129-
exog_ss_mod.build_statespace_graph(y)
129+
exog_ss_mod.build_statespace_graph(y, save_kalman_filter_outputs_in_idata=True)
130130

131131
return m
132132

@@ -299,34 +299,3 @@ def test_forecast_fails_if_exog_needed(exog_ss_mod, idata_exog):
299299
forecast_idata = exog_ss_mod.forecast(
300300
idata_exog, start=time_idx[-1], periods=10, random_seed=rng
301301
)
302-
303-
304-
@pytest.mark.parametrize(
305-
"shape",
306-
[
307-
None,
308-
(10,),
309-
(10, 1),
310-
pytest.param((10, 3), marks=[pytest.mark.xfail]),
311-
pytest.param((10, 1, 1), marks=[pytest.mark.xfail]),
312-
],
313-
ids=["None", "(10,)", "(10, 1)", "(10,3)", "(10, 1, 1)"],
314-
)
315-
def test_add_exogenous(rng, shape):
316-
ss_mod = st.LevelTrendComponent(order=1, innovations_order=0).build(verbose=False)
317-
y = rng.normal(size=(10, 1))
318-
319-
with pm.Model() as m:
320-
initial_trend = pm.Normal("initial_trend", shape=(1,))
321-
P0 = pm.Deterministic("P0", pt.eye(1, dtype=floatX))
322-
constant = pm.Normal("constant", shape=shape)
323-
ss_mod.add_exogenous(constant)
324-
ss_mod.build_statespace_graph(data=y)
325-
326-
names = ["x0", "P0", "c", "d", "T", "Z", "R", "H", "Q"]
327-
for name, matrix in zip(names, ss_mod.unpack_statespace()):
328-
if name not in [x.name for x in m.deterministics]:
329-
pm.Deterministic(name, matrix)
330-
331-
d, const = pm.draw([m["d"].squeeze(), m["constant"].squeeze()], 10)
332-
assert_allclose(d, const)

pymc_experimental/tests/statespace/test_structural.py

Lines changed: 41 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -64,19 +64,45 @@ def _assert_coord_shapes_match_matrices(mod, params):
6464
params["initial_state_cov"] = np.eye(mod.k_states)
6565

6666
x0, P0, c, d, T, Z, R, H, Q = unpack_symbolic_matrices_with_params(mod, params)
67+
6768
n_states = len(mod.coords[ALL_STATE_DIM])
68-
n_shocks = len(mod.coords[SHOCK_DIM])
69+
70+
# There will always be one shock dimension -- dummies are inserted into fully deterministic models to avoid errors
71+
# in the state space representation.
72+
n_shocks = max(1, len(mod.coords[SHOCK_DIM]))
6973
n_obs = len(mod.coords[OBS_STATE_DIM])
7074

71-
assert x0.shape[-1:] == (n_states,)
72-
assert P0.shape[-2:] == (n_states, n_states)
73-
assert c.shape[-1:] == (n_states,)
74-
assert d.shape[-1:] == (n_obs,)
75-
assert T.shape[-2:] == (n_states, n_states)
76-
assert Z.shape[-2:] == (n_obs, n_states)
77-
assert R.shape[-2:] == (n_states, n_shocks)
78-
assert H.shape[-2:] == (n_obs, n_obs)
79-
assert Q.shape[-2:] == (n_shocks, n_shocks)
75+
assert x0.shape[-1:] == (
76+
n_states,
77+
), f"x0 expected to have shape (n_states, ), found {x0.shape[-1:]}"
78+
assert P0.shape[-2:] == (
79+
n_states,
80+
n_states,
81+
), f"P0 expected to have shape (n_states, n_states), found {P0.shape[-2:]}"
82+
assert c.shape[-1:] == (
83+
n_states,
84+
), f"c expected to have shape (n_states, ), found {c.shape[-1:]}"
85+
assert d.shape[-1:] == (n_obs,), f"d expected to have shape (n_obs, ), found {d.shape[-1:]}"
86+
assert T.shape[-2:] == (
87+
n_states,
88+
n_states,
89+
), f"T expected to have shape (n_states, n_states), found {T.shape[-2:]}"
90+
assert Z.shape[-2:] == (
91+
n_obs,
92+
n_states,
93+
), f"Z expected to have shape (n_obs, n_states), found {Z.shape[-2:]}"
94+
assert R.shape[-2:] == (
95+
n_states,
96+
n_shocks,
97+
), f"R expected to have shape (n_states, n_shocks), found {R.shape[-2:]}"
98+
assert H.shape[-2:] == (
99+
n_obs,
100+
n_obs,
101+
), f"H expected to have shape (n_obs, n_obs), found {H.shape[-2:]}"
102+
assert Q.shape[-2:] == (
103+
n_shocks,
104+
n_shocks,
105+
), f"Q expected to have shape (n_shocks, n_shocks), found {Q.shape[-2:]}"
80106

81107

82108
def _assert_basic_coords_correct(mod):
@@ -514,12 +540,12 @@ def test_structural_model_against_statsmodels(
514540

515541
_assert_all_statespace_matrices_match(mod, params, sm_mod)
516542

517-
mod.build(verbose=False)
543+
built_model = mod.build(verbose=False)
518544

519-
_assert_coord_shapes_match_matrices(mod, params)
520-
_assert_param_dims_correct(mod.param_dims, expected_dims)
521-
_assert_coords_correct(mod.coords, expected_coords)
522-
_assert_params_info_correct(mod.param_info, mod.coords, mod.param_dims)
545+
_assert_coord_shapes_match_matrices(built_model, params)
546+
_assert_param_dims_correct(built_model.param_dims, expected_dims)
547+
_assert_coords_correct(built_model.coords, expected_coords)
548+
_assert_params_info_correct(built_model.param_info, built_model.coords, built_model.param_dims)
523549

524550

525551
def test_level_trend_model(rng):

pymc_experimental/tests/statespace/utilities/test_helpers.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -225,7 +225,9 @@ def unpack_symbolic_matrices_with_params(mod, param_dict, data_dict=None, mode="
225225
on_unused_input="raise",
226226
mode=mode,
227227
)
228+
228229
x0, P0, c, d, T, Z, R, H, Q = f_matrices(**param_dict, **data_dict)
230+
229231
return x0, P0, c, d, T, Z, R, H, Q
230232

231233

0 commit comments

Comments
 (0)