42
42
import pytensor .tensor as at
43
43
44
44
from pytensor .gradient import DisconnectedType , jacobian
45
- from pytensor .graph .basic import Apply , Node , Variable
45
+ from pytensor .graph .basic import Apply , Node , Variable , clone_replace
46
46
from pytensor .graph .features import AlreadyThere , Feature
47
47
from pytensor .graph .fg import FunctionGraph
48
48
from pytensor .graph .op import Op
49
49
from pytensor .graph .rewriting .basic import GraphRewriter , in2out , node_rewriter
50
50
from pytensor .scalar import Add , Exp , Log , Mul
51
+ from pytensor .scan .op import Scan
51
52
from pytensor .tensor .math import add , exp , log , mul
52
53
from pytensor .tensor .rewriting .basic import (
53
54
register_specialize ,
@@ -186,11 +187,94 @@ def transform_values(fgraph: FunctionGraph, node: Node) -> Optional[List[Node]]:
186
187
return trans_node .outputs
187
188
188
189
190
+ @node_rewriter (tracks = [Scan ])
191
+ def transform_scan_values (fgraph : FunctionGraph , node : Node ) -> Optional [List [Node ]]:
192
+ """Apply transforms to Scan value variables.
193
+
194
+ This specialized rewrite is needed because Scan replaces the original value variables
195
+ by a more complex graph. We want to apply the transform to the original value variable
196
+ in this subgraph, leaving the rest intact
197
+ """
198
+
199
+ rv_map_feature : Optional [PreserveRVMappings ] = getattr (fgraph , "preserve_rv_mappings" , None )
200
+ values_to_transforms : Optional [TransformValuesMapping ] = getattr (
201
+ fgraph , "values_to_transforms" , None
202
+ )
203
+
204
+ if rv_map_feature is None or values_to_transforms is None :
205
+ return None # pragma: no cover
206
+
207
+ rv_vars = []
208
+ value_vars = []
209
+
210
+ for out in node .outputs :
211
+ value = rv_map_feature .rv_values .get (out , None )
212
+ if value is None :
213
+ continue
214
+ rv_vars .append (out )
215
+ value_vars .append (value )
216
+
217
+ if not value_vars :
218
+ return None
219
+
220
+ transforms = [
221
+ values_to_transforms .get (rv_map_feature .original_values [value ], None )
222
+ for value_var in value_vars
223
+ ]
224
+
225
+ if all (transform is None for transform in transforms ):
226
+ return None
227
+
228
+ new_op = _create_transformed_rv_op (node .op , transforms )
229
+ trans_node = node .clone ()
230
+ trans_node .op = new_op
231
+
232
+ # We now assume that the old value variable represents the *transformed space*.
233
+ # This means that we need to replace all instance of the old value variable
234
+ # with "inversely/un-" transformed versions of itself.
235
+ for rv_var , value_var , transform in zip (rv_vars , value_vars , transforms ):
236
+ rv_var_out_idx = node .outputs .index (rv_var )
237
+ trans_node .outputs [rv_var_out_idx ].name = rv_var .name
238
+
239
+ if transform is None :
240
+ continue
241
+
242
+ # We access the original value variable and apply the transform to that
243
+ original_value_var = rv_map_feature .original_values [value_var ]
244
+ trans_original_value_var = transform .backward (original_value_var , * trans_node .inputs )
245
+
246
+ # We then replace the reference to the original value variable in the scan value
247
+ # variable by the back-transform projection computed above
248
+
249
+ # The first input corresponds to the original value variable. We are careful to
250
+ # only clone_replace that part of the graph, as we don't want to break the
251
+ # mappings between other rvs that are likely to be present in the rest of the
252
+ # scan value variable graph
253
+ # TODO: Is it true that the original value only appears in the first input
254
+ # and that no other RV can appear there?
255
+ (trans_original_value_var ,) = clone_replace (
256
+ (value_var .owner .inputs [0 ],),
257
+ replace = {original_value_var : trans_original_value_var },
258
+ )
259
+ trans_value_var = value_var .owner .clone_with_new_inputs (
260
+ inputs = [trans_original_value_var ] + value_var .owner .inputs [1 :]
261
+ ).default_output ()
262
+
263
+ new_value_var = transformed_variable (trans_value_var , original_value_var )
264
+
265
+ if value_var .name and getattr (transform , "name" , None ):
266
+ new_value_var .name = f"{ value_var .name } _{ transform .name } "
267
+
268
+ rv_map_feature .update_rv_maps (rv_var , new_value_var , trans_node .outputs [rv_var_out_idx ])
269
+
270
+ return trans_node .outputs
271
+
272
+
189
273
class TransformValuesMapping (Feature ):
190
274
r"""A `Feature` that maintains a map between value variables and their transforms."""
191
275
192
276
def __init__ (self , values_to_transforms ):
193
- self .values_to_transforms = values_to_transforms
277
+ self .values_to_transforms = values_to_transforms . copy ()
194
278
195
279
def on_attach (self , fgraph ):
196
280
if hasattr (fgraph , "values_to_transforms" ):
@@ -203,6 +287,7 @@ class TransformValuesRewrite(GraphRewriter):
203
287
r"""Transforms value variables according to a map."""
204
288
205
289
transform_rewrite = in2out (transform_values , ignore_newtrees = True )
290
+ scan_transform_rewrite = in2out (transform_scan_values , ignore_newtrees = True )
206
291
207
292
def __init__ (
208
293
self ,
@@ -226,7 +311,8 @@ def add_requirements(self, fgraph):
226
311
fgraph .attach_feature (values_transforms_feature )
227
312
228
313
def apply (self , fgraph : FunctionGraph ):
229
- return self .transform_rewrite .rewrite (fgraph )
314
+ self .transform_rewrite .rewrite (fgraph )
315
+ self .scan_transform_rewrite .rewrite (fgraph )
230
316
231
317
232
318
class MeasurableTransform (MeasurableElemwise ):
0 commit comments