Skip to content

Commit 7e0f56a

Browse files
ricardoV94twiecki
authored andcommitted
Change rng of partially observed RVs
1 parent bfc9e12 commit 7e0f56a

File tree

2 files changed

+35
-0
lines changed

2 files changed

+35
-0
lines changed

pymc/model.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1276,6 +1276,16 @@ def make_obs_var(
12761276
clone=False,
12771277
)
12781278
(observed_rv_var,) = local_subtensor_rv_lift.transform(fgraph, fgraph.outputs[0].owner)
1279+
# Make a clone of the RV, but change the rng so that observed and missing
1280+
# are not treated as equivalent nodes by aesara. This would happen if the
1281+
# size of the masked and unmasked array happened to coincide
1282+
_, size, _, *inps = observed_rv_var.owner.inputs
1283+
rng = self.model.next_rng()
1284+
observed_rv_var = observed_rv_var.owner.op(*inps, size=size, rng=rng)
1285+
# Add default_update to new rng
1286+
new_rng = observed_rv_var.owner.outputs[0]
1287+
observed_rv_var.update = (rng, new_rng)
1288+
rng.default_update = new_rng
12791289
observed_rv_var.name = f"{name}_observed"
12801290

12811291
observed_rv_var.tag.observations = nonmissing_data

pymc/tests/test_missing.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,10 @@
1818
import pytest
1919
import scipy.stats
2020

21+
from aesara.graph import graph_inputs
2122
from numpy import array, ma
2223

24+
from pymc import logpt
2325
from pymc.distributions import Dirichlet, Gamma, Normal, Uniform
2426
from pymc.exceptions import ImputationWarning
2527
from pymc.model import Model
@@ -207,3 +209,26 @@ def test_missing_vector_parameter():
207209
m.logp({"x_missing": np.array([-10, 10, -10, 10])}),
208210
scipy.stats.norm(scale=0.1).logpdf(0) * 6,
209211
)
212+
213+
214+
def test_missing_symmetric():
215+
"""Check that logpt works when partially observed variable have equal observed and
216+
unobserved dimensions.
217+
218+
This would fail in a previous implementation because the two variables would be
219+
equivalent and one of them would be discarded during MergeOptimization while
220+
buling the logpt graph
221+
"""
222+
with Model() as m:
223+
x = Gamma("x", alpha=3, beta=10, observed=np.array([1, np.nan]))
224+
225+
x_obs_rv = m["x_observed"]
226+
x_obs_vv = m.rvs_to_values[x_obs_rv]
227+
228+
x_unobs_rv = m["x_missing"]
229+
x_unobs_vv = m.rvs_to_values[x_unobs_rv]
230+
231+
logp = logpt([x_obs_rv, x_unobs_rv], {x_obs_rv: x_obs_vv, x_unobs_rv: x_unobs_vv})
232+
logp_inputs = list(graph_inputs([logp]))
233+
assert x_obs_vv in logp_inputs
234+
assert x_unobs_vv in logp_inputs

0 commit comments

Comments
 (0)