37
37
import abc
38
38
39
39
from copy import copy
40
- from typing import Callable , Dict , List , Optional , Tuple , Union
40
+ from typing import Callable , Dict , List , Optional , Sequence , Tuple , Union
41
41
42
42
import pytensor .tensor as at
43
43
@@ -133,43 +133,55 @@ def transform_values(fgraph: FunctionGraph, node: Node) -> Optional[List[Node]]:
133
133
``Y`` on the natural scale.
134
134
"""
135
135
136
- rv_map_feature = getattr (fgraph , "preserve_rv_mappings" , None )
137
- values_to_transforms = getattr (fgraph , "values_to_transforms" , None )
136
+ rv_map_feature : Optional [PreserveRVMappings ] = getattr (fgraph , "preserve_rv_mappings" , None )
137
+ values_to_transforms : Optional [TransformValuesMapping ] = getattr (
138
+ fgraph , "values_to_transforms" , None
139
+ )
138
140
139
141
if rv_map_feature is None or values_to_transforms is None :
140
142
return None # pragma: no cover
141
143
142
- try :
143
- rv_var = node .default_output ()
144
- rv_var_out_idx = node .outputs .index (rv_var )
145
- except ValueError :
146
- return None
144
+ rv_vars = []
145
+ value_vars = []
147
146
148
- value_var = rv_map_feature .rv_values .get (rv_var , None )
149
- if value_var is None :
147
+ for out in node .outputs :
148
+ value = rv_map_feature .rv_values .get (out , None )
149
+ if value is None :
150
+ continue
151
+ rv_vars .append (out )
152
+ value_vars .append (value )
153
+
154
+ if not value_vars :
150
155
return None
151
156
152
- transform = values_to_transforms .get (value_var , None )
157
+ transforms = [ values_to_transforms .get (value_var , None ) for value_var in value_vars ]
153
158
154
- if transform is None :
159
+ if all ( transform is None for transform in transforms ) :
155
160
return None
156
161
157
- new_op = _create_transformed_rv_op (node .op , transform )
162
+ new_op = _create_transformed_rv_op (node .op , transforms )
158
163
# Create a new `Apply` node and outputs
159
164
trans_node = node .clone ()
160
165
trans_node .op = new_op
161
- trans_node .outputs [rv_var_out_idx ].name = node .outputs [rv_var_out_idx ].name
162
166
163
167
# We now assume that the old value variable represents the *transformed space*.
164
168
# This means that we need to replace all instance of the old value variable
165
169
# with "inversely/un-" transformed versions of itself.
166
- new_value_var = transformed_variable (
167
- transform .backward (value_var , * trans_node .inputs ), value_var
168
- )
169
- if value_var .name and getattr (transform , "name" , None ):
170
- new_value_var .name = f"{ value_var .name } _{ transform .name } "
170
+ for rv_var , value_var , transform in zip (rv_vars , value_vars , transforms ):
171
+ rv_var_out_idx = node .outputs .index (rv_var )
172
+ trans_node .outputs [rv_var_out_idx ].name = rv_var .name
171
173
172
- rv_map_feature .update_rv_maps (rv_var , new_value_var , trans_node .outputs [rv_var_out_idx ])
174
+ if transform is None :
175
+ continue
176
+
177
+ new_value_var = transformed_variable (
178
+ transform .backward (value_var , * trans_node .inputs ), value_var
179
+ )
180
+
181
+ if value_var .name and getattr (transform , "name" , None ):
182
+ new_value_var .name = f"{ value_var .name } _{ transform .name } "
183
+
184
+ rv_map_feature .update_rv_maps (rv_var , new_value_var , trans_node .outputs [rv_var_out_idx ])
173
185
174
186
return trans_node .outputs
175
187
@@ -549,7 +561,7 @@ def log_jac_det(self, value, *inputs):
549
561
550
562
def _create_transformed_rv_op (
551
563
rv_op : Op ,
552
- transform : RVTransform ,
564
+ transforms : Union [ RVTransform , Sequence [ Union [ None , RVTransform ]]] ,
553
565
* ,
554
566
cls_dict_extra : Optional [Dict ] = None ,
555
567
) -> Op :
@@ -572,14 +584,20 @@ def _create_transformed_rv_op(
572
584
573
585
"""
574
586
575
- trans_name = getattr (transform , "name" , "transformed" )
587
+ if not isinstance (transforms , Sequence ):
588
+ transforms = (transforms ,)
589
+
590
+ trans_names = [
591
+ getattr (transform , "name" , "transformed" ) if transform is not None else "None"
592
+ for transform in transforms
593
+ ]
576
594
rv_op_type = type (rv_op )
577
595
rv_type_name = rv_op_type .__name__
578
596
cls_dict = rv_op_type .__dict__ .copy ()
579
597
rv_name = cls_dict .get ("name" , "" )
580
598
if rv_name :
581
- cls_dict ["name" ] = f"{ rv_name } _{ trans_name } "
582
- cls_dict ["transform " ] = transform
599
+ cls_dict ["name" ] = f"{ rv_name } _{ '_' . join ( trans_names ) } "
600
+ cls_dict ["transforms " ] = transforms
583
601
584
602
if cls_dict_extra is not None :
585
603
cls_dict .update (cls_dict_extra )
@@ -595,17 +613,27 @@ def transformed_logprob(op, values, *inputs, use_jacobian=True, **kwargs):
595
613
We assume that the value variable was back-transformed to be on the natural
596
614
support of the respective random variable.
597
615
"""
598
- ( value ,) = values
616
+ logprobs = _logprob ( rv_op , values , * inputs , ** kwargs )
599
617
600
- logprob = _logprob (rv_op , values , * inputs , ** kwargs )
618
+ if not isinstance (logprobs , Sequence ):
619
+ logprobs = [logprobs ]
601
620
602
621
if use_jacobian :
603
- assert isinstance (value .owner .op , TransformedVariable )
604
- original_forward_value = value .owner .inputs [1 ]
605
- jacobian = op .transform .log_jac_det (original_forward_value , * inputs )
606
- logprob += jacobian
607
-
608
- return logprob
622
+ assert len (values ) == len (logprobs ) == len (op .transforms )
623
+ logprobs_jac = []
624
+ for value , transform , logprob in zip (values , op .transforms , logprobs ):
625
+ if transform is None :
626
+ logprobs_jac .append (logprob )
627
+ continue
628
+ assert isinstance (value .owner .op , TransformedVariable )
629
+ original_forward_value = value .owner .inputs [1 ]
630
+ jacobian = transform .log_jac_det (original_forward_value , * inputs ).copy ()
631
+ if value .name :
632
+ jacobian .name = f"{ value .name } _jacobian"
633
+ logprobs_jac .append (logprob + jacobian )
634
+ logprobs = logprobs_jac
635
+
636
+ return logprobs
609
637
610
638
new_op = copy (rv_op )
611
639
new_op .__class__ = new_op_type
0 commit comments