diff --git a/docs/api_reference.rst b/docs/api_reference.rst
index 867a6e0fc..753cbf9e4 100644
--- a/docs/api_reference.rst
+++ b/docs/api_reference.rst
@@ -8,7 +8,8 @@ methods in the current release of PyMC experimental.
.. autosummary::
:toctree: generated/
- marginal_model.MarginalModel
+ as_model
+ MarginalModel
model_builder.ModelBuilder
Inference
diff --git a/pymc_experimental/__init__.py b/pymc_experimental/__init__.py
index ee1cc8798..e319ef35f 100644
--- a/pymc_experimental/__init__.py
+++ b/pymc_experimental/__init__.py
@@ -25,4 +25,5 @@
from pymc_experimental import distributions, gp, utils
from pymc_experimental.inference.fit import fit
-from pymc_experimental.marginal_model import MarginalModel
+from pymc_experimental.model.marginal_model import MarginalModel
+from pymc_experimental.model.model_api import as_model
diff --git a/pymc_experimental/model/__init__.py b/pymc_experimental/model/__init__.py
new file mode 100644
index 000000000..e69de29bb
diff --git a/pymc_experimental/marginal_model.py b/pymc_experimental/model/marginal_model.py
similarity index 100%
rename from pymc_experimental/marginal_model.py
rename to pymc_experimental/model/marginal_model.py
diff --git a/pymc_experimental/model/model_api.py b/pymc_experimental/model/model_api.py
new file mode 100644
index 000000000..8425a764f
--- /dev/null
+++ b/pymc_experimental/model/model_api.py
@@ -0,0 +1,46 @@
+from functools import wraps
+
+from pymc import Model
+
+
+def as_model(*model_args, **model_kwargs):
+ R"""
+ Decorator to provide context to PyMC models declared in a function.
+ This removes all need to think about context managers and lets you separate creating a generative model from using the model.
+
+ Adapted from `Rob Zinkov's blog post `_ and inspired by the `sampled `_ decorator for PyMC3.
+
+ Examples
+ --------
+ .. code:: python
+
+ import pymc as pm
+ import pymc_experimental as pmx
+
+ # The following are equivalent
+
+ # standard PyMC API with context manager
+ with pm.Model(coords={"obs": ["a", "b"]}) as model:
+ x = pm.Normal("x", 0., 1., dims="obs")
+ pm.sample()
+
+ # functional API using decorator
+ @pmx.as_model(coords={"obs": ["a", "b"]})
+ def basic_model():
+ pm.Normal("x", 0., 1., dims="obs")
+
+ m = basic_model()
+ pm.sample(model=m)
+
+ """
+
+ def decorator(f):
+ @wraps(f)
+ def make_model(*args, **kwargs):
+ with Model(*model_args, **model_kwargs) as m:
+ f(*args, **kwargs)
+ return m
+
+ return make_model
+
+ return decorator
diff --git a/pymc_experimental/tests/model/__init__.py b/pymc_experimental/tests/model/__init__.py
new file mode 100644
index 000000000..e69de29bb
diff --git a/pymc_experimental/tests/test_marginal_model.py b/pymc_experimental/tests/model/test_marginal_model.py
similarity index 99%
rename from pymc_experimental/tests/test_marginal_model.py
rename to pymc_experimental/tests/model/test_marginal_model.py
index d749fba80..f8ed5718d 100644
--- a/pymc_experimental/tests/test_marginal_model.py
+++ b/pymc_experimental/tests/model/test_marginal_model.py
@@ -12,7 +12,7 @@
from pymc.util import UNSET
from scipy.special import logsumexp
-from pymc_experimental.marginal_model import (
+from pymc_experimental.model.marginal_model import (
FiniteDiscreteMarginalRV,
MarginalModel,
is_conditional_dependent,
diff --git a/pymc_experimental/tests/model/test_model_api.py b/pymc_experimental/tests/model/test_model_api.py
new file mode 100644
index 000000000..a36c5e5b6
--- /dev/null
+++ b/pymc_experimental/tests/model/test_model_api.py
@@ -0,0 +1,22 @@
+import numpy as np
+import pymc as pm
+
+import pymc_experimental as pmx
+
+
+def test_logp():
+ """Compare standard PyMC `with pm.Model()` context API against `pmx.model` decorator
+ and a functional syntax. Checks whether the kwarg `coords` can be passed.
+ """
+ coords = {"obs": ["a", "b"]}
+
+ with pm.Model(coords=coords) as model:
+ pm.Normal("x", 0.0, 1.0, dims="obs")
+
+ @pmx.as_model(coords=coords)
+ def model_wrapped():
+ pm.Normal("x", 0.0, 1.0, dims="obs")
+
+ mw = model_wrapped()
+
+ np.testing.assert_equal(model.point_logps(), mw.point_logps())