From bb48c880caa452565a9e20ce64c10e91811a154d Mon Sep 17 00:00:00 2001 From: theorashid Date: Mon, 20 Nov 2023 14:09:28 +0000 Subject: [PATCH 01/10] move marginal_model to model/ --- pymc_experimental/__init__.py | 3 ++- pymc_experimental/model/__init__.py | 0 pymc_experimental/{ => model}/marginal_model.py | 0 pymc_experimental/tests/model/__init__.py | 0 pymc_experimental/tests/{ => model}/test_marginal_model.py | 2 +- 5 files changed, 3 insertions(+), 2 deletions(-) create mode 100644 pymc_experimental/model/__init__.py rename pymc_experimental/{ => model}/marginal_model.py (100%) create mode 100644 pymc_experimental/tests/model/__init__.py rename pymc_experimental/tests/{ => model}/test_marginal_model.py (99%) diff --git a/pymc_experimental/__init__.py b/pymc_experimental/__init__.py index ee1cc8798..1333b35d5 100644 --- a/pymc_experimental/__init__.py +++ b/pymc_experimental/__init__.py @@ -25,4 +25,5 @@ from pymc_experimental import distributions, gp, utils from pymc_experimental.inference.fit import fit -from pymc_experimental.marginal_model import MarginalModel +from pymc_experimental.model.marginal_model import MarginalModel +from pymc_experimental.model.model_api import model diff --git a/pymc_experimental/model/__init__.py b/pymc_experimental/model/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/pymc_experimental/marginal_model.py b/pymc_experimental/model/marginal_model.py similarity index 100% rename from pymc_experimental/marginal_model.py rename to pymc_experimental/model/marginal_model.py diff --git a/pymc_experimental/tests/model/__init__.py b/pymc_experimental/tests/model/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/pymc_experimental/tests/test_marginal_model.py b/pymc_experimental/tests/model/test_marginal_model.py similarity index 99% rename from pymc_experimental/tests/test_marginal_model.py rename to pymc_experimental/tests/model/test_marginal_model.py index d749fba80..f8ed5718d 100644 --- a/pymc_experimental/tests/test_marginal_model.py +++ b/pymc_experimental/tests/model/test_marginal_model.py @@ -12,7 +12,7 @@ from pymc.util import UNSET from scipy.special import logsumexp -from pymc_experimental.marginal_model import ( +from pymc_experimental.model.marginal_model import ( FiniteDiscreteMarginalRV, MarginalModel, is_conditional_dependent, From 5c7d42a7d751975b3b2e7a1160e9e29c0ed148f2 Mon Sep 17 00:00:00 2001 From: theorashid Date: Mon, 20 Nov 2023 14:10:24 +0000 Subject: [PATCH 02/10] add model decorator --- pymc_experimental/model/model_api.py | 45 +++++++++++++++++++ .../tests/model/test_model_api.py | 25 +++++++++++ 2 files changed, 70 insertions(+) create mode 100644 pymc_experimental/model/model_api.py create mode 100644 pymc_experimental/tests/model/test_model_api.py diff --git a/pymc_experimental/model/model_api.py b/pymc_experimental/model/model_api.py new file mode 100644 index 000000000..a1b799cf9 --- /dev/null +++ b/pymc_experimental/model/model_api.py @@ -0,0 +1,45 @@ +from functools import wraps + +import pymc as pm + + +def model(*model_args, **model_kwargs): + R""" + Decorator to provide context to PyMC models declared in a function. + This removes all need to think about context managers and lets you separate creating a generative model from using the model. + + Inspired by the [sampled](https://github.com/colcarroll/sampled) decorator for PyMC3. + + Examples + -------- + .. code:: python + + # The following are equivalent + + # standard PyMC API with context manager + with pm.Model(coords={"obs": ["a", "b"]}) as model: + x = pm.Normal("x", 0., 1., dims="obs") + pm.sample() + + # functional API using decorator + @pmx.model(coords={"obs": ["a", "b"]}) + def basic_model(): + pm.Normal("x", 0., 1., dims="obs") + + m = basic_model() + pm.sample(model=m) + + ... + + """ + + def decorator(f): + @wraps(f) + def make_model(*args, **kwargs): + with pm.Model(*model_args, **model_kwargs) as m: + f(*args, **kwargs) + return m + + return make_model + + return decorator diff --git a/pymc_experimental/tests/model/test_model_api.py b/pymc_experimental/tests/model/test_model_api.py new file mode 100644 index 000000000..81e77ba50 --- /dev/null +++ b/pymc_experimental/tests/model/test_model_api.py @@ -0,0 +1,25 @@ +import numpy as np +import pymc as pm + +import pymc_experimental as pmx + + +def test_sample(): + """Compare standard PyMC `with pm.Model()` context API against `pmx.model` decorator + and a functional syntax. Checks whether the kwarge `coords` can be passed. + """ + coords = {"obs": ["a", "b"]} + kwargs = {"draws": 50, "tune": 50, "chains": 1, "random_seed": 1} + + with pm.Model(coords=coords) as model: + pm.Normal("x", 0.0, 1.0, dims="obs") + idata = pm.sample(**kwargs) + + @pmx.model(coords=coords) + def model_wrapped(): + pm.Normal("x", 0.0, 1.0, dims="obs") + + mw = model_wrapped() + idata_wrapped = pm.sample(model=mw, **kwargs) + + np.testing.assert_array_equal(idata.posterior.x, idata_wrapped.posterior.x) From c5f62f5c4ce1704d4d7f57ba516ebc7b997527d0 Mon Sep 17 00:00:00 2001 From: theorashid Date: Mon, 20 Nov 2023 15:13:31 +0000 Subject: [PATCH 03/10] expose to docs --- docs/api_reference.rst | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/docs/api_reference.rst b/docs/api_reference.rst index 867a6e0fc..c7c7ce88b 100644 --- a/docs/api_reference.rst +++ b/docs/api_reference.rst @@ -8,7 +8,8 @@ methods in the current release of PyMC experimental. .. autosummary:: :toctree: generated/ - marginal_model.MarginalModel + model + MarginalModel model_builder.ModelBuilder Inference From 38007c2e715484784320092fd8a853c3b8574d29 Mon Sep 17 00:00:00 2001 From: theorashid Date: Mon, 20 Nov 2023 15:14:04 +0000 Subject: [PATCH 04/10] no ... at end of codeblock --- pymc_experimental/model/model_api.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/pymc_experimental/model/model_api.py b/pymc_experimental/model/model_api.py index a1b799cf9..0bfcd0a0b 100644 --- a/pymc_experimental/model/model_api.py +++ b/pymc_experimental/model/model_api.py @@ -29,8 +29,6 @@ def basic_model(): m = basic_model() pm.sample(model=m) - ... - """ def decorator(f): From 52a04e723d02843610c5b350606f0d34c6518904 Mon Sep 17 00:00:00 2001 From: theorashid Date: Mon, 20 Nov 2023 17:15:36 +0000 Subject: [PATCH 05/10] cheaper test using point_logps --- pymc_experimental/tests/model/test_model_api.py | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/pymc_experimental/tests/model/test_model_api.py b/pymc_experimental/tests/model/test_model_api.py index 81e77ba50..7d3b8f1df 100644 --- a/pymc_experimental/tests/model/test_model_api.py +++ b/pymc_experimental/tests/model/test_model_api.py @@ -4,22 +4,19 @@ import pymc_experimental as pmx -def test_sample(): +def test_logp(): """Compare standard PyMC `with pm.Model()` context API against `pmx.model` decorator - and a functional syntax. Checks whether the kwarge `coords` can be passed. + and a functional syntax. Checks whether the kwarg `coords` can be passed. """ coords = {"obs": ["a", "b"]} - kwargs = {"draws": 50, "tune": 50, "chains": 1, "random_seed": 1} with pm.Model(coords=coords) as model: pm.Normal("x", 0.0, 1.0, dims="obs") - idata = pm.sample(**kwargs) @pmx.model(coords=coords) def model_wrapped(): pm.Normal("x", 0.0, 1.0, dims="obs") mw = model_wrapped() - idata_wrapped = pm.sample(model=mw, **kwargs) - np.testing.assert_array_equal(idata.posterior.x, idata_wrapped.posterior.x) + np.testing.assert_array_equal(model.point_logps(), mw.point_logps()) From 5c83a472f28b90f1f48b853ca551710d74014855 Mon Sep 17 00:00:00 2001 From: theorashid Date: Mon, 20 Nov 2023 17:21:47 +0000 Subject: [PATCH 06/10] fix doc link and imports --- pymc_experimental/model/model_api.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/pymc_experimental/model/model_api.py b/pymc_experimental/model/model_api.py index 0bfcd0a0b..7a4c6be18 100644 --- a/pymc_experimental/model/model_api.py +++ b/pymc_experimental/model/model_api.py @@ -8,12 +8,15 @@ def model(*model_args, **model_kwargs): Decorator to provide context to PyMC models declared in a function. This removes all need to think about context managers and lets you separate creating a generative model from using the model. - Inspired by the [sampled](https://github.com/colcarroll/sampled) decorator for PyMC3. + Inspired by the `sampled `_ decorator for PyMC3. Examples -------- .. code:: python + import pymc as pm + import pymc_experimental as pmx + # The following are equivalent # standard PyMC API with context manager From f4aff55774f7513e27de8eb5a16929c9a8efe1ec Mon Sep 17 00:00:00 2001 From: theorashid Date: Mon, 20 Nov 2023 17:44:13 +0000 Subject: [PATCH 07/10] rename `model` as `as_model` --- docs/api_reference.rst | 2 +- pymc_experimental/__init__.py | 2 +- pymc_experimental/model/model_api.py | 4 ++-- pymc_experimental/tests/model/test_model_api.py | 2 +- 4 files changed, 5 insertions(+), 5 deletions(-) diff --git a/docs/api_reference.rst b/docs/api_reference.rst index c7c7ce88b..753cbf9e4 100644 --- a/docs/api_reference.rst +++ b/docs/api_reference.rst @@ -8,7 +8,7 @@ methods in the current release of PyMC experimental. .. autosummary:: :toctree: generated/ - model + as_model MarginalModel model_builder.ModelBuilder diff --git a/pymc_experimental/__init__.py b/pymc_experimental/__init__.py index 1333b35d5..e319ef35f 100644 --- a/pymc_experimental/__init__.py +++ b/pymc_experimental/__init__.py @@ -26,4 +26,4 @@ from pymc_experimental import distributions, gp, utils from pymc_experimental.inference.fit import fit from pymc_experimental.model.marginal_model import MarginalModel -from pymc_experimental.model.model_api import model +from pymc_experimental.model.model_api import as_model diff --git a/pymc_experimental/model/model_api.py b/pymc_experimental/model/model_api.py index 7a4c6be18..92a3d22d5 100644 --- a/pymc_experimental/model/model_api.py +++ b/pymc_experimental/model/model_api.py @@ -3,7 +3,7 @@ import pymc as pm -def model(*model_args, **model_kwargs): +def as_model(*model_args, **model_kwargs): R""" Decorator to provide context to PyMC models declared in a function. This removes all need to think about context managers and lets you separate creating a generative model from using the model. @@ -25,7 +25,7 @@ def model(*model_args, **model_kwargs): pm.sample() # functional API using decorator - @pmx.model(coords={"obs": ["a", "b"]}) + @pmx.as_model(coords={"obs": ["a", "b"]}) def basic_model(): pm.Normal("x", 0., 1., dims="obs") diff --git a/pymc_experimental/tests/model/test_model_api.py b/pymc_experimental/tests/model/test_model_api.py index 7d3b8f1df..e54517150 100644 --- a/pymc_experimental/tests/model/test_model_api.py +++ b/pymc_experimental/tests/model/test_model_api.py @@ -13,7 +13,7 @@ def test_logp(): with pm.Model(coords=coords) as model: pm.Normal("x", 0.0, 1.0, dims="obs") - @pmx.model(coords=coords) + @pmx.as_model(coords=coords) def model_wrapped(): pm.Normal("x", 0.0, 1.0, dims="obs") From 84306dfc38b0c5fcc2e9dc774ecc950ce4236788 Mon Sep 17 00:00:00 2001 From: theorashid Date: Mon, 20 Nov 2023 21:17:29 +0000 Subject: [PATCH 08/10] us assert_equal instead of assert_array_equal --- pymc_experimental/tests/model/test_model_api.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pymc_experimental/tests/model/test_model_api.py b/pymc_experimental/tests/model/test_model_api.py index e54517150..a36c5e5b6 100644 --- a/pymc_experimental/tests/model/test_model_api.py +++ b/pymc_experimental/tests/model/test_model_api.py @@ -19,4 +19,4 @@ def model_wrapped(): mw = model_wrapped() - np.testing.assert_array_equal(model.point_logps(), mw.point_logps()) + np.testing.assert_equal(model.point_logps(), mw.point_logps()) From b5210f7f0b55ff382e64291f0b2ccebbae7f0a79 Mon Sep 17 00:00:00 2001 From: theorashid Date: Tue, 21 Nov 2023 13:41:32 +0000 Subject: [PATCH 09/10] update docs to link to zaxtax blog Co-authored-by: Rob Zinkov <8529+zaxtax@users.noreply.github.com> --- pymc_experimental/model/model_api.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pymc_experimental/model/model_api.py b/pymc_experimental/model/model_api.py index 92a3d22d5..cd68600c9 100644 --- a/pymc_experimental/model/model_api.py +++ b/pymc_experimental/model/model_api.py @@ -8,7 +8,7 @@ def as_model(*model_args, **model_kwargs): Decorator to provide context to PyMC models declared in a function. This removes all need to think about context managers and lets you separate creating a generative model from using the model. - Inspired by the `sampled `_ decorator for PyMC3. + Adapted from `Rob Zinkov's blog post `_ and inspired by the `sampled `_ decorator for PyMC3. Examples -------- From d0a5ae0d9c39d94fe82d54fcb1aa477770c31926 Mon Sep 17 00:00:00 2001 From: theorashid Date: Wed, 22 Nov 2023 16:09:36 +0000 Subject: [PATCH 10/10] change import to from pymc import Model --- pymc_experimental/model/model_api.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pymc_experimental/model/model_api.py b/pymc_experimental/model/model_api.py index cd68600c9..8425a764f 100644 --- a/pymc_experimental/model/model_api.py +++ b/pymc_experimental/model/model_api.py @@ -1,6 +1,6 @@ from functools import wraps -import pymc as pm +from pymc import Model def as_model(*model_args, **model_kwargs): @@ -37,7 +37,7 @@ def basic_model(): def decorator(f): @wraps(f) def make_model(*args, **kwargs): - with pm.Model(*model_args, **model_kwargs) as m: + with Model(*model_args, **model_kwargs) as m: f(*args, **kwargs) return m