Skip to content

Commit a21e1f8

Browse files
committed
WIP Automatic marginalization finite discrete variables
1 parent 307913d commit a21e1f8

File tree

2 files changed

+398
-0
lines changed

2 files changed

+398
-0
lines changed

pymc_experimental/marginal_model.py

Lines changed: 249 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,249 @@
1+
from typing import Sequence, Tuple, Union
2+
3+
import aesara.tensor as at
4+
import numpy as np
5+
from aeppl import factorized_joint_logprob
6+
from aeppl.logprob import _logprob
7+
from aesara import clone_replace
8+
from aesara.compile import SharedVariable
9+
from aesara.compile.builders import OpFromGraph
10+
from aesara.graph import Constant, FunctionGraph, ancestors
11+
from aesara.tensor import TensorVariable
12+
from aesara.tensor.elemwise import Elemwise
13+
from pymc import SymbolicRandomVariable
14+
from pymc.aesaraf import inputvars
15+
from pymc.distributions.discrete import Bernoulli, Categorical, DiscreteUniform
16+
from pymc.model import Model
17+
18+
19+
def _replace_marginalized_subgraph(fgraph, rv_to_marginalize):
20+
# Check if it's even valid
21+
temp_fgraph = FunctionGraph(inputs=rv_to_marginalize, outputs=fgraph.outputs, clone=False)
22+
23+
24+
class MarginalModel(Model):
25+
def __init__(self, *args, **kwargs):
26+
super().__init__(*args, **kwargs)
27+
if self.parent is not None:
28+
self.marginalized_rvs = self.parent.marginalized_rvs
29+
else:
30+
self.marginalized_rvs = []
31+
32+
def marginalize(self, rvs_to_marginalize: Union[TensorVariable, Sequence[TensorVariable]]):
33+
# TODO: this does not need to be a property of a Model
34+
if not isinstance(rvs_to_marginalize, Sequence):
35+
rvs_to_marginalize = (rvs_to_marginalize,)
36+
37+
supported_dists = (Bernoulli, Categorical, DiscreteUniform)
38+
for rv_to_marginalize in rvs_to_marginalize:
39+
if rv_to_marginalize not in self.free_RVs:
40+
raise ValueError(
41+
f"Marginalized RV {rv_to_marginalize} is not a free RV in the model"
42+
)
43+
if not isinstance(rv_to_marginalize.owner.op, supported_dists):
44+
raise NotImplementedError(
45+
f"RV with distribution {rv_to_marginalize.owner.op} cannot be marginalized. "
46+
f"Supported distribution include {supported_dists}"
47+
)
48+
49+
if self.deterministics:
50+
# TODO: This should be fine if deterministics do not depend on marginalized RVs
51+
raise NotImplementedError("Models with deterministics cannot be marginalized")
52+
53+
if self.potentials:
54+
raise NotImplementedError("Models with potentials cannot be marginalized")
55+
56+
# Replaced with subgraph that need to be marginalized for each RV
57+
fg = FunctionGraph(outputs=self.basic_RVs, clone=False)
58+
toposort = fg.toposort()
59+
replacements = {}
60+
for rv_to_marginalize in sorted(
61+
rvs_to_marginalize, key=lambda rv: toposort.index(rv.owner)
62+
):
63+
old_rvs, new_rvs = _replace_finite_discrete_marginal_subgraph(
64+
fg, rv_to_marginalize, self.rvs_to_values
65+
)
66+
# Update old mappings
67+
for old_rv, new_rv in zip(old_rvs, new_rvs):
68+
replacements[old_rv] = new_rv
69+
if old_rv in self.free_RVs:
70+
index = self.free_RVs.index(old_rv)
71+
self.free_RVs.pop(index)
72+
self.free_RVs.insert(index, new_rv)
73+
else:
74+
index = self.observed_RVs.index(old_rv)
75+
self.observed_RVs.pop(index)
76+
self.observed_RVs.insert(index, new_rv)
77+
self.rvs_to_values[new_rv] = value = self.rvs_to_values.pop(old_rv)
78+
self.values_to_rvs[value] = new_rv
79+
self.rvs_to_transforms[new_rv] = self.rvs_to_transforms.pop(old_rv)
80+
# TODO: Automatic imputation RV does not seem to have total_size mapping
81+
self.rvs_to_total_sizes[new_rv] = self.rvs_to_total_sizes.pop(old_rv, None)
82+
83+
# This RV can now be safely ignored in the logp graph
84+
self.free_RVs.remove(rv_to_marginalize)
85+
value = self.rvs_to_values.pop(rv_to_marginalize)
86+
self.values_to_rvs.pop(value)
87+
self.rvs_to_transforms.pop(rv_to_marginalize)
88+
self.rvs_to_total_sizes.pop(rv_to_marginalize)
89+
90+
return replacements
91+
92+
93+
def _find_dependent_rvs(dependable_rv, all_rvs):
94+
# Find rvs than depend on dependable
95+
dependent_rvs = []
96+
for rv in all_rvs:
97+
if rv is dependable_rv:
98+
continue
99+
blockers = [other_rv for other_rv in all_rvs if other_rv is not rv]
100+
if dependable_rv in ancestors([rv], blockers=blockers):
101+
dependent_rvs.append(rv)
102+
return dependent_rvs
103+
104+
105+
def _find_input_rvs(output_rvs, all_rvs):
106+
blockers = [other_rv for other_rv in all_rvs if other_rv not in output_rvs]
107+
return [
108+
var
109+
for var in ancestors(output_rvs, blockers=blockers)
110+
if var in blockers
111+
or (var.owner is None and not isinstance(var, (Constant, SharedVariable)))
112+
]
113+
114+
115+
def _is_elemwise_subgraph(rv_to_marginalize, other_input_rvs, output_rvs):
116+
# TODO: No need to consider apply nodes outside the subgraph...
117+
fg = FunctionGraph(outputs=output_rvs, clone=False)
118+
119+
non_elemwise_blockers = [
120+
o for node in fg.apply_nodes if not isinstance(node.op, Elemwise) for o in node.outputs
121+
]
122+
blocker_candidates = [rv_to_marginalize] + other_input_rvs + non_elemwise_blockers
123+
blockers = [var for var in blocker_candidates if var not in output_rvs]
124+
125+
# TODO: We could actually use these truncated inputs to
126+
# generate a smaller Marginalized graph...
127+
truncated_inputs = [
128+
var
129+
for var in ancestors(output_rvs, blockers=blockers)
130+
if (
131+
var in blockers
132+
or (var.owner is None and not isinstance(var, (Constant, SharedVariable)))
133+
)
134+
]
135+
136+
# Check that we reach the marginalized rv following a pure elemwise graph
137+
if rv_to_marginalize not in truncated_inputs:
138+
return False
139+
140+
# Check that none of the truncated inputs depends on the marginalized_rv
141+
other_truncated_inputs = [inp for inp in truncated_inputs if inp is not rv_to_marginalize]
142+
# TODO: We don't need to go all the way to the root variables
143+
if rv_to_marginalize in ancestors(
144+
other_truncated_inputs, blockers=[rv_to_marginalize, *other_input_rvs]
145+
):
146+
return False
147+
return True
148+
149+
150+
class FiniteDiscreteMarginalRV(SymbolicRandomVariable):
151+
pass
152+
153+
154+
def _replace_finite_discrete_marginal_subgraph(fgraph, rv_to_marginalize, rvs_to_values):
155+
# TODO: This should eventually be integrated in a more general routine that can
156+
# identify other types of supported marginalization, of which finite discrete
157+
# RVs is just one
158+
159+
dependent_rvs = _find_dependent_rvs(rv_to_marginalize, rvs_to_values)
160+
input_rvs = _find_input_rvs(dependent_rvs, rvs_to_values)
161+
other_input_rvs = [rv for rv in input_rvs if rv is not rv_to_marginalize]
162+
# We don't need to worry about batched graphs if the RV is scalar.
163+
# TODO: This eval is a bit hackish
164+
if np.prod(rv_to_marginalize.shape.eval()) > 1:
165+
if not _is_elemwise_subgraph(rv_to_marginalize, other_input_rvs, dependent_rvs):
166+
raise NotImplementedError(
167+
"The subgraph between a marginalized RV and its dependents includes non Elemwise operations. "
168+
"This is currently not supported",
169+
)
170+
171+
marginalization_op = FiniteDiscreteMarginalRV(
172+
inputs=[rv_to_marginalize, *other_input_rvs],
173+
outputs=dependent_rvs,
174+
ndim_supp=None,
175+
)
176+
# Marginalized_RV logp is accounted by in the logp, so it can be safely ignored
177+
# rv_to_marginalize = ignore_logprob(rv_to_marginalize)
178+
marginalized_rvs = marginalization_op(rv_to_marginalize, *other_input_rvs)
179+
if not isinstance(marginalized_rvs, Sequence):
180+
marginalized_rvs = (marginalized_rvs,)
181+
fgraph.replace_all(tuple(zip(dependent_rvs, marginalized_rvs)))
182+
return dependent_rvs, marginalized_rvs
183+
184+
185+
def _get_domain_of_finite_discrete_rv(rv: TensorVariable) -> Tuple[int, ...]:
186+
op = rv.owner.op
187+
if isinstance(op, Bernoulli):
188+
return (0, 1)
189+
elif isinstance(op, Categorical):
190+
p_param = rv.owner.inputs[3]
191+
return tuple(range(at.get_vector_length(p_param)))
192+
elif isinstance(op, DiscreteUniform):
193+
lower, upper = rv.owner.inputs[3:]
194+
return tuple(
195+
range(
196+
at.get_scalar_constant_value(lower),
197+
at.get_scalar_constant_value(upper),
198+
)
199+
)
200+
201+
raise NotImplementedError(f"Cannot compute domain for op {op}")
202+
203+
204+
@_logprob.register(FiniteDiscreteMarginalRV)
205+
def finite_discrete_marginal_rv_logp(op, values, *inputs, **kwargs):
206+
207+
marginalized_rvs_node = op.make_node(*inputs)
208+
marginalized_rvs = clone_replace(
209+
op.inner_outputs,
210+
replace={u: v for u, v in zip(op.inner_inputs, marginalized_rvs_node.inputs)},
211+
)
212+
213+
marginalized_rv, *other_inputs = inputs
214+
other_inputs = list(inputvars(other_inputs))
215+
216+
rvs_to_values = {}
217+
dummy_marginalized_value = marginalized_rv.clone()
218+
rvs_to_values[marginalized_rv] = dummy_marginalized_value
219+
220+
rvs_to_values.update(zip(marginalized_rvs, values))
221+
_logp = at.sum(
222+
[
223+
at.sum(factor)
224+
for factor in factorized_joint_logprob(
225+
rv_values=rvs_to_values, warn_missing_rvs=False, **kwargs
226+
).values()
227+
]
228+
)
229+
# OpFromGraph does not accept constant inputs...
230+
_values = [
231+
value
232+
for value in rvs_to_values.values()
233+
if not isinstance(value, (Constant, SharedVariable))
234+
]
235+
# TODO: If we inline the logp graph, optimization becomes incredibly painful for
236+
# large domains... Would be great to find a way to vectorize the graph across
237+
# the domain values (when possible)
238+
logp_op = OpFromGraph([*_values, *other_inputs], [_logp], inline=False)
239+
240+
# PyMC does not allow RVs in the logp graph... Even if we are just using the shape
241+
# TODO: Get better work-around
242+
marginalized_rv_shape = marginalized_rv.shape.eval()
243+
values = [value for value in values if not isinstance(value, (Constant, SharedVariable))]
244+
return at.logsumexp(
245+
[
246+
logp_op(np.full(marginalized_rv_shape, marginalized_rv_const), *values, *other_inputs)
247+
for marginalized_rv_const in _get_domain_of_finite_discrete_rv(marginalized_rv)
248+
]
249+
)
Lines changed: 149 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,149 @@
1+
import aesara.tensor as at
2+
import numpy as np
3+
import pandas as pd
4+
import pymc as pm
5+
import pytest
6+
from aeppl.logprob import _logprob
7+
from aesara.graph import ancestors
8+
9+
from pymc_experimental.marginal_model import FiniteDiscreteMarginalRV, MarginalModel
10+
11+
12+
def test_marginalized_bernoulli_logp():
13+
"""Test logp of IR TestFiniteMarginalDiscreteRV directly"""
14+
idx = pm.Bernoulli.dist(0.7, name="idx")
15+
mu = at.constant([-1, 1])[idx]
16+
y = pm.Normal.dist(mu=mu, sigma=1.0, name="y")
17+
marginal_y = FiniteDiscreteMarginalRV([idx], [y], ndim_supp=None)(idx)
18+
19+
y_vv = y.clone()
20+
marginal_y_logp = _logprob(
21+
marginal_y.owner.op,
22+
(y_vv,),
23+
*marginal_y.owner.inputs,
24+
)
25+
26+
ref_logp = pm.logp(pm.NormalMixture.dist(w=[0.3, 0.7], mu=[-1, 1], sigma=1.0), y_vv).sum()
27+
np.testing.assert_almost_equal(
28+
marginal_y_logp.eval({y_vv: 2}),
29+
ref_logp.eval({y_vv: 2}),
30+
)
31+
32+
33+
def test_marginalize():
34+
data = [2] * 5
35+
36+
with pm.Model() as m_ref:
37+
sigma = pm.HalfNormal("sigma")
38+
y = pm.NormalMixture("y", w=[0.1, 0.3, 0.6], mu=[-1, 0, 1], sigma=sigma)
39+
z = pm.Normal("z", y, observed=data)
40+
41+
with MarginalModel() as m:
42+
sigma = pm.HalfNormal("sigma")
43+
idx = pm.Categorical("idx", p=[0.1, 0.3, 0.6])
44+
mu = at.switch(
45+
at.eq(idx, 0),
46+
-1,
47+
at.switch(
48+
at.eq(idx, 1),
49+
0,
50+
1,
51+
),
52+
)
53+
y = pm.Normal("y", mu=mu, sigma=sigma)
54+
z = pm.Normal("z", y, observed=data)
55+
56+
replacements = m.marginalize([idx])
57+
assert len(replacements) == 1
58+
59+
assert y not in m.free_RVs
60+
assert idx not in m.free_RVs
61+
62+
new_y = replacements[y]
63+
assert new_y in m.free_RVs
64+
assert new_y in ancestors([z])
65+
66+
assert isinstance(new_y.owner.op, FiniteDiscreteMarginalRV)
67+
# Ignore RNGs
68+
assert new_y.owner.inputs[:2] == [idx, sigma]
69+
70+
test_point = m_ref.initial_point()
71+
# TODO: Test we don't get warnings with missing RVs
72+
np.testing.assert_almost_equal(
73+
m.compile_logp()(test_point),
74+
m_ref.compile_logp()(test_point),
75+
)
76+
77+
78+
def test_marginalize_nested():
79+
raise NotImplementedError("Must write test")
80+
81+
82+
def test_not_supported_marginalization():
83+
"""Marginalized graphs with non-Elemwise Operations are not supported as they
84+
would violate the batching logp assumption"""
85+
86+
mu = at.constant([-1, 1])
87+
88+
# Allowed, as only elemwise operations connect idx to y
89+
with MarginalModel() as m:
90+
p = pm.Beta("p", 1, 1)
91+
idx = pm.Bernoulli("idx", p=p, size=2)
92+
y = pm.Normal("y", mu=pm.math.switch(idx, 0, 1))
93+
assert m.marginalize([idx])
94+
95+
# ALlowed, as index operation does not connext idx to y
96+
with MarginalModel() as m:
97+
p = pm.Beta("p", 1, 1)
98+
idx = pm.Bernoulli("idx", p=p, size=2)
99+
y = pm.Normal("y", mu=pm.math.switch(idx, mu[0], mu[1]))
100+
assert m.marginalize([idx])
101+
102+
# Not allowed, as index operation connects idx to y
103+
with MarginalModel() as m:
104+
p = pm.Beta("p", 1, 1)
105+
idx = pm.Bernoulli("idx", p=p, size=2)
106+
# Not allowed
107+
y = pm.Normal("y", mu=mu[idx])
108+
with pytest.raises(NotImplementedError):
109+
m.marginalize(idx)
110+
111+
# Not allowed, as index operation connects idx to y, even though there is a
112+
# pure Elemwise connection between the two
113+
with MarginalModel() as m:
114+
p = pm.Beta("p", 1, 1)
115+
idx = pm.Bernoulli("idx", p=p, size=2)
116+
y = pm.Normal("y", mu=mu[idx] + idx)
117+
with pytest.raises(NotImplementedError):
118+
m.marginalize(idx)
119+
120+
121+
def test_change_point_model():
122+
# fmt: off
123+
disaster_data = pd.Series(
124+
[4, 5, 4, 0, 1, 4, 3, 4, 0, 6, 3, 3, 4, 0, 2, 6,
125+
3, 3, 5, 4, 5, 3, 1, 4, 4, 1, 5, 5, 3, 4, 2, 5,
126+
2, 2, 3, 4, 2, 1, 3, np.nan, 2, 1, 1, 1, 1, 3, 0, 0,
127+
1, 0, 1, 1, 0, 0, 3, 1, 0, 3, 2, 2, 0, 1, 1, 1,
128+
0, 1, 0, 1, 0, 0, 0, 2, 1, 0, 0, 0, 1, 1, 0, 2,
129+
3, 3, 1, np.nan, 2, 1, 1, 1, 1, 2, 4, 2, 0, 0, 1, 4,
130+
0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 1, 0, 1]
131+
)
132+
# fmt: on
133+
years = np.arange(1851, 1962)
134+
135+
with MarginalModel() as disaster_model:
136+
switchpoint = pm.DiscreteUniform(
137+
"switchpoint", lower=years.min(), upper=years.max(), size=1
138+
)
139+
140+
early_rate = pm.Exponential("early_rate", 1.0)
141+
late_rate = pm.Exponential("late_rate", 1.0)
142+
rate = pm.math.switch(switchpoint >= years, early_rate, late_rate)
143+
144+
disasters = pm.Poisson("disasters", rate, observed=disaster_data)
145+
146+
disaster_model.marginalize([switchpoint])
147+
disaster_model.compile_logp()(disaster_model.initial_point())
148+
149+
raise NotImplementedError("Test not finished")

0 commit comments

Comments
 (0)