@@ -237,10 +237,14 @@ def _sample_external_nuts(
237
237
model : Model ,
238
238
progressbar : bool ,
239
239
idata_kwargs : Optional [Dict ],
240
+ nuts_sampler_kwargs : Optional [Dict ],
240
241
** kwargs ,
241
242
):
242
243
warnings .warn ("Use of external NUTS sampler is still experimental" , UserWarning )
243
244
245
+ if nuts_sampler_kwargs is None :
246
+ nuts_sampler_kwargs = {}
247
+
244
248
if sampler == "nutpie" :
245
249
try :
246
250
import nutpie
@@ -271,7 +275,7 @@ def _sample_external_nuts(
271
275
target_accept = target_accept ,
272
276
seed = _get_seeds_per_chain (random_seed , 1 )[0 ],
273
277
progress_bar = progressbar ,
274
- ** kwargs ,
278
+ ** nuts_sampler_kwargs ,
275
279
)
276
280
return idata
277
281
@@ -288,7 +292,7 @@ def _sample_external_nuts(
288
292
model = model ,
289
293
progressbar = progressbar ,
290
294
idata_kwargs = idata_kwargs ,
291
- ** kwargs ,
295
+ ** nuts_sampler_kwargs ,
292
296
)
293
297
return idata
294
298
@@ -304,7 +308,7 @@ def _sample_external_nuts(
304
308
initvals = initvals ,
305
309
model = model ,
306
310
idata_kwargs = idata_kwargs ,
307
- ** kwargs ,
311
+ ** nuts_sampler_kwargs ,
308
312
)
309
313
return idata
310
314
@@ -334,6 +338,7 @@ def sample(
334
338
keep_warning_stat : bool = False ,
335
339
return_inferencedata : bool = True ,
336
340
idata_kwargs : Optional [Dict [str , Any ]] = None ,
341
+ nuts_sampler_kwargs : Optional [Dict [str , Any ]] = None ,
337
342
callback = None ,
338
343
mp_ctx = None ,
339
344
model : Optional [Model ] = None ,
@@ -410,6 +415,9 @@ def sample(
410
415
`MultiTrace` (False). Defaults to `True`.
411
416
idata_kwargs : dict, optional
412
417
Keyword arguments for :func:`pymc.to_inference_data`
418
+ nuts_sampler_kwargs : dict, optional
419
+ Keyword arguments for the sampling library that implements nuts.
420
+ Only used when an external sampler is specified via the `nuts_sampler` kwarg.
413
421
callback : function, default=None
414
422
A function which gets called for every sample from the trace of a chain. The function is
415
423
called with the trace and the current draw and will contain all samples for a single trace.
@@ -493,6 +501,8 @@ def sample(
493
501
stacklevel = 2 ,
494
502
)
495
503
initvals = kwargs .pop ("start" )
504
+ if nuts_sampler_kwargs is None :
505
+ nuts_sampler_kwargs = {}
496
506
if "target_accept" in kwargs :
497
507
if "nuts" in kwargs and "target_accept" in kwargs ["nuts" ]:
498
508
raise ValueError (
@@ -569,6 +579,7 @@ def sample(
569
579
model = model ,
570
580
progressbar = progressbar ,
571
581
idata_kwargs = idata_kwargs ,
582
+ nuts_sampler_kwargs = nuts_sampler_kwargs ,
572
583
** kwargs ,
573
584
)
574
585
0 commit comments