Skip to content

Commit 4aea87c

Browse files
committed
Fix bug when taking the L_op of a Scan with mit-mot and disconnected output gradients
1 parent 49cf9d2 commit 4aea87c

File tree

2 files changed

+89
-34
lines changed

2 files changed

+89
-34
lines changed

pytensor/scan/op.py

Lines changed: 38 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -2509,36 +2509,40 @@ def compute_all_gradients(known_grads):
25092509
return rval
25102510

25112511
var_mappings = self.get_oinp_iinp_iout_oout_mappings()
2512-
dC_dinps_t = [None for inp in diff_inputs]
25132512
disconnected_dC_dinps_t = [True for inp in diff_inputs]
2513+
2514+
n_mit_mot_outs = info.n_mit_mot_outs
2515+
# In the case of mit-mot there can be more inner outputs than outer ones
2516+
n_extra_mit_mot_outs = n_mit_mot_outs - info.n_mit_mot
2517+
idx_nitsot_out_start = n_mit_mot_outs + info.n_mit_sot + info.n_sit_sot
2518+
idx_nitsot_out_end = idx_nitsot_out_start + info.n_nit_sot
2519+
2520+
# Create dummy variables for the internal input gradients
2521+
states = (
2522+
self.inner_mitmot(self_inputs)
2523+
+ self.inner_mitsot(self_inputs)
2524+
+ self.inner_sitsot(self_inputs)
2525+
)
25142526
dC_dXts = []
25152527
Xts = []
25162528
for idx, Xt in enumerate(diff_outputs):
25172529
# We are looking for x[t-1] for a given x[t]
2518-
if idx >= info.n_mit_mot_outs:
2530+
if idx >= n_mit_mot_outs:
25192531
Xt_placeholder = safe_new(Xt)
25202532
Xts.append(Xt_placeholder)
25212533

25222534
# Different processing based on whether Xt is a nitsot output
25232535
# or not. NOTE : This cannot be done by using
25242536
# "if Xt not in self.inner_nitsot_outs(self_outputs)" because
25252537
# the exact same variable can be used as multiple outputs.
2526-
idx_nitsot_start = info.n_mit_mot + info.n_mit_sot + info.n_sit_sot
2527-
idx_nitsot_end = idx_nitsot_start + info.n_nit_sot
2528-
if idx < idx_nitsot_start or idx >= idx_nitsot_end:
2538+
if idx < idx_nitsot_out_start or idx >= idx_nitsot_out_end:
25292539
# What we do here is loop through dC_douts and collect all
25302540
# those that are connected to the specific one and do an
25312541
# upcast on all of their dtypes to get the dtype for this
25322542
# specific output. Deciding if the gradient with this
25332543
# specific previous step is defined or not is done somewhere
25342544
# else.
25352545
dtypes = []
2536-
states = (
2537-
self.inner_mitmot(self_inputs)
2538-
+ self.inner_mitsot(self_inputs)
2539-
+ self.inner_sitsot(self_inputs)
2540-
)
2541-
25422546
for pos, inp in enumerate(states):
25432547
if inp in graph_inputs([Xt]):
25442548
# Get the index of the outer output that to which
@@ -2555,35 +2559,39 @@ def compute_all_gradients(known_grads):
25552559
new_dtype = config.floatX
25562560
dC_dXt = safe_new(Xt, dtype=new_dtype)
25572561
else:
2558-
if isinstance(dC_douts[idx].type, DisconnectedType):
2562+
# nit-sot outputs
2563+
# If not disconnected assume the output gradient type is a valid type for the input gradient
2564+
if isinstance(
2565+
dC_douts[idx - n_extra_mit_mot_outs].type, DisconnectedType
2566+
):
25592567
continue
2560-
dC_dXt = safe_new(dC_douts[idx][0])
2568+
dC_dXt = safe_new(dC_douts[idx - n_extra_mit_mot_outs][0])
25612569
dC_dXts.append(dC_dXt)
25622570

2571+
# Handle cases where the very same variable may be used as different outputs
2572+
# TODO: Couldn't we add a view Op to avoid this when building the Scan graph?
25632573
known_grads = {}
25642574
dc_dxts_idx = 0
25652575
for i in range(len(diff_outputs)):
2566-
if i < idx_nitsot_start or i >= idx_nitsot_end:
2567-
if diff_outputs[i] in known_grads:
2568-
known_grads[diff_outputs[i]] += dC_dXts[dc_dxts_idx]
2569-
else:
2570-
known_grads[diff_outputs[i]] = dC_dXts[dc_dxts_idx]
2571-
dc_dxts_idx += 1
2576+
if not (i < idx_nitsot_out_start or i >= idx_nitsot_out_end) and isinstance(
2577+
dC_douts[i - n_extra_mit_mot_outs].type, DisconnectedType
2578+
):
2579+
# Special case where we don't have a dC_dXt for disconnected nitsot outputs
2580+
continue
2581+
2582+
# Just some trouble to avoid a +0
2583+
if diff_outputs[i] in known_grads:
2584+
known_grads[diff_outputs[i]] += dC_dXts[dc_dxts_idx]
25722585
else:
2573-
if isinstance(dC_douts[i].type, DisconnectedType):
2574-
continue
2575-
else:
2576-
if diff_outputs[i] in known_grads:
2577-
known_grads[diff_outputs[i]] += dC_dXts[dc_dxts_idx]
2578-
else:
2579-
known_grads[diff_outputs[i]] = dC_dXts[dc_dxts_idx]
2580-
dc_dxts_idx += 1
2586+
known_grads[diff_outputs[i]] = dC_dXts[dc_dxts_idx]
2587+
dc_dxts_idx += 1
2588+
25812589
dC_dinps_t = compute_all_gradients(known_grads)
25822590

