Skip to content

Commit 0cc81b9

Browse files
committed
Convert GaussianRandomWalk into a SymbolicRandomVariable
* Implements distribution agnostic univariate RandomWalk
1 parent 4eeac62 commit 0cc81b9

File tree

5 files changed

+211
-199
lines changed

5 files changed

+211
-199
lines changed

pymc/distributions/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -109,6 +109,7 @@
109109
GaussianRandomWalk,
110110
MvGaussianRandomWalk,
111111
MvStudentTRandomWalk,
112+
RandomWalk,
112113
)
113114

114115
__all__ = [
@@ -173,6 +174,7 @@
173174
"LKJCholeskyCov",
174175
"LKJCorr",
175176
"AsymmetricLaplace",
177+
"RandomWalk",
176178
"GaussianRandomWalk",
177179
"MvGaussianRandomWalk",
178180
"MvStudentTRandomWalk",

pymc/distributions/logprob.py

Lines changed: 54 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,16 +17,20 @@
1717
from typing import Dict, List, Optional, Sequence, Union
1818

1919
import aesara
20-
import aesara.tensor as at
2120
import numpy as np
2221

23-
from aeppl import factorized_joint_logprob
22+
from aeppl import factorized_joint_logprob, logprob
2423
from aeppl.abstract import assign_custom_measurable_outputs
24+
from aeppl.logprob import _logprob
2525
from aeppl.logprob import logcdf as logcdf_aeppl
2626
from aeppl.logprob import logprob as logp_aeppl
27+
from aeppl.tensor import MeasurableJoin
2728
from aeppl.transforms import TransformValuesRewrite
29+
from aesara import tensor as at
30+
from aesara.graph import FunctionGraph, rewrite_graph
2831
from aesara.graph.basic import graph_inputs, io_toposort
2932
from aesara.tensor.random.op import RandomVariable
33+
from aesara.tensor.rewriting.basic import ShapeFeature, topo_constant_folding
3034
from aesara.tensor.subtensor import (
3135
AdvancedIncSubtensor,
3236
AdvancedIncSubtensor1,
@@ -320,3 +324,51 @@ def ignore_logprob(rv: TensorVariable) -> TensorVariable:
320324
return rv
321325
new_node = assign_custom_measurable_outputs(node, type_prefix=prefix)
322326
return new_node.outputs[node.outputs.index(rv)]
327+
328+
329+
@_logprob.register(MeasurableJoin)
330+
def logprob_join_constant_shapes(op, values, axis, *base_vars, **kwargs):
331+
"""Compute the log-likelihood graph for a `Join`.
332+
333+
This overrides the implementation in Aeppl, to constant fold the shapes
334+
of the base vars so that RandomVariables do not show up in the logp graph,
335+
which is a requirement enforced by `pymc.distributions.logprob.joint_logp`
336+
"""
337+
(value,) = values
338+
339+
base_var_shapes = [base_var.shape[axis] for base_var in base_vars]
340+
341+
shape_fg = FunctionGraph(
342+
outputs=base_var_shapes,
343+
features=[ShapeFeature()],
344+
clone=True,
345+
)
346+
base_var_shapes = rewrite_graph(shape_fg, custom_opt=topo_constant_folding).outputs
347+
348+
split_values = at.split(
349+
value,
350+
splits_size=[base_var_shape for base_var_shape in base_var_shapes],
351+
n_splits=len(base_vars),
352+
axis=axis,
353+
)
354+
355+
logps = [
356+
logprob(base_var, split_value) for base_var, split_value in zip(base_vars, split_values)
357+
]
358+
359+
if len({logp.ndim for logp in logps}) != 1:
360+
raise ValueError(
361+
"Joined logps have different number of dimensions, this can happen when "
362+
"joining univariate and multivariate distributions",
363+
)
364+
365+
base_vars_ndim_supp = split_values[0].ndim - logps[0].ndim
366+
join_logprob = at.concatenate(
367+
[
368+
at.atleast_1d(logprob(base_var, split_value))
369+
for base_var, split_value in zip(base_vars, split_values)
370+
],
371+
axis=axis - base_vars_ndim_supp,
372+
)
373+
374+
return join_logprob

0 commit comments

Comments
 (0)