From 299e5562966bd82a1b7f837efc900fd8ca791ccd Mon Sep 17 00:00:00 2001 From: Chris Fonnesbeck Date: Wed, 8 Mar 2023 08:45:19 -0600 Subject: [PATCH 01/14] Reinstate nuts_kwargs in sample for passing arguments to numpyro --- pymc/sampling/mcmc.py | 27 +++++++++++++++++---------- 1 file changed, 17 insertions(+), 10 deletions(-) diff --git a/pymc/sampling/mcmc.py b/pymc/sampling/mcmc.py index 0449b029f8..0888c37d37 100644 --- a/pymc/sampling/mcmc.py +++ b/pymc/sampling/mcmc.py @@ -237,6 +237,7 @@ def _sample_external_nuts( model: Model, progressbar: bool, idata_kwargs: Optional[Dict], + nuts_kwargs: Optional[Dict], **kwargs, ): warnings.warn("Use of external NUTS sampler is still experimental", UserWarning) @@ -287,7 +288,12 @@ def _sample_external_nuts( initvals=initvals, model=model, progressbar=progressbar, + keep_untransformed=nuts_kwargs.get("keep_untransformed", False), + chain_method=nuts_kwargs.get("chain_method", "parallel"), + postprocessing_backend=nuts_kwargs.get("postprocessing_backend"), + postprocessing_chunks=nuts_kwargs.get("postprocessing_chunks"), idata_kwargs=idata_kwargs, + nuts_kwargs=nuts_kwargs, **kwargs, ) return idata @@ -304,7 +310,6 @@ def _sample_external_nuts( initvals=initvals, model=model, idata_kwargs=idata_kwargs, - **kwargs, ) return idata @@ -334,6 +339,7 @@ def sample( keep_warning_stat: bool = False, return_inferencedata: bool = True, idata_kwargs: Optional[Dict[str, Any]] = None, + nuts_kwargs: Optional[Dict[str, Any]] = None, callback=None, mp_ctx=None, model: Optional[Model] = None, @@ -410,6 +416,8 @@ def sample( `MultiTrace` (False). Defaults to `True`. idata_kwargs : dict, optional Keyword arguments for :func:`pymc.to_inference_data` + nuts_kwargs : dict, optional + Keyword arguments for the NUTS sampler. callback : function, default=None A function which gets called for every sample from the trace of a chain. The function is called with the trace and the current draw and will contain all samples for a single trace. @@ -493,15 +501,14 @@ def sample( stacklevel=2, ) initvals = kwargs.pop("start") + if nuts_kwargs is None: + nuts_kwargs = {} if "target_accept" in kwargs: - if "nuts" in kwargs and "target_accept" in kwargs["nuts"]: + if nuts_kwargs is not None and "target_accept" in nuts_kwargs: raise ValueError( - "`target_accept` was defined twice. Please specify it either as a direct keyword argument or in the `nuts` kwarg." + "`target_accept` was defined twice. Please specify it either as a direct keyword argument or in the `nuts_kwargs` dict." ) - if "nuts" in kwargs: - kwargs["nuts"]["target_accept"] = kwargs.pop("target_accept") - else: - kwargs["nuts"] = {"target_accept": kwargs.pop("target_accept")} + nuts_kwargs["target_accept"] = kwargs.pop("target_accept") if isinstance(trace, list): raise DeprecationWarning( "We have removed support for partial traces because it simplified things." @@ -563,20 +570,20 @@ def sample( draws=draws, tune=tune, chains=chains, - target_accept=kwargs.pop("nuts", {}).get("target_accept", 0.8), + target_accept=nuts_kwargs.get("target_accept", 0.8), random_seed=random_seed, initvals=initvals, model=model, progressbar=progressbar, idata_kwargs=idata_kwargs, + nuts_kwargs=nuts_kwargs, **kwargs, ) if isinstance(step, list): step = CompoundStep(step) elif isinstance(step, NUTS) and auto_nuts_init: - if "nuts" in kwargs: - nuts_kwargs = kwargs.pop("nuts") + if nuts_kwargs is not None: [kwargs.setdefault(k, v) for k, v in nuts_kwargs.items()] _log.info("Auto-assigning NUTS sampler...") initial_points, step = init_nuts( From b9befb67841e43354884820d1cc52a7cbfde78f2 Mon Sep 17 00:00:00 2001 From: Chris Fonnesbeck Date: Wed, 8 Mar 2023 08:48:52 -0600 Subject: [PATCH 02/14] Removed redundant condition from if statement --- pymc/sampling/mcmc.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pymc/sampling/mcmc.py b/pymc/sampling/mcmc.py index 0888c37d37..8e6ad7011d 100644 --- a/pymc/sampling/mcmc.py +++ b/pymc/sampling/mcmc.py @@ -504,7 +504,7 @@ def sample( if nuts_kwargs is None: nuts_kwargs = {} if "target_accept" in kwargs: - if nuts_kwargs is not None and "target_accept" in nuts_kwargs: + if "target_accept" in nuts_kwargs: raise ValueError( "`target_accept` was defined twice. Please specify it either as a direct keyword argument or in the `nuts_kwargs` dict." ) From 2fbaab23972875a609b2b63aac11fc384f10d8ad Mon Sep 17 00:00:00 2001 From: Chris Fonnesbeck Date: Wed, 8 Mar 2023 09:01:59 -0600 Subject: [PATCH 03/14] Added sampler_kwargs --- pymc/sampling/mcmc.py | 19 ++++++++++++++----- 1 file changed, 14 insertions(+), 5 deletions(-) diff --git a/pymc/sampling/mcmc.py b/pymc/sampling/mcmc.py index 8e6ad7011d..9d3baf0b1b 100644 --- a/pymc/sampling/mcmc.py +++ b/pymc/sampling/mcmc.py @@ -238,6 +238,7 @@ def _sample_external_nuts( progressbar: bool, idata_kwargs: Optional[Dict], nuts_kwargs: Optional[Dict], + sampler_kwargs: Optional[Dict], **kwargs, ): warnings.warn("Use of external NUTS sampler is still experimental", UserWarning) @@ -288,10 +289,10 @@ def _sample_external_nuts( initvals=initvals, model=model, progressbar=progressbar, - keep_untransformed=nuts_kwargs.get("keep_untransformed", False), - chain_method=nuts_kwargs.get("chain_method", "parallel"), - postprocessing_backend=nuts_kwargs.get("postprocessing_backend"), - postprocessing_chunks=nuts_kwargs.get("postprocessing_chunks"), + keep_untransformed=sampler_kwargs.get("keep_untransformed", False), + chain_method=sampler_kwargs.get("chain_method", "parallel"), + postprocessing_backend=sampler_kwargs.get("postprocessing_backend"), + postprocessing_chunks=sampler_kwargs.get("postprocessing_chunks"), idata_kwargs=idata_kwargs, nuts_kwargs=nuts_kwargs, **kwargs, @@ -309,6 +310,10 @@ def _sample_external_nuts( random_seed=random_seed, initvals=initvals, model=model, + keep_untransformed=sampler_kwargs.get("keep_untransformed", False), + chain_method=sampler_kwargs.get("chain_method", "parallel"), + postprocessing_backend=sampler_kwargs.get("postprocessing_backend"), + postprocessing_chunks=sampler_kwargs.get("postprocessing_chunks"), idata_kwargs=idata_kwargs, ) return idata @@ -340,6 +345,7 @@ def sample( return_inferencedata: bool = True, idata_kwargs: Optional[Dict[str, Any]] = None, nuts_kwargs: Optional[Dict[str, Any]] = None, + sampler_kwargs: Optional[Dict[str, Any]] = None, callback=None, mp_ctx=None, model: Optional[Model] = None, @@ -418,6 +424,8 @@ def sample( Keyword arguments for :func:`pymc.to_inference_data` nuts_kwargs : dict, optional Keyword arguments for the NUTS sampler. + sampler_kwargs : dict, optional + Keyword arguments for the sampler. callback : function, default=None A function which gets called for every sample from the trace of a chain. The function is called with the trace and the current draw and will contain all samples for a single trace. @@ -442,7 +450,7 @@ def sample( For example: - 1. ``target_accept`` to NUTS: nuts={'target_accept':0.9} + 1. ``target_accept`` to NUTS: nuts_kwargs={'target_accept':0.9} 2. ``transit_p`` to BinaryGibbsMetropolis: binary_gibbs_metropolis={'transit_p':.7} Note that available step names are: @@ -577,6 +585,7 @@ def sample( progressbar=progressbar, idata_kwargs=idata_kwargs, nuts_kwargs=nuts_kwargs, + sampler_kwargs=sampler_kwargs, **kwargs, ) From ded6cbf2b1b78589b0b8473fda2bb3438aea379d Mon Sep 17 00:00:00 2001 From: Chris Fonnesbeck Date: Wed, 8 Mar 2023 09:51:38 -0600 Subject: [PATCH 04/14] Rename sampler_kwargs; fix test failures --- pymc/sampling/mcmc.py | 29 +++++++++++++++++------------ tests/sampling/test_mcmc.py | 8 ++++---- 2 files changed, 21 insertions(+), 16 deletions(-) diff --git a/pymc/sampling/mcmc.py b/pymc/sampling/mcmc.py index 9d3baf0b1b..c9c57a46d4 100644 --- a/pymc/sampling/mcmc.py +++ b/pymc/sampling/mcmc.py @@ -238,11 +238,14 @@ def _sample_external_nuts( progressbar: bool, idata_kwargs: Optional[Dict], nuts_kwargs: Optional[Dict], - sampler_kwargs: Optional[Dict], + nuts_sampler_kwargs: Optional[Dict], **kwargs, ): warnings.warn("Use of external NUTS sampler is still experimental", UserWarning) + if nuts_sampler_kwargs is None: + nuts_sampler_kwargs = {} + if sampler == "nutpie": try: import nutpie @@ -289,10 +292,10 @@ def _sample_external_nuts( initvals=initvals, model=model, progressbar=progressbar, - keep_untransformed=sampler_kwargs.get("keep_untransformed", False), - chain_method=sampler_kwargs.get("chain_method", "parallel"), - postprocessing_backend=sampler_kwargs.get("postprocessing_backend"), - postprocessing_chunks=sampler_kwargs.get("postprocessing_chunks"), + keep_untransformed=nuts_sampler_kwargs.get("keep_untransformed", False), + chain_method=nuts_sampler_kwargs.get("chain_method", "parallel"), + postprocessing_backend=nuts_sampler_kwargs.get("postprocessing_backend"), + postprocessing_chunks=nuts_sampler_kwargs.get("postprocessing_chunks"), idata_kwargs=idata_kwargs, nuts_kwargs=nuts_kwargs, **kwargs, @@ -310,10 +313,10 @@ def _sample_external_nuts( random_seed=random_seed, initvals=initvals, model=model, - keep_untransformed=sampler_kwargs.get("keep_untransformed", False), - chain_method=sampler_kwargs.get("chain_method", "parallel"), - postprocessing_backend=sampler_kwargs.get("postprocessing_backend"), - postprocessing_chunks=sampler_kwargs.get("postprocessing_chunks"), + keep_untransformed=nuts_sampler_kwargs.get("keep_untransformed", False), + chain_method=nuts_sampler_kwargs.get("chain_method", "parallel"), + postprocessing_backend=nuts_sampler_kwargs.get("postprocessing_backend"), + postprocessing_chunks=nuts_sampler_kwargs.get("postprocessing_chunks"), idata_kwargs=idata_kwargs, ) return idata @@ -345,7 +348,7 @@ def sample( return_inferencedata: bool = True, idata_kwargs: Optional[Dict[str, Any]] = None, nuts_kwargs: Optional[Dict[str, Any]] = None, - sampler_kwargs: Optional[Dict[str, Any]] = None, + nuts_sampler_kwargs: Optional[Dict[str, Any]] = None, callback=None, mp_ctx=None, model: Optional[Model] = None, @@ -424,7 +427,7 @@ def sample( Keyword arguments for :func:`pymc.to_inference_data` nuts_kwargs : dict, optional Keyword arguments for the NUTS sampler. - sampler_kwargs : dict, optional + nuts_sampler_kwargs : dict, optional Keyword arguments for the sampler. callback : function, default=None A function which gets called for every sample from the trace of a chain. The function is @@ -509,6 +512,8 @@ def sample( stacklevel=2, ) initvals = kwargs.pop("start") + if nuts_sampler_kwargs is None: + nuts_sampler_kwargs = {} if nuts_kwargs is None: nuts_kwargs = {} if "target_accept" in kwargs: @@ -585,7 +590,7 @@ def sample( progressbar=progressbar, idata_kwargs=idata_kwargs, nuts_kwargs=nuts_kwargs, - sampler_kwargs=sampler_kwargs, + nuts_sampler_kwargs=nuts_sampler_kwargs, **kwargs, ) diff --git a/tests/sampling/test_mcmc.py b/tests/sampling/test_mcmc.py index dbdc0d2a40..31f3338617 100644 --- a/tests/sampling/test_mcmc.py +++ b/tests/sampling/test_mcmc.py @@ -644,11 +644,11 @@ def test_step_args(): with pm.Model() as model: a = pm.Normal("a") idata0 = pm.sample(target_accept=0.5, random_seed=1410) - idata1 = pm.sample(nuts={"target_accept": 0.5}, random_seed=1410 * 2) - idata2 = pm.sample(target_accept=0.5, nuts={"max_treedepth": 10}, random_seed=1410) + idata1 = pm.sample(nuts_kwargs={"target_accept": 0.5}, random_seed=1410 * 2) + idata2 = pm.sample(target_accept=0.5, nuts_kwargs={"max_treedepth": 10}, random_seed=1410) with pytest.raises(ValueError, match="`target_accept` was defined twice."): - pm.sample(target_accept=0.5, nuts={"target_accept": 0.95}, random_seed=1410) + pm.sample(target_accept=0.5, nuts_kwargs={"target_accept": 0.95}, random_seed=1410) npt.assert_almost_equal(idata0.sample_stats.acceptance_rate.mean(), 0.5, decimal=1) npt.assert_almost_equal(idata1.sample_stats.acceptance_rate.mean(), 0.5, decimal=1) @@ -663,7 +663,7 @@ def test_step_args(): "ignore", "invalid value encountered in double_scalars", RuntimeWarning ) idata1 = pm.sample( - nuts={"target_accept": 0.5}, metropolis={"scaling": 0}, random_seed=1418 * 2 + nuts_kwargs={"target_accept": 0.5}, metropolis={"scaling": 0}, random_seed=1418 * 2 ) npt.assert_almost_equal(idata0.sample_stats.acceptance_rate.mean(), 0.5, decimal=1) From c99f18b4fafa74eafccd9f5cd87f546720872e08 Mon Sep 17 00:00:00 2001 From: Chris Fonnesbeck Date: Wed, 8 Mar 2023 11:05:48 -0600 Subject: [PATCH 05/14] Ensure nuts_kwargs get passed to pymc nuts --- pymc/sampling/mcmc.py | 4 ++-- tests/sampling/test_mcmc.py | 6 ++++++ 2 files changed, 8 insertions(+), 2 deletions(-) diff --git a/pymc/sampling/mcmc.py b/pymc/sampling/mcmc.py index c9c57a46d4..50d55a0f15 100644 --- a/pymc/sampling/mcmc.py +++ b/pymc/sampling/mcmc.py @@ -597,8 +597,8 @@ def sample( if isinstance(step, list): step = CompoundStep(step) elif isinstance(step, NUTS) and auto_nuts_init: - if nuts_kwargs is not None: - [kwargs.setdefault(k, v) for k, v in nuts_kwargs.items()] + for k, v in nuts_kwargs.items(): + kwargs.setdefault(k, v) _log.info("Auto-assigning NUTS sampler...") initial_points, step = init_nuts( init=init, diff --git a/tests/sampling/test_mcmc.py b/tests/sampling/test_mcmc.py index 31f3338617..ea220ec7cb 100644 --- a/tests/sampling/test_mcmc.py +++ b/tests/sampling/test_mcmc.py @@ -665,6 +665,12 @@ def test_step_args(): idata1 = pm.sample( nuts_kwargs={"target_accept": 0.5}, metropolis={"scaling": 0}, random_seed=1418 * 2 ) + idata2 = pm.sample( + nuts_sampler="numpyro", + nuts_kwargs={"target_accept": 0.5}, + metropolis={"scaling": 0}, + random_seed=1418 * 2, + ) npt.assert_almost_equal(idata0.sample_stats.acceptance_rate.mean(), 0.5, decimal=1) npt.assert_almost_equal(idata1.sample_stats.acceptance_rate.mean(), 0.5, decimal=1) From a539290ac2984f2f0ba33f14ed9e03932f7c1d70 Mon Sep 17 00:00:00 2001 From: Chris Fonnesbeck Date: Wed, 8 Mar 2023 11:48:33 -0600 Subject: [PATCH 06/14] Fix test_step_args failure --- tests/sampling/test_mcmc.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/tests/sampling/test_mcmc.py b/tests/sampling/test_mcmc.py index ea220ec7cb..06ee550fa6 100644 --- a/tests/sampling/test_mcmc.py +++ b/tests/sampling/test_mcmc.py @@ -646,6 +646,12 @@ def test_step_args(): idata0 = pm.sample(target_accept=0.5, random_seed=1410) idata1 = pm.sample(nuts_kwargs={"target_accept": 0.5}, random_seed=1410 * 2) idata2 = pm.sample(target_accept=0.5, nuts_kwargs={"max_treedepth": 10}, random_seed=1410) + idata3 = pm.sample( + nuts_sampler="numpyro", + target_accept=0.5, + nuts_kwargs={"max_treedepth": 10}, + random_seed=1410, + ) with pytest.raises(ValueError, match="`target_accept` was defined twice."): pm.sample(target_accept=0.5, nuts_kwargs={"target_accept": 0.95}, random_seed=1410) @@ -653,6 +659,7 @@ def test_step_args(): npt.assert_almost_equal(idata0.sample_stats.acceptance_rate.mean(), 0.5, decimal=1) npt.assert_almost_equal(idata1.sample_stats.acceptance_rate.mean(), 0.5, decimal=1) npt.assert_almost_equal(idata2.sample_stats.acceptance_rate.mean(), 0.5, decimal=1) + npt.assert_almost_equal(idata3.sample_stats.acceptance_rate.mean(), 0.5, decimal=1) with pm.Model() as model: a = pm.Normal("a") @@ -665,12 +672,6 @@ def test_step_args(): idata1 = pm.sample( nuts_kwargs={"target_accept": 0.5}, metropolis={"scaling": 0}, random_seed=1418 * 2 ) - idata2 = pm.sample( - nuts_sampler="numpyro", - nuts_kwargs={"target_accept": 0.5}, - metropolis={"scaling": 0}, - random_seed=1418 * 2, - ) npt.assert_almost_equal(idata0.sample_stats.acceptance_rate.mean(), 0.5, decimal=1) npt.assert_almost_equal(idata1.sample_stats.acceptance_rate.mean(), 0.5, decimal=1) From 2dc3537ac87a2070c569a7512a51398acc0f4e17 Mon Sep 17 00:00:00 2001 From: Chris Fonnesbeck Date: Thu, 9 Mar 2023 12:21:15 -0600 Subject: [PATCH 07/14] Moved numpyro test; use nuts_sampler_kwargs --- pymc/sampling/jax.py | 2 ++ pymc/sampling/mcmc.py | 10 ++-------- tests/sampling/test_mcmc.py | 7 ------- tests/sampling/test_mcmc_external.py | 14 ++++++++++++++ 4 files changed, 18 insertions(+), 15 deletions(-) diff --git a/pymc/sampling/jax.py b/pymc/sampling/jax.py index 0cdc8afd0c..10cacb9c2e 100644 --- a/pymc/sampling/jax.py +++ b/pymc/sampling/jax.py @@ -317,6 +317,7 @@ def sample_blackjax_nuts( postprocessing_backend: Optional[str] = None, postprocessing_chunks: Optional[int] = None, idata_kwargs: Optional[Dict[str, Any]] = None, + **kwargs, ) -> az.InferenceData: """ Draw samples from the posterior using the NUTS method from the ``blackjax`` library. @@ -529,6 +530,7 @@ def sample_numpyro_nuts( postprocessing_chunks: Optional[int] = None, idata_kwargs: Optional[Dict] = None, nuts_kwargs: Optional[Dict] = None, + **kwargs, ) -> az.InferenceData: """ Draw samples from the posterior using the NUTS method from the ``numpyro`` library. diff --git a/pymc/sampling/mcmc.py b/pymc/sampling/mcmc.py index 50d55a0f15..236dde0b87 100644 --- a/pymc/sampling/mcmc.py +++ b/pymc/sampling/mcmc.py @@ -292,13 +292,10 @@ def _sample_external_nuts( initvals=initvals, model=model, progressbar=progressbar, - keep_untransformed=nuts_sampler_kwargs.get("keep_untransformed", False), - chain_method=nuts_sampler_kwargs.get("chain_method", "parallel"), - postprocessing_backend=nuts_sampler_kwargs.get("postprocessing_backend"), - postprocessing_chunks=nuts_sampler_kwargs.get("postprocessing_chunks"), idata_kwargs=idata_kwargs, nuts_kwargs=nuts_kwargs, **kwargs, + **nuts_sampler_kwargs, ) return idata @@ -313,11 +310,8 @@ def _sample_external_nuts( random_seed=random_seed, initvals=initvals, model=model, - keep_untransformed=nuts_sampler_kwargs.get("keep_untransformed", False), - chain_method=nuts_sampler_kwargs.get("chain_method", "parallel"), - postprocessing_backend=nuts_sampler_kwargs.get("postprocessing_backend"), - postprocessing_chunks=nuts_sampler_kwargs.get("postprocessing_chunks"), idata_kwargs=idata_kwargs, + **nuts_sampler_kwargs, ) return idata diff --git a/tests/sampling/test_mcmc.py b/tests/sampling/test_mcmc.py index 06ee550fa6..31f3338617 100644 --- a/tests/sampling/test_mcmc.py +++ b/tests/sampling/test_mcmc.py @@ -646,12 +646,6 @@ def test_step_args(): idata0 = pm.sample(target_accept=0.5, random_seed=1410) idata1 = pm.sample(nuts_kwargs={"target_accept": 0.5}, random_seed=1410 * 2) idata2 = pm.sample(target_accept=0.5, nuts_kwargs={"max_treedepth": 10}, random_seed=1410) - idata3 = pm.sample( - nuts_sampler="numpyro", - target_accept=0.5, - nuts_kwargs={"max_treedepth": 10}, - random_seed=1410, - ) with pytest.raises(ValueError, match="`target_accept` was defined twice."): pm.sample(target_accept=0.5, nuts_kwargs={"target_accept": 0.95}, random_seed=1410) @@ -659,7 +653,6 @@ def test_step_args(): npt.assert_almost_equal(idata0.sample_stats.acceptance_rate.mean(), 0.5, decimal=1) npt.assert_almost_equal(idata1.sample_stats.acceptance_rate.mean(), 0.5, decimal=1) npt.assert_almost_equal(idata2.sample_stats.acceptance_rate.mean(), 0.5, decimal=1) - npt.assert_almost_equal(idata3.sample_stats.acceptance_rate.mean(), 0.5, decimal=1) with pm.Model() as model: a = pm.Normal("a") diff --git a/tests/sampling/test_mcmc_external.py b/tests/sampling/test_mcmc_external.py index 2439738955..0d0037f167 100644 --- a/tests/sampling/test_mcmc_external.py +++ b/tests/sampling/test_mcmc_external.py @@ -13,6 +13,7 @@ # limitations under the License. import numpy as np +import numpy.testing as npt import pytest from pymc import Model, Normal, sample @@ -63,3 +64,16 @@ def test_external_nuts_sampler(recwarn, nuts_sampler): assert idata1.posterior.chain.size == 2 assert idata1.posterior.draw.size == 500 np.testing.assert_array_equal(idata1.posterior.x, idata2.posterior.x) + + +def test_step_args(): + with Model() as model: + a = Normal("a") + idata = sample( + nuts_sampler="numpyro", + target_accept=0.5, + nuts_kwargs={"max_treedepth": 10}, + random_seed=1410, + ) + + npt.assert_almost_equal(idata.sample_stats.acceptance_rate.mean(), 0.5, decimal=1) From 3098120cd1d918f1a0ce903442c4eea3e33276bd Mon Sep 17 00:00:00 2001 From: Chris Fonnesbeck Date: Thu, 9 Mar 2023 12:58:51 -0600 Subject: [PATCH 08/14] Test failure fix --- pymc/sampling/mcmc.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/pymc/sampling/mcmc.py b/pymc/sampling/mcmc.py index 236dde0b87..3a10211c79 100644 --- a/pymc/sampling/mcmc.py +++ b/pymc/sampling/mcmc.py @@ -577,7 +577,6 @@ def sample( draws=draws, tune=tune, chains=chains, - target_accept=nuts_kwargs.get("target_accept", 0.8), random_seed=random_seed, initvals=initvals, model=model, @@ -586,13 +585,14 @@ def sample( nuts_kwargs=nuts_kwargs, nuts_sampler_kwargs=nuts_sampler_kwargs, **kwargs, + **nuts_kwargs, ) if isinstance(step, list): step = CompoundStep(step) elif isinstance(step, NUTS) and auto_nuts_init: - for k, v in nuts_kwargs.items(): - kwargs.setdefault(k, v) + # for k, v in nuts_kwargs.items(): + # kwargs.setdefault(k, v) _log.info("Auto-assigning NUTS sampler...") initial_points, step = init_nuts( init=init, @@ -604,7 +604,8 @@ def sample( jitter_max_retries=jitter_max_retries, tune=tune, initvals=initvals, - **kwargs, + # **kwargs, + **nuts_kwargs, ) if initial_points is None: From 4964c3b86366c403fcd1b5d45b80d18abd494838 Mon Sep 17 00:00:00 2001 From: Chris Fonnesbeck Date: Fri, 10 Mar 2023 22:08:32 -0500 Subject: [PATCH 09/14] Fix docstrings for nuts/sampler kwargs; debugging target accept argument passing --- pymc/sampling/mcmc.py | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/pymc/sampling/mcmc.py b/pymc/sampling/mcmc.py index 3a10211c79..f96750d788 100644 --- a/pymc/sampling/mcmc.py +++ b/pymc/sampling/mcmc.py @@ -420,9 +420,9 @@ def sample( idata_kwargs : dict, optional Keyword arguments for :func:`pymc.to_inference_data` nuts_kwargs : dict, optional - Keyword arguments for the NUTS sampler. + Keyword arguments for the NUTS step method. nuts_sampler_kwargs : dict, optional - Keyword arguments for the sampler. + Keyword arguments for the sampling library that implements nuts. callback : function, default=None A function which gets called for every sample from the trace of a chain. The function is called with the trace and the current draw and will contain all samples for a single trace. @@ -591,8 +591,8 @@ def sample( if isinstance(step, list): step = CompoundStep(step) elif isinstance(step, NUTS) and auto_nuts_init: - # for k, v in nuts_kwargs.items(): - # kwargs.setdefault(k, v) + for k, v in nuts_kwargs.items(): + kwargs.setdefault(k, v) _log.info("Auto-assigning NUTS sampler...") initial_points, step = init_nuts( init=init, @@ -604,8 +604,7 @@ def sample( jitter_max_retries=jitter_max_retries, tune=tune, initvals=initvals, - # **kwargs, - **nuts_kwargs, + **kwargs, ) if initial_points is None: From 0f65dab5efd878acdd6e64fe6d7834760b10cf9f Mon Sep 17 00:00:00 2001 From: Chris Fonnesbeck Date: Fri, 10 Mar 2023 22:17:13 -0500 Subject: [PATCH 10/14] Removed passing kwargs to external samplers; only nuts_sampler_kwargs for clarity --- pymc/sampling/mcmc.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/pymc/sampling/mcmc.py b/pymc/sampling/mcmc.py index f96750d788..611d562a91 100644 --- a/pymc/sampling/mcmc.py +++ b/pymc/sampling/mcmc.py @@ -276,7 +276,7 @@ def _sample_external_nuts( target_accept=target_accept, seed=_get_seeds_per_chain(random_seed, 1)[0], progress_bar=progressbar, - **kwargs, + **nuts_sampler_kwargs, ) return idata @@ -294,7 +294,6 @@ def _sample_external_nuts( progressbar=progressbar, idata_kwargs=idata_kwargs, nuts_kwargs=nuts_kwargs, - **kwargs, **nuts_sampler_kwargs, ) return idata From bc8e5641087382c9768378d615bc7604dda83e43 Mon Sep 17 00:00:00 2001 From: Chris Fonnesbeck Date: Sat, 11 Mar 2023 20:40:31 -0500 Subject: [PATCH 11/14] Revert nuts_kwargs to nuts --- pymc/sampling/mcmc.py | 25 ++++++++++--------------- tests/sampling/test_mcmc.py | 8 ++++---- 2 files changed, 14 insertions(+), 19 deletions(-) diff --git a/pymc/sampling/mcmc.py b/pymc/sampling/mcmc.py index 611d562a91..dd29b5c0b8 100644 --- a/pymc/sampling/mcmc.py +++ b/pymc/sampling/mcmc.py @@ -237,7 +237,6 @@ def _sample_external_nuts( model: Model, progressbar: bool, idata_kwargs: Optional[Dict], - nuts_kwargs: Optional[Dict], nuts_sampler_kwargs: Optional[Dict], **kwargs, ): @@ -293,7 +292,6 @@ def _sample_external_nuts( model=model, progressbar=progressbar, idata_kwargs=idata_kwargs, - nuts_kwargs=nuts_kwargs, **nuts_sampler_kwargs, ) return idata @@ -340,7 +338,6 @@ def sample( keep_warning_stat: bool = False, return_inferencedata: bool = True, idata_kwargs: Optional[Dict[str, Any]] = None, - nuts_kwargs: Optional[Dict[str, Any]] = None, nuts_sampler_kwargs: Optional[Dict[str, Any]] = None, callback=None, mp_ctx=None, @@ -418,8 +415,6 @@ def sample( `MultiTrace` (False). Defaults to `True`. idata_kwargs : dict, optional Keyword arguments for :func:`pymc.to_inference_data` - nuts_kwargs : dict, optional - Keyword arguments for the NUTS step method. nuts_sampler_kwargs : dict, optional Keyword arguments for the sampling library that implements nuts. callback : function, default=None @@ -446,7 +441,7 @@ def sample( For example: - 1. ``target_accept`` to NUTS: nuts_kwargs={'target_accept':0.9} + 1. ``target_accept`` to NUTS: nuts={'target_accept':0.9} 2. ``transit_p`` to BinaryGibbsMetropolis: binary_gibbs_metropolis={'transit_p':.7} Note that available step names are: @@ -507,14 +502,15 @@ def sample( initvals = kwargs.pop("start") if nuts_sampler_kwargs is None: nuts_sampler_kwargs = {} - if nuts_kwargs is None: - nuts_kwargs = {} if "target_accept" in kwargs: - if "target_accept" in nuts_kwargs: + if "nuts" in kwargs and "target_accept" in kwargs["nuts"]: raise ValueError( - "`target_accept` was defined twice. Please specify it either as a direct keyword argument or in the `nuts_kwargs` dict." + "`target_accept` was defined twice. Please specify it either as a direct keyword argument or in the `nuts` kwarg." ) - nuts_kwargs["target_accept"] = kwargs.pop("target_accept") + if "nuts" in kwargs: + kwargs["nuts"]["target_accept"] = kwargs.pop("target_accept") + else: + kwargs["nuts"] = {"target_accept": kwargs.pop("target_accept")} if isinstance(trace, list): raise DeprecationWarning( "We have removed support for partial traces because it simplified things." @@ -581,17 +577,16 @@ def sample( model=model, progressbar=progressbar, idata_kwargs=idata_kwargs, - nuts_kwargs=nuts_kwargs, nuts_sampler_kwargs=nuts_sampler_kwargs, **kwargs, - **nuts_kwargs, ) if isinstance(step, list): step = CompoundStep(step) elif isinstance(step, NUTS) and auto_nuts_init: - for k, v in nuts_kwargs.items(): - kwargs.setdefault(k, v) + if "nuts" in kwargs: + nuts_kwargs = kwargs.pop("nuts") + [kwargs.setdefault(k, v) for k, v in nuts_kwargs.items()] _log.info("Auto-assigning NUTS sampler...") initial_points, step = init_nuts( init=init, diff --git a/tests/sampling/test_mcmc.py b/tests/sampling/test_mcmc.py index 31f3338617..dbdc0d2a40 100644 --- a/tests/sampling/test_mcmc.py +++ b/tests/sampling/test_mcmc.py @@ -644,11 +644,11 @@ def test_step_args(): with pm.Model() as model: a = pm.Normal("a") idata0 = pm.sample(target_accept=0.5, random_seed=1410) - idata1 = pm.sample(nuts_kwargs={"target_accept": 0.5}, random_seed=1410 * 2) - idata2 = pm.sample(target_accept=0.5, nuts_kwargs={"max_treedepth": 10}, random_seed=1410) + idata1 = pm.sample(nuts={"target_accept": 0.5}, random_seed=1410 * 2) + idata2 = pm.sample(target_accept=0.5, nuts={"max_treedepth": 10}, random_seed=1410) with pytest.raises(ValueError, match="`target_accept` was defined twice."): - pm.sample(target_accept=0.5, nuts_kwargs={"target_accept": 0.95}, random_seed=1410) + pm.sample(target_accept=0.5, nuts={"target_accept": 0.95}, random_seed=1410) npt.assert_almost_equal(idata0.sample_stats.acceptance_rate.mean(), 0.5, decimal=1) npt.assert_almost_equal(idata1.sample_stats.acceptance_rate.mean(), 0.5, decimal=1) @@ -663,7 +663,7 @@ def test_step_args(): "ignore", "invalid value encountered in double_scalars", RuntimeWarning ) idata1 = pm.sample( - nuts_kwargs={"target_accept": 0.5}, metropolis={"scaling": 0}, random_seed=1418 * 2 + nuts={"target_accept": 0.5}, metropolis={"scaling": 0}, random_seed=1418 * 2 ) npt.assert_almost_equal(idata0.sample_stats.acceptance_rate.mean(), 0.5, decimal=1) From 64ac8559aca33d86b5984a949dd8482d58461126 Mon Sep 17 00:00:00 2001 From: Chris Fonnesbeck Date: Sat, 11 Mar 2023 21:55:24 -0500 Subject: [PATCH 12/14] Fix failures in test_mcmc_external --- pymc/sampling/mcmc.py | 5 +++++ tests/sampling/test_mcmc_external.py | 5 +---- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/pymc/sampling/mcmc.py b/pymc/sampling/mcmc.py index dd29b5c0b8..d485924d79 100644 --- a/pymc/sampling/mcmc.py +++ b/pymc/sampling/mcmc.py @@ -567,11 +567,16 @@ def sample( raise ValueError( "Model can not be sampled with NUTS alone. Your model is probably not continuous." ) + if "nuts" in kwargs: + target_accept = kwargs["nuts"].get("target_accept", 0.8) + else: + target_accept = 0.8 return _sample_external_nuts( sampler=nuts_sampler, draws=draws, tune=tune, chains=chains, + target_accept=target_accept, random_seed=random_seed, initvals=initvals, model=model, diff --git a/tests/sampling/test_mcmc_external.py b/tests/sampling/test_mcmc_external.py index 0d0037f167..3c86154c14 100644 --- a/tests/sampling/test_mcmc_external.py +++ b/tests/sampling/test_mcmc_external.py @@ -18,9 +18,6 @@ from pymc import Model, Normal, sample -# turns all warnings into errors for this module -pytestmark = pytest.mark.filterwarnings("error") - @pytest.mark.parametrize("nuts_sampler", ["pymc", "nutpie", "blackjax", "numpyro"]) def test_external_nuts_sampler(recwarn, nuts_sampler): @@ -72,7 +69,7 @@ def test_step_args(): idata = sample( nuts_sampler="numpyro", target_accept=0.5, - nuts_kwargs={"max_treedepth": 10}, + nuts={"max_treedepth": 10}, random_seed=1410, ) From 29f41ceb110465fe903a483f8b3ec8cd7ef7f133 Mon Sep 17 00:00:00 2001 From: Chris Fonnesbeck Date: Sat, 11 Mar 2023 22:08:01 -0500 Subject: [PATCH 13/14] Revert argument parameterization --- pymc/sampling/mcmc.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/pymc/sampling/mcmc.py b/pymc/sampling/mcmc.py index d485924d79..10e3f62f31 100644 --- a/pymc/sampling/mcmc.py +++ b/pymc/sampling/mcmc.py @@ -567,16 +567,12 @@ def sample( raise ValueError( "Model can not be sampled with NUTS alone. Your model is probably not continuous." ) - if "nuts" in kwargs: - target_accept = kwargs["nuts"].get("target_accept", 0.8) - else: - target_accept = 0.8 return _sample_external_nuts( sampler=nuts_sampler, draws=draws, tune=tune, chains=chains, - target_accept=target_accept, + target_accept=kwargs.pop("nuts", {}).get("target_accept", 0.8), random_seed=random_seed, initvals=initvals, model=model, From 1e08e559efa729bd5a129021f570ddd15b4a497a Mon Sep 17 00:00:00 2001 From: Chris Fonnesbeck Date: Sun, 12 Mar 2023 14:31:59 -0400 Subject: [PATCH 14/14] Update pymc/sampling/mcmc.py Apply suggested nuts_sampler_kwarg docstring change Co-authored-by: Thomas Wiecki --- pymc/sampling/mcmc.py | 1 + 1 file changed, 1 insertion(+) diff --git a/pymc/sampling/mcmc.py b/pymc/sampling/mcmc.py index 10e3f62f31..1dd22fa818 100644 --- a/pymc/sampling/mcmc.py +++ b/pymc/sampling/mcmc.py @@ -417,6 +417,7 @@ def sample( Keyword arguments for :func:`pymc.to_inference_data` nuts_sampler_kwargs : dict, optional Keyword arguments for the sampling library that implements nuts. + Only used when an external sampler is specified via the `nuts_sampler` kwarg. callback : function, default=None A function which gets called for every sample from the trace of a chain. The function is called with the trace and the current draw and will contain all samples for a single trace.