Skip to content

Commit cb64480

Browse files
authored
Add constant and observed data to nutpie idata (#6943)
* Add constant and observed data to nutpie idata * change order and add comment to revert
1 parent 6f4a040 commit cb64480

File tree

2 files changed

+48
-5
lines changed

2 files changed

+48
-5
lines changed

pymc/sampling/mcmc.py

Lines changed: 39 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -39,14 +39,20 @@
3939
import numpy as np
4040
import pytensor.gradient as tg
4141

42-
from arviz import InferenceData
42+
from arviz import InferenceData, dict_to_dataset
43+
from arviz.data.base import make_attrs
4344
from fastprogress.fastprogress import progress_bar
4445
from pytensor.graph.basic import Variable
4546
from typing_extensions import Protocol, TypeAlias
4647

4748
import pymc as pm
4849

4950
from pymc.backends import RunType, TraceOrBackend, init_traces
51+
from pymc.backends.arviz import (
52+
coords_and_dims_for_inferencedata,
53+
find_constants,
54+
find_observations,
55+
)
5056
from pymc.backends.base import IBaseTrace, MultiTrace, _choose_chains
5157
from pymc.blocking import DictToArrayBijection
5258
from pymc.exceptions import SamplingError
@@ -293,8 +299,8 @@ def _sample_external_nuts(
293299
"`idata_kwargs` are currently ignored by the nutpie sampler",
294300
UserWarning,
295301
)
296-
297302
compiled_model = nutpie.compile_pymc_model(model)
303+
t_start = time.time()
298304
idata = nutpie.sample(
299305
compiled_model,
300306
draws=draws,
@@ -305,6 +311,37 @@ def _sample_external_nuts(
305311
progress_bar=progressbar,
306312
**nuts_sampler_kwargs,
307313
)
314+
t_sample = time.time() - t_start
315+
# Temporary work-around. Revert once https://github.com/pymc-devs/nutpie/issues/74 is fixed
316+
# gather observed and constant data as nutpie.sample() has no access to the PyMC model
317+
coords, dims = coords_and_dims_for_inferencedata(model)
318+
constant_data = dict_to_dataset(
319+
find_constants(model),
320+
library=pm,
321+
coords=coords,
322+
dims=dims,
323+
default_dims=[],
324+
)
325+
observed_data = dict_to_dataset(
326+
find_observations(model),
327+
library=pm,
328+
coords=coords,
329+
dims=dims,
330+
default_dims=[],
331+
)
332+
attrs = make_attrs(
333+
{
334+
"sampling_time": t_sample,
335+
},
336+
library=nutpie,
337+
)
338+
for k, v in attrs.items():
339+
idata.posterior.attrs[k] = v
340+
idata.add_groups(
341+
{"constant_data": constant_data, "observed_data": observed_data},
342+
coords=coords,
343+
dims=dims,
344+
)
308345
return idata
309346

310347
elif sampler == "numpyro":

tests/sampling/test_mcmc_external.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
import numpy.testing as npt
1717
import pytest
1818

19-
from pymc import Model, Normal, sample
19+
from pymc import ConstantData, Model, Normal, sample
2020

2121

2222
@pytest.mark.parametrize("nuts_sampler", ["pymc", "nutpie", "blackjax", "numpyro"])
@@ -25,7 +25,11 @@ def test_external_nuts_sampler(recwarn, nuts_sampler):
2525
pytest.importorskip(nuts_sampler)
2626

2727
with Model():
28-
Normal("x")
28+
x = Normal("x", 100, 5)
29+
y = ConstantData("y", [1, 2, 3, 4])
30+
ConstantData("z", [100, 190, 310, 405])
31+
32+
Normal("L", mu=x, sigma=0.1, observed=y)
2933

3034
kwargs = dict(
3135
nuts_sampler=nuts_sampler,
@@ -55,7 +59,9 @@ def test_external_nuts_sampler(recwarn, nuts_sampler):
5559
)
5660
)
5761
assert warns == expected
58-
62+
assert "y" in idata1.constant_data
63+
assert "z" in idata1.constant_data
64+
assert "L" in idata1.observed_data
5965
assert idata1.posterior.chain.size == 2
6066
assert idata1.posterior.draw.size == 500
6167
np.testing.assert_array_equal(idata1.posterior.x, idata2.posterior.x)

0 commit comments

Comments
 (0)