Skip to content

Commit 4f10d49

Browse files
committed
WIP Automatic marginalization finite discrete variables
1 parent 307913d commit 4f10d49

File tree

3 files changed

+1412
-0
lines changed

3 files changed

+1412
-0
lines changed

notebooks/marginalized_changepoint_model.ipynb

Lines changed: 839 additions & 0 deletions
Large diffs are not rendered by default.

pymc_experimental/marginal_model.py

Lines changed: 361 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,361 @@
1+
from typing import Dict, Sequence, Tuple, Union
2+
3+
import aesara.tensor as at
4+
import numpy as np
5+
from aeppl import factorized_joint_logprob
6+
from aeppl.abstract import _get_measurable_outputs
7+
from aeppl.logprob import _logprob
8+
from aesara import clone_replace
9+
from aesara.compile import SharedVariable
10+
from aesara.compile.builders import OpFromGraph
11+
from aesara.graph import Constant, FunctionGraph, ancestors
12+
from aesara.tensor import TensorVariable
13+
from aesara.tensor.elemwise import Elemwise
14+
from aesara.tensor.random.op import RandomVariable
15+
from aesara.tensor.random.var import (
16+
RandomGeneratorSharedVariable,
17+
RandomStateSharedVariable,
18+
)
19+
from pymc import SymbolicRandomVariable
20+
from pymc.aesaraf import constant_fold, inputvars
21+
from pymc.distributions.discrete import Bernoulli, Categorical, DiscreteUniform
22+
from pymc.distributions.distribution import _moment, moment
23+
from pymc.model import Model
24+
25+
26+
class MarginalModel(Model):
27+
def __init__(self, *args, **kwargs):
28+
super().__init__(*args, **kwargs)
29+
if self.parent is not None:
30+
raise NotImplementedError("MarginalModel cannot be used inside another Model")
31+
else:
32+
self.marginalized_rvs_to_dependent_rvs = {}
33+
34+
def logp(self, vars=None, **kwargs):
35+
if not kwargs.get("sum", True):
36+
# Check if dependent RVs were requested
37+
if vars is not None and not isinstance(vars, Sequence):
38+
vars = (vars,)
39+
if vars is None or (
40+
{v for vs in self.marginalized_rvs_to_dependent_rvs.values() for v in vs}
41+
& {self.values_to_rvs.get(var, var) for var in vars}
42+
):
43+
raise ValueError(
44+
"Cannot request elemwise logp (sum=False) for variables that depend on a marginalized RV"
45+
)
46+
return super().logp(vars, **kwargs)
47+
48+
def point_logps(self, *args, **kwargs):
49+
# TODO: Fix this
50+
return {}
51+
52+
def marginalize(self, rvs_to_marginalize: Union[TensorVariable, Sequence[TensorVariable]]):
53+
# TODO: this does not need to be a property of a Model
54+
if not isinstance(rvs_to_marginalize, Sequence):
55+
rvs_to_marginalize = (rvs_to_marginalize,)
56+
57+
supported_dists = (Bernoulli, Categorical, DiscreteUniform)
58+
for rv_to_marginalize in rvs_to_marginalize:
59+
if rv_to_marginalize not in self.free_RVs:
60+
raise ValueError(
61+
f"Marginalized RV {rv_to_marginalize} is not a free RV in the model"
62+
)
63+
if not isinstance(rv_to_marginalize.owner.op, supported_dists):
64+
raise NotImplementedError(
65+
f"RV with distribution {rv_to_marginalize.owner.op} cannot be marginalized. "
66+
f"Supported distribution include {supported_dists}"
67+
)
68+
69+
if self.deterministics:
70+
# TODO: This should be fine if deterministics do not depend on marginalized RVs
71+
raise NotImplementedError("Models with deterministics cannot be marginalized")
72+
73+
if self.potentials:
74+
raise NotImplementedError("Models with potentials cannot be marginalized")
75+
76+
# Replaced with subgraph that need to be marginalized for each RV
77+
fg = FunctionGraph(outputs=self.basic_RVs, clone=False)
78+
toposort = fg.toposort()
79+
replacements = {}
80+
new_marginalized_rv = None
81+
new_dependent_rvs = []
82+
for rv_to_marginalize in sorted(
83+
rvs_to_marginalize, key=lambda rv: toposort.index(rv.owner)
84+
):
85+
old_rvs, new_rvs = _replace_finite_discrete_marginal_subgraph(
86+
fg, rv_to_marginalize, self.rvs_to_values
87+
)
88+
# Update old mappings
89+
for old_rv, new_rv in zip(old_rvs, new_rvs):
90+
replacements[old_rv] = new_rv
91+
92+
value = self.rvs_to_values.pop(old_rv)
93+
self.named_vars.pop(old_rv.name)
94+
new_rv.name = old_rv.name
95+
96+
if old_rv is rv_to_marginalize:
97+
self.free_RVs.remove(old_rv)
98+
self.values_to_rvs.pop(value)
99+
self.rvs_to_transforms.pop(old_rv)
100+
self.rvs_to_total_sizes.pop(old_rv)
101+
new_marginalized_rv = new_rv
102+
continue
103+
104+
new_dependent_rvs.append(new_rv)
105+
if old_rv in self.free_RVs:
106+
index = self.free_RVs.index(old_rv)
107+
self.free_RVs.pop(index)
108+
self.free_RVs.insert(index, new_rv)
109+
self._initial_values[new_rv] = self._initial_values.pop(old_rv)
110+
else:
111+
index = self.observed_RVs.index(old_rv)
112+
self.observed_RVs.pop(index)
113+
self.observed_RVs.insert(index, new_rv)
114+
self.rvs_to_values[new_rv] = value
115+
self.named_vars[new_rv.name] = new_rv
116+
self.values_to_rvs[value] = new_rv
117+
self.rvs_to_transforms[new_rv] = self.rvs_to_transforms.pop(old_rv)
118+
# TODO: Automatic imputation RV does not seem to have total_size mapping
119+
self.rvs_to_total_sizes[new_rv] = self.rvs_to_total_sizes.pop(old_rv, None)
120+
121+
self.marginalized_rvs_to_dependent_rvs[new_marginalized_rv] = new_dependent_rvs
122+
return replacements
123+
124+
125+
def _find_dependent_rvs(dependable_rv, all_rvs):
126+
# Find rvs than depend on dependable
127+
dependent_rvs = []
128+
for rv in all_rvs:
129+
if rv is dependable_rv:
130+
continue
131+
blockers = [other_rv for other_rv in all_rvs if other_rv is not rv]
132+
if dependable_rv in ancestors([rv], blockers=blockers):
133+
dependent_rvs.append(rv)
134+
return dependent_rvs
135+
136+
137+
def _find_input_rvs(output_rvs, all_rvs):
138+
blockers = [other_rv for other_rv in all_rvs if other_rv not in output_rvs]
139+
return [
140+
var
141+
for var in ancestors(output_rvs, blockers=blockers)
142+
if var in blockers
143+
or (var.owner is None and not isinstance(var, (Constant, SharedVariable)))
144+
]
145+
146+
147+
def _is_elemwise_subgraph(rv_to_marginalize, other_input_rvs, output_rvs):
148+
# TODO: No need to consider apply nodes outside the subgraph...
149+
fg = FunctionGraph(outputs=output_rvs, clone=False)
150+
151+
non_elemwise_blockers = [
152+
o for node in fg.apply_nodes if not isinstance(node.op, Elemwise) for o in node.outputs
153+
]
154+
blocker_candidates = [rv_to_marginalize] + other_input_rvs + non_elemwise_blockers
155+
blockers = [var for var in blocker_candidates if var not in output_rvs]
156+
157+
# TODO: We could actually use these truncated inputs to
158+
# generate a smaller Marginalized graph...
159+
truncated_inputs = [
160+
var
161+
for var in ancestors(output_rvs, blockers=blockers)
162+
if (
163+
var in blockers
164+
or (var.owner is None and not isinstance(var, (Constant, SharedVariable)))
165+
)
166+
]
167+
168+
# Check that we reach the marginalized rv following a pure elemwise graph
169+
if rv_to_marginalize not in truncated_inputs:
170+
return False
171+
172+
# Check that none of the truncated inputs depends on the marginalized_rv
173+
other_truncated_inputs = [inp for inp in truncated_inputs if inp is not rv_to_marginalize]
174+
# TODO: We don't need to go all the way to the root variables
175+
if rv_to_marginalize in ancestors(
176+
other_truncated_inputs, blockers=[rv_to_marginalize, *other_input_rvs]
177+
):
178+
return False
179+
return True
180+
181+
182+
SUPPORTED_RNG_TYPES = (RandomStateSharedVariable, RandomGeneratorSharedVariable)
183+
184+
185+
class FiniteDiscreteMarginalRV(SymbolicRandomVariable):
186+
def __init__(self, *args, n_updates: int, **kwargs):
187+
self.n_updates = n_updates
188+
super().__init__(*args, **kwargs)
189+
190+
def update(self, node):
191+
n_updates = node.op.n_updates
192+
shared_rng_inputs = node.inputs[:n_updates]
193+
update_outputs = node.outputs[:n_updates]
194+
assert len(update_outputs) == len(shared_rng_inputs)
195+
# We made sure to pass RNG inputs and output updates in the same order
196+
return {inp: out for inp, out in zip(shared_rng_inputs, update_outputs)}
197+
198+
199+
def _collect_updates(rvs: Sequence[TensorVariable]) -> Dict[TensorVariable, TensorVariable]:
200+
rng_updates = {}
201+
for rv in rvs:
202+
if isinstance(rv.owner.op, RandomVariable):
203+
rng = rv.owner.inputs[0]
204+
assert not hasattr(rng, "default_update")
205+
rng_updates[rng] = rv.owner.outputs[0]
206+
elif isinstance(rv.owner.op, SymbolicRandomVariable):
207+
rng_updates.update(rv.owner.op.udpate(rv.owner))
208+
else:
209+
raise TypeError(f"Unknown RV type: {rv.owner.op}")
210+
assert all(isinstance(rng, SUPPORTED_RNG_TYPES) for rng in rng_updates.keys())
211+
return rng_updates
212+
213+
214+
def _replace_finite_discrete_marginal_subgraph(fgraph, rv_to_marginalize, rvs_to_values):
215+
# TODO: This should eventually be integrated in a more general routine that can
216+
# identify other types of supported marginalization, of which finite discrete
217+
# RVs is just one
218+
219+
dependent_rvs = _find_dependent_rvs(rv_to_marginalize, rvs_to_values)
220+
if not dependent_rvs:
221+
raise ValueError(f"No RVs depend on marginalized RV {rv_to_marginalize}")
222+
223+
marginalized_rv_input_rvs = _find_input_rvs([rv_to_marginalize], rvs_to_values)
224+
dependent_rvs_input_rvs = [
225+
rv for rv in _find_input_rvs(dependent_rvs, rvs_to_values) if rv is not rv_to_marginalize
226+
]
227+
228+
# If the marginalized RV has batched dimensions, check that graph between
229+
# marginalized RV and dependent RVs is composed strictly of Elemwise Operations.
230+
# This implies (?) that the dimensions are completely independent and a logp graph
231+
# can ultimately be generated that is proportional to the support domain and not
232+
# We don't need to worry about batched graphs if the RV is scalar.
233+
# TODO: This eval is a bit hackish
234+
if np.prod(rv_to_marginalize.shape.eval()) > 1:
235+
if not _is_elemwise_subgraph(rv_to_marginalize, dependent_rvs_input_rvs, dependent_rvs):
236+
raise NotImplementedError(
237+
"The subgraph between a marginalized RV and its dependents includes non Elemwise operations. "
238+
"This is currently not supported",
239+
)
240+
241+
input_rvs = [*marginalized_rv_input_rvs, *dependent_rvs_input_rvs]
242+
rvs_to_marginalize = [rv_to_marginalize, *dependent_rvs]
243+
244+
# Collect update expressions of the inner RVs.
245+
# Note: This could be avoided if we inlined the MarginalOp Graph before collecting
246+
# the updates in `pymc.aesaraf.compile_pymc`
247+
updates_rvs_to_marginalize = _collect_updates(rvs_to_marginalize)
248+
n_updates = len(updates_rvs_to_marginalize)
249+
assert n_updates
250+
251+
outputs = list(updates_rvs_to_marginalize.values()) + rvs_to_marginalize
252+
# Clone replace inner RV rng inputs so that we can be sure of the update order
253+
replace_inputs = {rng: rng.type() for rng in updates_rvs_to_marginalize.keys()}
254+
# Clone replace outter RV inputs, so that their shared RNGs don't make it into
255+
# the inner graph of the marginalized RVs
256+
replace_inputs.update({input_rv: input_rv.type() for input_rv in input_rvs})
257+
cloned_outputs = clone_replace(outputs, replace=replace_inputs)
258+
259+
marginalization_op = FiniteDiscreteMarginalRV(
260+
inputs=list(replace_inputs.values()),
261+
outputs=cloned_outputs,
262+
ndim_supp=-1, # This will certainly break stuff :D
263+
n_updates=n_updates,
264+
)
265+
marginalized_rvs = marginalization_op(*replace_inputs.keys())[n_updates:]
266+
fgraph.replace_all(tuple(zip(rvs_to_marginalize, marginalized_rvs)))
267+
return rvs_to_marginalize, marginalized_rvs
268+
269+
270+
@_get_measurable_outputs.register(FiniteDiscreteMarginalRV)
271+
def _get_measurable_outputs_finite_discrete_marginal_rv(op, node):
272+
# The Marginalized RV (first non-update output) is not measurable, nor are updates
273+
return node.outputs[op.n_updates + 1 :]
274+
275+
276+
@_moment.register(FiniteDiscreteMarginalRV)
277+
def moment_finite_discrete_marginal_rv(op, rv, *rv_inputs):
278+
# Recreate inner RV and retrieve its moment
279+
node = rv.owner
280+
marginalized_rv, *dependent_rvs = clone_replace(
281+
op.inner_outputs[op.n_updates :],
282+
replace={u: v for u, v in zip(op.inner_inputs, rv_inputs)},
283+
)
284+
rv_idx = node.outputs[op.n_updates + 1 :].index(rv)
285+
rv = dependent_rvs[rv_idx]
286+
287+
moment_marginalized_rv = moment(marginalized_rv)
288+
(rv,) = clone_replace([rv], replace={marginalized_rv: moment_marginalized_rv})
289+
return moment(rv)
290+
291+
292+
def _get_domain_of_finite_discrete_rv(rv: TensorVariable) -> Tuple[int, ...]:
293+
op = rv.owner.op
294+
if isinstance(op, Bernoulli):
295+
return (0, 1)
296+
elif isinstance(op, Categorical):
297+
p_param = rv.owner.inputs[3]
298+
return tuple(range(at.get_vector_length(p_param)))
299+
elif isinstance(op, DiscreteUniform):
300+
lower, upper = constant_fold(rv.owner.inputs[3:])
301+
return tuple(range(lower, upper + 1))
302+
303+
raise NotImplementedError(f"Cannot compute domain for op {op}")
304+
305+
306+
@_logprob.register(FiniteDiscreteMarginalRV)
307+
def finite_discrete_marginal_rv_logp(op, values, *inputs, **kwargs):
308+
309+
marginalized_rvs_node = op.make_node(*inputs)
310+
marginalized_rv, *dependent_rvs = clone_replace(
311+
op.inner_outputs[op.n_updates :],
312+
replace={u: v for u, v in zip(op.inner_inputs, marginalized_rvs_node.inputs)},
313+
)
314+
315+
# Some inputs are not root inputs (such as transformed projections of value variables)
316+
# Or cannot be used as inputs to an OpFromGraph (shared variables and constants)
317+
inputs = list(inputvars(inputs))
318+
319+
rvs_to_values = {}
320+
dummy_marginalized_value = marginalized_rv.clone()
321+
rvs_to_values[marginalized_rv] = dummy_marginalized_value
322+
rvs_to_values.update(zip(dependent_rvs, values))
323+
_logp = at.sum(
324+
[
325+
at.sum(factor)
326+
for factor in factorized_joint_logprob(rv_values=rvs_to_values, **kwargs).values()
327+
]
328+
)
329+
# OpFromGraph does not accept constant inputs...
330+
_values = [
331+
value
332+
for value in rvs_to_values.values()
333+
if not isinstance(value, (Constant, SharedVariable))
334+
]
335+
# TODO: If we inline the logp graph, optimization becomes incredibly painful for
336+
# large domains... Would be great to find a way to vectorize the graph across
337+
# the domain values (when possible)
338+
logp_op = OpFromGraph([*_values, *inputs], [_logp], inline=False)
339+
340+
# PyMC does not allow RVs in the logp graph... Even if we are just using the shape
341+
# TODO: Get better work-around that .eval(). It probably makes sense to do a constant
342+
# fold pass in the final logp graph, so that individual logp functions don't have
343+
# to worry about it
344+
marginalized_rv_shape = marginalized_rv.shape.eval()
345+
non_const_values = [
346+
value for value in values if not isinstance(value, (Constant, SharedVariable))
347+
]
348+
logp = at.logsumexp(
349+
[
350+
logp_op(
351+
np.full(marginalized_rv_shape, marginalized_rv_const), *non_const_values, *inputs
352+
)
353+
for marginalized_rv_const in _get_domain_of_finite_discrete_rv(marginalized_rv)
354+
]
355+
)
356+
# In the case of multiple dependent values, the whole logp is assigned just to the
357+
# first value. This is a quite hackish, but Aeppl errors out if some value variable
358+
# is not assigned a specific logp term, and it also does not make sense to separate
359+
# them internally.
360+
dummy_logps = (at.constant([], name="dummy_marginalized_logp"),) * (len(values) - 1)
361+
return logp, *dummy_logps

0 commit comments

Comments
 (0)