@@ -2179,6 +2179,72 @@ def step(s, xtm2, xtm1, z):
2179
2179
assert gg .eval ({seq : [1 , 1 ], x0 : [1 , 1 ], z : 2 }) == 12
2180
2180
assert gg .eval ({seq : [1 , 1 ], x0 : [1 , 1 ], z : 1 }) == 3 / 2
2181
2181
2182
+ @pytest .mark .parametrize ("case" , ("inside-explicit" , "inside-implicit" , "outside" ))
2183
+ def test_non_shaped_input_disconnected_gradient (self , case ):
2184
+ """Test that Scan gradient works when non shaped variables are disconnected from the gradient.
2185
+
2186
+ Regression test for https://github.com/pymc-devs/pytensor/issues/6
2187
+ """
2188
+
2189
+ # In all cases rng is disconnected from the output gradient
2190
+ # Note that when it is an input to the scan (explicit or not) it is still not updated by the scan,
2191
+ # so it is equivalent to the `outside` case. A rewrite could have legally hoisted the rng out of the scan.
2192
+ rng = shared (np .random .default_rng ())
2193
+
2194
+ data = pt .zeros (16 )
2195
+
2196
+ nonlocal_random_index = pt .random .integers (16 , rng = rng )
2197
+ nonlocal_random_datum = data [nonlocal_random_index ]
2198
+
2199
+ if case == "outside" :
2200
+
2201
+ def step (s , random_datum ):
2202
+ return (random_datum + s ) ** 2
2203
+
2204
+ strict = True
2205
+ non_sequences = [nonlocal_random_datum ]
2206
+
2207
+ elif case == "inside-implicit" :
2208
+
2209
+ def step (s ):
2210
+ return (nonlocal_random_datum + s ) ** 2
2211
+
2212
+ strict = False
2213
+ non_sequences = [] # Scan will introduce the non_sequences for us
2214
+
2215
+ elif case == "inside-explicit" :
2216
+
2217
+ def step (s , data , rng ):
2218
+ random_index = pt .random .integers (
2219
+ 16 , rng = rng
2220
+ ) # Not updated by the scan
2221
+ random_datum = data [random_index ]
2222
+ return (random_datum + s ) ** 2
2223
+
2224
+ strict = (True ,)
2225
+ non_sequences = [data , rng ]
2226
+
2227
+ else :
2228
+ raise ValueError (f"Invalid case: { case } " )
2229
+
2230
+ seq = vector ("seq" )
2231
+ xs , _ = scan (
2232
+ step ,
2233
+ sequences = [seq ],
2234
+ non_sequences = non_sequences ,
2235
+ strict = strict ,
2236
+ )
2237
+ x0 = xs [0 ]
2238
+
2239
+ np .testing .assert_allclose (
2240
+ x0 .eval ({seq : [np .pi , np .nan , np .nan ]}),
2241
+ np .pi ** 2 ,
2242
+ )
2243
+ np .testing .assert_allclose (
2244
+ grad (x0 , seq )[0 ].eval ({seq : [np .pi , np .nan , np .nan ]}),
2245
+ 2 * np .pi ,
2246
+ )
2247
+
2182
2248
2183
2249
@pytest .mark .skipif (
2184
2250
not config .cxx , reason = "G++ not available, so we need to skip this test."
0 commit comments