Skip to content

Commit 31916ab

Browse files
Add test for univariate and multivariate marginal mixture
Fix issue with `ndim_supp` Co-authored-by: Jesse Grabowski <48652735+jessegrabowski@users.noreply.github.com>
1 parent 8dff969 commit 31916ab

File tree

2 files changed

+50
-3
lines changed

2 files changed

+50
-3
lines changed

pymc_experimental/model/marginal_model.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -580,10 +580,13 @@ def replace_finite_discrete_marginal_subgraph(fgraph, rv_to_marginalize, all_rvs
580580
raise ValueError(f"No RVs depend on marginalized RV {rv_to_marginalize}")
581581

582582
ndim_supp = {rv.owner.op.ndim_supp for rv in dependent_rvs}
583-
if max(ndim_supp) > 0:
583+
if len(ndim_supp) != 1:
584584
raise NotImplementedError(
585-
"Marginalization of withe dependent Multivariate RVs not implemented"
585+
"Marginalization with dependent variables of different support not implemented"
586586
)
587+
[ndim_supp] = ndim_supp
588+
if ndim_supp > 0:
589+
raise NotImplementedError("Marginalization with dependent Multivariate RVs not implemented")
587590

588591
marginalized_rv_input_rvs = find_conditional_input_rvs([rv_to_marginalize], all_rvs)
589592
dependent_rvs_input_rvs = [
@@ -621,7 +624,7 @@ def replace_finite_discrete_marginal_subgraph(fgraph, rv_to_marginalize, all_rvs
621624
marginalization_op = FiniteDiscreteMarginalRV(
622625
inputs=list(replace_inputs.values()),
623626
outputs=cloned_outputs,
624-
ndim_supp=0,
627+
ndim_supp=ndim_supp,
625628
)
626629
marginalized_rvs = marginalization_op(*replace_inputs.keys())
627630
fgraph.replace_all(tuple(zip(rvs_to_marginalize, marginalized_rvs)))

pymc_experimental/tests/model/test_marginal_model.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -598,3 +598,47 @@ def test_is_conditional_dependent_static_shape():
598598
x2 = pt.matrix("x2", shape=(9, 5))
599599
y2 = pt.random.normal(size=pt.shape(x2))
600600
assert not is_conditional_dependent(y2, x2, [x2, y2])
601+
602+
603+
@pytest.mark.parametrize("univariate", (True, False))
604+
def test_vector_univariate_mixture(univariate):
605+
606+
with MarginalModel() as m:
607+
idx = pm.Bernoulli("idx", p=0.5, shape=(2,) if univariate else ())
608+
609+
def dist(idx, size):
610+
return pm.math.switch(
611+
pm.math.eq(idx, 0),
612+
pm.Normal.dist([-10, -10], 1),
613+
pm.Normal.dist([10, 10], 1),
614+
)
615+
616+
pm.CustomDist("norm", idx, dist=dist)
617+
618+
m.marginalize(idx)
619+
logp_fn = m.compile_logp()
620+
621+
if univariate:
622+
with pm.Model() as ref_m:
623+
pm.NormalMixture("norm", w=[0.5, 0.5], mu=[[-10, 10], [-10, 10]], shape=(2,))
624+
else:
625+
with pm.Model() as ref_m:
626+
pm.Mixture(
627+
"norm",
628+
w=[0.5, 0.5],
629+
comp_dists=[
630+
pm.MvNormal.dist([-10, -10], np.eye(2)),
631+
pm.MvNormal.dist([10, 10], np.eye(2)),
632+
],
633+
shape=(2,),
634+
)
635+
ref_logp_fn = ref_m.compile_logp()
636+
637+
for test_value in (
638+
[-10, -10],
639+
[10, 10],
640+
[-10, 10],
641+
[-10, 10],
642+
):
643+
pt = {"norm": test_value}
644+
np.testing.assert_allclose(logp_fn(pt), ref_logp_fn(pt))

0 commit comments

Comments
 (0)