Skip to content

Commit 3249729

Browse files
Add JAX test suite
1 parent c3b05c9 commit 3249729

File tree

3 files changed

+194
-6
lines changed

3 files changed

+194
-6
lines changed

pymc_experimental/tests/statespace/test_kalman_filter.py

Lines changed: 39 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -37,11 +37,11 @@
3737
single_inout = initialize_filter(SingleTimeseriesFilter())
3838
steadystate_inout = initialize_filter(SteadyStateFilter())
3939

40-
f_standard = pytensor.function(*standard_inout)
41-
f_cholesky = pytensor.function(*cholesky_inout)
42-
f_univariate = pytensor.function(*univariate_inout)
43-
f_single_ts = pytensor.function(*single_inout)
44-
f_steady = pytensor.function(*steadystate_inout)
40+
f_standard = pytensor.function(*standard_inout, on_unused_input="ignore")
41+
f_cholesky = pytensor.function(*cholesky_inout, on_unused_input="ignore")
42+
f_univariate = pytensor.function(*univariate_inout, on_unused_input="ignore")
43+
f_single_ts = pytensor.function(*single_inout, on_unused_input="ignore")
44+
f_steady = pytensor.function(*steadystate_inout, on_unused_input="ignore")
4545

4646
filter_funcs = [f_standard, f_cholesky, f_univariate, f_single_ts, f_steady]
4747

@@ -315,3 +315,37 @@ def test_all_covariance_matrices_are_PSD(filter_func, filter_name, n_missing, ob
315315
atol=ATOL,
316316
err_msg=f"{name} is not symmetrical",
317317
)
318+
319+
320+
@pytest.mark.parametrize(
321+
"filter",
322+
[StandardFilter, SingleTimeseriesFilter, CholeskyFilter],
323+
ids=["standard", "single_ts", "cholesky"],
324+
)
325+
def test_kalman_filter_jax(filter):
326+
pytest.importorskip("jax")
327+
from pymc.sampling.jax import get_jaxified_graph
328+
329+
# TODO: Add UnivariateFilter to test; need to figure out the broadcasting issue when 2nd data dim is defined
330+
331+
p, m, r, n = 1, 5, 1, 10
332+
inputs, outputs = initialize_filter(filter(), mode="JAX")
333+
334+
# Shape of the data must be static for jax to know how long the scan is
335+
data = inputs.pop(0)
336+
data_specified = pt.specify_shape(data, (n, None))
337+
data_specified.name = "data"
338+
inputs = [data] + inputs
339+
340+
outputs = pytensor.graph.clone_replace(outputs, {data: data_specified})
341+
342+
inputs_np = make_test_inputs(p, m, r, n, rng)
343+
344+
f_jax = get_jaxified_graph(inputs, outputs)
345+
f_pt = pytensor.function(inputs, outputs, mode="FAST_COMPILE")
346+
347+
jax_outputs = f_jax(*inputs_np)
348+
pt_outputs = f_pt(*inputs_np)
349+
350+
for name, jax_res, pt_res in zip(output_names, jax_outputs, pt_outputs):
351+
assert_allclose(jax_res, pt_res, atol=ATOL, rtol=RTOL, err_msg=f"{name} failed!")

pymc_experimental/tests/statespace/test_statespace.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,7 @@ def pymc_mod(ss_mod):
9494
rho = pm.Beta("rho", 1, 1)
9595
zeta = pm.Deterministic("zeta", 1 - rho)
9696

