Skip to content

Commit bfbc8cc

Browse files
committed
Handle Scan in collect_default_updates
This allows proper seeding in CustomDists with Scans
1 parent 0073639 commit bfbc8cc

File tree

4 files changed

+152
-22
lines changed

4 files changed

+152
-22
lines changed

pymc/distributions/distribution.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -596,7 +596,7 @@ class CustomSymbolicDistRV(SymbolicRandomVariable):
596596
def update(self, node: Node):
597597
op = node.op
598598
inner_updates = collect_default_updates(
599-
op.inner_inputs, op.inner_outputs, must_be_shared=False
599+
inputs=op.inner_inputs, outputs=op.inner_outputs, must_be_shared=False
600600
)
601601

602602
# Map inner updates to outer inputs/outputs
@@ -668,7 +668,7 @@ def rv_op(
668668
):
669669
dummy_rv = dist(*dummy_dist_params, dummy_size_param)
670670
dummy_params = [dummy_size_param] + dummy_dist_params
671-
dummy_updates_dict = collect_default_updates(dummy_params, (dummy_rv,))
671+
dummy_updates_dict = collect_default_updates(inputs=dummy_params, outputs=(dummy_rv,))
672672

673673
rv_type = type(
674674
class_name,
@@ -713,7 +713,7 @@ def change_custom_symbolic_dist_size(op, rv, new_size, expand):
713713
dummy_dist_params = [dist_param.type() for dist_param in old_dist_params]
714714
dummy_rv = dist(*dummy_dist_params, dummy_size_param)
715715
dummy_params = [dummy_size_param] + dummy_dist_params
716-
dummy_updates_dict = collect_default_updates(dummy_params, (dummy_rv,))
716+
dummy_updates_dict = collect_default_updates(inputs=dummy_params, outputs=(dummy_rv,))
717717
new_rv_op = rv_type(
718718
inputs=dummy_params,
719719
outputs=[*dummy_updates_dict.values(), dummy_rv],

pymc/pytensorf.py

Lines changed: 56 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@
4747
from pytensor.graph.fg import FunctionGraph
4848
from pytensor.graph.op import Op
4949
from pytensor.scalar.basic import Cast
50+
from pytensor.scan.op import Scan
5051
from pytensor.tensor.basic import _as_tensor_variable
5152
from pytensor.tensor.elemwise import Elemwise
5253
from pytensor.tensor.random.op import RandomVariable
@@ -1004,16 +1005,49 @@ def reseed_rngs(
10041005

10051006

10061007
def collect_default_updates(
1007-
inputs: Sequence[Variable],
10081008
outputs: Sequence[Variable],
1009+
*,
1010+
inputs: Optional[Sequence[Variable]] = None,
10091011
must_be_shared: bool = True,
10101012
) -> Dict[Variable, Variable]:
10111013
"""Collect default update expression for shared-variable RNGs used by RVs between inputs and outputs.
10121014
1013-
If `must_be_shared` is False, update expressions will also be returned for non-shared input RNGs.
1014-
This can be useful to obtain the symbolic update expressions from inner graphs.
1015-
"""
1015+
Parameters
1016+
----------
1017+
outputs: list of PyTensor variables
1018+
List of variables in which graphs default updates will be collected.
1019+
inputs: list of PyTensor variables, optional
1020+
Input nodes above which default updates should not be collected.
1021+
When not provided, search will include top level inputs (roots).
1022+
must_be_shared: bool, default True
1023+
Used internally by PyMC. Whether updates should be collected for non-shared
1024+
RNG input variables. This is used to collect update expressions for inner graphs.
1025+
1026+
Examples
1027+
--------
1028+
.. code:: python
1029+
import pymc as pm
1030+
from pytensor.scan import scan
1031+
from pymc.pytensorf import collect_default_updates
1032+
1033+
def scan_step(xtm1):
1034+
x = xtm1 + pm.Normal.dist()
1035+
x_update = collect_default_updates([x])
1036+
return x, x_update
10161037
1038+
x0 = pm.Normal.dist()
1039+
1040+
xs, updates = scan(
1041+
fn=scan_step,
1042+
outputs_info=[x0],
1043+
n_steps=10,
1044+
)
1045+
1046+
# PyMC makes use of the updates to seed xs properly.
1047+
# Without updates, it would raise an error.
1048+
xs_draws = pm.draw(xs, draws=10)
1049+
1050+
"""
10171051
# Avoid circular import
10181052
from pymc.distributions.distribution import SymbolicRandomVariable
10191053

@@ -1048,16 +1082,31 @@ def find_default_update(clients, rng: Variable) -> Union[None, Variable]:
10481082
next_rng = client.op.update(client).get(rng)
10491083
if next_rng is None:
10501084
raise ValueError(
1051-
f"No update mapping found for RNG used in SymbolicRandomVariable Op {client.op}"
1085+
f"No update found for at least one RNG used in SymbolicRandomVariable Op {client.op}"
1086+
)
1087+
elif isinstance(client.op, Scan):
1088+
# Check if any shared output corresponds to the RNG
1089+
rng_idx = client.inputs.index(rng)
1090+
io_map = client.op.get_oinp_iinp_iout_oout_mappings()["outer_out_from_outer_inp"]
1091+
out_idx = io_map.get(rng_idx, -1)
1092+
if out_idx != -1:
1093+
next_rng = client.outputs[out_idx]
1094+
else: # No break
1095+
raise ValueError(
1096+
f"No update found for at least one RNG used in Scan Op {client.op}.\n"
1097+
"You can use `pytensorf.collect_default_updates` inside the Scan function to return updates automatically."
10521098
)
10531099
else:
1054-
# We don't know how this RNG should be updated (e.g., Scan).
1100+
# We don't know how this RNG should be updated (e.g., OpFromGraph).
10551101
# The user should provide an update manually
10561102
return None
10571103

10581104
# Recurse until we find final update for RNG
10591105
return find_default_update(clients, next_rng)
10601106

1107+
if inputs is None:
1108+
inputs = []
1109+
10611110
outputs = makeiter(outputs)
10621111
fg = FunctionGraph(outputs=outputs, clone=False)
10631112
clients = fg.clients
@@ -1129,7 +1178,7 @@ def compile_pymc(
11291178
"""
11301179
# Create an update mapping of RandomVariable's RNG so that it is automatically
11311180
# updated after every function call
1132-
rng_updates = collect_default_updates(inputs, outputs)
1181+
rng_updates = collect_default_updates(inputs=inputs, outputs=outputs)
11331182

11341183
# We always reseed random variables as this provides RNGs with no chances of collision
11351184
if rng_updates:

tests/distributions/test_distribution.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
import pytest
2424
import scipy.stats as st
2525

26+
from pytensor import scan
2627
from pytensor.tensor import TensorVariable
2728

2829
import pymc as pm
@@ -51,6 +52,7 @@
5152
from pymc.logprob.abstract import get_measurable_outputs
5253
from pymc.logprob.basic import logcdf, logp
5354
from pymc.model import Deterministic, Model
55+
from pymc.pytensorf import collect_default_updates
5456
from pymc.sampling import draw, sample
5557
from pymc.testing import (
5658
BaseTestDistributionRandom,
@@ -523,6 +525,48 @@ def old_random(size):
523525
# New API is fine
524526
pm.CustomDist.dist(dist=old_random, class_name="custom_dist")
525527

528+
def test_scan(self):
529+
def trw(nu, sigma, steps, size):
530+
def step(xtm1, nu, sigma):
531+
x = pm.StudentT.dist(nu=nu, mu=xtm1, sigma=sigma, shape=size)
532+
return x, collect_default_updates([x])
533+
534+
xs, _ = scan(
535+
fn=step,
536+
outputs_info=pt.zeros(size),
537+
non_sequences=[nu, sigma],
538+
n_steps=steps,
539+
)
540+
541+
# Logprob inference cannot be derived yet https://github.com/pymc-devs/pymc/issues/6360
542+
# xs = swapaxes(xs, 0, -1)
543+
544+
return xs
545+
546+
nu = 4
547+
sigma = 0.7
548+
steps = 99
549+
batch_size = 3
550+
x = CustomDist.dist(nu, sigma, steps, dist=trw, size=batch_size)
551+
552+
x_draw = pm.draw(x, random_seed=1)
553+
assert x_draw.shape == (steps, batch_size)
554+
np.testing.assert_allclose(pm.draw(x, random_seed=1), x_draw)
555+
assert not np.any(pm.draw(x, random_seed=2) == x_draw)
556+
557+
ref_dist = pm.RandomWalk.dist(
558+
init_dist=pm.Flat.dist(),
559+
innovation_dist=pm.StudentT.dist(nu=nu, sigma=sigma),
560+
steps=steps,
561+
size=(batch_size,),
562+
)
563+
ref_val = pt.concatenate([np.zeros((1, batch_size)), x_draw]).T
564+
565+
np.testing.assert_allclose(
566+
pm.logp(x, x_draw).eval().sum(0),
567+
pm.logp(ref_dist, ref_val).eval(),
568+
)
569+
526570

527571
class TestSymbolicRandomVariable:
528572
def test_inline(self):

tests/test_pytensorf.py

Lines changed: 49 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
import pytest
2525
import scipy.sparse as sps
2626

27-
from pytensor import shared
27+
from pytensor import scan, shared
2828
from pytensor.compile.builders import OpFromGraph
2929
from pytensor.graph.basic import Variable, equal_computations
3030
from pytensor.tensor.random.basic import normal, uniform
@@ -465,7 +465,7 @@ def update(self, node):
465465
],
466466
)(rng1, rng2)
467467
with pytest.raises(
468-
ValueError, match="No update mapping found for RNG used in SymbolicRandomVariable"
468+
ValueError, match="No update found for at least one RNG used in SymbolicRandomVariable"
469469
):
470470
compile_pymc(inputs=[], outputs=[dummy_x1, dummy_x2])
471471

@@ -531,7 +531,7 @@ def test_nested_updates(self):
531531
next_rng2, y = pt.random.normal(rng=next_rng1).owner.outputs
532532
next_rng3, z = pt.random.normal(rng=next_rng2).owner.outputs
533533

534-
collect_default_updates([], [x, y, z]) == {rng: next_rng3}
534+
collect_default_updates(inputs=[], outputs=[x, y, z]) == {rng: next_rng3}
535535

536536
fn = compile_pymc([], [x, y, z], random_seed=514)
537537
assert not set(list(np.array(fn()))) & set(list(np.array(fn())))
@@ -540,19 +540,56 @@ def test_nested_updates(self):
540540
fn = pytensor.function([], [x, y, z], updates={rng: next_rng1})
541541
assert set(list(np.array(fn()))) & set(list(np.array(fn())))
542542

543+
def test_collect_default_updates_must_be_shared(self):
544+
shared_rng = pytensor.shared(np.random.default_rng())
545+
nonshared_rng = shared_rng.type()
543546

544-
def test_collect_default_updates_must_be_shared():
545-
shared_rng = pytensor.shared(np.random.default_rng())
546-
nonshared_rng = shared_rng.type()
547+
next_rng_of_shared, x = pt.random.normal(rng=shared_rng).owner.outputs
548+
next_rng_of_nonshared, y = pt.random.normal(rng=nonshared_rng).owner.outputs
547549

548-
next_rng_of_shared, x = pt.random.normal(rng=shared_rng).owner.outputs
549-
next_rng_of_nonshared, y = pt.random.normal(rng=nonshared_rng).owner.outputs
550+
res = collect_default_updates(inputs=[nonshared_rng], outputs=[x, y])
551+
assert res == {shared_rng: next_rng_of_shared}
550552

551-
res = collect_default_updates(inputs=[nonshared_rng], outputs=[x, y])
552-
assert res == {shared_rng: next_rng_of_shared}
553+
res = collect_default_updates(inputs=[nonshared_rng], outputs=[x, y], must_be_shared=False)
554+
assert res == {shared_rng: next_rng_of_shared, nonshared_rng: next_rng_of_nonshared}
553555

554-
res = collect_default_updates(inputs=[nonshared_rng], outputs=[x, y], must_be_shared=False)
555-
assert res == {shared_rng: next_rng_of_shared, nonshared_rng: next_rng_of_nonshared}
556+
def test_scan_updates(self):
557+
def step_with_update(x, rng):
558+
next_rng, x = pm.Normal.dist(x, rng=rng).owner.outputs
559+
return x, {rng: next_rng}
560+
561+
def step_wo_update(x, rng):
562+
return step_with_update(x, rng)[0]
563+
564+
rng = pytensor.shared(np.random.default_rng())
565+
566+
xs, next_rng = scan(
567+
fn=step_wo_update,
568+
outputs_info=[pt.zeros(())],
569+
non_sequences=[rng],
570+
n_steps=10,
571+
name="test_scan",
572+
)
573+
574+
assert not next_rng
575+
576+
with pytest.raises(
577+
ValueError,
578+
match=r"No update found for at least one RNG used in Scan Op for\{cpu,test_scan\}",
579+
):
580+
collect_default_updates([xs])
581+
582+
ys, next_rng = scan(
583+
fn=step_with_update,
584+
outputs_info=[pt.zeros(())],
585+
non_sequences=[rng],
586+
n_steps=10,
587+
)
588+
589+
assert collect_default_updates([ys]) == {rng: tuple(next_rng.values())[0]}
590+
591+
fn = compile_pymc([], ys, random_seed=1)
592+
assert not (set(fn()) & set(fn()))
556593

557594

558595
def test_replace_rng_nodes():

0 commit comments

Comments
 (0)