Skip to content

Commit 74748c7

Browse files
authored
Fix compute_log_prior in models with Deterministics (#7168)
1 parent 97a7a00 commit 74748c7

File tree

2 files changed

+24
-5
lines changed

2 files changed

+24
-5
lines changed

pymc/stats/log_density.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -131,7 +131,7 @@ def compute_log_density(
131131
target_rvs = model.observed_RVs
132132
target_str = "observed_RVs"
133133
else:
134-
target_rvs = model.unobserved_RVs
134+
target_rvs = model.free_RVs
135135
target_str = "free_RVs"
136136

137137
if var_names is None:

tests/stats/test_log_density.py

Lines changed: 23 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919

2020
from pymc.distributions import Dirichlet, Normal
2121
from pymc.distributions.transforms import log
22-
from pymc.model import Model
22+
from pymc.model import Deterministic, Model
2323
from pymc.stats.log_density import compute_log_likelihood, compute_log_prior
2424
from tests.distributions.test_multivariate import dirichlet_logpdf
2525

@@ -41,7 +41,7 @@ def test_basic(self, transform):
4141
assert m.rvs_to_transforms[x] is transform
4242

4343
assert res is idata
44-
assert res.log_likelihood.dims == {"chain": 4, "draw": 25, "test_dim": 3}
44+
assert res.log_likelihood.sizes == {"chain": 4, "draw": 25, "test_dim": 3}
4545

4646
np.testing.assert_allclose(
4747
res.log_likelihood["y"].values,
@@ -62,7 +62,7 @@ def test_multivariate(self):
6262
idata = InferenceData(posterior=dict_to_dataset({"p": p_draws}))
6363
res = compute_log_likelihood(idata)
6464

65-
assert res.log_likelihood.dims == {"chain": 4, "draw": 25, "test_event_dim": 10}
65+
assert res.log_likelihood.sizes == {"chain": 4, "draw": 25, "test_event_dim": 10}
6666

6767
np.testing.assert_allclose(
6868
res.log_likelihood["y"].values,
@@ -149,7 +149,26 @@ def test_basic_log_prior(self, transform):
149149
assert m.rvs_to_transforms[x] is transform
150150

151151
assert res is idata
152-
assert res.log_prior.dims == {"chain": 4, "draw": 25}
152+
assert res.log_prior.sizes == {"chain": 4, "draw": 25}
153+
154+
np.testing.assert_allclose(
155+
res.log_prior["x"].values,
156+
st.norm(0, 1).logpdf(idata.posterior["x"].values),
157+
)
158+
159+
def test_deterministic_log_prior(self):
160+
with Model() as m:
161+
x = Normal("x")
162+
Deterministic("d", 2 * x)
163+
Normal("y", x, observed=[0, 1, 2])
164+
165+
idata = InferenceData(posterior=dict_to_dataset({"x": np.arange(100).reshape(4, 25)}))
166+
res = compute_log_prior(idata)
167+
168+
assert res is idata
169+
assert "x" in res.log_prior
170+
assert "d" not in res.log_prior
171+
assert res.log_prior.sizes == {"chain": 4, "draw": 25}
153172

154173
np.testing.assert_allclose(
155174
res.log_prior["x"].values,

0 commit comments

Comments
 (0)