diff --git a/tests/statespace/test_coord_assignment.py b/tests/statespace/test_coord_assignment.py index 40aaec124..4290c850e 100644 --- a/tests/statespace/test_coord_assignment.py +++ b/tests/statespace/test_coord_assignment.py @@ -137,7 +137,6 @@ def make_model(index): with pytest.warns(UserWarning, match="No time index found on the supplied data"): ss_mod.build_statespace_graph( a["A"], - mode="JAX", ) return model diff --git a/tests/statespace/test_statespace_JAX.py b/tests/statespace/test_statespace_JAX.py index 9e8d9975f..b045f82ea 100644 --- a/tests/statespace/test_statespace_JAX.py +++ b/tests/statespace/test_statespace_JAX.py @@ -22,8 +22,8 @@ ) from tests.statespace.utilities.test_helpers import load_nile_test_data -pytest.importorskip("jax") -pytest.importorskip("numpyro") +pytest.importorskip("numba") +# pytest.importorskip("numpyro") floatX = pytensor.config.floatX @@ -38,7 +38,7 @@ def pymc_mod(ss_mod): zeta = pm.Deterministic("zeta", 1 - rho) ss_mod.build_statespace_graph( - data=nile, mode="JAX", save_kalman_filter_outputs_in_idata=True + data=nile, mode="NUMBA", save_kalman_filter_outputs_in_idata=True ) names = ["x0", "P0", "c", "d", "T", "Z", "R", "H", "Q"] for name, matrix in zip(names, ss_mod.unpack_statespace()): @@ -62,7 +62,7 @@ def exog_pymc_mod(exog_ss_mod, rng): beta_exog = pm.Normal("beta_exog", dims=["exog_state"]) sigma_trend = pm.Exponential("sigma_trend", 1, dims=["trend_shock"]) - exog_ss_mod.build_statespace_graph(y, mode="JAX") + exog_ss_mod.build_statespace_graph(y, mode="NUMBA") return m @@ -77,12 +77,13 @@ def idata(pymc_mod, rng): tune=1, chains=1, random_seed=rng, - nuts_sampler="numpyro", + nuts_sampler="pymc", + compile_kwargs={"mode": "NUMBA"}, progressbar=False, ) with freeze_dims_and_data(pymc_mod): idata_prior = pm.sample_prior_predictive( - samples=10, random_seed=rng, compile_kwargs={"mode": "JAX"} + samples=10, random_seed=rng, compile_kwargs={"mode": "NUMBA"} ) idata.extend(idata_prior) @@ -100,12 +101,13 @@ def idata_exog(exog_pymc_mod, rng): tune=1, chains=1, random_seed=rng, - nuts_sampler="numpyro", + nuts_sampler="pymc", + compile_kwargs={"mode": "NUMBA"}, progressbar=False, ) with freeze_dims_and_data(pymc_mod): idata_prior = pm.sample_prior_predictive( - samples=10, random_seed=rng, compile_kwargs={"mode": "JAX"} + samples=10, random_seed=rng, compile_kwargs={"mode": "NUMBA"} ) idata.extend(idata_prior) @@ -121,7 +123,7 @@ def test_no_nans_in_sampling_output(ss_mod, group, matrix, idata): @pytest.mark.parametrize("group", ["prior", "posterior"]) @pytest.mark.parametrize("kind", ["conditional", "unconditional"]) def test_sampling_methods(group, kind, ss_mod, idata, rng): - assert ss_mod._fit_mode == "JAX" + assert ss_mod._fit_mode == "NUMBA" f = getattr(ss_mod, f"sample_{kind}_{group}") with pytest.warns(UserWarning, match="The RandomType SharedVariables"):