Description
Description
This pertains to the logprob
submodule. During logprob derivation of an expression like
import numpy as np
import pymc as pm
x_raw = pm.Normal.dist(np.arange(5), shape=(2, 5))
x = pm.math.clip(x_raw, -1, 1) # Censored normal
pm.logp(x, np.zeros((2, 5)))
We create a MeasurableClip
that replaces x
, when we identify we can derive the logprob as a simple censored pdf. This MeasurableClip however does not retain any of the meta-information about the type of RV that it encapsulates (ndim_supp, dtype, support axis).
pymc/pymc/logprob/censoring.py
Lines 61 to 67 in a0d6ba0
If we wanted to further compose the graph, we would find issues when some operation needs to know this information
x_raw = pm.Normal.dist(np.arange(5), shape=(2, 5))
x = pm.math.clip(x_raw, -1, 1) # Censored normal<
x = x.T
pm.logp(x, np.zeros((5, 2))) # NotImplementedError: PyMC could not infer logp of input variable.
This happens because to infer the logprob of a transposed (dimshuffled) variable, we need to know the original support dimensionality and support axis (which is always the rightmost for pure distributions):
Lines 285 to 298 in a0d6ba0
If we propagated that information to the MeasurableClip (ndim_supp=0, support_axis=None, dtype="mixed"), the Dimshuffle
rewrite could be safely used and we could derive the logp for the second example. This is also useful for other rewrites...
More context in aesara-devs/aeppl#183