diff --git a/pymc/logprob/scan.py b/pymc/logprob/scan.py index d24440dccf..32669db412 100644 --- a/pymc/logprob/scan.py +++ b/pymc/logprob/scan.py @@ -459,8 +459,6 @@ def find_measurable_scans(fgraph, node): # We're going to set those values on our `new_val_var` so that it can # serve as a complete replacement for the old input `outer_input_var`. - # from pytensor.graph import clone_replace - # new_val_var = outer_input_var.owner.clone_with_new_inputs( [new_val_var] + outer_input_var.owner.inputs[1:] ).default_output() diff --git a/pymc/logprob/transforms.py b/pymc/logprob/transforms.py index 6de89f8314..0e89635af6 100644 --- a/pymc/logprob/transforms.py +++ b/pymc/logprob/transforms.py @@ -42,12 +42,13 @@ import pytensor.tensor as at from pytensor.gradient import DisconnectedType, jacobian -from pytensor.graph.basic import Apply, Node, Variable +from pytensor.graph.basic import Apply, Node, Variable, clone_replace from pytensor.graph.features import AlreadyThere, Feature from pytensor.graph.fg import FunctionGraph from pytensor.graph.op import Op from pytensor.graph.rewriting.basic import GraphRewriter, in2out, node_rewriter from pytensor.scalar import Add, Exp, Log, Mul +from pytensor.scan.op import Scan from pytensor.tensor.math import add, exp, log, mul from pytensor.tensor.rewriting.basic import ( register_specialize, @@ -186,11 +187,94 @@ def transform_values(fgraph: FunctionGraph, node: Node) -> Optional[List[Node]]: return trans_node.outputs +@node_rewriter(tracks=[Scan]) +def transform_scan_values(fgraph: FunctionGraph, node: Node) -> Optional[List[Node]]: + """Apply transforms to Scan value variables. + + This specialized rewrite is needed because Scan replaces the original value variables + by a more complex graph. We want to apply the transform to the original value variable + in this subgraph, leaving the rest intact + """ + + rv_map_feature: Optional[PreserveRVMappings] = getattr(fgraph, "preserve_rv_mappings", None) + values_to_transforms: Optional[TransformValuesMapping] = getattr( + fgraph, "values_to_transforms", None + ) + + if rv_map_feature is None or values_to_transforms is None: + return None # pragma: no cover + + rv_vars = [] + value_vars = [] + + for out in node.outputs: + value = rv_map_feature.rv_values.get(out, None) + if value is None: + continue + rv_vars.append(out) + value_vars.append(value) + + if not value_vars: + return None + + transforms = [ + values_to_transforms.get(rv_map_feature.original_values[value], None) + for value_var in value_vars + ] + + if all(transform is None for transform in transforms): + return None + + new_op = _create_transformed_rv_op(node.op, transforms) + trans_node = node.clone() + trans_node.op = new_op + + # We now assume that the old value variable represents the *transformed space*. + # This means that we need to replace all instance of the old value variable + # with "inversely/un-" transformed versions of itself. + for rv_var, value_var, transform in zip(rv_vars, value_vars, transforms): + rv_var_out_idx = node.outputs.index(rv_var) + trans_node.outputs[rv_var_out_idx].name = rv_var.name + + if transform is None: + continue + + # We access the original value variable and apply the transform to that + original_value_var = rv_map_feature.original_values[value_var] + trans_original_value_var = transform.backward(original_value_var, *trans_node.inputs) + + # We then replace the reference to the original value variable in the scan value + # variable by the back-transform projection computed above + + # The first input corresponds to the original value variable. We are careful to + # only clone_replace that part of the graph, as we don't want to break the + # mappings between other rvs that are likely to be present in the rest of the + # scan value variable graph + # TODO: Is it true that the original value only appears in the first input + # and that no other RV can appear there? + (trans_original_value_var,) = clone_replace( + (value_var.owner.inputs[0],), + replace={original_value_var: trans_original_value_var}, + ) + trans_value_var = value_var.owner.clone_with_new_inputs( + inputs=[trans_original_value_var] + value_var.owner.inputs[1:] + ).default_output() + + new_value_var = transformed_variable(trans_value_var, original_value_var) + + if value_var.name and getattr(transform, "name", None): + new_value_var.name = f"{value_var.name}_{transform.name}" + + rv_map_feature.update_rv_maps(rv_var, new_value_var, trans_node.outputs[rv_var_out_idx]) + + return trans_node.outputs + + class TransformValuesMapping(Feature): r"""A `Feature` that maintains a map between value variables and their transforms.""" def __init__(self, values_to_transforms): - self.values_to_transforms = values_to_transforms + self.values_to_transforms = values_to_transforms.copy() def on_attach(self, fgraph): if hasattr(fgraph, "values_to_transforms"): @@ -203,6 +287,7 @@ class TransformValuesRewrite(GraphRewriter): r"""Transforms value variables according to a map.""" transform_rewrite = in2out(transform_values, ignore_newtrees=True) + scan_transform_rewrite = in2out(transform_scan_values, ignore_newtrees=True) def __init__( self, @@ -226,7 +311,8 @@ def add_requirements(self, fgraph): fgraph.attach_feature(values_transforms_feature) def apply(self, fgraph: FunctionGraph): - return self.transform_rewrite.rewrite(fgraph) + self.transform_rewrite.rewrite(fgraph) + self.scan_transform_rewrite.rewrite(fgraph) class MeasurableTransform(MeasurableElemwise): diff --git a/pymc/tests/logprob/test_transforms.py b/pymc/tests/logprob/test_transforms.py index 474f8e4bb7..8467f3ddba 100644 --- a/pymc/tests/logprob/test_transforms.py +++ b/pymc/tests/logprob/test_transforms.py @@ -45,6 +45,7 @@ from pytensor.compile.builders import OpFromGraph from pytensor.graph.basic import equal_computations from pytensor.graph.fg import FunctionGraph +from pytensor.scan import scan from pymc.distributions.transforms import _default_transform, log, logodds from pymc.logprob.abstract import MeasurableVariable, _get_measurable_outputs, _logprob @@ -781,3 +782,56 @@ def test_invalid_broadcasted_transform_rv_fails(): logp = joint_logprob({y_rv: y_vv}) logp.eval({y_vv: [0, 0, 0, 0], loc: [0, 0, 0, 0]}) assert False, "Should have failed before" + + +def test_scan_transform(): + """Test that Scan valued variables can be transformed""" + + init = at.random.beta(1, 1, name="init") + init_vv = init.clone() + + innov, _ = scan( + fn=lambda prev_innov: at.random.beta(prev_innov * 10, (1 - prev_innov) * 10), + outputs_info=[init], + n_steps=4, + ) + innov.name = "innov" + innov_vv = innov.clone() + + tr = TransformValuesRewrite( + { + init_vv: LogOddsTransform(), + innov_vv: LogOddsTransform(), + } + ) + logp = factorized_joint_logprob( + {init: init_vv, innov: innov_vv}, extra_rewrites=tr, use_jacobian=True + )[innov_vv] + logp_fn = pytensor.function([init_vv, innov_vv], logp, on_unused_input="ignore") + + # Create an unrolled scan graph as reference + innov = [] + prev_innov = init + for i in range(4): + next_innov = at.random.beta(prev_innov * 10, (1 - prev_innov) * 10, name=f"innov[i]") + innov.append(next_innov) + prev_innov = next_innov + innov = at.stack(innov) + innov.name = "innov" + + tr = TransformValuesRewrite( + { + init_vv: LogOddsTransform(), + innov_vv: LogOddsTransform(), + } + ) + ref_logp = factorized_joint_logprob( + {init: init_vv, innov: innov_vv}, extra_rewrites=tr, use_jacobian=True + )[innov_vv] + ref_logp_fn = pytensor.function([init_vv, innov_vv], ref_logp, on_unused_input="ignore") + + test_point = { + "init": np.array(-0.5), + "innov": np.full((4,), -0.5), + } + np.testing.assert_allclose(logp_fn(**test_point), ref_logp_fn(**test_point))