From af027ad113d5b6ea7133929c75830afe74d713a7 Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Sat, 9 Mar 2024 21:34:15 +0000 Subject: [PATCH] Implement specialized transformed logp dispatch --- pymc/distributions/multivariate.py | 14 ++++-- pymc/distributions/transforms.py | 11 ++--- pymc/initial_point.py | 2 +- pymc/logprob/abstract.py | 56 +++++++++++++++++++++++- pymc/logprob/basic.py | 2 +- pymc/logprob/transform_value.py | 45 +++++++++++++++---- pymc/logprob/transforms.py | 36 +-------------- pymc/logprob/utils.py | 6 +-- pymc/model/fgraph.py | 2 +- pymc/model/transform/conditioning.py | 2 +- tests/distributions/test_multivariate.py | 15 +++++++ tests/distributions/test_transform.py | 2 +- tests/logprob/test_transforms.py | 2 +- tests/test_util.py | 2 +- 14 files changed, 133 insertions(+), 64 deletions(-) diff --git a/pymc/distributions/multivariate.py b/pymc/distributions/multivariate.py index 956bca276d..3b7bed5112 100644 --- a/pymc/distributions/multivariate.py +++ b/pymc/distributions/multivariate.py @@ -69,7 +69,7 @@ to_tuple, ) from pymc.distributions.transforms import Interval, ZeroSumTransform, _default_transform -from pymc.logprob.abstract import _logprob +from pymc.logprob.abstract import _logprob, _transformed_logprob from pymc.math import kron_diag, kron_dot from pymc.pytensorf import intX from pymc.util import check_dist_not_registered @@ -2818,8 +2818,7 @@ def zerosumnormal_support_point(op, rv, *rv_inputs): @_default_transform.register(ZeroSumNormalRV) def zerosum_default_transform(op, rv): - n_zerosum_axes = tuple(np.arange(-op.ndim_supp, 0)) - return ZeroSumTransform(n_zerosum_axes) + return ZeroSumTransform(n_zerosum_axes=op.ndim_supp) @_logprob.register(ZeroSumNormalRV) @@ -2845,3 +2844,12 @@ def zerosumnormal_logp(op, values, normal_dist, sigma, support_shape, **kwargs): ) return check_parameters(out, *zerosums, msg="mean(value, axis=n_zerosum_axes) = 0") + + +@_transformed_logprob.register(ZeroSumNormalRV, ZeroSumTransform) +def transformed_zerosumnormal_logp(op, transform, unconstrained_value, rv_inputs): + _, sigma, _ = rv_inputs + zerosum_axes = transform.zerosum_axes + if len(zerosum_axes) != op.ndim_supp: + raise NotImplementedError + return pm.logp(Normal.dist(0, sigma), unconstrained_value).sum(zerosum_axes) diff --git a/pymc/distributions/transforms.py b/pymc/distributions/transforms.py index aeedceedd3..3db7960cf3 100644 --- a/pymc/distributions/transforms.py +++ b/pymc/distributions/transforms.py @@ -25,6 +25,7 @@ from pytensor.graph import Op from pytensor.tensor import TensorVariable +from pymc.logprob.abstract import Transform from pymc.logprob.transforms import ( ChainedTransform, CircularTransform, @@ -32,11 +33,9 @@ LogOddsTransform, LogTransform, SimplexTransform, - Transform, ) __all__ = [ - "Transform", "simplex", "logodds", "Interval", @@ -277,8 +276,10 @@ class ZeroSumTransform(Transform): __props__ = ("zerosum_axes",) - def __init__(self, zerosum_axes): - self.zerosum_axes = tuple(int(axis) for axis in zerosum_axes) + def __init__(self, n_zerosum_axes: int): + if not n_zerosum_axes > 0: + raise ValueError("Transform is only valid for n_zerosum_axes > 0") + self.zerosum_axes = tuple(range(-n_zerosum_axes, 0)) @staticmethod def extend_axis(array, axis): @@ -314,7 +315,7 @@ def backward(self, value, *rv_inputs): return value def log_jac_det(self, value, *rv_inputs): - return pt.constant(0.0) + return pt.zeros(value.shape[: -len(self.zerosum_axes)]) log_exp_m1 = LogExpM1() diff --git a/pymc/initial_point.py b/pymc/initial_point.py index f9d7855bbc..c839a1be7d 100644 --- a/pymc/initial_point.py +++ b/pymc/initial_point.py @@ -25,7 +25,7 @@ from pytensor.graph.fg import FunctionGraph from pytensor.tensor.variable import TensorVariable -from pymc.logprob.transforms import Transform +from pymc.logprob.abstract import Transform from pymc.pytensorf import compile_pymc, find_rng_nodes, replace_rng_nodes, reseed_rngs from pymc.util import get_transformed_name, get_untransformed_name, is_transformed_name diff --git a/pymc/logprob/abstract.py b/pymc/logprob/abstract.py index f35ab4c523..bae4508f7b 100644 --- a/pymc/logprob/abstract.py +++ b/pymc/logprob/abstract.py @@ -38,10 +38,15 @@ from collections.abc import Sequence from functools import singledispatch +from typing import Union +import multipledispatch +import pytensor.tensor as pt + +from pytensor.gradient import jacobian from pytensor.graph.op import Op from pytensor.graph.utils import MetaType -from pytensor.tensor import TensorVariable +from pytensor.tensor import TensorVariable, Variable from pytensor.tensor.elemwise import Elemwise from pytensor.tensor.random.op import RandomVariable @@ -153,3 +158,52 @@ def __init__(self, scalar_op, *args, **kwargs): MeasurableVariable.register(MeasurableElemwise) + + +class Transform(abc.ABC): + ndim_supp = None + + @abc.abstractmethod + def forward(self, value: TensorVariable, *inputs: Variable) -> TensorVariable: + """Apply the transformation.""" + + @abc.abstractmethod + def backward( + self, value: TensorVariable, *inputs: Variable + ) -> Union[TensorVariable, tuple[TensorVariable, ...]]: + """Invert the transformation. Multiple values may be returned when the + transformation is not 1-to-1""" + + def log_jac_det(self, value: TensorVariable, *inputs) -> TensorVariable: + """Construct the log of the absolute value of the Jacobian determinant.""" + if self.ndim_supp not in (0, 1): + raise NotImplementedError( + f"RVTransform default log_jac_det only implemented for ndim_supp in (0, 1), got {self.ndim_supp=}" + ) + if self.ndim_supp == 0: + jac = pt.reshape(pt.grad(pt.sum(self.backward(value, *inputs)), [value]), value.shape) + return pt.log(pt.abs(jac)) + else: + phi_inv = self.backward(value, *inputs) + return pt.log(pt.abs(pt.nlinalg.det(pt.atleast_2d(jacobian(phi_inv, [value])[0])))) + + def __str__(self): + return f"{self.__class__.__name__}" + + +@multipledispatch.dispatch(Op, Transform) +def _transformed_logprob( + op: Op, + transform: Transform, + unconstrained_value: TensorVariable, + rv_inputs: Sequence[TensorVariable], +): + """Create a graph for the log-density/mass of a transformed ``RandomVariable``. + + This function dispatches on the type of ``op``, which should be a subclass + of ``RandomVariable`` and ``transform``, which should be a subclass of ``Transform``. + + """ + raise NotImplementedError( + f"Transformed logprob method not implemented for {op} with transform {transform}" + ) diff --git a/pymc/logprob/basic.py b/pymc/logprob/basic.py index 446ef59355..6494109ec9 100644 --- a/pymc/logprob/basic.py +++ b/pymc/logprob/basic.py @@ -58,6 +58,7 @@ from pymc.logprob.abstract import ( MeasurableVariable, + Transform, _icdf_helper, _logcdf_helper, _logprob, @@ -65,7 +66,6 @@ ) from pymc.logprob.rewriting import cleanup_ir, construct_ir_fgraph from pymc.logprob.transform_value import TransformValuesRewrite -from pymc.logprob.transforms import Transform from pymc.logprob.utils import rvs_in_graph from pymc.pytensorf import replace_vars_in_graphs diff --git a/pymc/logprob/transform_value.py b/pymc/logprob/transform_value.py index 966d4b069a..48fe5d3585 100644 --- a/pymc/logprob/transform_value.py +++ b/pymc/logprob/transform_value.py @@ -23,12 +23,12 @@ from pytensor.graph.fg import FunctionGraph from pytensor.graph.replace import clone_replace from pytensor.graph.rewriting.basic import GraphRewriter, in2out, node_rewriter +from pytensor.graph.rewriting.db import RewriteDatabaseQuery, SequenceDB from pytensor.scan.op import Scan from pytensor.tensor.variable import TensorVariable -from pymc.logprob.abstract import MeasurableVariable, _logprob +from pymc.logprob.abstract import MeasurableVariable, Transform, _logprob, _transformed_logprob from pymc.logprob.rewriting import PreserveRVMappings, cleanup_ir_rewrites_db -from pymc.logprob.transforms import Transform class TransformedValue(Op): @@ -97,7 +97,26 @@ def transformed_value_logprob(op, values, *rv_outs, use_jacobian=True, **kwargs) This is introduced by the `TransformValuesRewrite` """ rv_op = rv_outs[0].owner.op + transforms = op.transforms rv_inputs = rv_outs[0].owner.inputs + + if use_jacobian and len(values) == 1 and len(transforms) == 1: + # Check if there's a specialized transform logp implemented + [value] = values + assert isinstance(value.owner.op, TransformedValue) + unconstrained_value = value.owner.inputs[1] + [transform] = transforms + try: + return _transformed_logprob( + rv_op, + transform, + unconstrained_value=unconstrained_value, + rv_inputs=rv_inputs, + **kwargs, + ) + except NotImplementedError: + pass + logprobs = _logprob(rv_op, values, *rv_inputs, **kwargs) if not isinstance(logprobs, Sequence): @@ -112,8 +131,8 @@ def transformed_value_logprob(op, values, *rv_outs, use_jacobian=True, **kwargs) continue assert isinstance(value.owner.op, TransformedValue) - original_forward_value = value.owner.inputs[1] - log_jac_det = transform.log_jac_det(original_forward_value, *rv_inputs).copy() + unconstrained_value = value.owner.inputs[1] + log_jac_det = transform.log_jac_det(unconstrained_value, *rv_inputs).copy() # The jacobian determinant has less dims than the logp # when a multivariate transform (like Simplex or Ordered) is applied to univariate distributions. # In this case we have to reduce the last logp dimensions, as they are no longer independent @@ -299,6 +318,17 @@ def transform_scan_values(fgraph: FunctionGraph, node: Apply) -> Optional[list[A return transformed_rv_node.outputs +transform_values_rewrites_db = SequenceDB() +transform_values_rewrites_db.name = "transform_values_rewrites_db" + +transform_values_rewrites_db.register( + "transform_values", in2out(transform_values, ignore_newtrees=True), "basic" +) +transform_values_rewrites_db.register( + "transform_scan_values", in2out(transform_scan_values, ignore_newtrees=True), "basic" +) + + class TransformValuesMapping(Feature): r"""A `Feature` that maintains a map between value variables and their transforms.""" @@ -315,9 +345,6 @@ def on_attach(self, fgraph): class TransformValuesRewrite(GraphRewriter): r"""Transforms value variables according to a map.""" - transform_rewrite = in2out(transform_values, ignore_newtrees=True) - scan_transform_rewrite = in2out(transform_scan_values, ignore_newtrees=True) - def __init__( self, values_to_transforms: dict[TensorVariable, Union[Transform, None]], @@ -340,8 +367,8 @@ def add_requirements(self, fgraph): fgraph.attach_feature(values_transforms_feature) def apply(self, fgraph: FunctionGraph): - self.transform_rewrite.rewrite(fgraph) - self.scan_transform_rewrite.rewrite(fgraph) + query = RewriteDatabaseQuery(include=["basic"]) + transform_values_rewrites_db.query(query).rewrite(fgraph) @node_rewriter([TransformedValue]) diff --git a/pymc/logprob/transforms.py b/pymc/logprob/transforms.py index 3702a97550..dcd4e84684 100644 --- a/pymc/logprob/transforms.py +++ b/pymc/logprob/transforms.py @@ -33,15 +33,13 @@ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE # SOFTWARE. -import abc -from typing import Callable, Optional, Union +from typing import Callable, Optional import numpy as np import pytensor.tensor as pt from pytensor import scan -from pytensor.gradient import jacobian from pytensor.graph.basic import Node, Variable from pytensor.graph.fg import FunctionGraph from pytensor.graph.rewriting.basic import node_rewriter @@ -109,6 +107,7 @@ from pymc.logprob.abstract import ( MeasurableElemwise, MeasurableVariable, + Transform, _icdf, _icdf_helper, _logcdf, @@ -124,37 +123,6 @@ ) -class Transform(abc.ABC): - ndim_supp = None - - @abc.abstractmethod - def forward(self, value: TensorVariable, *inputs: Variable) -> TensorVariable: - """Apply the transformation.""" - - @abc.abstractmethod - def backward( - self, value: TensorVariable, *inputs: Variable - ) -> Union[TensorVariable, tuple[TensorVariable, ...]]: - """Invert the transformation. Multiple values may be returned when the - transformation is not 1-to-1""" - - def log_jac_det(self, value: TensorVariable, *inputs) -> TensorVariable: - """Construct the log of the absolute value of the Jacobian determinant.""" - if self.ndim_supp not in (0, 1): - raise NotImplementedError( - f"RVTransform default log_jac_det only implemented for ndim_supp in (0, 1), got {self.ndim_supp=}" - ) - if self.ndim_supp == 0: - jac = pt.reshape(pt.grad(pt.sum(self.backward(value, *inputs)), [value]), value.shape) - return pt.log(pt.abs(jac)) - else: - phi_inv = self.backward(value, *inputs) - return pt.log(pt.abs(pt.nlinalg.det(pt.atleast_2d(jacobian(phi_inv, [value])[0])))) - - def __str__(self): - return f"{self.__class__.__name__}" - - class MeasurableTransform(MeasurableElemwise): """A placeholder used to specify a log-likelihood for a transformed measurable variable""" diff --git a/pymc/logprob/utils.py b/pymc/logprob/utils.py index 49827f7a61..a464f8f9a1 100644 --- a/pymc/logprob/utils.py +++ b/pymc/logprob/utils.py @@ -33,7 +33,6 @@ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE # SOFTWARE. -import typing import warnings from collections.abc import Container, Sequence @@ -56,13 +55,10 @@ from pytensor.tensor.random.op import RandomVariable from pytensor.tensor.variable import TensorVariable -from pymc.logprob.abstract import MeasurableVariable, _logprob +from pymc.logprob.abstract import MeasurableVariable, Transform, _logprob from pymc.pytensorf import replace_vars_in_graphs from pymc.util import makeiter -if typing.TYPE_CHECKING: - from pymc.logprob.transforms import Transform - def replace_rvs_by_values( graphs: Sequence[TensorVariable], diff --git a/pymc/model/fgraph.py b/pymc/model/fgraph.py index 48903c9b72..24dacf1e2b 100644 --- a/pymc/model/fgraph.py +++ b/pymc/model/fgraph.py @@ -23,7 +23,7 @@ from pytensor.scalar import Identity from pytensor.tensor.elemwise import Elemwise -from pymc.logprob.transforms import Transform +from pymc.logprob.abstract import Transform from pymc.model.core import Model from pymc.pytensorf import StringType, find_rng_nodes, toposort_replace diff --git a/pymc/model/transform/conditioning.py b/pymc/model/transform/conditioning.py index b321007c68..b27c516bb3 100644 --- a/pymc/model/transform/conditioning.py +++ b/pymc/model/transform/conditioning.py @@ -20,7 +20,7 @@ from pytensor.tensor import TensorVariable from pymc import Model -from pymc.logprob.transforms import Transform +from pymc.logprob.abstract import Transform from pymc.logprob.utils import rvs_in_graph from pymc.model.fgraph import ( ModelDeterministic, diff --git a/tests/distributions/test_multivariate.py b/tests/distributions/test_multivariate.py index 0541d8f497..b5b2c9998e 100644 --- a/tests/distributions/test_multivariate.py +++ b/tests/distributions/test_multivariate.py @@ -1770,6 +1770,21 @@ def test_batched_sigma(self): sigma=batch_test_sigma[None, :, None], n_zerosum_axes=2, support_shape=(3, 2) ) + def test_transformed_logprob(self): + with pm.Model() as m: + x = pm.ZeroSumNormal("x", sigma=np.pi, shape=(5, 3), n_zerosum_axes=1) + pytensor.dprint(m.compile_logp().f) + + [transformed_logp] = m.logp(sum=False) + + unconstrained_value = m.rvs_to_values[x] + transform = m.rvs_to_transforms[x] + constrained_value = transform.backward(unconstrained_value) + reference_logp = pm.logp(x, constrained_value) + + test_dict = {unconstrained_value: pm.draw(transform.forward(x))} + np.testing.assert_allclose(transformed_logp.eval(test_dict), reference_logp.eval(test_dict)) + class TestMvStudentTCov(BaseTestDistributionRandom): def mvstudentt_rng_fn(self, size, nu, mu, scale, rng): diff --git a/tests/distributions/test_transform.py b/tests/distributions/test_transform.py index b0187a4ebe..4b58c4a9e3 100644 --- a/tests/distributions/test_transform.py +++ b/tests/distributions/test_transform.py @@ -24,8 +24,8 @@ import pymc as pm import pymc.distributions.transforms as tr +from pymc.logprob.abstract import Transform from pymc.logprob.basic import transformed_conditional_logp -from pymc.logprob.transforms import Transform from pymc.pytensorf import floatX, jacobian from pymc.testing import ( Circ, diff --git a/tests/logprob/test_transforms.py b/tests/logprob/test_transforms.py index acf7296f47..5bdad178f8 100644 --- a/tests/logprob/test_transforms.py +++ b/tests/logprob/test_transforms.py @@ -45,6 +45,7 @@ from pymc.distributions.continuous import Cauchy, ChiSquared from pymc.distributions.discrete import Bernoulli +from pymc.logprob.abstract import Transform from pymc.logprob.basic import conditional_logp, icdf, logcdf, logp from pymc.logprob.transforms import ( ArccoshTransform, @@ -61,7 +62,6 @@ ScaleTransform, SinhTransform, TanhTransform, - Transform, ) from pymc.logprob.utils import ParameterValueError from pymc.testing import Rplusbig, Vector, assert_no_rvs diff --git a/tests/test_util.py b/tests/test_util.py index 61d916249e..7a8dbffbfe 100644 --- a/tests/test_util.py +++ b/tests/test_util.py @@ -22,7 +22,7 @@ import pymc as pm -from pymc.distributions.transforms import Transform +from pymc.logprob.abstract import Transform from pymc.util import ( UNSET, _get_seeds_per_chain,