diff --git a/docs/api_reference.rst b/docs/api_reference.rst index 867a6e0fc..753cbf9e4 100644 --- a/docs/api_reference.rst +++ b/docs/api_reference.rst @@ -8,7 +8,8 @@ methods in the current release of PyMC experimental. .. autosummary:: :toctree: generated/ - marginal_model.MarginalModel + as_model + MarginalModel model_builder.ModelBuilder Inference diff --git a/pymc_experimental/__init__.py b/pymc_experimental/__init__.py index ee1cc8798..e319ef35f 100644 --- a/pymc_experimental/__init__.py +++ b/pymc_experimental/__init__.py @@ -25,4 +25,5 @@ from pymc_experimental import distributions, gp, utils from pymc_experimental.inference.fit import fit -from pymc_experimental.marginal_model import MarginalModel +from pymc_experimental.model.marginal_model import MarginalModel +from pymc_experimental.model.model_api import as_model diff --git a/pymc_experimental/model/__init__.py b/pymc_experimental/model/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/pymc_experimental/marginal_model.py b/pymc_experimental/model/marginal_model.py similarity index 100% rename from pymc_experimental/marginal_model.py rename to pymc_experimental/model/marginal_model.py diff --git a/pymc_experimental/model/model_api.py b/pymc_experimental/model/model_api.py new file mode 100644 index 000000000..8425a764f --- /dev/null +++ b/pymc_experimental/model/model_api.py @@ -0,0 +1,46 @@ +from functools import wraps + +from pymc import Model + + +def as_model(*model_args, **model_kwargs): + R""" + Decorator to provide context to PyMC models declared in a function. + This removes all need to think about context managers and lets you separate creating a generative model from using the model. + + Adapted from `Rob Zinkov's blog post `_ and inspired by the `sampled `_ decorator for PyMC3. + + Examples + -------- + .. code:: python + + import pymc as pm + import pymc_experimental as pmx + + # The following are equivalent + + # standard PyMC API with context manager + with pm.Model(coords={"obs": ["a", "b"]}) as model: + x = pm.Normal("x", 0., 1., dims="obs") + pm.sample() + + # functional API using decorator + @pmx.as_model(coords={"obs": ["a", "b"]}) + def basic_model(): + pm.Normal("x", 0., 1., dims="obs") + + m = basic_model() + pm.sample(model=m) + + """ + + def decorator(f): + @wraps(f) + def make_model(*args, **kwargs): + with Model(*model_args, **model_kwargs) as m: + f(*args, **kwargs) + return m + + return make_model + + return decorator diff --git a/pymc_experimental/tests/model/__init__.py b/pymc_experimental/tests/model/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/pymc_experimental/tests/test_marginal_model.py b/pymc_experimental/tests/model/test_marginal_model.py similarity index 99% rename from pymc_experimental/tests/test_marginal_model.py rename to pymc_experimental/tests/model/test_marginal_model.py index d749fba80..f8ed5718d 100644 --- a/pymc_experimental/tests/test_marginal_model.py +++ b/pymc_experimental/tests/model/test_marginal_model.py @@ -12,7 +12,7 @@ from pymc.util import UNSET from scipy.special import logsumexp -from pymc_experimental.marginal_model import ( +from pymc_experimental.model.marginal_model import ( FiniteDiscreteMarginalRV, MarginalModel, is_conditional_dependent, diff --git a/pymc_experimental/tests/model/test_model_api.py b/pymc_experimental/tests/model/test_model_api.py new file mode 100644 index 000000000..a36c5e5b6 --- /dev/null +++ b/pymc_experimental/tests/model/test_model_api.py @@ -0,0 +1,22 @@ +import numpy as np +import pymc as pm + +import pymc_experimental as pmx + + +def test_logp(): + """Compare standard PyMC `with pm.Model()` context API against `pmx.model` decorator + and a functional syntax. Checks whether the kwarg `coords` can be passed. + """ + coords = {"obs": ["a", "b"]} + + with pm.Model(coords=coords) as model: + pm.Normal("x", 0.0, 1.0, dims="obs") + + @pmx.as_model(coords=coords) + def model_wrapped(): + pm.Normal("x", 0.0, 1.0, dims="obs") + + mw = model_wrapped() + + np.testing.assert_equal(model.point_logps(), mw.point_logps())