25832591
# mask inputs that get no gradients
25842592
for dx in range(len(dC_dinps_t)):
2585-
if not dC_dinps_t[dx]:
2586-
dC_dinps_t[dx] = pt.zeros_like(diff_inputs[dx])
2593+
if dC_dinps_t[dx] is None:
2594+
dC_dinps_t[dx] = dC_dinps_t[dx] = pt.zeros_like(diff_inputs[dx])
25872595
else:
25882596
disconnected_dC_dinps_t[dx] = False
25892597
for Xt, Xt_placeholder in zip(
@@ -2846,7 +2854,6 @@ def compute_all_gradients(known_grads):
28462854
for idx in range(info.n_sit_sot):
28472855
mitmot_inp_taps.append([0, 1])
28482856
mitmot_out_taps.append([1])
2849-
through_shared = False
28502857
if not isinstance(dC_douts[idx + offset].type, DisconnectedType):
28512858
outer_inp_mitmot.append(dC_douts[idx + offset][::-1])
28522859
else:
@@ -3007,9 +3014,7 @@ def compute_all_gradients(known_grads):
30073014
name=f"grad_of_{self.name}" if self.name else None,
30083015
allow_gc=self.allow_gc,
30093016
)
3010-
outputs = local_op(*outer_inputs)
3011-
if not isinstance(outputs, list | tuple):
3012-
outputs = [outputs]
3017+
outputs = local_op(*outer_inputs, return_list=True)
30133018
# Re-order the gradients correctly
30143019
gradients = [DisconnectedType()()]
30153020

@@ -3095,7 +3100,6 @@ def compute_all_gradients(known_grads):
30953100
)
30963101
)
30973102

3098-
start = len(gradients)
30993103
gradients += [DisconnectedType()() for _ in range(info.n_nit_sot)]
31003104
begin = end
31013105

tests/scan/test_basic.py

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2128,6 +2128,57 @@ def test_R_op_mitmot(self):
21282128
# TODO: We should test something about the Rop!
21292129
Rop(d_cost_wrt_pars, pars, p)
21302130

2131+
def test_second_derivative_disconnected_cost_with_mit_mot(self):
2132+
# This test is a regression test for a bug that was revealed
2133+
# when we computed the pushforward of a Scan gradient via two applications of pullback
2134+
seq = pt.vector("seq", shape=(2,))
2135+
z = pt.scalar("z")
2136+
x0 = pt.vector("x0", shape=(2,))
2137+
2138+
# When s is 1 and z is 2, xs[-1] is just a sneaky
2139+
# x ** 4 (after two nsteps)
2140+
# grad should be 4 * x ** 3
2141+
# and grad of grad should be 12 * x ** 2
2142+
def step(s, xtm2, xtm1, z):
2143+
return s * ((xtm2 * 0 + xtm1) ** 2) * (z / 2)
2144+
2145+
xs, _ = scan(
2146+
step,
2147+
sequences=[seq],
2148+
outputs_info=[{"initial": x0, "taps": (-2, -1)}],
2149+
non_sequences=[z],
2150+
n_steps=2,
2151+
)
2152+
last_x = xs[-1]
2153+
2154+
g_wrt_x0, g_wrt_z, g_wrt_seq = pt.grad(last_x, [x0, z, seq])
2155+
g = g_wrt_x0.sum() + g_wrt_z.sum() * 0 + g_wrt_seq.sum() * 0
2156+
assert g.eval({seq: [1, 1], x0: [1, 1], z: 2}) == 4
2157+
gg = pt.grad(g, wrt=x0).sum()
2158+
assert gg.eval({seq: [1, 1], x0: [1, 1], z: 2}) == 12
2159+
assert gg.eval({seq: [2, 2], x0: [1, 1], z: 2}) == 96
2160+
2161+
# Leave out z
2162+
g_wrt_x0, g_wrt_seq = pt.grad(last_x, [x0, seq])
2163+
g = g_wrt_x0.sum() + g_wrt_seq.sum() * 0
2164+
gg = pt.grad(g, wrt=x0).sum()
2165+
assert gg.eval({seq: [1, 1], x0: [1, 1], z: 2}) == 12
2166+
assert gg.eval({seq: [2, 2], x0: [1, 1], z: 2}) == 96
2167+
2168+
# Leave out seq
2169+
g_wrt_x0, g_wrt_z = pt.grad(last_x, [x0, z])
2170+
g = g_wrt_x0.sum() + g_wrt_z.sum() * 0
2171+
gg = pt.grad(g, wrt=x0).sum()
2172+
assert gg.eval({seq: [1, 1], x0: [1, 1], z: 2}) == 12
2173+
assert gg.eval({seq: [1, 1], x0: [1, 1], z: 1}) == 3 / 2
2174+
2175+
# Leave out z and seq
2176+
g_wrt_x0 = pt.grad(last_x, x0)
2177+
g = g_wrt_x0.sum()
2178+
gg = pt.grad(g, wrt=x0).sum()
2179+
assert gg.eval({seq: [1, 1], x0: [1, 1], z: 2}) == 12
2180+
assert gg.eval({seq: [1, 1], x0: [1, 1], z: 1}) == 3 / 2
2181+
21312182

21322183
@pytest.mark.skipif(
21332184
not config.cxx, reason="G++ not available, so we need to skip this test."

0 commit comments

Comments
 (0)