Skip to content

Commit afd4740

Browse files
committed
WIP Automatic marginalization finite discrete variables
1 parent 307913d commit afd4740

File tree

3 files changed

+802
-0
lines changed

3 files changed

+802
-0
lines changed

notebooks/marginalized_changepoint_model.ipynb

Lines changed: 409 additions & 0 deletions
Large diffs are not rendered by default.

pymc_experimental/marginal_model.py

Lines changed: 244 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,244 @@
1+
from typing import Sequence, Tuple, Union
2+
3+
import aesara.tensor as at
4+
import numpy as np
5+
from aeppl import factorized_joint_logprob
6+
from aeppl.logprob import _logprob
7+
from aesara import clone_replace
8+
from aesara.compile import SharedVariable
9+
from aesara.compile.builders import OpFromGraph
10+
from aesara.graph import Constant, FunctionGraph, ancestors
11+
from aesara.tensor import TensorVariable
12+
from aesara.tensor.elemwise import Elemwise
13+
from pymc import SymbolicRandomVariable
14+
from pymc.aesaraf import inputvars
15+
from pymc.distributions.discrete import Bernoulli, Categorical, DiscreteUniform
16+
from pymc.model import Model
17+
18+
19+
class MarginalModel(Model):
20+
def __init__(self, *args, **kwargs):
21+
super().__init__(*args, **kwargs)
22+
if self.parent is not None:
23+
self.marginalized_rvs = self.parent.marginalized_rvs
24+
else:
25+
self.marginalized_rvs = []
26+
27+
def marginalize(self, rvs_to_marginalize: Union[TensorVariable, Sequence[TensorVariable]]):
28+
# TODO: this does not need to be a property of a Model
29+
if not isinstance(rvs_to_marginalize, Sequence):
30+
rvs_to_marginalize = (rvs_to_marginalize,)
31+
32+
supported_dists = (Bernoulli, Categorical, DiscreteUniform)
33+
for rv_to_marginalize in rvs_to_marginalize:
34+
if rv_to_marginalize not in self.free_RVs:
35+
raise ValueError(
36+
f"Marginalized RV {rv_to_marginalize} is not a free RV in the model"
37+
)
38+
if not isinstance(rv_to_marginalize.owner.op, supported_dists):
39+
raise NotImplementedError(
40+
f"RV with distribution {rv_to_marginalize.owner.op} cannot be marginalized. "
41+
f"Supported distribution include {supported_dists}"
42+
)
43+
44+
if self.deterministics:
45+
# TODO: This should be fine if deterministics do not depend on marginalized RVs
46+
raise NotImplementedError("Models with deterministics cannot be marginalized")
47+
48+
if self.potentials:
49+
raise NotImplementedError("Models with potentials cannot be marginalized")
50+
51+
# Replaced with subgraph that need to be marginalized for each RV
52+
fg = FunctionGraph(outputs=self.basic_RVs, clone=False)
53+
toposort = fg.toposort()
54+
replacements = {}
55+
for rv_to_marginalize in sorted(
56+
rvs_to_marginalize, key=lambda rv: toposort.index(rv.owner)
57+
):
58+
old_rvs, new_rvs = _replace_finite_discrete_marginal_subgraph(
59+
fg, rv_to_marginalize, self.rvs_to_values
60+
)
61+
# Update old mappings
62+
for old_rv, new_rv in zip(old_rvs, new_rvs):
63+
replacements[old_rv] = new_rv
64+
if old_rv in self.free_RVs:
65+
index = self.free_RVs.index(old_rv)
66+
self.free_RVs.pop(index)
67+
self.free_RVs.insert(index, new_rv)
68+
else:
69+
index = self.observed_RVs.index(old_rv)
70+
self.observed_RVs.pop(index)
71+
self.observed_RVs.insert(index, new_rv)
72+
self.rvs_to_values[new_rv] = value = self.rvs_to_values.pop(old_rv)
73+
self.values_to_rvs[value] = new_rv
74+
self.rvs_to_transforms[new_rv] = self.rvs_to_transforms.pop(old_rv)
75+
# TODO: Automatic imputation RV does not seem to have total_size mapping
76+
self.rvs_to_total_sizes[new_rv] = self.rvs_to_total_sizes.pop(old_rv, None)
77+
78+
# This RV can now be safely ignored in the logp graph
79+
self.free_RVs.remove(rv_to_marginalize)
80+
value = self.rvs_to_values.pop(rv_to_marginalize)
81+
self.values_to_rvs.pop(value)
82+
self.rvs_to_transforms.pop(rv_to_marginalize)
83+
self.rvs_to_total_sizes.pop(rv_to_marginalize)
84+
85+
return replacements
86+
87+
88+
def _find_dependent_rvs(dependable_rv, all_rvs):
89+
# Find rvs than depend on dependable
90+
dependent_rvs = []
91+
for rv in all_rvs:
92+
if rv is dependable_rv:
93+
continue
94+
blockers = [other_rv for other_rv in all_rvs if other_rv is not rv]
95+
if dependable_rv in ancestors([rv], blockers=blockers):
96+
dependent_rvs.append(rv)
97+
return dependent_rvs
98+
99+
100+
def _find_input_rvs(output_rvs, all_rvs):
101+
blockers = [other_rv for other_rv in all_rvs if other_rv not in output_rvs]
102+
return [
103+
var
104+
for var in ancestors(output_rvs, blockers=blockers)
105+
if var in blockers
106+
or (var.owner is None and not isinstance(var, (Constant, SharedVariable)))
107+
]
108+
109+
110+
def _is_elemwise_subgraph(rv_to_marginalize, other_input_rvs, output_rvs):
111+
# TODO: No need to consider apply nodes outside the subgraph...
112+
fg = FunctionGraph(outputs=output_rvs, clone=False)
113+
114+
non_elemwise_blockers = [
115+
o for node in fg.apply_nodes if not isinstance(node.op, Elemwise) for o in node.outputs
116+
]
117+
blocker_candidates = [rv_to_marginalize] + other_input_rvs + non_elemwise_blockers
118+
blockers = [var for var in blocker_candidates if var not in output_rvs]
119+
120+
# TODO: We could actually use these truncated inputs to
121+
# generate a smaller Marginalized graph...
122+
truncated_inputs = [
123+
var
124+
for var in ancestors(output_rvs, blockers=blockers)
125+
if (
126+
var in blockers
127+
or (var.owner is None and not isinstance(var, (Constant, SharedVariable)))
128+
)
129+
]
130+
131+
# Check that we reach the marginalized rv following a pure elemwise graph
132+
if rv_to_marginalize not in truncated_inputs:
133+
return False
134+
135+
# Check that none of the truncated inputs depends on the marginalized_rv
136+
other_truncated_inputs = [inp for inp in truncated_inputs if inp is not rv_to_marginalize]
137+
# TODO: We don't need to go all the way to the root variables
138+
if rv_to_marginalize in ancestors(
139+
other_truncated_inputs, blockers=[rv_to_marginalize, *other_input_rvs]
140+
):
141+
return False
142+
return True
143+
144+
145+
def _replace_finite_discrete_marginal_subgraph(fgraph, rv_to_marginalize, rvs_to_values):
146+
# TODO: This should eventually be integrated in a more general routine that can
147+
# identify other types of supported marginalization, of which finite discrete
148+
# RVs is just one
149+
150+
dependent_rvs = _find_dependent_rvs(rv_to_marginalize, rvs_to_values)
151+
input_rvs = _find_input_rvs(dependent_rvs, rvs_to_values)
152+
other_input_rvs = [rv for rv in input_rvs if rv is not rv_to_marginalize]
153+
# We don't need to worry about batched graphs if the RV is scalar.
154+
# TODO: This eval is a bit hackish
155+
if np.prod(rv_to_marginalize.shape.eval()) > 1:
156+
if not _is_elemwise_subgraph(rv_to_marginalize, other_input_rvs, dependent_rvs):
157+
raise NotImplementedError(
158+
"The subgraph between a marginalized RV and its dependents includes non Elemwise operations. "
159+
"This is currently not supported",
160+
)
161+
162+
marginalization_op = FiniteDiscreteMarginalRV(
163+
inputs=[rv_to_marginalize, *other_input_rvs],
164+
outputs=dependent_rvs,
165+
ndim_supp=None,
166+
)
167+
# Marginalized_RV logp is accounted by in the logp, so it can be safely ignored
168+
# rv_to_marginalize = ignore_logprob(rv_to_marginalize)
169+
marginalized_rvs = marginalization_op(rv_to_marginalize, *other_input_rvs)
170+
if not isinstance(marginalized_rvs, Sequence):
171+
marginalized_rvs = (marginalized_rvs,)
172+
fgraph.replace_all(tuple(zip(dependent_rvs, marginalized_rvs)))
173+
return dependent_rvs, marginalized_rvs
174+
175+
176+
class FiniteDiscreteMarginalRV(SymbolicRandomVariable):
177+
pass
178+
179+
180+
def _get_domain_of_finite_discrete_rv(rv: TensorVariable) -> Tuple[int, ...]:
181+
op = rv.owner.op
182+
if isinstance(op, Bernoulli):
183+
return (0, 1)
184+
elif isinstance(op, Categorical):
185+
p_param = rv.owner.inputs[3]
186+
return tuple(range(at.get_vector_length(p_param)))
187+
elif isinstance(op, DiscreteUniform):
188+
lower, upper = rv.owner.inputs[3:]
189+
return tuple(
190+
range(
191+
at.get_scalar_constant_value(lower),
192+
at.get_scalar_constant_value(upper),
193+
)
194+
)
195+
196+
raise NotImplementedError(f"Cannot compute domain for op {op}")
197+
198+
199+
@_logprob.register(FiniteDiscreteMarginalRV)
200+
def finite_discrete_marginal_rv_logp(op, values, *inputs, **kwargs):
201+
202+
marginalized_rvs_node = op.make_node(*inputs)
203+
marginalized_rvs = clone_replace(
204+
op.inner_outputs,
205+
replace={u: v for u, v in zip(op.inner_inputs, marginalized_rvs_node.inputs)},
206+
)
207+
208+
marginalized_rv, *other_inputs = inputs
209+
other_inputs = list(inputvars(other_inputs))
210+
211+
rvs_to_values = {}
212+
dummy_marginalized_value = marginalized_rv.clone()
213+
rvs_to_values[marginalized_rv] = dummy_marginalized_value
214+
215+
rvs_to_values.update(zip(marginalized_rvs, values))
216+
_logp = at.sum(
217+
[
218+
at.sum(factor)
219+
for factor in factorized_joint_logprob(
220+
rv_values=rvs_to_values, warn_missing_rvs=False, **kwargs
221+
).values()
222+
]
223+
)
224+
# OpFromGraph does not accept constant inputs...
225+
_values = [
226+
value
227+
for value in rvs_to_values.values()
228+
if not isinstance(value, (Constant, SharedVariable))
229+
]
230+
# TODO: If we inline the logp graph, optimization becomes incredibly painful for
231+
# large domains... Would be great to find a way to vectorize the graph across
232+
# the domain values (when possible)
233+
logp_op = OpFromGraph([*_values, *other_inputs], [_logp], inline=False)
234+
235+
# PyMC does not allow RVs in the logp graph... Even if we are just using the shape
236+
# TODO: Get better work-around
237+
marginalized_rv_shape = marginalized_rv.shape.eval()
238+
values = [value for value in values if not isinstance(value, (Constant, SharedVariable))]
239+
return at.logsumexp(
240+
[
241+
logp_op(np.full(marginalized_rv_shape, marginalized_rv_const), *values, *other_inputs)
242+
for marginalized_rv_const in _get_domain_of_finite_discrete_rv(marginalized_rv)
243+
]
244+
)
Lines changed: 149 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,149 @@
1+
import aesara.tensor as at
2+
import numpy as np
3+
import pandas as pd
4+
import pymc as pm
5+
import pytest
6+
from aeppl.logprob import _logprob
7+
from aesara.graph import ancestors
8+
9+
from pymc_experimental.marginal_model import FiniteDiscreteMarginalRV, MarginalModel
10+
11+
12+
def test_marginalized_bernoulli_logp():
13+
"""Test logp of IR TestFiniteMarginalDiscreteRV directly"""
14+
idx = pm.Bernoulli.dist(0.7, name="idx")
15+
mu = at.constant([-1, 1])[idx]
16+
y = pm.Normal.dist(mu=mu, sigma=1.0, name="y")
17+
marginal_y = FiniteDiscreteMarginalRV([idx], [y], ndim_supp=None)(idx)
18+
19+
y_vv = y.clone()
20+
marginal_y_logp = _logprob(
21+
marginal_y.owner.op,
22+
(y_vv,),
23+
*marginal_y.owner.inputs,
24+
)
25+
26+
ref_logp = pm.logp(pm.NormalMixture.dist(w=[0.3, 0.7], mu=[-1, 1], sigma=1.0), y_vv).sum()
27+
np.testing.assert_almost_equal(
28+
marginal_y_logp.eval({y_vv: 2}),
29+
ref_logp.eval({y_vv: 2}),
30+
)
31+
32+
33+
def test_marginalize():
34+
data = [2] * 5
35+
36+
with pm.Model() as m_ref:
37+
sigma = pm.HalfNormal("sigma")
38+
y = pm.NormalMixture("y", w=[0.1, 0.3, 0.6], mu=[-1, 0, 1], sigma=sigma)
39+
z = pm.Normal("z", y, observed=data)
40+
41+
with MarginalModel() as m:
42+
sigma = pm.HalfNormal("sigma")
43+
idx = pm.Categorical("idx", p=[0.1, 0.3, 0.6])
44+
mu = at.switch(
45+
at.eq(idx, 0),
46+
-1,
47+
at.switch(
48+
at.eq(idx, 1),
49+
0,
50+
1,
51+
),
52+
)
53+
y = pm.Normal("y", mu=mu, sigma=sigma)
54+
z = pm.Normal("z", y, observed=data)
55+
56+
replacements = m.marginalize([idx])
57+
assert len(replacements) == 1
58+
59+
assert y not in m.free_RVs
60+
assert idx not in m.free_RVs
61+
62+
new_y = replacements[y]
63+
assert new_y in m.free_RVs
64+
assert new_y in ancestors([z])
65+
66+
assert isinstance(new_y.owner.op, FiniteDiscreteMarginalRV)
67+
# Ignore RNGs
68+
assert new_y.owner.inputs[:2] == [idx, sigma]
69+
70+
test_point = m_ref.initial_point()
71+
# TODO: Test we don't get warnings with missing RVs
72+
np.testing.assert_almost_equal(
73+
m.compile_logp()(test_point),
74+
m_ref.compile_logp()(test_point),
75+
)
76+
77+
78+
def test_marginalize_nested():
79+
raise NotImplementedError("Must write test")
80+
81+
82+
def test_not_supported_marginalization():
83+
"""Marginalized graphs with non-Elemwise Operations are not supported as they
84+
would violate the batching logp assumption"""
85+
86+
mu = at.constant([-1, 1])
87+
88+
# Allowed, as only elemwise operations connect idx to y
89+
with MarginalModel() as m:
90+
p = pm.Beta("p", 1, 1)
91+
idx = pm.Bernoulli("idx", p=p, size=2)
92+
y = pm.Normal("y", mu=pm.math.switch(idx, 0, 1))
93+
assert m.marginalize([idx])
94+
95+
# ALlowed, as index operation does not connext idx to y
96+
with MarginalModel() as m:
97+
p = pm.Beta("p", 1, 1)
98+
idx = pm.Bernoulli("idx", p=p, size=2)
99+
y = pm.Normal("y", mu=pm.math.switch(idx, mu[0], mu[1]))
100+
assert m.marginalize([idx])
101+
102+
# Not allowed, as index operation connects idx to y
103+
with MarginalModel() as m:
104+
p = pm.Beta("p", 1, 1)
105+
idx = pm.Bernoulli("idx", p=p, size=2)
106+
# Not allowed
107+
y = pm.Normal("y", mu=mu[idx])
108+
with pytest.raises(NotImplementedError):
109+
m.marginalize(idx)
110+
111+
# Not allowed, as index operation connects idx to y, even though there is a
112+
# pure Elemwise connection between the two
113+
with MarginalModel() as m:
114+
p = pm.Beta("p", 1, 1)
115+
idx = pm.Bernoulli("idx", p=p, size=2)
116+
y = pm.Normal("y", mu=mu[idx] + idx)
117+
with pytest.raises(NotImplementedError):
118+
m.marginalize(idx)
119+
120+
121+
def test_change_point_model():
122+
# fmt: off
123+
disaster_data = pd.Series(
124+
[4, 5, 4, 0, 1, 4, 3, 4, 0, 6, 3, 3, 4, 0, 2, 6,
125+
3, 3, 5, 4, 5, 3, 1, 4, 4, 1, 5, 5, 3, 4, 2, 5,
126+
2, 2, 3, 4, 2, 1, 3, np.nan, 2, 1, 1, 1, 1, 3, 0, 0,
127+
1, 0, 1, 1, 0, 0, 3, 1, 0, 3, 2, 2, 0, 1, 1, 1,
128+
0, 1, 0, 1, 0, 0, 0, 2, 1, 0, 0, 0, 1, 1, 0, 2,
129+
3, 3, 1, np.nan, 2, 1, 1, 1, 1, 2, 4, 2, 0, 0, 1, 4,
130+
0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 1, 0, 1]
131+
)
132+
# fmt: on
133+
years = np.arange(1851, 1962)
134+
135+
with MarginalModel() as disaster_model:
136+
switchpoint = pm.DiscreteUniform(
137+
"switchpoint", lower=years.min(), upper=years.max(), size=1
138+
)
139+
140+
early_rate = pm.Exponential("early_rate", 1.0)
141+
late_rate = pm.Exponential("late_rate", 1.0)
142+
rate = pm.math.switch(switchpoint >= years, early_rate, late_rate)
143+
144+
disasters = pm.Poisson("disasters", rate, observed=disaster_data)
145+
146+
disaster_model.marginalize([switchpoint])
147+
disaster_model.compile_logp()(disaster_model.initial_point())
148+
149+
raise NotImplementedError("Test not finished")

0 commit comments

Comments
 (0)