From 6a1862a06a015ec7ad61c58bd88df737550b08fa Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Wed, 20 Mar 2024 14:42:16 +0100 Subject: [PATCH 1/4] Bump PyTensor dependency --- conda-envs/environment-dev.yml | 2 +- conda-envs/environment-docs.yml | 2 +- conda-envs/environment-jax.yml | 2 +- conda-envs/environment-test.yml | 2 +- conda-envs/windows-environment-dev.yml | 2 +- conda-envs/windows-environment-test.yml | 2 +- pymc/sampling/forward.py | 2 +- requirements-dev.txt | 2 +- requirements.txt | 2 +- 9 files changed, 9 insertions(+), 9 deletions(-) diff --git a/conda-envs/environment-dev.yml b/conda-envs/environment-dev.yml index 5065a4c8b0..a2cf7c25d3 100644 --- a/conda-envs/environment-dev.yml +++ b/conda-envs/environment-dev.yml @@ -14,7 +14,7 @@ dependencies: - numpy>=1.15.0 - pandas>=0.24.0 - pip -- pytensor>=2.18.1,<2.19 +- pytensor>=2.19,<2.20 - python-graphviz - networkx - scipy>=1.4.1 diff --git a/conda-envs/environment-docs.yml b/conda-envs/environment-docs.yml index 809e7ee5cf..8622703837 100644 --- a/conda-envs/environment-docs.yml +++ b/conda-envs/environment-docs.yml @@ -12,7 +12,7 @@ dependencies: - numpy>=1.15.0 - pandas>=0.24.0 - pip -- pytensor>=2.18.1,<2.19 +- pytensor>=2.19,<2.20 - python-graphviz - scipy>=1.4.1 - typing-extensions>=3.7.4 diff --git a/conda-envs/environment-jax.yml b/conda-envs/environment-jax.yml index 0419863db7..3defc540bd 100644 --- a/conda-envs/environment-jax.yml +++ b/conda-envs/environment-jax.yml @@ -21,7 +21,7 @@ dependencies: - numpyro>=0.8.0 - pandas>=0.24.0 - pip -- pytensor>=2.18.1,<2.19 +- pytensor>=2.19,<2.20 - python-graphviz - networkx - scipy>=1.4.1 diff --git a/conda-envs/environment-test.yml b/conda-envs/environment-test.yml index 594e1ca79b..8272cca239 100644 --- a/conda-envs/environment-test.yml +++ b/conda-envs/environment-test.yml @@ -17,7 +17,7 @@ dependencies: - numpy>=1.15.0 - pandas>=0.24.0 - pip -- pytensor>=2.18.1,<2.19 +- pytensor>=2.19,<2.20 - python-graphviz - networkx - scipy>=1.4.1 diff --git a/conda-envs/windows-environment-dev.yml b/conda-envs/windows-environment-dev.yml index bc0bc607bf..25fdeb419c 100644 --- a/conda-envs/windows-environment-dev.yml +++ b/conda-envs/windows-environment-dev.yml @@ -14,7 +14,7 @@ dependencies: - numpy>=1.15.0 - pandas>=0.24.0 - pip -- pytensor>=2.18.1,<2.19 +- pytensor>=2.19,<2.20 - python-graphviz - networkx - scipy>=1.4.1 diff --git a/conda-envs/windows-environment-test.yml b/conda-envs/windows-environment-test.yml index 6dd348bc79..900e3e227e 100644 --- a/conda-envs/windows-environment-test.yml +++ b/conda-envs/windows-environment-test.yml @@ -17,7 +17,7 @@ dependencies: - numpy>=1.15.0 - pandas>=0.24.0 - pip -- pytensor>=2.18.1,<2.19 +- pytensor>=2.19,<2.20 - python-graphviz - networkx - scipy>=1.4.1 diff --git a/pymc/sampling/forward.py b/pymc/sampling/forward.py index 6dba188d6b..3a1fc5a785 100644 --- a/pymc/sampling/forward.py +++ b/pymc/sampling/forward.py @@ -200,7 +200,7 @@ def shared_value_matches(var): # Walk the graph from inputs to outputs and tag the volatile variables nodes: list[Variable] = general_toposort( fg.outputs, deps=lambda x: x.owner.inputs if x.owner else [] - ) + ) # type: ignore volatile_nodes: set[Any] = set() for node in nodes: if ( diff --git a/requirements-dev.txt b/requirements-dev.txt index 8aff7d60c9..ddcf9ded9b 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -18,7 +18,7 @@ numpydoc pandas>=0.24.0 polyagamma pre-commit>=2.8.0 -pytensor>=2.18.1,<2.19 +pytensor>=2.19,<2.20 pytest-cov>=2.5 pytest>=3.0 scipy>=1.4.1 diff --git a/requirements.txt b/requirements.txt index c84312012f..0bc21049c1 100644 --- a/requirements.txt +++ b/requirements.txt @@ -4,6 +4,6 @@ cloudpickle fastprogress>=0.2.0 numpy>=1.15.0 pandas>=0.24.0 -pytensor>=2.18.1,<2.19 +pytensor>=2.19,<2.20 scipy>=1.4.1 typing-extensions>=3.7.4 From 5e89b694324f5a3656e11ceea6a14073369a8eaa Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Wed, 20 Mar 2024 15:15:02 +0100 Subject: [PATCH 2/4] Test on Python 3.12 --- .github/workflows/tests.yml | 8 ++++---- conda-envs/environment-jax.yml | 4 ++-- setup.py | 1 + 3 files changed, 7 insertions(+), 6 deletions(-) diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index b693fe6e33..71b78a9747 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -48,7 +48,7 @@ jobs: matrix: os: [ubuntu-20.04] floatx: [float64] - python-version: ["3.11"] + python-version: ["3.12"] test-subset: - | tests/test_util.py @@ -263,7 +263,7 @@ jobs: matrix: os: [macos-latest] floatx: [float64] - python-version: ["3.10"] + python-version: ["3.12"] test-subset: - | tests/sampling/test_parallel.py @@ -342,7 +342,7 @@ jobs: matrix: os: [ubuntu-20.04] floatx: [float64] - python-version: ["3.11"] + python-version: ["3.12"] test-subset: - tests/sampling/test_jax.py tests/sampling/test_mcmc_external.py fail-fast: false @@ -410,7 +410,7 @@ jobs: matrix: os: [windows-latest] floatx: [float32] - python-version: ["3.11"] + python-version: ["3.12"] test-subset: - tests/sampling/test_mcmc.py tests/ode/test_ode.py tests/ode/test_utils.py tests/distributions/test_transform.py fail-fast: false diff --git a/conda-envs/environment-jax.yml b/conda-envs/environment-jax.yml index 3defc540bd..0986f43046 100644 --- a/conda-envs/environment-jax.yml +++ b/conda-envs/environment-jax.yml @@ -13,8 +13,8 @@ dependencies: - h5py>=2.7 # Jaxlib version must not be greater than jax version! - blackjax>=1.0.0 -- jaxlib==0.4.14 -- jax==0.4.16 +- jaxlib==0.4.23 +- jax==0.4.23 - libblas=*=*mkl - mkl-service - numpy>=1.15.0 diff --git a/setup.py b/setup.py index a0ac9dd301..2b5e03dd2c 100755 --- a/setup.py +++ b/setup.py @@ -33,6 +33,7 @@ "Programming Language :: Python :: 3.9", "Programming Language :: Python :: 3.10", "Programming Language :: Python :: 3.11", + "Programming Language :: Python :: 3.12", "License :: OSI Approved :: Apache Software License", "Intended Audience :: Science/Research", "Topic :: Scientific/Engineering", From 64e99febb0a59a0462047ccf911bd9abe5e5eb6a Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Wed, 20 Mar 2024 19:29:00 +0100 Subject: [PATCH 3/4] Ignore DeprecationWarning triggered by arviz --- tests/backends/test_arviz.py | 12 ++++++++++-- tests/distributions/test_timeseries.py | 7 ++++++- tests/sampling/test_mcmc_external.py | 2 +- tests/variational/test_inference.py | 22 +++++++++++++++++----- 4 files changed, 34 insertions(+), 9 deletions(-) diff --git a/tests/backends/test_arviz.py b/tests/backends/test_arviz.py index 9f2c74dd5c..cf9bc3fc00 100644 --- a/tests/backends/test_arviz.py +++ b/tests/backends/test_arviz.py @@ -32,7 +32,11 @@ from pymc.exceptions import ImputationWarning # Turn all warnings into errors for this module -pytestmark = pytest.mark.filterwarnings("error") +pytestmark = pytest.mark.filterwarnings( + "error", + # Related to https://github.com/arviz-devs/arviz/issues/2327 + "ignore:datetime.datetime.utcnow():DeprecationWarning", +) @pytest.fixture(scope="module") @@ -672,13 +676,17 @@ def test_include_transformed(self): ) assert "p_interval__" in inference_data.posterior + @pytest.mark.filterwarnings( + "error", + # Related to https://github.com/arviz-devs/arviz/issues/2327 + "ignore:datetime.datetime.utcnow():DeprecationWarning", + ) @pytest.mark.parametrize("chains", (1, 2)) def test_single_chain(self, chains): # Test that no UserWarning is raised when sampling with NUTS defaults # When this test was added, a `UserWarning: More chains (500) than draws (1)` used to be issued # when sampling with a single chain - warnings.simplefilter("error") with pm.Model(): pm.Normal("x") pm.sample(chains=chains, return_inferencedata=True) diff --git a/tests/distributions/test_timeseries.py b/tests/distributions/test_timeseries.py index 208ffffaeb..34b3f1c4a0 100644 --- a/tests/distributions/test_timeseries.py +++ b/tests/distributions/test_timeseries.py @@ -48,7 +48,12 @@ # Turn all warnings into errors for this module # Ignoring NumPy deprecation warning tracked in https://github.com/pymc-devs/pytensor/issues/146 -pytestmark = pytest.mark.filterwarnings("error", "ignore: NumPy will stop allowing conversion") +pytestmark = pytest.mark.filterwarnings( + "error", + "ignore: NumPy will stop allowing conversion", + # Related to https://github.com/arviz-devs/arviz/issues/2327 + "ignore:datetime.datetime.utcnow():DeprecationWarning", +) class TestRandomWalk: diff --git a/tests/sampling/test_mcmc_external.py b/tests/sampling/test_mcmc_external.py index 4975ee6e7d..a6c0e16680 100644 --- a/tests/sampling/test_mcmc_external.py +++ b/tests/sampling/test_mcmc_external.py @@ -47,7 +47,7 @@ def test_external_nuts_sampler(recwarn, nuts_sampler): warns = { (warn.category, warn.message.args[0]) for warn in recwarn - if warn.category is not FutureWarning + if warn.category not in (FutureWarning, DeprecationWarning) } expected = set() if nuts_sampler == "nutpie": diff --git a/tests/variational/test_inference.py b/tests/variational/test_inference.py index 2e0a3c1887..99613511a7 100644 --- a/tests/variational/test_inference.py +++ b/tests/variational/test_inference.py @@ -172,7 +172,13 @@ def test_fit_oo(inference, fit_kwargs, simple_model_data): warn_ctxt = nullcontext() with warn_ctxt: - trace = inference.fit(**fit_kwargs).sample(10000) + with warnings.catch_warnings(): + # Related to https://github.com/arviz-devs/arviz/issues/2327 + warnings.filterwarnings( + "ignore", message="datetime.datetime.utcnow()", category=DeprecationWarning + ) + + trace = inference.fit(**fit_kwargs).sample(10000) mu_post = simple_model_data["mu_post"] d = simple_model_data["d"] np.testing.assert_allclose(np.mean(trace.posterior["mu"]), mu_post, rtol=0.05) @@ -207,10 +213,16 @@ def test_fit_start(inference_spec, simple_model): expected_warning = observed_value.name.startswith("minibatch") with warnings.catch_warnings(record=True) as record: warnings.simplefilter("always") - try: - trace = inference.fit(n=0).sample(10000) - except NotImplementedInference as e: - pytest.skip(str(e)) + with warnings.catch_warnings(): + # Related to https://github.com/arviz-devs/arviz/issues/2327 + warnings.filterwarnings( + "ignore", message="datetime.datetime.utcnow()", category=DeprecationWarning + ) + + try: + trace = inference.fit(n=0).sample(10000) + except NotImplementedInference as e: + pytest.skip(str(e)) if expected_warning: assert len(record) > 0 From fc98122b188083082d7471fb7ce49d925fdad03a Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Fri, 22 Mar 2024 11:18:40 +0100 Subject: [PATCH 4/4] Remove filter for solved warning --- tests/distributions/test_timeseries.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/tests/distributions/test_timeseries.py b/tests/distributions/test_timeseries.py index 34b3f1c4a0..4e1bfd723b 100644 --- a/tests/distributions/test_timeseries.py +++ b/tests/distributions/test_timeseries.py @@ -47,10 +47,8 @@ from pymc.testing import assert_support_point_is_expected, select_by_precision # Turn all warnings into errors for this module -# Ignoring NumPy deprecation warning tracked in https://github.com/pymc-devs/pytensor/issues/146 pytestmark = pytest.mark.filterwarnings( "error", - "ignore: NumPy will stop allowing conversion", # Related to https://github.com/arviz-devs/arviz/issues/2327 "ignore:datetime.datetime.utcnow():DeprecationWarning", )