@@ -123,10 +123,7 @@ public static Tensor[] _GradientsHelper(Tensor[] ys,
123
123
{
124
124
// generate gradient subgraph for op.
125
125
var op = queue . Dequeue ( ) ;
126
- if ( op . name == "rnn/while/Exit" )
127
- {
128
126
129
- }
130
127
_maybe_colocate_with ( op , gradient_uid , colocate_gradients_with_ops ) ;
131
128
{
132
129
if ( loop_state != null )
@@ -136,15 +133,14 @@ public static Tensor[] _GradientsHelper(Tensor[] ys,
136
133
loop_state . ExitGradWhileContext ( op , before : true ) ;
137
134
138
135
Tensor [ ] in_grads = null ;
136
+ Func < Operation , Tensor [ ] , Tensor [ ] > grad_fn = null ;
139
137
var is_partitioned_call = _IsPartitionedCall ( op ) ;
140
138
var is_func_call = false ;
141
139
var has_out_grads = out_grads . Exists ( x => x != null ) ;
142
140
if ( has_out_grads && ! stop_ops . Contains ( op ) )
143
141
{
144
142
// A grad_fn must be defined, either as a function or as None
145
143
// for ops that do not have gradients.
146
-
147
- Func < Operation , Tensor [ ] , Tensor [ ] > grad_fn = null ;
148
144
try
149
145
{
150
146
grad_fn = ops . get_gradient_function ( op ) ;
@@ -167,61 +163,57 @@ public static Tensor[] _GradientsHelper(Tensor[] ys,
167
163
throw new LookupError ( $ "No gradient defined for operation '{ op . name } ' (op type: { op . type } )") ;
168
164
}
169
165
}
166
+ }
170
167
171
- if ( loop_state != null )
172
- loop_state . EnterGradWhileContext ( op , before : false ) ;
168
+ if ( loop_state != null )
169
+ loop_state . EnterGradWhileContext ( op , before : false ) ;
173
170
174
- if ( ( is_func_call || grad_fn != null ) && has_out_grads )
171
+ if ( ( is_func_call || grad_fn != null ) && has_out_grads )
172
+ {
173
+ // NOTE: If _AggregatedGrads didn't compute a value for the i'th
174
+ // output, it means that the cost does not depend on output[i],
175
+ // therefore dC/doutput[i] is 0.
176
+ foreach ( var ( i , out_grad ) in enumerate ( out_grads ) )
175
177
{
176
- // NOTE: If _AggregatedGrads didn't compute a value for the i'th
177
- // output, it means that the cost does not depend on output[i],
178
- // therefore dC/doutput[i] is 0.
179
- foreach ( var ( i , out_grad ) in enumerate ( out_grads ) )
180
- {
181
- if ( out_grad == null &&
182
- ( grad_fn == null || _IsTrainable ( op . outputs [ i ] ) ) )
183
- {
184
- // Only trainable outputs or outputs for a function call that
185
- // will use SymbolicGradient get a zero gradient. Gradient
186
- // functions should ignore the gradient for other outputs.
187
- if ( loop_state != null )
188
- out_grads [ i ] = new List < Tensor > { loop_state . ZerosLike ( op , i ) } ;
189
- else
190
- out_grads [ i ] = new List < Tensor > { control_flow_ops . ZerosLikeOutsideLoop ( op , i ) } ;
191
- }
192
- }
193
-
194
- tf_with ( ops . name_scope ( op . name + "_grad" ) , scope1 =>
178
+ if ( out_grad == null &&
179
+ ( grad_fn == null || _IsTrainable ( op . outputs [ i ] ) ) )
195
180
{
196
- if ( grad_fn != null )
197
- {
198
- in_grads = _MaybeCompile ( grad_scope ,
199
- op ,
200
- out_grads . Where ( x => x != null ) . Select ( x => x [ 0 ] ) . ToArray ( ) ,
201
- null ,
202
- grad_fn ) ;
203
- }
181
+ // Only trainable outputs or outputs for a function call that
182
+ // will use SymbolicGradient get a zero gradient. Gradient
183
+ // functions should ignore the gradient for other outputs.
184
+ if ( loop_state != null )
185
+ out_grads [ i ] = new List < Tensor > { loop_state . ZerosLike ( op , i ) } ;
204
186
else
205
- {
206
- throw new NotImplementedException ( "lambda: _SymGrad(op, out_grads)" ) ;
207
- }
208
- _VerifyGeneratedGradients ( in_grads , op ) ;
209
- if ( gate_gradients && in_grads . Count ( x => x != null ) > 1 )
210
- {
211
- ops . _colocate_with_for_gradient ( null , gradient_uid , ignore_existing : true ) ;
212
- in_grads = control_flow_ops . tuple ( in_grads ) ;
213
- }
214
- } ) ;
187
+ out_grads [ i ] = new List < Tensor > { control_flow_ops . ZerosLikeOutsideLoop ( op , i ) } ;
188
+ }
215
189
}
216
- else
190
+
191
+ tf_with ( ops . name_scope ( op . name + "_grad" ) , scope1 =>
217
192
{
218
- // If no grad_fn is defined or none of out_grads is available,
219
- // just propagate a list of None backwards.
220
- in_grads = new Tensor [ _NonEagerInputs ( op , xs ) . Count ( ) ] ;
221
- }
193
+ if ( grad_fn != null )
194
+ {
195
+ in_grads = _MaybeCompile ( grad_scope ,
196
+ op ,
197
+ out_grads . Where ( x => x != null ) . Select ( x => x [ 0 ] ) . ToArray ( ) ,
198
+ null ,
199
+ grad_fn ) ;
200
+ }
201
+ else
202
+ {
203
+ throw new NotImplementedException ( "lambda: _SymGrad(op, out_grads)" ) ;
204
+ }
205
+ _VerifyGeneratedGradients ( in_grads , op ) ;
206
+ if ( gate_gradients && in_grads . Count ( x => x != null ) > 1 )
207
+ {
208
+ ops . _colocate_with_for_gradient ( null , gradient_uid , ignore_existing : true ) ;
209
+ in_grads = control_flow_ops . tuple ( in_grads ) ;
210
+ }
211
+ } ) ;
222
212
}
223
213
else
224
214
{
215
+ // If no grad_fn is defined or none of out_grads is available,
216
+ // just propagate a list of None backwards.
225
217
in_grads = new Tensor [ _NonEagerInputs ( op , xs ) . Count ( ) ] ;
226
218
}
227
219
@@ -370,7 +362,16 @@ private static void _SetGrad(Dictionary<string, List<List<Tensor>>> grads, Tenso
370
362
grads [ op . name ] = op_grads ;
371
363
}
372
364
var t_grads = op_grads [ t . value_index ] ;
373
- t_grads . Add ( grad ) ;
365
+ if ( t_grads . Count == 0 )
366
+ t_grads . Add ( grad ) ;
367
+ else
368
+ op_grads [ t . value_index ] [ 0 ] = grad ;
369
+
370
+ /*if (control_flow_util.IsLoopSwitch(op) &&
371
+ t_grads[0] == null)
372
+ op_grads[t.value_index] = new List<Tensor> { grad };
373
+ else
374
+ t_grads.Add(grad);*/
374
375
}
375
376
376
377
private static IEnumerable < Tensor > _NonEagerInputs ( Operation op , Tensor [ ] xs )
@@ -379,15 +380,19 @@ private static IEnumerable<Tensor> _NonEagerInputs(Operation op, Tensor[] xs)
379
380
yield return op . inputs [ i ] ;
380
381
}
381
382
382
- private static List < List < Tensor > > _AggregatedGrads ( Dictionary < string , List < List < Tensor > > > grads , Operation op , string gradient_uid , object loop_state , int aggregation_method = 0 )
383
+ private static List < List < Tensor > > _AggregatedGrads ( Dictionary < string , List < List < Tensor > > > grads , Operation op , string gradient_uid ,
384
+ ControlFlowState loop_state , int aggregation_method = 0 )
383
385
{
384
386
var out_grads = _GetGrads ( grads , op ) ;
385
387
386
388
foreach ( var ( i , out_grad ) in enumerate ( out_grads ) )
387
389
{
388
390
if ( loop_state != null )
389
391
{
390
-
392
+ if ( out_grads . Count > 1 &&
393
+ out_grads [ 1 ] . Count > 0 &&
394
+ control_flow_util . IsLoopSwitch ( op ) )
395
+ continue ;
391
396
}
392
397
393
398
// Aggregate multiple gradients, and convert [] to None.
0 commit comments