Skip to content

Commit 43cce2e

Browse files
committed
Implement Scan value transforms
1 parent 176b2dc commit 43cce2e

File tree

3 files changed

+143
-5
lines changed

3 files changed

+143
-5
lines changed

pymc/logprob/scan.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -459,8 +459,6 @@ def find_measurable_scans(fgraph, node):
459459

460460
# We're going to set those values on our `new_val_var` so that it can
461461
# serve as a complete replacement for the old input `outer_input_var`.
462-
# from pytensor.graph import clone_replace
463-
#
464462
new_val_var = outer_input_var.owner.clone_with_new_inputs(
465463
[new_val_var] + outer_input_var.owner.inputs[1:]
466464
).default_output()

pymc/logprob/transforms.py

Lines changed: 89 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -42,12 +42,13 @@
4242
import pytensor.tensor as at
4343

4444
from pytensor.gradient import DisconnectedType, jacobian
45-
from pytensor.graph.basic import Apply, Node, Variable
45+
from pytensor.graph.basic import Apply, Node, Variable, clone_replace
4646
from pytensor.graph.features import AlreadyThere, Feature
4747
from pytensor.graph.fg import FunctionGraph
4848
from pytensor.graph.op import Op
4949
from pytensor.graph.rewriting.basic import GraphRewriter, in2out, node_rewriter
5050
from pytensor.scalar import Add, Exp, Log, Mul
51+
from pytensor.scan.op import Scan
5152
from pytensor.tensor.math import add, exp, log, mul
5253
from pytensor.tensor.rewriting.basic import (
5354
register_specialize,
@@ -186,11 +187,94 @@ def transform_values(fgraph: FunctionGraph, node: Node) -> Optional[List[Node]]:
186187
return trans_node.outputs
187188

188189

190+
@node_rewriter(tracks=[Scan])
191+
def transform_scan_values(fgraph: FunctionGraph, node: Node) -> Optional[List[Node]]:
192+
"""Apply transforms to Scan value variables.
193+
194+
This specialized rewrite is needed because Scan replaces the original value variables
195+
by a more complex graph. We want to apply the transform to the original value variable
196+
in this subgraph, leaving the rest intact
197+
"""
198+
199+
rv_map_feature: Optional[PreserveRVMappings] = getattr(fgraph, "preserve_rv_mappings", None)
200+
values_to_transforms: Optional[TransformValuesMapping] = getattr(
201+
fgraph, "values_to_transforms", None
202+
)
203+
204+
if rv_map_feature is None or values_to_transforms is None:
205+
return None # pragma: no cover
206+
207+
rv_vars = []
208+
value_vars = []
209+
210+
for out in node.outputs:
211+
value = rv_map_feature.rv_values.get(out, None)
212+
if value is None:
213+
continue
214+
rv_vars.append(out)
215+
value_vars.append(value)
216+
217+
if not value_vars:
218+
return None
219+
220+
transforms = [
221+
values_to_transforms.get(rv_map_feature.original_values[value], None)
222+
for value_var in value_vars
223+
]
224+
225+
if all(transform is None for transform in transforms):
226+
return None
227+
228+
new_op = _create_transformed_rv_op(node.op, transforms)
229+
trans_node = node.clone()
230+
trans_node.op = new_op
231+
232+
# We now assume that the old value variable represents the *transformed space*.
233+
# This means that we need to replace all instance of the old value variable
234+
# with "inversely/un-" transformed versions of itself.
235+
for rv_var, value_var, transform in zip(rv_vars, value_vars, transforms):
236+
rv_var_out_idx = node.outputs.index(rv_var)
237+
trans_node.outputs[rv_var_out_idx].name = rv_var.name
238+
239+
if transform is None:
240+
continue
241+
242+
# We access the original value variable and apply the transform to that
243+
original_value_var = rv_map_feature.original_values[value_var]
244+
trans_original_value_var = transform.backward(original_value_var, *trans_node.inputs)
245+
246+
# We then replace the reference to the original value variable in the scan value
247+
# variable by the back-transform projection computed above
248+
249+
# The first input corresponds to the original value variable. We are careful to
250+
# only clone_replace that part of the graph, as we don't want to break the
251+
# mappings between other rvs that are likely to be present in the rest of the
252+
# scan value variable graph
253+
# TODO: Is it true that the original value only appears in the first input
254+
# and that no other RV can appear there?
255+
(trans_original_value_var,) = clone_replace(
256+
(value_var.owner.inputs[0],),
257+
replace={original_value_var: trans_original_value_var},
258+
)
259+
trans_value_var = value_var.owner.clone_with_new_inputs(
260+
inputs=[trans_original_value_var] + value_var.owner.inputs[1:]
261+
).default_output()
262+
263+
new_value_var = transformed_variable(trans_value_var, original_value_var)
264+
265+
if value_var.name and getattr(transform, "name", None):
266+
new_value_var.name = f"{value_var.name}_{transform.name}"
267+
268+
rv_map_feature.update_rv_maps(rv_var, new_value_var, trans_node.outputs[rv_var_out_idx])
269+
270+
return trans_node.outputs
271+
272+
189273
class TransformValuesMapping(Feature):
190274
r"""A `Feature` that maintains a map between value variables and their transforms."""
191275

192276
def __init__(self, values_to_transforms):
193-
self.values_to_transforms = values_to_transforms
277+
self.values_to_transforms = values_to_transforms.copy()
194278

195279
def on_attach(self, fgraph):
196280
if hasattr(fgraph, "values_to_transforms"):
@@ -203,6 +287,7 @@ class TransformValuesRewrite(GraphRewriter):
203287
r"""Transforms value variables according to a map."""
204288

205289
transform_rewrite = in2out(transform_values, ignore_newtrees=True)
290+
scan_transform_rewrite = in2out(transform_scan_values, ignore_newtrees=True)
206291

207292
def __init__(
208293
self,
@@ -226,7 +311,8 @@ def add_requirements(self, fgraph):
226311
fgraph.attach_feature(values_transforms_feature)
227312

228313
def apply(self, fgraph: FunctionGraph):
229-
return self.transform_rewrite.rewrite(fgraph)
314+
self.transform_rewrite.rewrite(fgraph)
315+
self.scan_transform_rewrite.rewrite(fgraph)
230316

231317

232318
class MeasurableTransform(MeasurableElemwise):

pymc/tests/logprob/test_transforms.py

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@
4545
from pytensor.compile.builders import OpFromGraph
4646
from pytensor.graph.basic import equal_computations
4747
from pytensor.graph.fg import FunctionGraph
48+
from pytensor.scan import scan
4849

4950
from pymc.distributions.transforms import _default_transform, log, logodds
5051
from pymc.logprob.abstract import MeasurableVariable, _get_measurable_outputs, _logprob
@@ -781,3 +782,56 @@ def test_invalid_broadcasted_transform_rv_fails():
781782
logp = joint_logprob({y_rv: y_vv})
782783
logp.eval({y_vv: [0, 0, 0, 0], loc: [0, 0, 0, 0]})
783784
assert False, "Should have failed before"
785+
786+
787+
def test_scan_transform():
788+
"""Test that Scan valued variables can be transformed"""
789+
790+
init = at.random.beta(1, 1, name="init")
791+
init_vv = init.clone()
792+
793+
innov, _ = scan(
794+
fn=lambda prev_innov: at.random.beta(prev_innov * 10, (1 - prev_innov) * 10),
795+
outputs_info=[init],
796+
n_steps=4,
797+
)
798+
innov.name = "innov"
799+
innov_vv = innov.clone()
800+
801+
tr = TransformValuesRewrite(
802+
{
803+
init_vv: LogOddsTransform(),
804+
innov_vv: LogOddsTransform(),
805+
}
806+
)
807+
logp = factorized_joint_logprob(
808+
{init: init_vv, innov: innov_vv}, extra_rewrites=tr, use_jacobian=True
809+
)[innov_vv]
810+
logp_fn = pytensor.function([init_vv, innov_vv], logp, on_unused_input="ignore")
811+
812+
# Create an unrolled scan graph as reference
813+
innov = []
814+
prev_innov = init
815+
for i in range(4):
816+
next_innov = at.random.beta(prev_innov * 10, (1 - prev_innov) * 10, name=f"innov[i]")
817+
innov.append(next_innov)
818+
prev_innov = next_innov
819+
innov = at.stack(innov)
820+
innov.name = "innov"
821+
822+
tr = TransformValuesRewrite(
823+
{
824+
init_vv: LogOddsTransform(),
825+
innov_vv: LogOddsTransform(),
826+
}
827+
)
828+
ref_logp = factorized_joint_logprob(
829+
{init: init_vv, innov: innov_vv}, extra_rewrites=tr, use_jacobian=True
830+
)[innov_vv]
831+
ref_logp_fn = pytensor.function([init_vv, innov_vv], ref_logp, on_unused_input="ignore")
832+
833+
test_point = {
834+
"init": np.array(-0.5),
835+
"innov": np.full((4,), -0.5),
836+
}
837+
np.testing.assert_allclose(logp_fn(**test_point), ref_logp_fn(**test_point))

0 commit comments

Comments
 (0)