|
17 | 17 | from typing import Dict, List, Optional, Sequence, Union
|
18 | 18 |
|
19 | 19 | import aesara
|
20 |
| -import aesara.tensor as at |
21 | 20 | import numpy as np
|
22 | 21 |
|
23 |
| -from aeppl import factorized_joint_logprob |
| 22 | +from aeppl import factorized_joint_logprob, logprob |
24 | 23 | from aeppl.abstract import assign_custom_measurable_outputs
|
| 24 | +from aeppl.logprob import _logprob |
25 | 25 | from aeppl.logprob import logcdf as logcdf_aeppl
|
26 | 26 | from aeppl.logprob import logprob as logp_aeppl
|
| 27 | +from aeppl.tensor import MeasurableJoin |
27 | 28 | from aeppl.transforms import TransformValuesRewrite
|
| 29 | +from aesara import tensor as at |
| 30 | +from aesara.graph import FunctionGraph, rewrite_graph |
28 | 31 | from aesara.graph.basic import graph_inputs, io_toposort
|
29 | 32 | from aesara.tensor.random.op import RandomVariable
|
| 33 | +from aesara.tensor.rewriting.basic import ShapeFeature, topo_constant_folding |
30 | 34 | from aesara.tensor.subtensor import (
|
31 | 35 | AdvancedIncSubtensor,
|
32 | 36 | AdvancedIncSubtensor1,
|
@@ -320,3 +324,51 @@ def ignore_logprob(rv: TensorVariable) -> TensorVariable:
|
320 | 324 | return rv
|
321 | 325 | new_node = assign_custom_measurable_outputs(node, type_prefix=prefix)
|
322 | 326 | return new_node.outputs[node.outputs.index(rv)]
|
| 327 | + |
| 328 | + |
| 329 | +@_logprob.register(MeasurableJoin) |
| 330 | +def logprob_join_constant_shapes(op, values, axis, *base_vars, **kwargs): |
| 331 | + """Compute the log-likelihood graph for a `Join`. |
| 332 | +
|
| 333 | + This overrides the implementation in Aeppl, to constant fold the shapes |
| 334 | + of the base vars so that RandomVariables do not show up in the logp graph, |
| 335 | + which is a requirement enforced by `pymc.distributions.logprob.joint_logp` |
| 336 | + """ |
| 337 | + (value,) = values |
| 338 | + |
| 339 | + base_var_shapes = [base_var.shape[axis] for base_var in base_vars] |
| 340 | + |
| 341 | + shape_fg = FunctionGraph( |
| 342 | + outputs=base_var_shapes, |
| 343 | + features=[ShapeFeature()], |
| 344 | + clone=True, |
| 345 | + ) |
| 346 | + base_var_shapes = rewrite_graph(shape_fg, custom_opt=topo_constant_folding).outputs |
| 347 | + |
| 348 | + split_values = at.split( |
| 349 | + value, |
| 350 | + splits_size=[base_var_shape for base_var_shape in base_var_shapes], |
| 351 | + n_splits=len(base_vars), |
| 352 | + axis=axis, |
| 353 | + ) |
| 354 | + |
| 355 | + logps = [ |
| 356 | + logprob(base_var, split_value) for base_var, split_value in zip(base_vars, split_values) |
| 357 | + ] |
| 358 | + |
| 359 | + if len({logp.ndim for logp in logps}) != 1: |
| 360 | + raise ValueError( |
| 361 | + "Joined logps have different number of dimensions, this can happen when " |
| 362 | + "joining univariate and multivariate distributions", |
| 363 | + ) |
| 364 | + |
| 365 | + base_vars_ndim_supp = split_values[0].ndim - logps[0].ndim |
| 366 | + join_logprob = at.concatenate( |
| 367 | + [ |
| 368 | + at.atleast_1d(logprob(base_var, split_value)) |
| 369 | + for base_var, split_value in zip(base_vars, split_values) |
| 370 | + ], |
| 371 | + axis=axis - base_vars_ndim_supp, |
| 372 | + ) |
| 373 | + |
| 374 | + return join_logprob |
0 commit comments