Skip to content

Commit 1c88f72

Browse files
committed
Handle Scan gradients of non shaped disconnected inputs
1 parent a179f9f commit 1c88f72

File tree

2 files changed

+74
-2
lines changed

2 files changed

+74
-2
lines changed

pytensor/scan/op.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,7 @@
7272
from pytensor.graph.features import NoOutputFromInplace
7373
from pytensor.graph.op import HasInnerGraph, Op
7474
from pytensor.graph.replace import clone_replace
75+
from pytensor.graph.type import HasShape
7576
from pytensor.graph.utils import InconsistencyError, MissingInputError
7677
from pytensor.link.c.basic import CLinker
7778
from pytensor.printing import op_debug_information
@@ -2591,7 +2592,11 @@ def compute_all_gradients(known_grads):
25912592
# mask inputs that get no gradients
25922593
for dx in range(len(dC_dinps_t)):
25932594
if dC_dinps_t[dx] is None:
2594-
dC_dinps_t[dx] = dC_dinps_t[dx] = pt.zeros_like(diff_inputs[dx])
2595+
dC_dinps_t[dx] = dC_dinps_t[dx] = (
2596+
pt.zeros_like(diff_inputs[dx])
2597+
if isinstance(diff_inputs[dx].type, HasShape)
2598+
else pt.zeros(())
2599+
)
25952600
else:
25962601
disconnected_dC_dinps_t[dx] = False
25972602
for Xt, Xt_placeholder in zip(
@@ -2965,7 +2970,8 @@ def compute_all_gradients(known_grads):
29652970
else:
29662971
outer_inp_sitsot.append(
29672972
pt.zeros(
2968-
[grad_steps + 1] + [x.shape[i] for i in range(x.ndim)],
2973+
[grad_steps + 1]
2974+
+ (list(x.shape) if isinstance(x.type, HasShape) else []),
29692975
dtype=y.dtype,
29702976
)
29712977
)

tests/scan/test_basic.py

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2179,6 +2179,72 @@ def step(s, xtm2, xtm1, z):
21792179
assert gg.eval({seq: [1, 1], x0: [1, 1], z: 2}) == 12
21802180
assert gg.eval({seq: [1, 1], x0: [1, 1], z: 1}) == 3 / 2
21812181

2182+
@pytest.mark.parametrize("case", ("inside-explicit", "inside-implicit", "outside"))
2183+
def test_non_shaped_input_disconnected_gradient(self, case):
2184+
"""Test that Scan gradient works when non shaped variables are disconnected from the gradient.
2185+
2186+
Regression test for https://github.com/pymc-devs/pytensor/issues/6
2187+
"""
2188+
2189+
# In all cases rng is disconnected from the output gradient
2190+
# Note that when it is an input to the scan (explicit or not) it is still not updated by the scan,
2191+
# so it is equivalent to the `outside` case. A rewrite could have legally hoisted the rng out of the scan.
2192+
rng = shared(np.random.default_rng())
2193+
2194+
data = pt.zeros(16)
2195+
2196+
nonlocal_random_index = pt.random.integers(16, rng=rng)
2197+
nonlocal_random_datum = data[nonlocal_random_index]
2198+
2199+
if case == "outside":
2200+
2201+
def step(s, random_datum):
2202+
return (random_datum + s) ** 2
2203+
2204+
strict = True
2205+
non_sequences = [nonlocal_random_datum]
2206+
2207+
elif case == "inside-implicit":
2208+
2209+
def step(s):
2210+
return (nonlocal_random_datum + s) ** 2
2211+
2212+
strict = False
2213+
non_sequences = [] # Scan will introduce the non_sequences for us
2214+
2215+
elif case == "inside-explicit":
2216+
2217+
def step(s, data, rng):
2218+
random_index = pt.random.integers(
2219+
16, rng=rng
2220+
) # Not updated by the scan
2221+
random_datum = data[random_index]
2222+
return (random_datum + s) ** 2
2223+
2224+
strict = (True,)
2225+
non_sequences = [data, rng]
2226+
2227+
else:
2228+
raise ValueError(f"Invalid case: {case}")
2229+
2230+
seq = vector("seq")
2231+
xs, _ = scan(
2232+
step,
2233+
sequences=[seq],
2234+
non_sequences=non_sequences,
2235+
strict=strict,
2236+
)
2237+
x0 = xs[0]
2238+
2239+
np.testing.assert_allclose(
2240+
x0.eval({seq: [np.pi, np.nan, np.nan]}),
2241+
np.pi**2,
2242+
)
2243+
np.testing.assert_allclose(
2244+
grad(x0, seq)[0].eval({seq: [np.pi, np.nan, np.nan]}),
2245+
2 * np.pi,
2246+
)
2247+
21822248

21832249
@pytest.mark.skipif(
21842250
not config.cxx, reason="G++ not available, so we need to skip this test."

0 commit comments

Comments
 (0)