Skip to content

Incorporate RV meta information in intermediate MeasurableVariables #6360

Open
@ricardoV94

Description

@ricardoV94

Description

This pertains to the logprob submodule. During logprob derivation of an expression like

import numpy as np
import pymc as pm

x_raw = pm.Normal.dist(np.arange(5), shape=(2, 5))
x = pm.math.clip(x_raw, -1, 1)  # Censored normal

pm.logp(x, np.zeros((2, 5)))

We create a MeasurableClip that replaces x, when we identify we can derive the logprob as a simple censored pdf. This MeasurableClip however does not retain any of the meta-information about the type of RV that it encapsulates (ndim_supp, dtype, support axis).

class MeasurableClip(MeasurableElemwise):
"""A placeholder used to specify a log-likelihood for a clipped RV sub-graph."""
valid_scalar_types = (Clip,)
measurable_clip = MeasurableClip(scalar_clip)

If we wanted to further compose the graph, we would find issues when some operation needs to know this information

x_raw = pm.Normal.dist(np.arange(5), shape=(2, 5))
x = pm.math.clip(x_raw, -1, 1)  # Censored normal<
x = x.T

pm.logp(x, np.zeros((5, 2)))  # NotImplementedError: PyMC could not infer logp of input variable.

This happens because to infer the logprob of a transposed (dimshuffled) variable, we need to know the original support dimensionality and support axis (which is always the rightmost for pure distributions):

pymc/pymc/logprob/tensor.py

Lines 285 to 298 in a0d6ba0

# We can only apply this rewrite directly to `RandomVariable`s, as those are
# the only `Op`s for which we always know the support axis. Other measurable
# variables can have arbitrary support axes (e.g., if they contain separate
# `MeasurableDimShuffle`s). Most measurable variables with `DimShuffle`s
# should still be supported as long as the `DimShuffle`s can be merged/
# lifted towards the base RandomVariable.
# TODO: If we include the support axis as meta information in each
# intermediate MeasurableVariable, we can lift this restriction.
if not (
base_var.owner
and isinstance(base_var.owner.op, RandomVariable)
and base_var not in rv_map_feature.rv_values
):
return None # pragma: no cover

If we propagated that information to the MeasurableClip (ndim_supp=0, support_axis=None, dtype="mixed"), the Dimshuffle rewrite could be safely used and we could derive the logp for the second example. This is also useful for other rewrites...

More context in aesara-devs/aeppl#183

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions