Skip to content

Commit f1279fa

Browse files
committed
Allow transforms to work with multiple valued nodes
1 parent 5d2c0de commit f1279fa

File tree

2 files changed

+122
-32
lines changed

2 files changed

+122
-32
lines changed

pymc/logprob/transforms.py

Lines changed: 60 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@
3737
import abc
3838

3939
from copy import copy
40-
from typing import Callable, Dict, List, Optional, Tuple, Union
40+
from typing import Callable, Dict, List, Optional, Sequence, Tuple, Union
4141

4242
import pytensor.tensor as at
4343

@@ -133,43 +133,55 @@ def transform_values(fgraph: FunctionGraph, node: Node) -> Optional[List[Node]]:
133133
``Y`` on the natural scale.
134134
"""
135135

136-
rv_map_feature = getattr(fgraph, "preserve_rv_mappings", None)
137-
values_to_transforms = getattr(fgraph, "values_to_transforms", None)
136+
rv_map_feature: Optional[PreserveRVMappings] = getattr(fgraph, "preserve_rv_mappings", None)
137+
values_to_transforms: Optional[TransformValuesMapping] = getattr(
138+
fgraph, "values_to_transforms", None
139+
)
138140

139141
if rv_map_feature is None or values_to_transforms is None:
140142
return None # pragma: no cover
141143

142-
try:
143-
rv_var = node.default_output()
144-
rv_var_out_idx = node.outputs.index(rv_var)
145-
except ValueError:
146-
return None
144+
rv_vars = []
145+
value_vars = []
147146

148-
value_var = rv_map_feature.rv_values.get(rv_var, None)
149-
if value_var is None:
147+
for out in node.outputs:
148+
value = rv_map_feature.rv_values.get(out, None)
149+
if value is None:
150+
continue
151+
rv_vars.append(out)
152+
value_vars.append(value)
153+
154+
if not value_vars:
150155
return None
151156

152-
transform = values_to_transforms.get(value_var, None)
157+
transforms = [values_to_transforms.get(value_var, None) for value_var in value_vars]
153158

154-
if transform is None:
159+
if all(transform is None for transform in transforms):
155160
return None
156161

157-
new_op = _create_transformed_rv_op(node.op, transform)
162+
new_op = _create_transformed_rv_op(node.op, transforms)
158163
# Create a new `Apply` node and outputs
159164
trans_node = node.clone()
160165
trans_node.op = new_op
161-
trans_node.outputs[rv_var_out_idx].name = node.outputs[rv_var_out_idx].name
162166

163167
# We now assume that the old value variable represents the *transformed space*.
164168
# This means that we need to replace all instance of the old value variable
165169
# with "inversely/un-" transformed versions of itself.
166-
new_value_var = transformed_variable(
167-
transform.backward(value_var, *trans_node.inputs), value_var
168-
)
169-
if value_var.name and getattr(transform, "name", None):
170-
new_value_var.name = f"{value_var.name}_{transform.name}"
170+
for rv_var, value_var, transform in zip(rv_vars, value_vars, transforms):
171+
rv_var_out_idx = node.outputs.index(rv_var)
172+
trans_node.outputs[rv_var_out_idx].name = rv_var.name
171173

172-
rv_map_feature.update_rv_maps(rv_var, new_value_var, trans_node.outputs[rv_var_out_idx])
174+
if transform is None:
175+
continue
176+
177+
new_value_var = transformed_variable(
178+
transform.backward(value_var, *trans_node.inputs), value_var
179+
)
180+
181+
if value_var.name and getattr(transform, "name", None):
182+
new_value_var.name = f"{value_var.name}_{transform.name}"
183+
184+
rv_map_feature.update_rv_maps(rv_var, new_value_var, trans_node.outputs[rv_var_out_idx])
173185

174186
return trans_node.outputs
175187

@@ -549,7 +561,7 @@ def log_jac_det(self, value, *inputs):
549561

550562
def _create_transformed_rv_op(
551563
rv_op: Op,
552-
transform: RVTransform,
564+
transforms: Union[RVTransform, Sequence[Union[None, RVTransform]]],
553565
*,
554566
cls_dict_extra: Optional[Dict] = None,
555567
) -> Op:
@@ -572,14 +584,20 @@ def _create_transformed_rv_op(
572584
573585
"""
574586

575-
trans_name = getattr(transform, "name", "transformed")
587+
if not isinstance(transforms, Sequence):
588+
transforms = (transforms,)
589+
590+
trans_names = [
591+
getattr(transform, "name", "transformed") if transform is not None else "None"
592+
for transform in transforms
593+
]
576594
rv_op_type = type(rv_op)
577595
rv_type_name = rv_op_type.__name__
578596
cls_dict = rv_op_type.__dict__.copy()
579597
rv_name = cls_dict.get("name", "")
580598
if rv_name:
581-
cls_dict["name"] = f"{rv_name}_{trans_name}"
582-
cls_dict["transform"] = transform
599+
cls_dict["name"] = f"{rv_name}_{'_'.join(trans_names)}"
600+
cls_dict["transforms"] = transforms
583601

584602
if cls_dict_extra is not None:
585603
cls_dict.update(cls_dict_extra)
@@ -595,17 +613,27 @@ def transformed_logprob(op, values, *inputs, use_jacobian=True, **kwargs):
595613
We assume that the value variable was back-transformed to be on the natural
596614
support of the respective random variable.
597615
"""
598-
(value,) = values
616+
logprobs = _logprob(rv_op, values, *inputs, **kwargs)
599617

600-
logprob = _logprob(rv_op, values, *inputs, **kwargs)
618+
if not isinstance(logprobs, Sequence):
619+
logprobs = [logprobs]
601620

602621
if use_jacobian:
603-
assert isinstance(value.owner.op, TransformedVariable)
604-
original_forward_value = value.owner.inputs[1]
605-
jacobian = op.transform.log_jac_det(original_forward_value, *inputs)
606-
logprob += jacobian
607-
608-
return logprob
622+
assert len(values) == len(logprobs) == len(op.transforms)
623+
logprobs_jac = []
624+
for value, transform, logprob in zip(values, op.transforms, logprobs):
625+
if transform is None:
626+
logprobs_jac.append(logprob)
627+
continue
628+
assert isinstance(value.owner.op, TransformedVariable)
629+
original_forward_value = value.owner.inputs[1]
630+
jacobian = transform.log_jac_det(original_forward_value, *inputs).copy()
631+
if value.name:
632+
jacobian.name = f"{value.name}_jacobian"
633+
logprobs_jac.append(logprob + jacobian)
634+
logprobs = logprobs_jac
635+
636+
return logprobs
609637

610638
new_op = copy(rv_op)
611639
new_op.__class__ = new_op_type

pymc/tests/logprob/test_transforms.py

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,10 +42,12 @@
4242
import scipy.special
4343

4444
from numdifftools import Jacobian
45+
from pytensor.compile.builders import OpFromGraph
4546
from pytensor.graph.basic import equal_computations
4647
from pytensor.graph.fg import FunctionGraph
4748

4849
from pymc.distributions.transforms import _default_transform, log, logodds
50+
from pymc.logprob.abstract import MeasurableVariable, _get_measurable_outputs, _logprob
4951
from pymc.logprob.joint_logprob import factorized_joint_logprob, joint_logprob
5052
from pymc.logprob.transforms import (
5153
ChainedTransform,
@@ -437,6 +439,66 @@ def test_default_transform_multiout():
437439
)
438440

439441

442+
@pytest.fixture(scope="module")
443+
def multiout_measurable_op():
444+
# Create a dummy Op that just returns the two inputs
445+
mu1, mu2 = at.scalars("mu1", "mu2")
446+
447+
class TestOpFromGraph(OpFromGraph):
448+
def do_constant_folding(self, fgraph, node):
449+
False
450+
451+
multiout_op = TestOpFromGraph([mu1, mu2], [mu1 + 0.0, mu2 + 0.0])
452+
453+
MeasurableVariable.register(TestOpFromGraph)
454+
455+
@_logprob.register(TestOpFromGraph)
456+
def logp_multiout(op, values, mu1, mu2):
457+
value1, value2 = values
458+
return value1 + mu1, value2 + mu2
459+
460+
@_get_measurable_outputs.register(TestOpFromGraph)
461+
def measurable_multiout_op_outputs(op, node):
462+
return node.outputs
463+
464+
return multiout_op
465+
466+
467+
@pytest.mark.parametrize("transform_x", (True, False))
468+
@pytest.mark.parametrize("transform_y", (True, False))
469+
def test_nondefault_transform_multiout(transform_x, transform_y, multiout_measurable_op):
470+
x, y = multiout_measurable_op(1, 2)
471+
x.name = "x"
472+
y.name = "y"
473+
x_vv = x.clone()
474+
y_vv = y.clone()
475+
476+
transform_rewrite = TransformValuesRewrite(
477+
{
478+
x_vv: LogTransform() if transform_x else None,
479+
y_vv: ExpTransform() if transform_y else None,
480+
}
481+
)
482+
483+
logp = joint_logprob({x: x_vv, y: y_vv}, extra_rewrites=transform_rewrite)
484+
485+
x_vv_test = np.random.normal()
486+
y_vv_test = np.abs(np.random.normal())
487+
488+
expected_logp = 0
489+
if not transform_x:
490+
expected_logp += x_vv_test + 1
491+
else:
492+
expected_logp += np.exp(x_vv_test) + 1 + x_vv_test
493+
# y logp
494+
if not transform_y:
495+
expected_logp += y_vv_test + 2
496+
else:
497+
expected_logp += np.log(y_vv_test) + 2 - np.log(y_vv_test)
498+
499+
np.testing.assert_almost_equal(logp.eval({x_vv: x_vv_test, y_vv: y_vv_test}), expected_logp)
500+
501+
440502
def test_TransformValuesMapping():
441503
x = at.vector()
442504
fg = FunctionGraph(outputs=[x])

0 commit comments

Comments
 (0)