|
18 | 18 | import pytest
|
19 | 19 | import scipy.stats
|
20 | 20 |
|
| 21 | +from aesara.graph import graph_inputs |
21 | 22 | from numpy import array, ma
|
22 | 23 |
|
| 24 | +from pymc import logpt |
23 | 25 | from pymc.distributions import Dirichlet, Gamma, Normal, Uniform
|
24 | 26 | from pymc.exceptions import ImputationWarning
|
25 | 27 | from pymc.model import Model
|
@@ -207,3 +209,26 @@ def test_missing_vector_parameter():
|
207 | 209 | m.logp({"x_missing": np.array([-10, 10, -10, 10])}),
|
208 | 210 | scipy.stats.norm(scale=0.1).logpdf(0) * 6,
|
209 | 211 | )
|
| 212 | + |
| 213 | + |
| 214 | +def test_missing_symmetric(): |
| 215 | + """Check that logpt works when partially observed variable have equal observed and |
| 216 | + unobserved dimensions. |
| 217 | +
|
| 218 | + This would fail in a previous implementation because the two variables would be |
| 219 | + equivalent and one of them would be discarded during MergeOptimization while |
| 220 | + buling the logpt graph |
| 221 | + """ |
| 222 | + with Model() as m: |
| 223 | + x = Gamma("x", alpha=3, beta=10, observed=np.array([1, np.nan])) |
| 224 | + |
| 225 | + x_obs_rv = m["x_observed"] |
| 226 | + x_obs_vv = m.rvs_to_values[x_obs_rv] |
| 227 | + |
| 228 | + x_unobs_rv = m["x_missing"] |
| 229 | + x_unobs_vv = m.rvs_to_values[x_unobs_rv] |
| 230 | + |
| 231 | + logp = logpt([x_obs_rv, x_unobs_rv], {x_obs_rv: x_obs_vv, x_unobs_rv: x_unobs_vv}) |
| 232 | + logp_inputs = list(graph_inputs([logp])) |
| 233 | + assert x_obs_vv in logp_inputs |
| 234 | + assert x_unobs_vv in logp_inputs |
0 commit comments