Skip to content

Commit 9227827

Browse files
jaharvey8ricardoV94
authored andcommitted
Add helper to convert model coords and dims into format accepted by InferenceData
1 parent f77372c commit 9227827

File tree

5 files changed

+47
-50
lines changed

5 files changed

+47
-50
lines changed

pymc/backends/arviz.py

Lines changed: 17 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,18 @@ def is_data(name, var, model) -> bool:
100100
return constant_data
101101

102102

103+
def coords_and_dims_for_inferencedata(model: Model) -> Tuple[Dict[str, Any], Dict[str, Any]]:
104+
"""Parse PyMC model coords and dims format to one accepted by InferenceData."""
105+
coords = {
106+
cname: np.array(cvals) if isinstance(cvals, tuple) else cvals
107+
for cname, cvals in model.coords.items()
108+
if cvals is not None
109+
}
110+
dims = {dname: list(dvals) for dname, dvals in model.named_vars_to_dims.items()}
111+
112+
return coords, dims
113+
114+
103115
class _DefaultTrace:
104116
"""
105117
Utility for collecting samples into a dictionary.
@@ -216,19 +228,11 @@ def __init__(
216228
" one of trace, prior, posterior_predictive or predictions."
217229
)
218230

219-
# Make coord types more rigid
220-
untyped_coords: Dict[str, Optional[Sequence[Any]]] = {**self.model.coords}
221-
if coords:
222-
untyped_coords.update(coords)
223-
self.coords = {
224-
cname: np.array(cvals) if isinstance(cvals, tuple) else cvals
225-
for cname, cvals in untyped_coords.items()
226-
if cvals is not None
227-
}
228-
229-
self.dims = {} if dims is None else dims
230-
model_dims = {k: list(v) for k, v in self.model.named_vars_to_dims.items()}
231-
self.dims = {**model_dims, **self.dims}
231+
user_coords = {} if coords is None else coords
232+
user_dims = {} if dims is None else dims
233+
model_coords, model_dims = coords_and_dims_for_inferencedata(self.model)
234+
self.coords = {**model_coords, **user_coords}
235+
self.dims = {**model_dims, **user_dims}
232236

233237
if sample_dims is None:
234238
sample_dims = ["chain", "draw"]

pymc/model/core.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -942,7 +942,7 @@ def RV_dims(self) -> Dict[str, Tuple[Union[str, None], ...]]:
942942
Entries in the tuples may be ``None``, if the RV dimension was not given a name.
943943
"""
944944
warnings.warn(
945-
"Model.RV_dims is deprecated. User Model.named_vars_to_dims instead.",
945+
"Model.RV_dims is deprecated. Use Model.named_vars_to_dims instead.",
946946
FutureWarning,
947947
)
948948
return self.named_vars_to_dims

pymc/sampling/jax.py

Lines changed: 9 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,11 @@
3737
from pytensor.tensor.random.type import RandomType
3838

