From 30dedf24b14228565131e79e8a27cf9de643c31c Mon Sep 17 00:00:00 2001 From: Will Dean Date: Tue, 22 Apr 2025 13:07:51 -0400 Subject: [PATCH 01/14] push up pymc-marketing mock --- pymc/testing.py | 52 +++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 52 insertions(+) diff --git a/pymc/testing.py b/pymc/testing.py index 7ef6751892..2a303dd51a 100644 --- a/pymc/testing.py +++ b/pymc/testing.py @@ -18,6 +18,8 @@ from collections.abc import Callable, Sequence from typing import Any +from arviz import InferenceData +from xarray import DataArray import numpy as np import pytensor import pytensor.tensor as pt @@ -982,3 +984,53 @@ def assert_no_rvs(vars: Sequence[Variable]) -> None: rvs = rvs_in_graph(vars) if rvs: raise AssertionError(f"RV found in graph: {rvs}") + + +def mock_sample(*args, **kwargs): + """Mock the pm.sample function by returning prior predictive samples as posterior. + + Useful for testing models that use pm.sample without running MCMC sampling. + + Examples + -------- + Using mock_sample with pytest + + .. code-block:: python + + import pytest + + import pymc as pm + from pymc.testing import mock_sample + + + @pytest.fixture(scope="module") + def mock_pymc_sample(): + original_sample = pm.sample + pm.sample = mock_sample + + yield + + pm.sample = original_sample + + """ + random_seed = kwargs.get("random_seed", None) + model = kwargs.get("model", None) + draws = kwargs.get("draws", 10) + n_chains = kwargs.get("chains", 1) + idata: InferenceData = pm.sample_prior_predictive( + model=model, + random_seed=random_seed, + draws=draws, + ) + + expanded_chains = DataArray( + np.ones(n_chains), + coords={"chain": np.arange(n_chains)}, + ) + idata.add_groups( + posterior=(idata.prior.mean("chain") * expanded_chains).transpose("chain", "draw", ...) + ) + del idata.prior + if "prior_predictive" in idata: + del idata.prior_predictive + return idata From 54705ae3d791f8edf12a345ce694eb2265888650 Mon Sep 17 00:00:00 2001 From: Will Dean Date: Tue, 22 Apr 2025 13:37:31 -0400 Subject: [PATCH 02/14] run pre-commit --- pymc/testing.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/pymc/testing.py b/pymc/testing.py index 2a303dd51a..08f38a2fdf 100644 --- a/pymc/testing.py +++ b/pymc/testing.py @@ -18,12 +18,11 @@ from collections.abc import Callable, Sequence from typing import Any -from arviz import InferenceData -from xarray import DataArray import numpy as np import pytensor import pytensor.tensor as pt +from arviz import InferenceData from numpy import random as nr from numpy import testing as npt from pytensor.compile.mode import Mode @@ -33,6 +32,7 @@ from pytensor.tensor.random.op import RandomVariable from scipy import special as sp from scipy import stats as st +from xarray import DataArray import pymc as pm @@ -1028,7 +1028,7 @@ def mock_pymc_sample(): coords={"chain": np.arange(n_chains)}, ) idata.add_groups( - posterior=(idata.prior.mean("chain") * expanded_chains).transpose("chain", "draw", ...) + posterior=(idata.prior.mean("chain") * expanded_chains).transpose("chain", "draw", ...), ) del idata.prior if "prior_predictive" in idata: From 6fd8f453a8c0a40994e12ea4a883926a71672d85 Mon Sep 17 00:00:00 2001 From: Will Dean Date: Tue, 22 Apr 2025 13:42:29 -0400 Subject: [PATCH 03/14] add small test --- tests/test_testing.py | 15 ++++++++++++++- 1 file changed, 14 insertions(+), 1 deletion(-) diff --git a/tests/test_testing.py b/tests/test_testing.py index c8caf063c2..e86500f937 100644 --- a/tests/test_testing.py +++ b/tests/test_testing.py @@ -15,7 +15,8 @@ import pytest -from pymc.testing import Domain +from pymc.testing import Domain, mock_sample +from tests.models import simple_normal @pytest.mark.parametrize( @@ -32,3 +33,15 @@ def test_domain(values, edges, expectation): with expectation: Domain(values, edges=edges) + + +def test_mock_sample() -> None: + _, model, _ = simple_normal(bounded_prior=True) + + idata = mock_sample(model=model) + + assert "posterior" in idata + assert "observed_data" in idata + assert "prior" not in idata + assert "posterior_predictive" not in idata + assert "sample_stats" not in idata From 36da40ab357196da502a6684ae7dd643b14ae92e Mon Sep 17 00:00:00 2001 From: Will Dean Date: Tue, 22 Apr 2025 13:56:06 -0400 Subject: [PATCH 04/14] use positional arg for draws like in actual sample --- pymc/testing.py | 4 ++-- tests/test_testing.py | 15 +++++++++++++-- 2 files changed, 15 insertions(+), 4 deletions(-) diff --git a/pymc/testing.py b/pymc/testing.py index 08f38a2fdf..5840336ce6 100644 --- a/pymc/testing.py +++ b/pymc/testing.py @@ -986,7 +986,7 @@ def assert_no_rvs(vars: Sequence[Variable]) -> None: raise AssertionError(f"RV found in graph: {rvs}") -def mock_sample(*args, **kwargs): +def mock_sample(draws: int = 10, **kwargs): """Mock the pm.sample function by returning prior predictive samples as posterior. Useful for testing models that use pm.sample without running MCMC sampling. @@ -1015,7 +1015,7 @@ def mock_pymc_sample(): """ random_seed = kwargs.get("random_seed", None) model = kwargs.get("model", None) - draws = kwargs.get("draws", 10) + draws = kwargs.get("draws", draws) n_chains = kwargs.get("chains", 1) idata: InferenceData = pm.sample_prior_predictive( model=model, diff --git a/tests/test_testing.py b/tests/test_testing.py index e86500f937..ff05b5e869 100644 --- a/tests/test_testing.py +++ b/tests/test_testing.py @@ -35,13 +35,24 @@ def test_domain(values, edges, expectation): Domain(values, edges=edges) -def test_mock_sample() -> None: +@pytest.mark.parametrize( + "args, kwargs, expected_draws", + [ + pytest.param((), {}, 10, id="default"), + pytest.param((100,), {}, 100, id="positional-draws"), + pytest.param((), {"draws": 100}, 100, id="keyword-draws"), + ], +) +def test_mock_sample(args, kwargs, expected_draws) -> None: _, model, _ = simple_normal(bounded_prior=True) - idata = mock_sample(model=model) + with model: + idata = mock_sample(*args, **kwargs) assert "posterior" in idata assert "observed_data" in idata assert "prior" not in idata assert "posterior_predictive" not in idata assert "sample_stats" not in idata + + assert idata.posterior.sizes == {"chain": 1, "draw": expected_draws} From 5884a39de74b2fca00ca03e584fde934af548a9c Mon Sep 17 00:00:00 2001 From: Will Dean Date: Tue, 22 Apr 2025 14:03:29 -0400 Subject: [PATCH 05/14] better for mypy --- pymc/testing.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/pymc/testing.py b/pymc/testing.py index 5840336ce6..5b80bdddd0 100644 --- a/pymc/testing.py +++ b/pymc/testing.py @@ -1028,9 +1028,9 @@ def mock_pymc_sample(): coords={"chain": np.arange(n_chains)}, ) idata.add_groups( - posterior=(idata.prior.mean("chain") * expanded_chains).transpose("chain", "draw", ...), + posterior=(idata["prior"].mean("chain") * expanded_chains).transpose("chain", "draw", ...), ) - del idata.prior + del idata["prior"] if "prior_predictive" in idata: - del idata.prior_predictive + del idata["prior_predictive"] return idata From 34709f2c1a7680c67a057eb72823358c9cc74fa0 Mon Sep 17 00:00:00 2001 From: Will Dean Date: Tue, 22 Apr 2025 14:30:44 -0400 Subject: [PATCH 06/14] provide the setup and breakdown for pytest fixtures --- pymc/testing.py | 51 +++++++++++++++++++++++++++++++++++++++++++ tests/test_testing.py | 17 ++++++++++++++- 2 files changed, 67 insertions(+), 1 deletion(-) diff --git a/pymc/testing.py b/pymc/testing.py index 5b80bdddd0..88f47f677a 100644 --- a/pymc/testing.py +++ b/pymc/testing.py @@ -1034,3 +1034,54 @@ def mock_pymc_sample(): if "prior_predictive" in idata: del idata["prior_predictive"] return idata + + +def mock_sample_setup_and_breakdown(): + """Set up and tear down mocking of PyMC sampling functions for testing. + + This function is designed to be used with pytest fixtures to temporarily replace + PyMC's sampling functionality with faster alternatives for testing purposes. + + Effects during the fixture's active period: + * Replaces pm.sample with mock_sample, which uses prior predictive sampling + instead of MCMC + * Replaces pm.Flat with pm.Normal to avoid issues with unbounded priors + * Replaces pm.HalfFlat with pm.HalfNormal to avoid issues with semi-bounded priors + * Automatically restores all original functions after the test completes + + Examples + -------- + .. code-block:: python + + import pytest + import pymc as pm + from pymc.testing import mock_sample_setup_and_breakdown + + # Register as a pytest fixture + mock_pymc_sample = pytest.fixture(scope="function")(mock_sample_setup_and_breakdown) + + + # Use in a test function + def test_model_inference(mock_pymc_sample): + with pm.Model() as model: + x = pm.Normal("x", 0, 1) + # This will use mock_sample instead of actual MCMC + idata = pm.sample() + # Test with the inference data... + + """ + import pymc as pm + + original_flat = pm.Flat + original_half_flat = pm.HalfFlat + original_sample = pm.sample + + pm.sample = mock_sample + pm.Flat = pm.Normal + pm.HalfFlat = pm.HalfNormal + + yield + + pm.sample = original_sample + pm.Flat = original_flat + pm.HalfFlat = original_half_flat diff --git a/tests/test_testing.py b/tests/test_testing.py index ff05b5e869..b3eff4d335 100644 --- a/tests/test_testing.py +++ b/tests/test_testing.py @@ -15,7 +15,9 @@ import pytest -from pymc.testing import Domain, mock_sample +import pymc as pm + +from pymc.testing import Domain, mock_sample, mock_sample_setup_and_breakdown from tests.models import simple_normal @@ -56,3 +58,16 @@ def test_mock_sample(args, kwargs, expected_draws) -> None: assert "sample_stats" not in idata assert idata.posterior.sizes == {"chain": 1, "draw": expected_draws} + + +mock_pymc_sample = pytest.fixture(scope="function")(mock_sample_setup_and_breakdown) + + +def test_fixture(mock_pymc_sample) -> None: + # This has Flat distribution + _, model, _ = simple_normal(bounded_prior=False) + + with model: + idata = pm.sample() + + assert idata.posterior.sizes == {"chain": 1, "draw": 10} From d3d965a644bbcb13a6e3cbd798705ce0e769772c Mon Sep 17 00:00:00 2001 From: Will Dean Date: Tue, 22 Apr 2025 14:39:17 -0400 Subject: [PATCH 07/14] change name for testing convention --- pymc/testing.py | 6 +++--- tests/test_testing.py | 4 ++-- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/pymc/testing.py b/pymc/testing.py index 88f47f677a..36d72c2664 100644 --- a/pymc/testing.py +++ b/pymc/testing.py @@ -1036,7 +1036,7 @@ def mock_pymc_sample(): return idata -def mock_sample_setup_and_breakdown(): +def mock_sample_setup_and_teardown(): """Set up and tear down mocking of PyMC sampling functions for testing. This function is designed to be used with pytest fixtures to temporarily replace @@ -1055,10 +1055,10 @@ def mock_sample_setup_and_breakdown(): import pytest import pymc as pm - from pymc.testing import mock_sample_setup_and_breakdown + from pymc.testing import mock_sample_setup_and_teardown # Register as a pytest fixture - mock_pymc_sample = pytest.fixture(scope="function")(mock_sample_setup_and_breakdown) + mock_pymc_sample = pytest.fixture(scope="function")(mock_sample_setup_and_teardown) # Use in a test function diff --git a/tests/test_testing.py b/tests/test_testing.py index b3eff4d335..eb55af11a5 100644 --- a/tests/test_testing.py +++ b/tests/test_testing.py @@ -17,7 +17,7 @@ import pymc as pm -from pymc.testing import Domain, mock_sample, mock_sample_setup_and_breakdown +from pymc.testing import Domain, mock_sample, mock_sample_setup_and_teardown from tests.models import simple_normal @@ -60,7 +60,7 @@ def test_mock_sample(args, kwargs, expected_draws) -> None: assert idata.posterior.sizes == {"chain": 1, "draw": expected_draws} -mock_pymc_sample = pytest.fixture(scope="function")(mock_sample_setup_and_breakdown) +mock_pymc_sample = pytest.fixture(scope="function")(mock_sample_setup_and_teardown) def test_fixture(mock_pymc_sample) -> None: From 7312437e48bbf89d90adeab7412241b2cd1921d5 Mon Sep 17 00:00:00 2001 From: Will Dean Date: Wed, 23 Apr 2025 13:06:21 -0400 Subject: [PATCH 08/14] bit more explicit on the test --- tests/test_testing.py | 18 +++++++++++++----- 1 file changed, 13 insertions(+), 5 deletions(-) diff --git a/tests/test_testing.py b/tests/test_testing.py index eb55af11a5..78002d03a9 100644 --- a/tests/test_testing.py +++ b/tests/test_testing.py @@ -63,11 +63,19 @@ def test_mock_sample(args, kwargs, expected_draws) -> None: mock_pymc_sample = pytest.fixture(scope="function")(mock_sample_setup_and_teardown) -def test_fixture(mock_pymc_sample) -> None: - # This has Flat distribution - _, model, _ = simple_normal(bounded_prior=False) +@pytest.fixture(scope="function") +def dummy_model() -> pm.Model: + with pm.Model() as model: + pm.Flat("flat") + pm.HalfFlat("half_flat") - with model: + return model + + +def test_fixture(mock_pymc_sample, dummy_model) -> None: + with dummy_model: idata = pm.sample() - assert idata.posterior.sizes == {"chain": 1, "draw": 10} + posterior = idata.posterior + assert posterior.sizes == {"chain": 1, "draw": 10} + assert (posterior["half_flat"] >= 0).all() From ebe8cd20dd8f8876c2d40892dbaa6e6dd563d173 Mon Sep 17 00:00:00 2001 From: Will Dean Date: Thu, 24 Apr 2025 15:16:11 -0400 Subject: [PATCH 09/14] add to the documentation --- docs/source/api/testing.rst | 14 ++++++++++++++ 1 file changed, 14 insertions(+) create mode 100644 docs/source/api/testing.rst diff --git a/docs/source/api/testing.rst b/docs/source/api/testing.rst new file mode 100644 index 0000000000..10e01a5cf2 --- /dev/null +++ b/docs/source/api/testing.rst @@ -0,0 +1,14 @@ +======= +Testing +======= + +This submodule contains tools to help with testing PyMC code. + + +.. currentmodule:: pymc.testing + +.. autosummary:: + :toctree: generated/ + + mock_sample + mock_sample_setup_and_teardown From 06754f7320d6fccd03ca8526e5fa52c0e9c181cf Mon Sep 17 00:00:00 2001 From: Will Dean Date: Thu, 24 Apr 2025 15:30:06 -0400 Subject: [PATCH 10/14] use expand_dims method --- pymc/testing.py | 12 ++++++------ tests/test_testing.py | 14 ++++++++------ 2 files changed, 14 insertions(+), 12 deletions(-) diff --git a/pymc/testing.py b/pymc/testing.py index 36d72c2664..a7ec51152b 100644 --- a/pymc/testing.py +++ b/pymc/testing.py @@ -32,7 +32,6 @@ from pytensor.tensor.random.op import RandomVariable from scipy import special as sp from scipy import stats as st -from xarray import DataArray import pymc as pm @@ -1023,12 +1022,13 @@ def mock_pymc_sample(): draws=draws, ) - expanded_chains = DataArray( - np.ones(n_chains), - coords={"chain": np.arange(n_chains)}, - ) idata.add_groups( - posterior=(idata["prior"].mean("chain") * expanded_chains).transpose("chain", "draw", ...), + posterior=( + idata["prior"] + .isel(chain=0) + .expand_dims({"chain": range(n_chains)}) + .transpose("chain", "draw", ...) + ) ) del idata["prior"] if "prior_predictive" in idata: diff --git a/tests/test_testing.py b/tests/test_testing.py index 78002d03a9..105e2f6209 100644 --- a/tests/test_testing.py +++ b/tests/test_testing.py @@ -38,14 +38,16 @@ def test_domain(values, edges, expectation): @pytest.mark.parametrize( - "args, kwargs, expected_draws", + "args, kwargs, expected_size", [ - pytest.param((), {}, 10, id="default"), - pytest.param((100,), {}, 100, id="positional-draws"), - pytest.param((), {"draws": 100}, 100, id="keyword-draws"), + pytest.param((), {}, (1, 10), id="default"), + pytest.param((100,), {}, (1, 100), id="positional-draws"), + pytest.param((), {"draws": 100}, (1, 100), id="keyword-draws"), + pytest.param((100,), {"chains": 6}, (6, 100), id="chains"), ], ) -def test_mock_sample(args, kwargs, expected_draws) -> None: +def test_mock_sample(args, kwargs, expected_size) -> None: + expected_chains, expected_draws = expected_size _, model, _ = simple_normal(bounded_prior=True) with model: @@ -57,7 +59,7 @@ def test_mock_sample(args, kwargs, expected_draws) -> None: assert "posterior_predictive" not in idata assert "sample_stats" not in idata - assert idata.posterior.sizes == {"chain": 1, "draw": expected_draws} + assert idata.posterior.sizes == {"chain": expected_chains, "draw": expected_draws} mock_pymc_sample = pytest.fixture(scope="function")(mock_sample_setup_and_teardown) From 34404ae916ad2dd30a5d3df6077ddfe9cc6e38da Mon Sep 17 00:00:00 2001 From: Will Dean Date: Thu, 24 Apr 2025 15:34:24 -0400 Subject: [PATCH 11/14] add to the toc --- docs/source/api.rst | 1 + 1 file changed, 1 insertion(+) diff --git a/docs/source/api.rst b/docs/source/api.rst index a82da9bc99..d80c0984ff 100644 --- a/docs/source/api.rst +++ b/docs/source/api.rst @@ -22,6 +22,7 @@ API api/shape_utils api/backends api/misc + api/testing ------------------ Dimensionality From 91f9c5fb1340230330f9dee29ea9ef887011a647 Mon Sep 17 00:00:00 2001 From: Will Dean Date: Thu, 24 Apr 2025 15:42:00 -0400 Subject: [PATCH 12/14] alterations to docstrings --- pymc/testing.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/pymc/testing.py b/pymc/testing.py index a7ec51152b..d0b94f9327 100644 --- a/pymc/testing.py +++ b/pymc/testing.py @@ -994,6 +994,10 @@ def mock_sample(draws: int = 10, **kwargs): -------- Using mock_sample with pytest + .. note:: + + Use :func:`pymc.testing.mock_sample_setup_and_teardown` directly for pytest fixtures. + .. code-block:: python import pytest @@ -1043,6 +1047,7 @@ def mock_sample_setup_and_teardown(): PyMC's sampling functionality with faster alternatives for testing purposes. Effects during the fixture's active period: + * Replaces pm.sample with mock_sample, which uses prior predictive sampling instead of MCMC * Replaces pm.Flat with pm.Normal to avoid issues with unbounded priors @@ -1051,8 +1056,11 @@ def mock_sample_setup_and_teardown(): Examples -------- + Use with `pytest` to mock actual PyMC sampling in test suite. + .. code-block:: python + # tests/conftest.py import pytest import pymc as pm from pymc.testing import mock_sample_setup_and_teardown @@ -1061,6 +1069,7 @@ def mock_sample_setup_and_teardown(): mock_pymc_sample = pytest.fixture(scope="function")(mock_sample_setup_and_teardown) + # tests/test_model.py # Use in a test function def test_model_inference(mock_pymc_sample): with pm.Model() as model: From 8a1f76fa58dff22be9f5ea0cc41adabbc8df7df2 Mon Sep 17 00:00:00 2001 From: Will Dean Date: Thu, 24 Apr 2025 15:48:07 -0400 Subject: [PATCH 13/14] change format and provide links --- pymc/testing.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/pymc/testing.py b/pymc/testing.py index d0b94f9327..38bf2cfafb 100644 --- a/pymc/testing.py +++ b/pymc/testing.py @@ -1048,11 +1048,12 @@ def mock_sample_setup_and_teardown(): Effects during the fixture's active period: - * Replaces pm.sample with mock_sample, which uses prior predictive sampling - instead of MCMC - * Replaces pm.Flat with pm.Normal to avoid issues with unbounded priors - * Replaces pm.HalfFlat with pm.HalfNormal to avoid issues with semi-bounded priors - * Automatically restores all original functions after the test completes + * Replaces :func:`pymc.sample` with :func:`pymc.testing.mock_sample`, which uses + prior predictive sampling instead of MCMC + * Replaces distributions: + * :class:`pymc.Flat` with :class:`pymc.Normal` + * :class:`pymc.HalfFlat` with :class:`pymc.HalfNormal` + * Automatically restores all original functions and distributions after the test completes Examples -------- From d3a8c1cb614b385727e5f10f2d185dfebd4c9865 Mon Sep 17 00:00:00 2001 From: Will Dean Date: Thu, 24 Apr 2025 15:54:29 -0400 Subject: [PATCH 14/14] link to the functions in the docstring --- pymc/testing.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pymc/testing.py b/pymc/testing.py index 38bf2cfafb..a5fdc28327 100644 --- a/pymc/testing.py +++ b/pymc/testing.py @@ -986,7 +986,7 @@ def assert_no_rvs(vars: Sequence[Variable]) -> None: def mock_sample(draws: int = 10, **kwargs): - """Mock the pm.sample function by returning prior predictive samples as posterior. + """Mock :func:`pymc.sample` with :func:`pymc.sample_prior_predictive`. Useful for testing models that use pm.sample without running MCMC sampling.