39
39
import numpy as np
40
40
import pytensor .gradient as tg
41
41
42
- from arviz import InferenceData
42
+ from arviz import InferenceData , dict_to_dataset
43
+ from arviz .data .base import make_attrs
43
44
from fastprogress .fastprogress import progress_bar
44
45
from pytensor .graph .basic import Variable
45
46
from typing_extensions import Protocol , TypeAlias
46
47
47
48
import pymc as pm
48
49
49
50
from pymc .backends import RunType , TraceOrBackend , init_traces
51
+ from pymc .backends .arviz import (
52
+ coords_and_dims_for_inferencedata ,
53
+ find_constants ,
54
+ find_observations ,
55
+ )
50
56
from pymc .backends .base import IBaseTrace , MultiTrace , _choose_chains
51
57
from pymc .blocking import DictToArrayBijection
52
58
from pymc .exceptions import SamplingError
@@ -293,8 +299,8 @@ def _sample_external_nuts(
293
299
"`idata_kwargs` are currently ignored by the nutpie sampler" ,
294
300
UserWarning ,
295
301
)
296
-
297
302
compiled_model = nutpie .compile_pymc_model (model )
303
+ t_start = time .time ()
298
304
idata = nutpie .sample (
299
305
compiled_model ,
300
306
draws = draws ,
@@ -305,6 +311,37 @@ def _sample_external_nuts(
305
311
progress_bar = progressbar ,
306
312
** nuts_sampler_kwargs ,
307
313
)
314
+ t_sample = time .time () - t_start
315
+ # Temporary work-around. Revert once https://github.com/pymc-devs/nutpie/issues/74 is fixed
316
+ # gather observed and constant data as nutpie.sample() has no access to the PyMC model
317
+ coords , dims = coords_and_dims_for_inferencedata (model )
318
+ constant_data = dict_to_dataset (
319
+ find_constants (model ),
320
+ library = pm ,
321
+ coords = coords ,
322
+ dims = dims ,
323
+ default_dims = [],
324
+ )
325
+ observed_data = dict_to_dataset (
326
+ find_observations (model ),
327
+ library = pm ,
328
+ coords = coords ,
329
+ dims = dims ,
330
+ default_dims = [],
331
+ )
332
+ attrs = make_attrs (
333
+ {
334
+ "sampling_time" : t_sample ,
335
+ },
336
+ library = nutpie ,
337
+ )
338
+ for k , v in attrs .items ():
339
+ idata .posterior .attrs [k ] = v
340
+ idata .add_groups (
341
+ {"constant_data" : constant_data , "observed_data" : observed_data },
342
+ coords = coords ,
343
+ dims = dims ,
344
+ )
308
345
return idata
309
346
310
347
elif sampler == "numpyro" :
0 commit comments