3939
from pymc import Model, modelcontext
40-
from pymc.backends.arviz import find_constants, find_observations
40+
from pymc.backends.arviz import (
41+
coords_and_dims_for_inferencedata,
42+
find_constants,
43+
find_observations,
44+
)
4145
from pymc.distributions.multivariate import PosDefMatrix
4246
from pymc.initial_point import StartDict
4347
from pymc.logprob.utils import CheckParameterValue
@@ -392,17 +396,6 @@ def sample_blackjax_nuts(
392396

393397
vars_to_sample = list(get_default_varnames(var_names, include_transformed=keep_untransformed))
394398

395-
coords = {
396-
cname: np.array(cvals) if isinstance(cvals, tuple) else cvals
397-
for cname, cvals in model.coords.items()
398-
if cvals is not None
399-
}
400-
401-
dims = {
402-
var_name: [dim for dim in dims if dim is not None]
403-
for var_name, dims in model.named_vars_to_dims.items()
404-
}
405-
406399
(random_seed,) = _get_seeds_per_chain(random_seed, 1)
407400

408401
tic1 = datetime.now()
@@ -485,7 +478,7 @@ def sample_blackjax_nuts(
485478
"sampling_time": (tic3 - tic2).total_seconds(),
486479
}
487480

488-
posterior = mcmc_samples
481+
coords, dims = coords_and_dims_for_inferencedata(model)
489482
# Update 'coords' and 'dims' extracted from the model with user 'idata_kwargs'
490483
# and drop keys 'coords' and 'dims' from 'idata_kwargs' if present.
491484
_update_coords_and_dims(coords=coords, dims=dims, idata_kwargs=idata_kwargs)
@@ -500,7 +493,7 @@ def sample_blackjax_nuts(
500493
dims=dims,
501494
attrs=make_attrs(attrs, library=blackjax),
502495
)
503-
az_trace = to_trace(posterior=posterior, **idata_kwargs)
496+
az_trace = to_trace(posterior=mcmc_samples, **idata_kwargs)
504497

505498
return az_trace
506499

@@ -613,17 +606,6 @@ def sample_numpyro_nuts(
613606

614607
vars_to_sample = list(get_default_varnames(var_names, include_transformed=keep_untransformed))
615608

616-
coords = {
617-
cname: np.array(cvals) if isinstance(cvals, tuple) else cvals
618-
for cname, cvals in model.coords.items()
619-
if cvals is not None
620-
}
621-
622-
dims = {
623-
var_name: [dim for dim in dims if dim is not None]
624-
for var_name, dims in model.named_vars_to_dims.items()
625-
}
626-
627609
(random_seed,) = _get_seeds_per_chain(random_seed, 1)
628610

629611
tic1 = datetime.now()
@@ -715,7 +697,7 @@ def sample_numpyro_nuts(
715697
"sampling_time": (tic3 - tic2).total_seconds(),
716698
}
717699

718-
posterior = mcmc_samples
700+
coords, dims = coords_and_dims_for_inferencedata(model)
719701
# Update 'coords' and 'dims' extracted from the model with user 'idata_kwargs'
720702
# and drop keys 'coords' and 'dims' from 'idata_kwargs' if present.
721703
_update_coords_and_dims(coords=coords, dims=dims, idata_kwargs=idata_kwargs)
@@ -730,5 +712,5 @@ def sample_numpyro_nuts(
730712
dims=dims,
731713
attrs=make_attrs(attrs, library=numpyro),
732714
)
733-
az_trace = to_trace(posterior=posterior, **idata_kwargs)
715+
az_trace = to_trace(posterior=mcmc_samples, **idata_kwargs)
734716
return az_trace

pymc/stats/log_likelihood.py

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -13,14 +13,12 @@
1313
# limitations under the License.
1414
from typing import Optional, Sequence, cast
1515

16-
import numpy as np
17-
1816
from arviz import InferenceData, dict_to_dataset
1917
from fastprogress import progress_bar
2018

2119
import pymc
2220

23-
from pymc.backends.arviz import _DefaultTrace
21+
from pymc.backends.arviz import _DefaultTrace, coords_and_dims_for_inferencedata
2422
from pymc.model import Model, modelcontext
2523
from pymc.pytensorf import PointFunc
2624
from pymc.util import dataset_to_point_list
@@ -113,14 +111,12 @@ def compute_log_likelihood(
113111
(*[len(coord) for coord in stacked_dims.values()], *array.shape[1:])
114112
)
115113

114+
coords, dims = coords_and_dims_for_inferencedata(model)
116115
loglike_dataset = dict_to_dataset(
117116
loglike_trace,
118117
library=pymc,
119-
dims={dname: list(dvals) for dname, dvals in model.named_vars_to_dims.items()},
120-
coords={
121-
cname: np.array(cvals) if isinstance(cvals, tuple) else cvals
122-
for cname, cvals in model.coords.items()
123-
},
118+
dims=dims,
119+
coords=coords,
124120
default_dims=list(sample_dims),
125121
skip_event_dims=True,
126122
)

tests/stats/test_log_likelihood.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
import pytest
1616
import scipy.stats as st
1717

18-
from arviz import InferenceData, dict_to_dataset
18+
from arviz import InferenceData, dict_to_dataset, from_dict
1919

2020
from pymc.distributions import Dirichlet, Normal
2121
from pymc.distributions.transforms import log
@@ -117,3 +117,18 @@ def test_invalid_var_names(self):
117117
idata = InferenceData(posterior=dict_to_dataset({"x": np.arange(100).reshape(4, 25)}))
118118
with pytest.raises(ValueError, match="var_names must refer to observed_RVs"):
119119
compute_log_likelihood(idata, var_names=["x"])
120+
121+
def test_dims_without_coords(self):
122+
# Issues #6820
123+
with Model() as m:
124+
x = Normal("x")
125+
y = Normal("y", x, observed=[0, 0, 0], shape=(3,), dims="obs")
126+
127+
trace = from_dict({"x": [[0, 1]]})
128+
llike = compute_log_likelihood(trace)
129+
130+
assert len(llike.log_likelihood["obs"]) == 3
131+
np.testing.assert_allclose(
132+
llike.log_likelihood["y"].values,
133+
st.norm.logpdf([[[0, 0, 0], [1, 1, 1]]]),
134+
)

0 commit comments

Comments
 (0)