diff --git a/pymc/sampling/jax.py b/pymc/sampling/jax.py index 0cdc8afd0c..10cacb9c2e 100644 --- a/pymc/sampling/jax.py +++ b/pymc/sampling/jax.py @@ -317,6 +317,7 @@ def sample_blackjax_nuts( postprocessing_backend: Optional[str] = None, postprocessing_chunks: Optional[int] = None, idata_kwargs: Optional[Dict[str, Any]] = None, + **kwargs, ) -> az.InferenceData: """ Draw samples from the posterior using the NUTS method from the ``blackjax`` library. @@ -529,6 +530,7 @@ def sample_numpyro_nuts( postprocessing_chunks: Optional[int] = None, idata_kwargs: Optional[Dict] = None, nuts_kwargs: Optional[Dict] = None, + **kwargs, ) -> az.InferenceData: """ Draw samples from the posterior using the NUTS method from the ``numpyro`` library. diff --git a/pymc/sampling/mcmc.py b/pymc/sampling/mcmc.py index 0449b029f8..1dd22fa818 100644 --- a/pymc/sampling/mcmc.py +++ b/pymc/sampling/mcmc.py @@ -237,10 +237,14 @@ def _sample_external_nuts( model: Model, progressbar: bool, idata_kwargs: Optional[Dict], + nuts_sampler_kwargs: Optional[Dict], **kwargs, ): warnings.warn("Use of external NUTS sampler is still experimental", UserWarning) + if nuts_sampler_kwargs is None: + nuts_sampler_kwargs = {} + if sampler == "nutpie": try: import nutpie @@ -271,7 +275,7 @@ def _sample_external_nuts( target_accept=target_accept, seed=_get_seeds_per_chain(random_seed, 1)[0], progress_bar=progressbar, - **kwargs, + **nuts_sampler_kwargs, ) return idata @@ -288,7 +292,7 @@ def _sample_external_nuts( model=model, progressbar=progressbar, idata_kwargs=idata_kwargs, - **kwargs, + **nuts_sampler_kwargs, ) return idata @@ -304,7 +308,7 @@ def _sample_external_nuts( initvals=initvals, model=model, idata_kwargs=idata_kwargs, - **kwargs, + **nuts_sampler_kwargs, ) return idata @@ -334,6 +338,7 @@ def sample( keep_warning_stat: bool = False, return_inferencedata: bool = True, idata_kwargs: Optional[Dict[str, Any]] = None, + nuts_sampler_kwargs: Optional[Dict[str, Any]] = None, callback=None, mp_ctx=None, model: Optional[Model] = None, @@ -410,6 +415,9 @@ def sample( `MultiTrace` (False). Defaults to `True`. idata_kwargs : dict, optional Keyword arguments for :func:`pymc.to_inference_data` + nuts_sampler_kwargs : dict, optional + Keyword arguments for the sampling library that implements nuts. + Only used when an external sampler is specified via the `nuts_sampler` kwarg. callback : function, default=None A function which gets called for every sample from the trace of a chain. The function is called with the trace and the current draw and will contain all samples for a single trace. @@ -493,6 +501,8 @@ def sample( stacklevel=2, ) initvals = kwargs.pop("start") + if nuts_sampler_kwargs is None: + nuts_sampler_kwargs = {} if "target_accept" in kwargs: if "nuts" in kwargs and "target_accept" in kwargs["nuts"]: raise ValueError( @@ -569,6 +579,7 @@ def sample( model=model, progressbar=progressbar, idata_kwargs=idata_kwargs, + nuts_sampler_kwargs=nuts_sampler_kwargs, **kwargs, ) diff --git a/tests/sampling/test_mcmc_external.py b/tests/sampling/test_mcmc_external.py index 2439738955..3c86154c14 100644 --- a/tests/sampling/test_mcmc_external.py +++ b/tests/sampling/test_mcmc_external.py @@ -13,13 +13,11 @@ # limitations under the License. import numpy as np +import numpy.testing as npt import pytest from pymc import Model, Normal, sample -# turns all warnings into errors for this module -pytestmark = pytest.mark.filterwarnings("error") - @pytest.mark.parametrize("nuts_sampler", ["pymc", "nutpie", "blackjax", "numpyro"]) def test_external_nuts_sampler(recwarn, nuts_sampler): @@ -63,3 +61,16 @@ def test_external_nuts_sampler(recwarn, nuts_sampler): assert idata1.posterior.chain.size == 2 assert idata1.posterior.draw.size == 500 np.testing.assert_array_equal(idata1.posterior.x, idata2.posterior.x) + + +def test_step_args(): + with Model() as model: + a = Normal("a") + idata = sample( + nuts_sampler="numpyro", + target_accept=0.5, + nuts={"max_treedepth": 10}, + random_seed=1410, + ) + + npt.assert_almost_equal(idata.sample_stats.acceptance_rate.mean(), 0.5, decimal=1)