Skip to content

Commit 93bfa1b

Browse files
committed
Fix bug in JAX cloning of RNG shared variables
1 parent 2b7f95c commit 93bfa1b

File tree

2 files changed

+45
-1
lines changed

2 files changed

+45
-1
lines changed

pytensor/link/jax/linker.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,14 @@ def fgraph_convert(self, fgraph, input_storage, storage_map, **kwargs):
4444
new_inp_storage = [new_inp.get_value(borrow=True)]
4545
storage_map[new_inp] = new_inp_storage
4646
old_inp_storage = storage_map.pop(old_inp)
47-
input_storage[input_storage.index(old_inp_storage)] = new_inp_storage
47+
# Find index of old_inp_storage in input_storage
48+
for input_storage_idx, input_storage_item in enumerate(input_storage):
49+
# We have to establish equality based on identity because input_storage may contain numpy arrays
50+
if input_storage_item is old_inp_storage:
51+
break
52+
else: # no break
53+
raise ValueError()
54+
input_storage[input_storage_idx] = new_inp_storage
4855
fgraph.remove_input(
4956
fgraph.inputs.index(old_inp), reason="JAXLinker.fgraph_convert"
5057
)

tests/link/jax/test_random.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,43 @@ def test_random_updates(rng_ctor):
6363
)
6464

6565

66+
def test_random_updates_input_storage_order():
67+
"""Test case described in issue #314.
68+
69+
This happened when we tried to update the input storage after we clone the shared RNG.
70+
We used to call `input_storage.index(old_input_storage)` which would fail when the input_storage contained
71+
numpy arrays before the RNG value, which would fail the equality check.
72+
73+
"""
74+
pt_rng = RandomStream(1)
75+
76+
batchshape = (3, 1, 4, 4)
77+
inp_shared = pytensor.shared(
78+
np.zeros(batchshape, dtype="float64"), name="inp_shared"
79+
)
80+
81+
inp = at.tensor4(dtype="float64", name="inp")
82+
inp_update = inp + pt_rng.normal(size=inp.shape, loc=5, scale=1e-5)
83+
84+
# This function replaces inp by input_shared in the update expression
85+
# This is what caused the RNG to appear later than inp_shared in the input_storage
86+
with pytest.warns(
87+
UserWarning,
88+
match=r"The RandomType SharedVariables \[.+\] will not be used",
89+
):
90+
fn = pytensor.function(
91+
inputs=[],
92+
outputs=[],
93+
updates={inp_shared: inp_update},
94+
givens={inp: inp_shared},
95+
mode="JAX",
96+
)
97+
fn()
98+
np.testing.assert_allclose(inp_shared.get_value(), 5, rtol=1e-3)
99+
fn()
100+
np.testing.assert_allclose(inp_shared.get_value(), 10, rtol=1e-3)
101+
102+
66103
@pytest.mark.parametrize(
67104
"rv_op, dist_params, base_size, cdf_name, params_conv",
68105
[

0 commit comments

Comments
 (0)