@@ -2509,36 +2509,40 @@ def compute_all_gradients(known_grads):
2509
2509
return rval
2510
2510
2511
2511
var_mappings = self .get_oinp_iinp_iout_oout_mappings ()
2512
- dC_dinps_t = [None for inp in diff_inputs ]
2513
2512
disconnected_dC_dinps_t = [True for inp in diff_inputs ]
2513
+
2514
+ n_mit_mot_outs = info .n_mit_mot_outs
2515
+ # In the case of mit-mot there can be more inner outputs than outer ones
2516
+ n_extra_mit_mot_outs = n_mit_mot_outs - info .n_mit_mot
2517
+ idx_nitsot_out_start = n_mit_mot_outs + info .n_mit_sot + info .n_sit_sot
2518
+ idx_nitsot_out_end = idx_nitsot_out_start + info .n_nit_sot
2519
+
2520
+ # Create dummy variables for the internal input gradients
2521
+ states = (
2522
+ self .inner_mitmot (self_inputs )
2523
+ + self .inner_mitsot (self_inputs )
2524
+ + self .inner_sitsot (self_inputs )
2525
+ )
2514
2526
dC_dXts = []
2515
2527
Xts = []
2516
2528
for idx , Xt in enumerate (diff_outputs ):
2517
2529
# We are looking for x[t-1] for a given x[t]
2518
- if idx >= info . n_mit_mot_outs :
2530
+ if idx >= n_mit_mot_outs :
2519
2531
Xt_placeholder = safe_new (Xt )
2520
2532
Xts .append (Xt_placeholder )
2521
2533
2522
2534
# Different processing based on whether Xt is a nitsot output
2523
2535
# or not. NOTE : This cannot be done by using
2524
2536
# "if Xt not in self.inner_nitsot_outs(self_outputs)" because
2525
2537
# the exact same variable can be used as multiple outputs.
2526
- idx_nitsot_start = info .n_mit_mot + info .n_mit_sot + info .n_sit_sot
2527
- idx_nitsot_end = idx_nitsot_start + info .n_nit_sot
2528
- if idx < idx_nitsot_start or idx >= idx_nitsot_end :
2538
+ if idx < idx_nitsot_out_start or idx >= idx_nitsot_out_end :
2529
2539
# What we do here is loop through dC_douts and collect all
2530
2540
# those that are connected to the specific one and do an
2531
2541
# upcast on all of their dtypes to get the dtype for this
2532
2542
# specific output. Deciding if the gradient with this
2533
2543
# specific previous step is defined or not is done somewhere
2534
2544
# else.
2535
2545
dtypes = []
2536
- states = (
2537
- self .inner_mitmot (self_inputs )
2538
- + self .inner_mitsot (self_inputs )
2539
- + self .inner_sitsot (self_inputs )
2540
- )
2541
-
2542
2546
for pos , inp in enumerate (states ):
2543
2547
if inp in graph_inputs ([Xt ]):
2544
2548
# Get the index of the outer output that to which
@@ -2555,35 +2559,39 @@ def compute_all_gradients(known_grads):
2555
2559
new_dtype = config .floatX
2556
2560
dC_dXt = safe_new (Xt , dtype = new_dtype )
2557
2561
else :
2558
- if isinstance (dC_douts [idx ].type , DisconnectedType ):
2562
+ # nit-sot outputs
2563
+ # If not disconnected assume the output gradient type is a valid type for the input gradient
2564
+ if isinstance (
2565
+ dC_douts [idx - n_extra_mit_mot_outs ].type , DisconnectedType
2566
+ ):
2559
2567
continue
2560
- dC_dXt = safe_new (dC_douts [idx ][0 ])
2568
+ dC_dXt = safe_new (dC_douts [idx - n_extra_mit_mot_outs ][0 ])
2561
2569
dC_dXts .append (dC_dXt )
2562
2570
2571
+ # Handle cases where the very same variable may be used as different outputs
2572
+ # TODO: Couldn't we add a view Op to avoid this when building the Scan graph?
2563
2573
known_grads = {}
2564
2574
dc_dxts_idx = 0
2565
2575
for i in range (len (diff_outputs )):
2566
- if i < idx_nitsot_start or i >= idx_nitsot_end :
2567
- if diff_outputs [i ] in known_grads :
2568
- known_grads [diff_outputs [i ]] += dC_dXts [dc_dxts_idx ]
2569
- else :
2570
- known_grads [diff_outputs [i ]] = dC_dXts [dc_dxts_idx ]
2571
- dc_dxts_idx += 1
2576
+ if not (i < idx_nitsot_out_start or i >= idx_nitsot_out_end ) and isinstance (
2577
+ dC_douts [i - n_extra_mit_mot_outs ].type , DisconnectedType
2578
+ ):
2579
+ # Special case where we don't have a dC_dXt for disconnected nitsot outputs
2580
+ continue
2581
+
2582
+ # Just some trouble to avoid a +0
2583
+ if diff_outputs [i ] in known_grads :
2584
+ known_grads [diff_outputs [i ]] += dC_dXts [dc_dxts_idx ]
2572
2585
else :
2573
- if isinstance (dC_douts [i ].type , DisconnectedType ):
2574
- continue
2575
- else :
2576
- if diff_outputs [i ] in known_grads :
2577
- known_grads [diff_outputs [i ]] += dC_dXts [dc_dxts_idx ]
2578
- else :
2579
- known_grads [diff_outputs [i ]] = dC_dXts [dc_dxts_idx ]
2580
- dc_dxts_idx += 1
2586
+ known_grads [diff_outputs [i ]] = dC_dXts [dc_dxts_idx ]
2587
+ dc_dxts_idx += 1
2588
+
2581
2589
dC_dinps_t = compute_all_gradients (known_grads )
2582
2590
2583
2591
# mask inputs that get no gradients
2584
2592
for dx in range (len (dC_dinps_t )):
2585
- if not dC_dinps_t [dx ]:
2586
- dC_dinps_t [dx ] = pt .zeros_like (diff_inputs [dx ])
2593
+ if dC_dinps_t [dx ] is None :
2594
+ dC_dinps_t [dx ] = dC_dinps_t [ dx ] = pt .zeros_like (diff_inputs [dx ])
2587
2595
else :
2588
2596
disconnected_dC_dinps_t [dx ] = False
2589
2597
for Xt , Xt_placeholder in zip (
@@ -2846,7 +2854,6 @@ def compute_all_gradients(known_grads):
2846
2854
for idx in range (info .n_sit_sot ):
2847
2855
mitmot_inp_taps .append ([0 , 1 ])
2848
2856
mitmot_out_taps .append ([1 ])
2849
- through_shared = False
2850
2857
if not isinstance (dC_douts [idx + offset ].type , DisconnectedType ):
2851
2858
outer_inp_mitmot .append (dC_douts [idx + offset ][::- 1 ])
2852
2859
else :
@@ -3007,9 +3014,7 @@ def compute_all_gradients(known_grads):
3007
3014
name = f"grad_of_{ self .name } " if self .name else None ,
3008
3015
allow_gc = self .allow_gc ,
3009
3016
)
3010
- outputs = local_op (* outer_inputs )
3011
- if not isinstance (outputs , list | tuple ):
3012
- outputs = [outputs ]
3017
+ outputs = local_op (* outer_inputs , return_list = True )
3013
3018
# Re-order the gradients correctly
3014
3019
gradients = [DisconnectedType ()()]
3015
3020
@@ -3095,7 +3100,6 @@ def compute_all_gradients(known_grads):
3095
3100
)
3096
3101
)
3097
3102
3098
- start = len (gradients )
3099
3103
gradients += [DisconnectedType ()() for _ in range (info .n_nit_sot )]
3100
3104
begin = end
3101
3105
0 commit comments