|
| 1 | +import warnings |
| 2 | + |
| 3 | +import numpy as np |
| 4 | +import pymc as pm |
| 5 | +import pytensor |
| 6 | +import pytensor.tensor as pt |
| 7 | +import pytest |
| 8 | +from pymc.model.transform.optimization import freeze_dims_and_data |
| 9 | + |
| 10 | +from pymc_experimental.statespace.utils.constants import ( |
| 11 | + FILTER_OUTPUT_NAMES, |
| 12 | + MATRIX_NAMES, |
| 13 | + SMOOTHER_OUTPUT_NAMES, |
| 14 | +) |
| 15 | +from pymc_experimental.tests.statespace.test_statespace import ( # pylint: disable=unused-import |
| 16 | + exog_ss_mod, |
| 17 | + make_statespace_mod, |
| 18 | + ss_mod, |
| 19 | +) |
| 20 | +from pymc_experimental.tests.statespace.utilities.shared_fixtures import ( # pylint: disable=unused-import |
| 21 | + rng, |
| 22 | +) |
| 23 | +from pymc_experimental.tests.statespace.utilities.test_helpers import ( |
| 24 | + load_nile_test_data, |
| 25 | +) |
| 26 | + |
| 27 | +pytest.importorskip("jax") |
| 28 | + |
| 29 | + |
| 30 | +floatX = pytensor.config.floatX |
| 31 | +nile = load_nile_test_data() |
| 32 | +ALL_SAMPLE_OUTPUTS = MATRIX_NAMES + FILTER_OUTPUT_NAMES + SMOOTHER_OUTPUT_NAMES |
| 33 | + |
| 34 | + |
| 35 | +@pytest.fixture(scope="session") |
| 36 | +def pymc_mod(ss_mod): |
| 37 | + with pm.Model(coords=ss_mod.coords) as pymc_mod: |
| 38 | + rho = pm.Beta("rho", 1, 1) |
| 39 | + zeta = pm.Deterministic("zeta", 1 - rho) |
| 40 | + |
| 41 | + ss_mod.build_statespace_graph( |
| 42 | + data=nile, mode="JAX", save_kalman_filter_outputs_in_idata=True |
| 43 | + ) |
| 44 | + names = ["x0", "P0", "c", "d", "T", "Z", "R", "H", "Q"] |
| 45 | + for name, matrix in zip(names, ss_mod.unpack_statespace()): |
| 46 | + pm.Deterministic(name, matrix) |
| 47 | + |
| 48 | + return pymc_mod |
| 49 | + |
| 50 | + |
| 51 | +@pytest.fixture(scope="session") |
| 52 | +def exog_pymc_mod(exog_ss_mod, rng): |
| 53 | + y = rng.normal(size=(100, 1)).astype(floatX) |
| 54 | + X = rng.normal(size=(100, 3)).astype(floatX) |
| 55 | + |
| 56 | + with pm.Model(coords=exog_ss_mod.coords) as m: |
| 57 | + exog_data = pm.Data("data_exog", X) |
| 58 | + initial_trend = pm.Normal("initial_trend", dims=["trend_state"]) |
| 59 | + P0_sigma = pm.Exponential("P0_sigma", 1) |
| 60 | + P0 = pm.Deterministic( |
| 61 | + "P0", pt.eye(exog_ss_mod.k_states) * P0_sigma, dims=["state", "state_aux"] |
| 62 | + ) |
| 63 | + beta_exog = pm.Normal("beta_exog", dims=["exog_state"]) |
| 64 | + |
| 65 | + sigma_trend = pm.Exponential("sigma_trend", 1, dims=["trend_shock"]) |
| 66 | + exog_ss_mod.build_statespace_graph(y, mode="JAX") |
| 67 | + |
| 68 | + return m |
| 69 | + |
| 70 | + |
| 71 | +@pytest.fixture(scope="session") |
| 72 | +def idata(pymc_mod, rng): |
| 73 | + with warnings.catch_warnings(action="ignore"): |
| 74 | + with pymc_mod: |
| 75 | + idata = pm.sample( |
| 76 | + draws=10, |
| 77 | + tune=1, |
| 78 | + chains=1, |
| 79 | + random_seed=rng, |
| 80 | + nuts_sampler="numpyro", |
| 81 | + progressbar=False, |
| 82 | + ) |
| 83 | + with freeze_dims_and_data(pymc_mod): |
| 84 | + idata_prior = pm.sample_prior_predictive( |
| 85 | + samples=10, random_seed=rng, compile_kwargs={"mode": "JAX"} |
| 86 | + ) |
| 87 | + |
| 88 | + idata.extend(idata_prior) |
| 89 | + return idata |
| 90 | + |
| 91 | + |
| 92 | +@pytest.fixture(scope="session") |
| 93 | +def idata_exog(exog_pymc_mod, rng): |
| 94 | + with warnings.catch_warnings(action="ignore"): |
| 95 | + with exog_pymc_mod: |
| 96 | + idata = pm.sample( |
| 97 | + draws=10, |
| 98 | + tune=1, |
| 99 | + chains=1, |
| 100 | + random_seed=rng, |
| 101 | + nuts_sampler="numpyro", |
| 102 | + progressbar=False, |
| 103 | + ) |
| 104 | + with freeze_dims_and_data(pymc_mod): |
| 105 | + idata_prior = pm.sample_prior_predictive( |
| 106 | + samples=10, random_seed=rng, compile_kwargs={"mode": "JAX"} |
| 107 | + ) |
| 108 | + |
| 109 | + idata.extend(idata_prior) |
| 110 | + return idata |
| 111 | + |
| 112 | + |
| 113 | +@pytest.mark.parametrize("group", ["posterior", "prior"]) |
| 114 | +@pytest.mark.parametrize("matrix", ALL_SAMPLE_OUTPUTS) |
| 115 | +def test_no_nans_in_sampling_output(ss_mod, group, matrix, idata): |
| 116 | + assert not np.any(np.isnan(idata[group][matrix].values)) |
| 117 | + |
| 118 | + |
| 119 | +@pytest.mark.parametrize("group", ["prior", "posterior"]) |
| 120 | +@pytest.mark.parametrize("kind", ["conditional", "unconditional"]) |
| 121 | +def test_sampling_methods(group, kind, ss_mod, idata, rng): |
| 122 | + assert ss_mod._fit_mode == "JAX" |
| 123 | + |
| 124 | + f = getattr(ss_mod, f"sample_{kind}_{group}") |
| 125 | + with pytest.warns(UserWarning, match="The RandomType SharedVariables"): |
| 126 | + test_idata = f(idata, random_seed=rng) |
| 127 | + |
| 128 | + if kind == "conditional": |
| 129 | + for output in ["filtered", "predicted", "smoothed"]: |
| 130 | + assert f"{output}_{group}" in test_idata |
| 131 | + assert not np.any(np.isnan(test_idata[f"{output}_{group}"].values)) |
| 132 | + assert not np.any(np.isnan(test_idata[f"{output}_{group}_observed"].values)) |
| 133 | + |
| 134 | + if kind == "unconditional": |
| 135 | + for output in ["latent", "observed"]: |
| 136 | + assert f"{group}_{output}" in test_idata |
| 137 | + assert not np.any(np.isnan(test_idata[f"{group}_{output}"].values)) |
| 138 | + |
| 139 | + |
| 140 | +@pytest.mark.parametrize("filter_output", ["predicted", "filtered", "smoothed"]) |
| 141 | +def test_forecast(filter_output, ss_mod, idata, rng): |
| 142 | + time_idx = idata.posterior.coords["time"].values |
| 143 | + |
| 144 | + with pytest.warns(UserWarning, match="The RandomType SharedVariables"): |
| 145 | + forecast_idata = ss_mod.forecast( |
| 146 | + idata, start=time_idx[-1], periods=10, filter_output=filter_output, random_seed=rng |
| 147 | + ) |
| 148 | + |
| 149 | + assert forecast_idata.coords["time"].values.shape == (10,) |
| 150 | + assert forecast_idata.forecast_latent.dims == ("chain", "draw", "time", "state") |
| 151 | + assert forecast_idata.forecast_observed.dims == ("chain", "draw", "time", "observed_state") |
| 152 | + |
| 153 | + assert not np.any(np.isnan(forecast_idata.forecast_latent.values)) |
| 154 | + assert not np.any(np.isnan(forecast_idata.forecast_observed.values)) |
0 commit comments