diff --git a/conda-envs/environment-test.yml b/conda-envs/environment-test.yml index a91fe2ea..d50875dc 100644 --- a/conda-envs/environment-test.yml +++ b/conda-envs/environment-test.yml @@ -10,6 +10,6 @@ dependencies: - dask - xhistogram - pip: - - pymc>=5.4.1 # CI was failing to resolve + - pymc>=5.6.0 # CI was failing to resolve - blackjax - scikit-learn diff --git a/conda-envs/windows-environment-test.yml b/conda-envs/windows-environment-test.yml index ca1a5f4a..2f4040bc 100644 --- a/conda-envs/windows-environment-test.yml +++ b/conda-envs/windows-environment-test.yml @@ -9,5 +9,5 @@ dependencies: - dask - xhistogram - pip: - - pymc>=5.4.1 # CI was failing to resolve + - pymc>=5.6.0 # CI was failing to resolve - scikit-learn diff --git a/pymc_experimental/tests/test_marginal_model.py b/pymc_experimental/tests/test_marginal_model.py index cab0404c..d90bc5c2 100644 --- a/pymc_experimental/tests/test_marginal_model.py +++ b/pymc_experimental/tests/test_marginal_model.py @@ -207,14 +207,14 @@ def test_marginalized_change_point_model(disaster_model): ip = m.initial_point() ip.pop("switchpoint") ref_logp_fn = m.compile_logp( - [m["switchpoint"], m["disasters_observed"], m["disasters_missing"]] + [m["switchpoint"], m["disasters_observed"], m["disasters_unobserved"]] ) ref_logp = logsumexp([ref_logp_fn({**ip, **{"switchpoint": year}}) for year in years]) with pytest.warns(UserWarning, match="There are multiple dependent variables"): m.marginalize(m["switchpoint"]) - logp = m.compile_logp([m["disasters_observed"], m["disasters_missing"]])(ip) + logp = m.compile_logp([m["disasters_observed"], m["disasters_unobserved"]])(ip) np.testing.assert_almost_equal(logp, ref_logp) @@ -241,7 +241,9 @@ def test_marginalized_change_point_model_sampling(disaster_model): before_marg["late_rate"].mean(), after_marg["late_rate"].mean(), rtol=1e-2 ) np.testing.assert_allclose( - before_marg["disasters_missing"].mean(), after_marg["disasters_missing"].mean(), rtol=1e-2 + before_marg["disasters_unobserved"].mean(), + after_marg["disasters_unobserved"].mean(), + rtol=1e-2, ) diff --git a/requirements.txt b/requirements.txt index 4b306c2f..088372c7 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,2 +1,2 @@ -pymc>=5.5.0 +pymc>=5.6.0 scikit-learn