@@ -63,6 +63,43 @@ def test_random_updates(rng_ctor):
63
63
)
64
64
65
65
66
+ def test_random_updates_input_storage_order ():
67
+ """Test case described in issue #314.
68
+
69
+ This happened when we tried to update the input storage after we clone the shared RNG.
70
+ We used to call `input_storage.index(old_input_storage)` which would fail when the input_storage contained
71
+ numpy arrays before the RNG value, which would fail the equality check.
72
+
73
+ """
74
+ pt_rng = RandomStream (1 )
75
+
76
+ batchshape = (3 , 1 , 4 , 4 )
77
+ inp_shared = pytensor .shared (
78
+ np .zeros (batchshape , dtype = "float64" ), name = "inp_shared"
79
+ )
80
+
81
+ inp = at .tensor4 (dtype = "float64" , name = "inp" )
82
+ inp_update = inp + pt_rng .normal (size = inp .shape , loc = 5 , scale = 1e-5 )
83
+
84
+ # This function replaces inp by input_shared in the update expression
85
+ # This is what caused the RNG to appear later than inp_shared in the input_storage
86
+ with pytest .warns (
87
+ UserWarning ,
88
+ match = r"The RandomType SharedVariables \[.+\] will not be used" ,
89
+ ):
90
+ fn = pytensor .function (
91
+ inputs = [],
92
+ outputs = [],
93
+ updates = {inp_shared : inp_update },
94
+ givens = {inp : inp_shared },
95
+ mode = "JAX" ,
96
+ )
97
+ fn ()
98
+ np .testing .assert_allclose (inp_shared .get_value (), 5 , rtol = 1e-3 )
99
+ fn ()
100
+ np .testing .assert_allclose (inp_shared .get_value (), 10 , rtol = 1e-3 )
101
+
102
+
66
103
@pytest .mark .parametrize (
67
104
"rv_op, dist_params, base_size, cdf_name, params_conv" ,
68
105
[
0 commit comments