Skip to content

Add public testing function to mock sample #7761

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 14 commits into from
May 1, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/source/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ API
api/shape_utils
api/backends
api/misc
api/testing

------------------
Dimensionality
Expand Down
14 changes: 14 additions & 0 deletions docs/source/api/testing.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
=======
Testing
=======

This submodule contains tools to help with testing PyMC code.


.. currentmodule:: pymc.testing

.. autosummary::
:toctree: generated/

mock_sample
mock_sample_setup_and_teardown
113 changes: 113 additions & 0 deletions pymc/testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import pytensor
import pytensor.tensor as pt

from arviz import InferenceData
from numpy import random as nr
from numpy import testing as npt
from pytensor.compile.mode import Mode
Expand Down Expand Up @@ -982,3 +983,115 @@ def assert_no_rvs(vars: Sequence[Variable]) -> None:
rvs = rvs_in_graph(vars)
if rvs:
raise AssertionError(f"RV found in graph: {rvs}")


def mock_sample(draws: int = 10, **kwargs):
"""Mock :func:`pymc.sample` with :func:`pymc.sample_prior_predictive`.

Useful for testing models that use pm.sample without running MCMC sampling.

Examples
--------
Using mock_sample with pytest

.. note::

Use :func:`pymc.testing.mock_sample_setup_and_teardown` directly for pytest fixtures.

.. code-block:: python

import pytest

import pymc as pm
from pymc.testing import mock_sample


@pytest.fixture(scope="module")
def mock_pymc_sample():
original_sample = pm.sample
pm.sample = mock_sample

yield

pm.sample = original_sample

"""
random_seed = kwargs.get("random_seed", None)
model = kwargs.get("model", None)
draws = kwargs.get("draws", draws)
n_chains = kwargs.get("chains", 1)
idata: InferenceData = pm.sample_prior_predictive(
model=model,
random_seed=random_seed,
draws=draws,
)

idata.add_groups(
posterior=(
idata["prior"]
.isel(chain=0)
.expand_dims({"chain": range(n_chains)})
.transpose("chain", "draw", ...)
)
)
del idata["prior"]
if "prior_predictive" in idata:
del idata["prior_predictive"]
return idata


def mock_sample_setup_and_teardown():
"""Set up and tear down mocking of PyMC sampling functions for testing.

This function is designed to be used with pytest fixtures to temporarily replace
PyMC's sampling functionality with faster alternatives for testing purposes.

Effects during the fixture's active period:

* Replaces :func:`pymc.sample` with :func:`pymc.testing.mock_sample`, which uses
prior predictive sampling instead of MCMC
* Replaces distributions:
* :class:`pymc.Flat` with :class:`pymc.Normal`
* :class:`pymc.HalfFlat` with :class:`pymc.HalfNormal`
* Automatically restores all original functions and distributions after the test completes

Examples
--------
Use with `pytest` to mock actual PyMC sampling in test suite.

.. code-block:: python

# tests/conftest.py
import pytest
import pymc as pm
from pymc.testing import mock_sample_setup_and_teardown

# Register as a pytest fixture
mock_pymc_sample = pytest.fixture(scope="function")(mock_sample_setup_and_teardown)


# tests/test_model.py
# Use in a test function
def test_model_inference(mock_pymc_sample):
with pm.Model() as model:
x = pm.Normal("x", 0, 1)
# This will use mock_sample instead of actual MCMC
idata = pm.sample()
# Test with the inference data...

"""
import pymc as pm

original_flat = pm.Flat
original_half_flat = pm.HalfFlat
original_sample = pm.sample

pm.sample = mock_sample
pm.Flat = pm.Normal
pm.HalfFlat = pm.HalfNormal

yield

pm.sample = original_sample
pm.Flat = original_flat
pm.HalfFlat = original_half_flat
51 changes: 50 additions & 1 deletion tests/test_testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,10 @@

import pytest

from pymc.testing import Domain
import pymc as pm

from pymc.testing import Domain, mock_sample, mock_sample_setup_and_teardown
from tests.models import simple_normal


@pytest.mark.parametrize(
Expand All @@ -32,3 +35,49 @@
def test_domain(values, edges, expectation):
with expectation:
Domain(values, edges=edges)


@pytest.mark.parametrize(
"args, kwargs, expected_size",
[
pytest.param((), {}, (1, 10), id="default"),
pytest.param((100,), {}, (1, 100), id="positional-draws"),
pytest.param((), {"draws": 100}, (1, 100), id="keyword-draws"),
pytest.param((100,), {"chains": 6}, (6, 100), id="chains"),
],
)
def test_mock_sample(args, kwargs, expected_size) -> None:
expected_chains, expected_draws = expected_size
_, model, _ = simple_normal(bounded_prior=True)

with model:
idata = mock_sample(*args, **kwargs)

assert "posterior" in idata
assert "observed_data" in idata
assert "prior" not in idata
assert "posterior_predictive" not in idata
assert "sample_stats" not in idata

assert idata.posterior.sizes == {"chain": expected_chains, "draw": expected_draws}


mock_pymc_sample = pytest.fixture(scope="function")(mock_sample_setup_and_teardown)


@pytest.fixture(scope="function")
def dummy_model() -> pm.Model:
with pm.Model() as model:
pm.Flat("flat")
pm.HalfFlat("half_flat")

return model


def test_fixture(mock_pymc_sample, dummy_model) -> None:
with dummy_model:
idata = pm.sample()

posterior = idata.posterior
assert posterior.sizes == {"chain": 1, "draw": 10}
assert (posterior["half_flat"] >= 0).all()