97-
ss_mod.build_statespace_graph(data=nile, save_kalman_filter_outputs_in_idata=True)
97+
ss_mod.build_statespace_graph(data=nile, save_kalman_filter_outputs_in_idata=False)
9898
names = ["x0", "P0", "c", "d", "T", "Z", "R", "H", "Q"]
9999
for name, matrix in zip(names, ss_mod.unpack_statespace()):
100100
pm.Deterministic(name, matrix)
Lines changed: 154 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,154 @@
1+
import warnings
2+
3+
import numpy as np
4+
import pymc as pm
5+
import pytensor
6+
import pytensor.tensor as pt
7+
import pytest
8+
from pymc.model.transform.optimization import freeze_dims_and_data
9+
10+
from pymc_experimental.statespace.utils.constants import (
11+
FILTER_OUTPUT_NAMES,
12+
MATRIX_NAMES,
13+
SMOOTHER_OUTPUT_NAMES,
14+
)
15+
from pymc_experimental.tests.statespace.test_statespace import ( # pylint: disable=unused-import
16+
exog_ss_mod,
17+
make_statespace_mod,
18+
ss_mod,
19+
)
20+
from pymc_experimental.tests.statespace.utilities.shared_fixtures import ( # pylint: disable=unused-import
21+
rng,
22+
)
23+
from pymc_experimental.tests.statespace.utilities.test_helpers import (
24+
load_nile_test_data,
25+
)
26+
27+
pytest.importorskip("jax")
28+
29+
30+
floatX = pytensor.config.floatX
31+
nile = load_nile_test_data()
32+
ALL_SAMPLE_OUTPUTS = MATRIX_NAMES + FILTER_OUTPUT_NAMES + SMOOTHER_OUTPUT_NAMES
33+
34+
35+
@pytest.fixture(scope="session")
36+
def pymc_mod(ss_mod):
37+
with pm.Model(coords=ss_mod.coords) as pymc_mod:
38+
rho = pm.Beta("rho", 1, 1)
39+
zeta = pm.Deterministic("zeta", 1 - rho)
40+
41+
ss_mod.build_statespace_graph(
42+
data=nile, mode="JAX", save_kalman_filter_outputs_in_idata=True
43+
)
44+
names = ["x0", "P0", "c", "d", "T", "Z", "R", "H", "Q"]
45+
for name, matrix in zip(names, ss_mod.unpack_statespace()):
46+
pm.Deterministic(name, matrix)
47+
48+
return pymc_mod
49+
50+
51+
@pytest.fixture(scope="session")
52+
def exog_pymc_mod(exog_ss_mod, rng):
53+
y = rng.normal(size=(100, 1)).astype(floatX)
54+
X = rng.normal(size=(100, 3)).astype(floatX)
55+
56+
with pm.Model(coords=exog_ss_mod.coords) as m:
57+
exog_data = pm.Data("data_exog", X)
58+
initial_trend = pm.Normal("initial_trend", dims=["trend_state"])
59+
P0_sigma = pm.Exponential("P0_sigma", 1)
60+
P0 = pm.Deterministic(
61+
"P0", pt.eye(exog_ss_mod.k_states) * P0_sigma, dims=["state", "state_aux"]
62+
)
63+
beta_exog = pm.Normal("beta_exog", dims=["exog_state"])
64+
65+
sigma_trend = pm.Exponential("sigma_trend", 1, dims=["trend_shock"])
66+
exog_ss_mod.build_statespace_graph(y, mode="JAX")
67+
68+
return m
69+
70+
71+
@pytest.fixture(scope="session")
72+
def idata(pymc_mod, rng):
73+
with warnings.catch_warnings(action="ignore"):
74+
with pymc_mod:
75+
idata = pm.sample(
76+
draws=10,
77+
tune=1,
78+
chains=1,
79+
random_seed=rng,
80+
nuts_sampler="numpyro",
81+
progressbar=False,
82+
)
83+
with freeze_dims_and_data(pymc_mod):
84+
idata_prior = pm.sample_prior_predictive(
85+
samples=10, random_seed=rng, compile_kwargs={"mode": "JAX"}
86+
)
87+
88+
idata.extend(idata_prior)
89+
return idata
90+
91+
92+
@pytest.fixture(scope="session")
93+
def idata_exog(exog_pymc_mod, rng):
94+
with warnings.catch_warnings(action="ignore"):
95+
with exog_pymc_mod:
96+
idata = pm.sample(
97+
draws=10,
98+
tune=1,
99+
chains=1,
100+
random_seed=rng,
101+
nuts_sampler="numpyro",
102+
progressbar=False,
103+
)
104+
with freeze_dims_and_data(pymc_mod):
105+
idata_prior = pm.sample_prior_predictive(
106+
samples=10, random_seed=rng, compile_kwargs={"mode": "JAX"}
107+
)
108+
109+
idata.extend(idata_prior)
110+
return idata
111+
112+
113+
@pytest.mark.parametrize("group", ["posterior", "prior"])
114+
@pytest.mark.parametrize("matrix", ALL_SAMPLE_OUTPUTS)
115+
def test_no_nans_in_sampling_output(ss_mod, group, matrix, idata):
116+
assert not np.any(np.isnan(idata[group][matrix].values))
117+
118+
119+
@pytest.mark.parametrize("group", ["prior", "posterior"])
120+
@pytest.mark.parametrize("kind", ["conditional", "unconditional"])
121+
def test_sampling_methods(group, kind, ss_mod, idata, rng):
122+
assert ss_mod._fit_mode == "JAX"
123+
124+
f = getattr(ss_mod, f"sample_{kind}_{group}")
125+
with pytest.warns(UserWarning, match="The RandomType SharedVariables"):
126+
test_idata = f(idata, random_seed=rng)
127+
128+
if kind == "conditional":
129+
for output in ["filtered", "predicted", "smoothed"]:
130+
assert f"{output}_{group}" in test_idata
131+
assert not np.any(np.isnan(test_idata[f"{output}_{group}"].values))
132+
assert not np.any(np.isnan(test_idata[f"{output}_{group}_observed"].values))
133+
134+
if kind == "unconditional":
135+
for output in ["latent", "observed"]:
136+
assert f"{group}_{output}" in test_idata
137+
assert not np.any(np.isnan(test_idata[f"{group}_{output}"].values))
138+
139+
140+
@pytest.mark.parametrize("filter_output", ["predicted", "filtered", "smoothed"])
141+
def test_forecast(filter_output, ss_mod, idata, rng):
142+
time_idx = idata.posterior.coords["time"].values
143+
144+
with pytest.warns(UserWarning, match="The RandomType SharedVariables"):
145+
forecast_idata = ss_mod.forecast(
146+
idata, start=time_idx[-1], periods=10, filter_output=filter_output, random_seed=rng
147+
)
148+
149+
assert forecast_idata.coords["time"].values.shape == (10,)
150+
assert forecast_idata.forecast_latent.dims == ("chain", "draw", "time", "state")
151+
assert forecast_idata.forecast_observed.dims == ("chain", "draw", "time", "observed_state")
152+
153+
assert not np.any(np.isnan(forecast_idata.forecast_latent.values))
154+
assert not np.any(np.isnan(forecast_idata.forecast_observed.values))

0 commit comments

Comments
 (0)