Skip to content

Commit da68d11

Browse files
fonnesbecktwiecki
andauthored
Add nuts_sampler_kwargs and nuts_kwargs to pm.sample (#6581)
* Reinstate nuts_kwargs in sample for passing arguments to numpyro * Removed redundant condition from if statement * Added sampler_kwargs * Rename sampler_kwargs; fix test failures * Ensure nuts_kwargs get passed to pymc nuts * Fix test_step_args failure * Moved numpyro test; use nuts_sampler_kwargs * Test failure fix * Fix docstrings for nuts/sampler kwargs; debugging target accept argument passing * Removed passing kwargs to external samplers; only nuts_sampler_kwargs for clarity * Revert nuts_kwargs to nuts * Fix failures in test_mcmc_external * Revert argument parameterization * Update pymc/sampling/mcmc.py Apply suggested nuts_sampler_kwarg docstring change Co-authored-by: Thomas Wiecki <thomas.wiecki@gmail.com> --------- Co-authored-by: Thomas Wiecki <thomas.wiecki@gmail.com>
1 parent ff8a4c7 commit da68d11

File tree

3 files changed

+30
-6
lines changed

3 files changed

+30
-6
lines changed

pymc/sampling/jax.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -317,6 +317,7 @@ def sample_blackjax_nuts(
317317
postprocessing_backend: Optional[str] = None,
318318
postprocessing_chunks: Optional[int] = None,
319319
idata_kwargs: Optional[Dict[str, Any]] = None,
320+
**kwargs,
320321
) -> az.InferenceData:
321322
"""
322323
Draw samples from the posterior using the NUTS method from the ``blackjax`` library.
@@ -529,6 +530,7 @@ def sample_numpyro_nuts(
529530
postprocessing_chunks: Optional[int] = None,
530531
idata_kwargs: Optional[Dict] = None,
531532
nuts_kwargs: Optional[Dict] = None,
533+
**kwargs,
532534
) -> az.InferenceData:
533535
"""
534536
Draw samples from the posterior using the NUTS method from the ``numpyro`` library.

pymc/sampling/mcmc.py

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -237,10 +237,14 @@ def _sample_external_nuts(
237237
model: Model,
238238
progressbar: bool,
239239
idata_kwargs: Optional[Dict],
240+
nuts_sampler_kwargs: Optional[Dict],
240241
**kwargs,
241242
):
242243
warnings.warn("Use of external NUTS sampler is still experimental", UserWarning)
243244

245+
if nuts_sampler_kwargs is None:
246+
nuts_sampler_kwargs = {}
247+
244248
if sampler == "nutpie":
245249
try:
246250
import nutpie
@@ -271,7 +275,7 @@ def _sample_external_nuts(
271275
target_accept=target_accept,
272276
seed=_get_seeds_per_chain(random_seed, 1)[0],
273277
progress_bar=progressbar,
274-
**kwargs,
278+
**nuts_sampler_kwargs,
275279
)
276280
return idata
277281

@@ -288,7 +292,7 @@ def _sample_external_nuts(
288292
model=model,
289293
progressbar=progressbar,
290294
idata_kwargs=idata_kwargs,
291-
**kwargs,
295+
**nuts_sampler_kwargs,
292296
)
293297
return idata
294298

@@ -304,7 +308,7 @@ def _sample_external_nuts(
304308
initvals=initvals,
305309
model=model,
306310
idata_kwargs=idata_kwargs,
307-
**kwargs,
311+
**nuts_sampler_kwargs,
308312
)
309313
return idata
310314

@@ -334,6 +338,7 @@ def sample(
334338
keep_warning_stat: bool = False,
335339
return_inferencedata: bool = True,
336340
idata_kwargs: Optional[Dict[str, Any]] = None,
341+
nuts_sampler_kwargs: Optional[Dict[str, Any]] = None,
337342
callback=None,
338343
mp_ctx=None,
339344
model: Optional[Model] = None,
@@ -410,6 +415,9 @@ def sample(
410415
`MultiTrace` (False). Defaults to `True`.
411416
idata_kwargs : dict, optional
412417
Keyword arguments for :func:`pymc.to_inference_data`
418+
nuts_sampler_kwargs : dict, optional
419+
Keyword arguments for the sampling library that implements nuts.
420+
Only used when an external sampler is specified via the `nuts_sampler` kwarg.
413421
callback : function, default=None
414422
A function which gets called for every sample from the trace of a chain. The function is
415423
called with the trace and the current draw and will contain all samples for a single trace.
@@ -493,6 +501,8 @@ def sample(
493501
stacklevel=2,
494502
)
495503
initvals = kwargs.pop("start")
504+
if nuts_sampler_kwargs is None:
505+
nuts_sampler_kwargs = {}
496506
if "target_accept" in kwargs:
497507
if "nuts" in kwargs and "target_accept" in kwargs["nuts"]:
498508
raise ValueError(
@@ -569,6 +579,7 @@ def sample(
569579
model=model,
570580
progressbar=progressbar,
571581
idata_kwargs=idata_kwargs,
582+
nuts_sampler_kwargs=nuts_sampler_kwargs,
572583
**kwargs,
573584
)
574585

tests/sampling/test_mcmc_external.py

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,13 +13,11 @@
1313
# limitations under the License.
1414

1515
import numpy as np
16+
import numpy.testing as npt
1617
import pytest
1718

1819
from pymc import Model, Normal, sample
1920

20-
# turns all warnings into errors for this module
21-
pytestmark = pytest.mark.filterwarnings("error")
22-
2321

2422
@pytest.mark.parametrize("nuts_sampler", ["pymc", "nutpie", "blackjax", "numpyro"])
2523
def test_external_nuts_sampler(recwarn, nuts_sampler):
@@ -63,3 +61,16 @@ def test_external_nuts_sampler(recwarn, nuts_sampler):
6361
assert idata1.posterior.chain.size == 2
6462
assert idata1.posterior.draw.size == 500
6563
np.testing.assert_array_equal(idata1.posterior.x, idata2.posterior.x)
64+
65+
66+
def test_step_args():
67+
with Model() as model:
68+
a = Normal("a")
69+
idata = sample(
70+
nuts_sampler="numpyro",
71+
target_accept=0.5,
72+
nuts={"max_treedepth": 10},
73+
random_seed=1410,
74+
)
75+
76+
npt.assert_almost_equal(idata.sample_stats.acceptance_rate.mean(), 0.5, decimal=1)

0 commit comments

Comments
 (0)