Skip to content

Commit ad250d0

Browse files
committed
implement _SwitchGrad when merge_grad is not null.
1 parent fcd2cd6 commit ad250d0

File tree

3 files changed

+67
-68
lines changed

3 files changed

+67
-68
lines changed

src/TensorFlowNET.Core/Gradients/control_flow_grad.cs

Lines changed: 8 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,12 @@ public static Tensor[] _SwitchGrad(Operation op, Tensor[] grads)
4848
{
4949
var merge_grad = grad_ctxt.grad_state.switch_map.get(op);
5050
if (merge_grad != null)
51-
throw new NotImplementedException("_SwitchGrad merge_grad != null");
51+
{
52+
if (grads[1] != null)
53+
control_flow_ops._AddNextAndBackEdge(merge_grad, grads[1],
54+
enforce_shape_invariant: false);
55+
return new Tensor[] { null, null };
56+
}
5257
else if (grads[0] != null)
5358
{
5459
merge_grad = merge(new[] { grads[0], grads[0] }, name: "b_switch")[0];
@@ -233,17 +238,9 @@ public static Tensor[] _EnterGrad(Operation op, Tensor[] grads)
233238
return grads;
234239
if (op.get_attr<bool>("is_constant"))
235240
{
236-
throw new NotImplementedException("_EnterGrad is_constant");
237-
// Add a gradient accumulator for each loop invariant.
238-
// if isinstance(grad, ops.Tensor) :
239-
// result = grad_ctxt.AddBackpropAccumulator(op, grad)
240-
// elif isinstance(grad, ops.IndexedSlices) :
241-
// result = grad_ctxt.AddBackpropIndexedSlicesAccumulator(op, grad)
242-
// else:
243-
// # TODO(yuanbyu, lukasr): Add support for SparseTensor.
244-
// raise TypeError("Type %s not supported" % type(grad))
241+
// Add a gradient accumulator for each loop invariant.
242+
result = grad_ctxt.AddBackpropAccumulator(op, grad);
245243
}
246-
247244
else
248245
{
249246
result = control_flow_ops.exit(grad);

src/TensorFlowNET.Core/Gradients/gradients_util.cs

Lines changed: 58 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -123,10 +123,7 @@ public static Tensor[] _GradientsHelper(Tensor[] ys,
123123
{
124124
// generate gradient subgraph for op.
125125
var op = queue.Dequeue();
126-
if(op.name == "rnn/while/Exit")
127-
{
128126

129-
}
130127
_maybe_colocate_with(op, gradient_uid, colocate_gradients_with_ops);
131128
{
132129
if (loop_state != null)
@@ -136,15 +133,14 @@ public static Tensor[] _GradientsHelper(Tensor[] ys,
136133
loop_state.ExitGradWhileContext(op, before: true);
137134

138135
Tensor[] in_grads = null;
136+
Func<Operation, Tensor[], Tensor[]> grad_fn = null;
139137
var is_partitioned_call = _IsPartitionedCall(op);
140138
var is_func_call = false;
141139
var has_out_grads = out_grads.Exists(x => x != null);
142140
if (has_out_grads && !stop_ops.Contains(op))
143141
{
144142
// A grad_fn must be defined, either as a function or as None
145143
// for ops that do not have gradients.
146-
147-
Func<Operation, Tensor[], Tensor[]> grad_fn = null;
148144
try
149145
{
150146
grad_fn = ops.get_gradient_function(op);
@@ -167,61 +163,57 @@ public static Tensor[] _GradientsHelper(Tensor[] ys,
167163
throw new LookupError($"No gradient defined for operation '{op.name}' (op type: {op.type})");
168164
}
169165
}
166+
}
170167

171-
if (loop_state != null)
172-
loop_state.EnterGradWhileContext(op, before: false);
168+
if (loop_state != null)
169+
loop_state.EnterGradWhileContext(op, before: false);
173170

174-
if ((is_func_call || grad_fn != null) && has_out_grads)
171+
if ((is_func_call || grad_fn != null) && has_out_grads)
172+
{
173+
// NOTE: If _AggregatedGrads didn't compute a value for the i'th
174+
// output, it means that the cost does not depend on output[i],
175+
// therefore dC/doutput[i] is 0.
176+
foreach (var (i, out_grad) in enumerate(out_grads))
175177
{
176-
// NOTE: If _AggregatedGrads didn't compute a value for the i'th
177-
// output, it means that the cost does not depend on output[i],
178-
// therefore dC/doutput[i] is 0.
179-
foreach (var (i, out_grad) in enumerate(out_grads))
180-
{
181-
if (out_grad == null &&
182-
(grad_fn == null || _IsTrainable(op.outputs[i])))
183-
{
184-
// Only trainable outputs or outputs for a function call that
185-
// will use SymbolicGradient get a zero gradient. Gradient
186-
// functions should ignore the gradient for other outputs.
187-
if (loop_state != null)
188-
out_grads[i] = new List<Tensor> { loop_state.ZerosLike(op, i) };
189-
else
190-
out_grads[i] = new List<Tensor> { control_flow_ops.ZerosLikeOutsideLoop(op, i) };
191-
}
192-
}
193-
194-
tf_with(ops.name_scope(op.name + "_grad"), scope1 =>
178+
if (out_grad == null &&
179+
(grad_fn == null || _IsTrainable(op.outputs[i])))
195180
{
196-
if (grad_fn != null)
197-
{
198-
in_grads = _MaybeCompile(grad_scope,
199-
op,
200-
out_grads.Where(x => x != null).Select(x => x[0]).ToArray(),
201-
null,
202-
grad_fn);
203-
}
181+
// Only trainable outputs or outputs for a function call that
182+
// will use SymbolicGradient get a zero gradient. Gradient
183+
// functions should ignore the gradient for other outputs.
184+
if (loop_state != null)
185+
out_grads[i] = new List<Tensor> { loop_state.ZerosLike(op, i) };
204186
else
205-
{
206-
throw new NotImplementedException("lambda: _SymGrad(op, out_grads)");
207-
}
208-
_VerifyGeneratedGradients(in_grads, op);
209-
if (gate_gradients && in_grads.Count(x => x != null) > 1)
210-
{
211-
ops._colocate_with_for_gradient(null, gradient_uid, ignore_existing: true);
212-
in_grads = control_flow_ops.tuple(in_grads);
213-
}
214-
});
187+
out_grads[i] = new List<Tensor> { control_flow_ops.ZerosLikeOutsideLoop(op, i) };
188+
}
215189
}
216-
else
190+
191+
tf_with(ops.name_scope(op.name + "_grad"), scope1 =>
217192
{
218-
// If no grad_fn is defined or none of out_grads is available,
219-
// just propagate a list of None backwards.
220-
in_grads = new Tensor[_NonEagerInputs(op, xs).Count()];
221-
}
193+
if (grad_fn != null)
194+
{
195+
in_grads = _MaybeCompile(grad_scope,
196+
op,
197+
out_grads.Where(x => x != null).Select(x => x[0]).ToArray(),
198+
null,
199+
grad_fn);
200+
}
201+
else
202+
{
203+
throw new NotImplementedException("lambda: _SymGrad(op, out_grads)");
204+
}
205+
_VerifyGeneratedGradients(in_grads, op);
206+
if (gate_gradients && in_grads.Count(x => x != null) > 1)
207+
{
208+
ops._colocate_with_for_gradient(null, gradient_uid, ignore_existing: true);
209+
in_grads = control_flow_ops.tuple(in_grads);
210+
}
211+
});
222212
}
223213
else
224214
{
215+
// If no grad_fn is defined or none of out_grads is available,
216+
// just propagate a list of None backwards.
225217
in_grads = new Tensor[_NonEagerInputs(op, xs).Count()];
226218
}
227219

@@ -370,7 +362,16 @@ private static void _SetGrad(Dictionary<string, List<List<Tensor>>> grads, Tenso
370362
grads[op.name] = op_grads;
371363
}
372364
var t_grads = op_grads[t.value_index];
373-
t_grads.Add(grad);
365+
if (t_grads.Count == 0)
366+
t_grads.Add(grad);
367+
else
368+
op_grads[t.value_index][0] = grad;
369+
370+
/*if (control_flow_util.IsLoopSwitch(op) &&
371+
t_grads[0] == null)
372+
op_grads[t.value_index] = new List<Tensor> { grad };
373+
else
374+
t_grads.Add(grad);*/
374375
}
375376

376377
private static IEnumerable<Tensor> _NonEagerInputs(Operation op, Tensor[] xs)
@@ -379,15 +380,19 @@ private static IEnumerable<Tensor> _NonEagerInputs(Operation op, Tensor[] xs)
379380
yield return op.inputs[i];
380381
}
381382

382-
private static List<List<Tensor>> _AggregatedGrads(Dictionary<string, List<List<Tensor>>> grads, Operation op, string gradient_uid, object loop_state, int aggregation_method = 0)
383+
private static List<List<Tensor>> _AggregatedGrads(Dictionary<string, List<List<Tensor>>> grads, Operation op, string gradient_uid,
384+
ControlFlowState loop_state, int aggregation_method = 0)
383385
{
384386
var out_grads = _GetGrads(grads, op);
385387

386388
foreach (var (i, out_grad) in enumerate(out_grads))
387389
{
388390
if (loop_state != null)
389391
{
390-
392+
if (out_grads.Count > 1 &&
393+
out_grads[1].Count > 0 &&
394+
control_flow_util.IsLoopSwitch(op))
395+
continue;
391396
}
392397

393398
// Aggregate multiple gradients, and convert [] to None.

src/TensorFlowNET.Core/Operations/Operation.cs

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -182,10 +182,7 @@ public Operation(NodeDef node_def, Graph g, Tensor[] inputs = null, TF_DataType[
182182
// This will be set by self.inputs.
183183
if (op_def == null)
184184
op_def = g.GetOpDef(node_def.Op);
185-
if(node_def.Name == "gradients/rnn/while/basic_rnn_cell/Tanh_grad/TanhGrad/f_acc")
186-
{
187-
188-
}
185+
189186
var grouped_inputs = _reconstruct_sequence_inputs(op_def, inputs, node_def.Attr);
190187
_handle = ops._create_c_op(g, node_def, grouped_inputs, control_input_ops.ToArray());
191188
_is_stateful = op_def.IsStateful;

0 commit comments

Comments
 (0)