|
| 1 | +from typing import Dict, Sequence, Tuple, Union |
| 2 | + |
| 3 | +import aesara.tensor as at |
| 4 | +import numpy as np |
| 5 | +from aeppl import factorized_joint_logprob |
| 6 | +from aeppl.abstract import _get_measurable_outputs |
| 7 | +from aeppl.logprob import _logprob |
| 8 | +from aesara import clone_replace |
| 9 | +from aesara.compile import SharedVariable |
| 10 | +from aesara.compile.builders import OpFromGraph |
| 11 | +from aesara.graph import Constant, FunctionGraph, ancestors |
| 12 | +from aesara.tensor import TensorVariable |
| 13 | +from aesara.tensor.elemwise import Elemwise |
| 14 | +from aesara.tensor.random.op import RandomVariable |
| 15 | +from aesara.tensor.random.var import ( |
| 16 | + RandomGeneratorSharedVariable, |
| 17 | + RandomStateSharedVariable, |
| 18 | +) |
| 19 | +from pymc import SymbolicRandomVariable |
| 20 | +from pymc.aesaraf import constant_fold, inputvars |
| 21 | +from pymc.distributions.discrete import Bernoulli, Categorical, DiscreteUniform |
| 22 | +from pymc.distributions.distribution import _moment, moment |
| 23 | +from pymc.model import Model |
| 24 | + |
| 25 | + |
| 26 | +class MarginalModel(Model): |
| 27 | + def __init__(self, *args, **kwargs): |
| 28 | + super().__init__(*args, **kwargs) |
| 29 | + if self.parent is not None: |
| 30 | + raise NotImplementedError("MarginalModel cannot be used inside another Model") |
| 31 | + else: |
| 32 | + self.marginalized_rvs_to_dependent_rvs = {} |
| 33 | + |
| 34 | + def logp(self, vars=None, **kwargs): |
| 35 | + if not kwargs.get("sum", True): |
| 36 | + # Check if dependent RVs were requested |
| 37 | + if vars is not None and not isinstance(vars, Sequence): |
| 38 | + vars = (vars,) |
| 39 | + if vars is None or ( |
| 40 | + {v for vs in self.marginalized_rvs_to_dependent_rvs.values() for v in vs} |
| 41 | + & {self.values_to_rvs.get(var, var) for var in vars} |
| 42 | + ): |
| 43 | + raise ValueError( |
| 44 | + "Cannot request elemwise logp (sum=False) for variables that depend on a marginalized RV" |
| 45 | + ) |
| 46 | + return super().logp(vars, **kwargs) |
| 47 | + |
| 48 | + def point_logps(self, *args, **kwargs): |
| 49 | + # TODO: Fix this |
| 50 | + return {} |
| 51 | + |
| 52 | + def marginalize(self, rvs_to_marginalize: Union[TensorVariable, Sequence[TensorVariable]]): |
| 53 | + # TODO: this does not need to be a property of a Model |
| 54 | + if not isinstance(rvs_to_marginalize, Sequence): |
| 55 | + rvs_to_marginalize = (rvs_to_marginalize,) |
| 56 | + |
| 57 | + supported_dists = (Bernoulli, Categorical, DiscreteUniform) |
| 58 | + for rv_to_marginalize in rvs_to_marginalize: |
| 59 | + if rv_to_marginalize not in self.free_RVs: |
| 60 | + raise ValueError( |
| 61 | + f"Marginalized RV {rv_to_marginalize} is not a free RV in the model" |
| 62 | + ) |
| 63 | + if not isinstance(rv_to_marginalize.owner.op, supported_dists): |
| 64 | + raise NotImplementedError( |
| 65 | + f"RV with distribution {rv_to_marginalize.owner.op} cannot be marginalized. " |
| 66 | + f"Supported distribution include {supported_dists}" |
| 67 | + ) |
| 68 | + |
| 69 | + if self.deterministics: |
| 70 | + # TODO: This should be fine if deterministics do not depend on marginalized RVs |
| 71 | + raise NotImplementedError("Models with deterministics cannot be marginalized") |
| 72 | + |
| 73 | + if self.potentials: |
| 74 | + raise NotImplementedError("Models with potentials cannot be marginalized") |
| 75 | + |
| 76 | + # Replaced with subgraph that need to be marginalized for each RV |
| 77 | + fg = FunctionGraph(outputs=self.basic_RVs, clone=False) |
| 78 | + toposort = fg.toposort() |
| 79 | + replacements = {} |
| 80 | + new_marginalized_rv = None |
| 81 | + new_dependent_rvs = [] |
| 82 | + for rv_to_marginalize in sorted( |
| 83 | + rvs_to_marginalize, key=lambda rv: toposort.index(rv.owner) |
| 84 | + ): |
| 85 | + old_rvs, new_rvs = _replace_finite_discrete_marginal_subgraph( |
| 86 | + fg, rv_to_marginalize, self.rvs_to_values |
| 87 | + ) |
| 88 | + # Update old mappings |
| 89 | + for old_rv, new_rv in zip(old_rvs, new_rvs): |
| 90 | + replacements[old_rv] = new_rv |
| 91 | + |
| 92 | + value = self.rvs_to_values.pop(old_rv) |
| 93 | + self.named_vars.pop(old_rv.name) |
| 94 | + new_rv.name = old_rv.name |
| 95 | + |
| 96 | + if old_rv is rv_to_marginalize: |
| 97 | + self.free_RVs.remove(old_rv) |
| 98 | + self.values_to_rvs.pop(value) |
| 99 | + self.rvs_to_transforms.pop(old_rv) |
| 100 | + self.rvs_to_total_sizes.pop(old_rv) |
| 101 | + new_marginalized_rv = new_rv |
| 102 | + continue |
| 103 | + |
| 104 | + new_dependent_rvs.append(new_rv) |
| 105 | + if old_rv in self.free_RVs: |
| 106 | + index = self.free_RVs.index(old_rv) |
| 107 | + self.free_RVs.pop(index) |
| 108 | + self.free_RVs.insert(index, new_rv) |
| 109 | + self._initial_values[new_rv] = self._initial_values.pop(old_rv) |
| 110 | + else: |
| 111 | + index = self.observed_RVs.index(old_rv) |
| 112 | + self.observed_RVs.pop(index) |
| 113 | + self.observed_RVs.insert(index, new_rv) |
| 114 | + self.rvs_to_values[new_rv] = value |
| 115 | + self.named_vars[new_rv.name] = new_rv |
| 116 | + self.values_to_rvs[value] = new_rv |
| 117 | + self.rvs_to_transforms[new_rv] = self.rvs_to_transforms.pop(old_rv) |
| 118 | + # TODO: Automatic imputation RV does not seem to have total_size mapping |
| 119 | + self.rvs_to_total_sizes[new_rv] = self.rvs_to_total_sizes.pop(old_rv, None) |
| 120 | + |
| 121 | + self.marginalized_rvs_to_dependent_rvs[new_marginalized_rv] = new_dependent_rvs |
| 122 | + return replacements |
| 123 | + |
| 124 | + |
| 125 | +def _find_dependent_rvs(dependable_rv, all_rvs): |
| 126 | + # Find rvs than depend on dependable |
| 127 | + dependent_rvs = [] |
| 128 | + for rv in all_rvs: |
| 129 | + if rv is dependable_rv: |
| 130 | + continue |
| 131 | + blockers = [other_rv for other_rv in all_rvs if other_rv is not rv] |
| 132 | + if dependable_rv in ancestors([rv], blockers=blockers): |
| 133 | + dependent_rvs.append(rv) |
| 134 | + return dependent_rvs |
| 135 | + |
| 136 | + |
| 137 | +def _find_input_rvs(output_rvs, all_rvs): |
| 138 | + blockers = [other_rv for other_rv in all_rvs if other_rv not in output_rvs] |
| 139 | + return [ |
| 140 | + var |
| 141 | + for var in ancestors(output_rvs, blockers=blockers) |
| 142 | + if var in blockers |
| 143 | + or (var.owner is None and not isinstance(var, (Constant, SharedVariable))) |
| 144 | + ] |
| 145 | + |
| 146 | + |
| 147 | +def _is_elemwise_subgraph(rv_to_marginalize, other_input_rvs, output_rvs): |
| 148 | + # TODO: No need to consider apply nodes outside the subgraph... |
| 149 | + fg = FunctionGraph(outputs=output_rvs, clone=False) |
| 150 | + |
| 151 | + non_elemwise_blockers = [ |
| 152 | + o for node in fg.apply_nodes if not isinstance(node.op, Elemwise) for o in node.outputs |
| 153 | + ] |
| 154 | + blocker_candidates = [rv_to_marginalize] + other_input_rvs + non_elemwise_blockers |
| 155 | + blockers = [var for var in blocker_candidates if var not in output_rvs] |
| 156 | + |
| 157 | + # TODO: We could actually use these truncated inputs to |
| 158 | + # generate a smaller Marginalized graph... |
| 159 | + truncated_inputs = [ |
| 160 | + var |
| 161 | + for var in ancestors(output_rvs, blockers=blockers) |
| 162 | + if ( |
| 163 | + var in blockers |
| 164 | + or (var.owner is None and not isinstance(var, (Constant, SharedVariable))) |
| 165 | + ) |
| 166 | + ] |
| 167 | + |
| 168 | + # Check that we reach the marginalized rv following a pure elemwise graph |
| 169 | + if rv_to_marginalize not in truncated_inputs: |
| 170 | + return False |
| 171 | + |
| 172 | + # Check that none of the truncated inputs depends on the marginalized_rv |
| 173 | + other_truncated_inputs = [inp for inp in truncated_inputs if inp is not rv_to_marginalize] |
| 174 | + # TODO: We don't need to go all the way to the root variables |
| 175 | + if rv_to_marginalize in ancestors( |
| 176 | + other_truncated_inputs, blockers=[rv_to_marginalize, *other_input_rvs] |
| 177 | + ): |
| 178 | + return False |
| 179 | + return True |
| 180 | + |
| 181 | + |
| 182 | +SUPPORTED_RNG_TYPES = (RandomStateSharedVariable, RandomGeneratorSharedVariable) |
| 183 | + |
| 184 | + |
| 185 | +class FiniteDiscreteMarginalRV(SymbolicRandomVariable): |
| 186 | + def __init__(self, *args, n_updates: int, **kwargs): |
| 187 | + self.n_updates = n_updates |
| 188 | + super().__init__(*args, **kwargs) |
| 189 | + |
| 190 | + def update(self, node): |
| 191 | + n_updates = node.op.n_updates |
| 192 | + shared_rng_inputs = node.inputs[:n_updates] |
| 193 | + update_outputs = node.outputs[:n_updates] |
| 194 | + assert len(update_outputs) == len(shared_rng_inputs) |
| 195 | + # We made sure to pass RNG inputs and output updates in the same order |
| 196 | + return {inp: out for inp, out in zip(shared_rng_inputs, update_outputs)} |
| 197 | + |
| 198 | + |
| 199 | +def _collect_updates(rvs: Sequence[TensorVariable]) -> Dict[TensorVariable, TensorVariable]: |
| 200 | + rng_updates = {} |
| 201 | + for rv in rvs: |
| 202 | + if isinstance(rv.owner.op, RandomVariable): |
| 203 | + rng = rv.owner.inputs[0] |
| 204 | + assert not hasattr(rng, "default_update") |
| 205 | + rng_updates[rng] = rv.owner.outputs[0] |
| 206 | + elif isinstance(rv.owner.op, SymbolicRandomVariable): |
| 207 | + rng_updates.update(rv.owner.op.udpate(rv.owner)) |
| 208 | + else: |
| 209 | + raise TypeError(f"Unknown RV type: {rv.owner.op}") |
| 210 | + assert all(isinstance(rng, SUPPORTED_RNG_TYPES) for rng in rng_updates.keys()) |
| 211 | + return rng_updates |
| 212 | + |
| 213 | + |
| 214 | +def _replace_finite_discrete_marginal_subgraph(fgraph, rv_to_marginalize, rvs_to_values): |
| 215 | + # TODO: This should eventually be integrated in a more general routine that can |
| 216 | + # identify other types of supported marginalization, of which finite discrete |
| 217 | + # RVs is just one |
| 218 | + |
| 219 | + dependent_rvs = _find_dependent_rvs(rv_to_marginalize, rvs_to_values) |
| 220 | + if not dependent_rvs: |
| 221 | + raise ValueError(f"No RVs depend on marginalized RV {rv_to_marginalize}") |
| 222 | + |
| 223 | + marginalized_rv_input_rvs = _find_input_rvs([rv_to_marginalize], rvs_to_values) |
| 224 | + dependent_rvs_input_rvs = [ |
| 225 | + rv for rv in _find_input_rvs(dependent_rvs, rvs_to_values) if rv is not rv_to_marginalize |
| 226 | + ] |
| 227 | + |
| 228 | + # If the marginalized RV has batched dimensions, check that graph between |
| 229 | + # marginalized RV and dependent RVs is composed strictly of Elemwise Operations. |
| 230 | + # This implies (?) that the dimensions are completely independent and a logp graph |
| 231 | + # can ultimately be generated that is proportional to the support domain and not |
| 232 | + # We don't need to worry about batched graphs if the RV is scalar. |
| 233 | + # TODO: This eval is a bit hackish |
| 234 | + if np.prod(rv_to_marginalize.shape.eval()) > 1: |
| 235 | + if not _is_elemwise_subgraph(rv_to_marginalize, dependent_rvs_input_rvs, dependent_rvs): |
| 236 | + raise NotImplementedError( |
| 237 | + "The subgraph between a marginalized RV and its dependents includes non Elemwise operations. " |
| 238 | + "This is currently not supported", |
| 239 | + ) |
| 240 | + |
| 241 | + input_rvs = [*marginalized_rv_input_rvs, *dependent_rvs_input_rvs] |
| 242 | + rvs_to_marginalize = [rv_to_marginalize, *dependent_rvs] |
| 243 | + |
| 244 | + # Collect update expressions of the inner RVs. |
| 245 | + # Note: This could be avoided if we inlined the MarginalOp Graph before collecting |
| 246 | + # the updates in `pymc.aesaraf.compile_pymc` |
| 247 | + updates_rvs_to_marginalize = _collect_updates(rvs_to_marginalize) |
| 248 | + n_updates = len(updates_rvs_to_marginalize) |
| 249 | + assert n_updates |
| 250 | + |
| 251 | + outputs = list(updates_rvs_to_marginalize.values()) + rvs_to_marginalize |
| 252 | + # Clone replace inner RV rng inputs so that we can be sure of the update order |
| 253 | + replace_inputs = {rng: rng.type() for rng in updates_rvs_to_marginalize.keys()} |
| 254 | + # Clone replace outter RV inputs, so that their shared RNGs don't make it into |
| 255 | + # the inner graph of the marginalized RVs |
| 256 | + replace_inputs.update({input_rv: input_rv.type() for input_rv in input_rvs}) |
| 257 | + cloned_outputs = clone_replace(outputs, replace=replace_inputs) |
| 258 | + |
| 259 | + marginalization_op = FiniteDiscreteMarginalRV( |
| 260 | + inputs=list(replace_inputs.values()), |
| 261 | + outputs=cloned_outputs, |
| 262 | + ndim_supp=-1, # This will certainly break stuff :D |
| 263 | + n_updates=n_updates, |
| 264 | + ) |
| 265 | + marginalized_rvs = marginalization_op(*replace_inputs.keys())[n_updates:] |
| 266 | + fgraph.replace_all(tuple(zip(rvs_to_marginalize, marginalized_rvs))) |
| 267 | + return rvs_to_marginalize, marginalized_rvs |
| 268 | + |
| 269 | + |
| 270 | +@_get_measurable_outputs.register(FiniteDiscreteMarginalRV) |
| 271 | +def _get_measurable_outputs_finite_discrete_marginal_rv(op, node): |
| 272 | + # The Marginalized RV (first non-update output) is not measurable, nor are updates |
| 273 | + return node.outputs[op.n_updates + 1 :] |
| 274 | + |
| 275 | + |
| 276 | +@_moment.register(FiniteDiscreteMarginalRV) |
| 277 | +def moment_finite_discrete_marginal_rv(op, rv, *rv_inputs): |
| 278 | + # Recreate inner RV and retrieve its moment |
| 279 | + node = rv.owner |
| 280 | + marginalized_rv, *dependent_rvs = clone_replace( |
| 281 | + op.inner_outputs[op.n_updates :], |
| 282 | + replace={u: v for u, v in zip(op.inner_inputs, rv_inputs)}, |
| 283 | + ) |
| 284 | + rv_idx = node.outputs[op.n_updates + 1 :].index(rv) |
| 285 | + rv = dependent_rvs[rv_idx] |
| 286 | + |
| 287 | + moment_marginalized_rv = moment(marginalized_rv) |
| 288 | + (rv,) = clone_replace([rv], replace={marginalized_rv: moment_marginalized_rv}) |
| 289 | + return moment(rv) |
| 290 | + |
| 291 | + |
| 292 | +def _get_domain_of_finite_discrete_rv(rv: TensorVariable) -> Tuple[int, ...]: |
| 293 | + op = rv.owner.op |
| 294 | + if isinstance(op, Bernoulli): |
| 295 | + return (0, 1) |
| 296 | + elif isinstance(op, Categorical): |
| 297 | + p_param = rv.owner.inputs[3] |
| 298 | + return tuple(range(at.get_vector_length(p_param))) |
| 299 | + elif isinstance(op, DiscreteUniform): |
| 300 | + lower, upper = constant_fold(rv.owner.inputs[3:]) |
| 301 | + return tuple(range(lower, upper + 1)) |
| 302 | + |
| 303 | + raise NotImplementedError(f"Cannot compute domain for op {op}") |
| 304 | + |
| 305 | + |
| 306 | +@_logprob.register(FiniteDiscreteMarginalRV) |
| 307 | +def finite_discrete_marginal_rv_logp(op, values, *inputs, **kwargs): |
| 308 | + |
| 309 | + marginalized_rvs_node = op.make_node(*inputs) |
| 310 | + marginalized_rv, *dependent_rvs = clone_replace( |
| 311 | + op.inner_outputs[op.n_updates :], |
| 312 | + replace={u: v for u, v in zip(op.inner_inputs, marginalized_rvs_node.inputs)}, |
| 313 | + ) |
| 314 | + |
| 315 | + # Some inputs are not root inputs (such as transformed projections of value variables) |
| 316 | + # Or cannot be used as inputs to an OpFromGraph (shared variables and constants) |
| 317 | + inputs = list(inputvars(inputs)) |
| 318 | + |
| 319 | + rvs_to_values = {} |
| 320 | + dummy_marginalized_value = marginalized_rv.clone() |
| 321 | + rvs_to_values[marginalized_rv] = dummy_marginalized_value |
| 322 | + rvs_to_values.update(zip(dependent_rvs, values)) |
| 323 | + _logp = at.sum( |
| 324 | + [ |
| 325 | + at.sum(factor) |
| 326 | + for factor in factorized_joint_logprob(rv_values=rvs_to_values, **kwargs).values() |
| 327 | + ] |
| 328 | + ) |
| 329 | + # OpFromGraph does not accept constant inputs... |
| 330 | + _values = [ |
| 331 | + value |
| 332 | + for value in rvs_to_values.values() |
| 333 | + if not isinstance(value, (Constant, SharedVariable)) |
| 334 | + ] |
| 335 | + # TODO: If we inline the logp graph, optimization becomes incredibly painful for |
| 336 | + # large domains... Would be great to find a way to vectorize the graph across |
| 337 | + # the domain values (when possible) |
| 338 | + logp_op = OpFromGraph([*_values, *inputs], [_logp], inline=False) |
| 339 | + |
| 340 | + # PyMC does not allow RVs in the logp graph... Even if we are just using the shape |
| 341 | + # TODO: Get better work-around that .eval(). It probably makes sense to do a constant |
| 342 | + # fold pass in the final logp graph, so that individual logp functions don't have |
| 343 | + # to worry about it |
| 344 | + marginalized_rv_shape = marginalized_rv.shape.eval() |
| 345 | + non_const_values = [ |
| 346 | + value for value in values if not isinstance(value, (Constant, SharedVariable)) |
| 347 | + ] |
| 348 | + logp = at.logsumexp( |
| 349 | + [ |
| 350 | + logp_op( |
| 351 | + np.full(marginalized_rv_shape, marginalized_rv_const), *non_const_values, *inputs |
| 352 | + ) |
| 353 | + for marginalized_rv_const in _get_domain_of_finite_discrete_rv(marginalized_rv) |
| 354 | + ] |
| 355 | + ) |
| 356 | + # In the case of multiple dependent values, the whole logp is assigned just to the |
| 357 | + # first value. This is a quite hackish, but Aeppl errors out if some value variable |
| 358 | + # is not assigned a specific logp term, and it also does not make sense to separate |
| 359 | + # them internally. |
| 360 | + dummy_logps = (at.constant([], name="dummy_marginalized_logp"),) * (len(values) - 1) |
| 361 | + return logp, *dummy_logps |
0 commit comments