Skip to content

Fix transformed Scan values #6343

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Dec 10, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions pymc/logprob/scan.py
Original file line number Diff line number Diff line change
Expand Up @@ -459,8 +459,6 @@ def find_measurable_scans(fgraph, node):

# We're going to set those values on our `new_val_var` so that it can
# serve as a complete replacement for the old input `outer_input_var`.
# from pytensor.graph import clone_replace
#
new_val_var = outer_input_var.owner.clone_with_new_inputs(
[new_val_var] + outer_input_var.owner.inputs[1:]
).default_output()
Expand Down
92 changes: 89 additions & 3 deletions pymc/logprob/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,12 +42,13 @@
import pytensor.tensor as at

from pytensor.gradient import DisconnectedType, jacobian
from pytensor.graph.basic import Apply, Node, Variable
from pytensor.graph.basic import Apply, Node, Variable, clone_replace
from pytensor.graph.features import AlreadyThere, Feature
from pytensor.graph.fg import FunctionGraph
from pytensor.graph.op import Op
from pytensor.graph.rewriting.basic import GraphRewriter, in2out, node_rewriter
from pytensor.scalar import Add, Exp, Log, Mul
from pytensor.scan.op import Scan
from pytensor.tensor.math import add, exp, log, mul
from pytensor.tensor.rewriting.basic import (
register_specialize,
Expand Down Expand Up @@ -186,11 +187,94 @@ def transform_values(fgraph: FunctionGraph, node: Node) -> Optional[List[Node]]:
return trans_node.outputs


@node_rewriter(tracks=[Scan])
def transform_scan_values(fgraph: FunctionGraph, node: Node) -> Optional[List[Node]]:
"""Apply transforms to Scan value variables.

This specialized rewrite is needed because Scan replaces the original value variables
by a more complex graph. We want to apply the transform to the original value variable
in this subgraph, leaving the rest intact
"""

rv_map_feature: Optional[PreserveRVMappings] = getattr(fgraph, "preserve_rv_mappings", None)
values_to_transforms: Optional[TransformValuesMapping] = getattr(
fgraph, "values_to_transforms", None
)

if rv_map_feature is None or values_to_transforms is None:
return None # pragma: no cover

rv_vars = []
value_vars = []

for out in node.outputs:
value = rv_map_feature.rv_values.get(out, None)
if value is None:
continue
rv_vars.append(out)
value_vars.append(value)

if not value_vars:
return None

transforms = [
values_to_transforms.get(rv_map_feature.original_values[value], None)
for value_var in value_vars
]

if all(transform is None for transform in transforms):
return None

new_op = _create_transformed_rv_op(node.op, transforms)
trans_node = node.clone()
trans_node.op = new_op

# We now assume that the old value variable represents the *transformed space*.
# This means that we need to replace all instance of the old value variable
# with "inversely/un-" transformed versions of itself.
for rv_var, value_var, transform in zip(rv_vars, value_vars, transforms):
rv_var_out_idx = node.outputs.index(rv_var)
trans_node.outputs[rv_var_out_idx].name = rv_var.name

if transform is None:
continue

# We access the original value variable and apply the transform to that
original_value_var = rv_map_feature.original_values[value_var]
trans_original_value_var = transform.backward(original_value_var, *trans_node.inputs)

# We then replace the reference to the original value variable in the scan value
# variable by the back-transform projection computed above

# The first input corresponds to the original value variable. We are careful to
# only clone_replace that part of the graph, as we don't want to break the
# mappings between other rvs that are likely to be present in the rest of the
# scan value variable graph
# TODO: Is it true that the original value only appears in the first input
# and that no other RV can appear there?
(trans_original_value_var,) = clone_replace(
(value_var.owner.inputs[0],),
replace={original_value_var: trans_original_value_var},
)
trans_value_var = value_var.owner.clone_with_new_inputs(
inputs=[trans_original_value_var] + value_var.owner.inputs[1:]
).default_output()

new_value_var = transformed_variable(trans_value_var, original_value_var)

if value_var.name and getattr(transform, "name", None):
new_value_var.name = f"{value_var.name}_{transform.name}"

rv_map_feature.update_rv_maps(rv_var, new_value_var, trans_node.outputs[rv_var_out_idx])

return trans_node.outputs


class TransformValuesMapping(Feature):
r"""A `Feature` that maintains a map between value variables and their transforms."""

def __init__(self, values_to_transforms):
self.values_to_transforms = values_to_transforms
self.values_to_transforms = values_to_transforms.copy()

def on_attach(self, fgraph):
if hasattr(fgraph, "values_to_transforms"):
Expand All @@ -203,6 +287,7 @@ class TransformValuesRewrite(GraphRewriter):
r"""Transforms value variables according to a map."""

transform_rewrite = in2out(transform_values, ignore_newtrees=True)
scan_transform_rewrite = in2out(transform_scan_values, ignore_newtrees=True)

def __init__(
self,
Expand All @@ -226,7 +311,8 @@ def add_requirements(self, fgraph):
fgraph.attach_feature(values_transforms_feature)

def apply(self, fgraph: FunctionGraph):
return self.transform_rewrite.rewrite(fgraph)
self.transform_rewrite.rewrite(fgraph)
self.scan_transform_rewrite.rewrite(fgraph)


class MeasurableTransform(MeasurableElemwise):
Expand Down
54 changes: 54 additions & 0 deletions pymc/tests/logprob/test_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
from pytensor.compile.builders import OpFromGraph
from pytensor.graph.basic import equal_computations
from pytensor.graph.fg import FunctionGraph
from pytensor.scan import scan

from pymc.distributions.transforms import _default_transform, log, logodds
from pymc.logprob.abstract import MeasurableVariable, _get_measurable_outputs, _logprob
Expand Down Expand Up @@ -781,3 +782,56 @@ def test_invalid_broadcasted_transform_rv_fails():
logp = joint_logprob({y_rv: y_vv})
logp.eval({y_vv: [0, 0, 0, 0], loc: [0, 0, 0, 0]})
assert False, "Should have failed before"


def test_scan_transform():
"""Test that Scan valued variables can be transformed"""

init = at.random.beta(1, 1, name="init")
init_vv = init.clone()

innov, _ = scan(
fn=lambda prev_innov: at.random.beta(prev_innov * 10, (1 - prev_innov) * 10),
outputs_info=[init],
n_steps=4,
)
innov.name = "innov"
innov_vv = innov.clone()

tr = TransformValuesRewrite(
{
init_vv: LogOddsTransform(),
innov_vv: LogOddsTransform(),
}
)
logp = factorized_joint_logprob(
{init: init_vv, innov: innov_vv}, extra_rewrites=tr, use_jacobian=True
)[innov_vv]
logp_fn = pytensor.function([init_vv, innov_vv], logp, on_unused_input="ignore")

# Create an unrolled scan graph as reference
innov = []
prev_innov = init
for i in range(4):
next_innov = at.random.beta(prev_innov * 10, (1 - prev_innov) * 10, name=f"innov[i]")
innov.append(next_innov)
prev_innov = next_innov
innov = at.stack(innov)
innov.name = "innov"

tr = TransformValuesRewrite(
{
init_vv: LogOddsTransform(),
innov_vv: LogOddsTransform(),
}
)
ref_logp = factorized_joint_logprob(
{init: init_vv, innov: innov_vv}, extra_rewrites=tr, use_jacobian=True
)[innov_vv]
ref_logp_fn = pytensor.function([init_vv, innov_vv], ref_logp, on_unused_input="ignore")

test_point = {
"init": np.array(-0.5),
"innov": np.full((4,), -0.5),
}
np.testing.assert_allclose(logp_fn(**test_point), ref_logp_fn(**test_point))