Skip to content

Commit 99bd08b

Browse files
committed
fix object reference issue for _AggregatedGrads #303
1 parent 4a36747 commit 99bd08b

File tree

4 files changed

+46
-14
lines changed

4 files changed

+46
-14
lines changed
Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
<Project Sdk="Microsoft.NET.Sdk">
1+
<Project Sdk="Microsoft.NET.Sdk">
22
<PropertyGroup>
33
<AssemblyName>TensorFlow.Net.Hub</AssemblyName>
44
<RootNamespace>Tensorflow.Hub</RootNamespace>
@@ -8,7 +8,7 @@
88
<ProjectReference Include="..\..\src\TensorFlowNET.Core\TensorFlowNET.Core.csproj" />
99
</ItemGroup>
1010
<ItemGroup>
11-
<PackageReference Include="NumSharp" Version="0.10.4" />
11+
<PackageReference Include="NumSharp" Version="0.10.5" />
1212
<PackageReference Include="sharpcompress" Version="0.23.0" />
1313
</ItemGroup>
1414
</Project>

src/TensorFlowNET.Core/Gradients/gradients_util.cs

Lines changed: 10 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -137,7 +137,7 @@ public static Tensor[] _GradientsHelper(Tensor[] ys,
137137
if (loop_state != null)
138138
;
139139
else
140-
out_grads[i] = control_flow_ops.ZerosLikeOutsideLoop(op, i);
140+
out_grads[i] = new List<Tensor> { control_flow_ops.ZerosLikeOutsideLoop(op, i) };
141141
}
142142
}
143143

@@ -146,7 +146,7 @@ public static Tensor[] _GradientsHelper(Tensor[] ys,
146146
string name1 = scope1;
147147
if (grad_fn != null)
148148
{
149-
in_grads = _MaybeCompile(grad_scope, op, out_grads, null, grad_fn);
149+
in_grads = _MaybeCompile(grad_scope, op, out_grads[0].ToArray(), null, grad_fn);
150150
_VerifyGeneratedGradients(in_grads, op);
151151
}
152152

@@ -310,10 +310,9 @@ private static IEnumerable<Tensor> _NonEagerInputs(Operation op, Tensor[] xs)
310310
yield return op.inputs[i];
311311
}
312312

313-
private static Tensor[] _AggregatedGrads(Dictionary<string, List<List<Tensor>>> grads, Operation op, string gradient_uid, object loop_state, int aggregation_method = 0)
313+
private static List<List<Tensor>> _AggregatedGrads(Dictionary<string, List<List<Tensor>>> grads, Operation op, string gradient_uid, object loop_state, int aggregation_method = 0)
314314
{
315315
var out_grads = _GetGrads(grads, op);
316-
var return_grads = new Tensor[out_grads.Count];
317316

318317
foreach (var (i, out_grad) in enumerate(out_grads))
319318
{
@@ -334,21 +333,21 @@ private static Tensor[] _AggregatedGrads(Dictionary<string, List<List<Tensor>>>
334333
throw new ValueError("_AggregatedGrads out_grad.Length == 0");
335334
}
336335

337-
return_grads[i] = out_grad[0];
336+
out_grads[i] = out_grad;
338337
}
339338
else
340339
{
341340
used = "add_n";
342-
return_grads[i] = _MultiDeviceAddN(out_grad.ToArray(), gradient_uid);
341+
out_grads[i] = new List<Tensor> { _MultiDeviceAddN(out_grad.ToArray(), gradient_uid) };
343342
}
344343
}
345344
else
346345
{
347-
return_grads[i] = null;
346+
out_grads[i] = null;
348347
}
349348
}
350349

351-
return return_grads;
350+
return out_grads;
352351
}
353352

354353
/// <summary>
@@ -362,18 +361,18 @@ private static Tensor _MultiDeviceAddN(Tensor[] tensor_list, string gradient_uid
362361
// Basic function structure comes from control_flow_ops.group().
363362
// Sort tensors according to their devices.
364363
var tensors_on_device = new Dictionary<string, List<Tensor>>();
365-
364+
366365
foreach (var tensor in tensor_list)
367366
{
368367
if (!tensors_on_device.ContainsKey(tensor.Device))
369368
tensors_on_device[tensor.Device] = new List<Tensor>();
370369

371370
tensors_on_device[tensor.Device].Add(tensor);
372371
}
373-
372+
374373
// For each device, add the tensors on that device first.
375374
var summands = new List<Tensor>();
376-
foreach(var dev in tensors_on_device.Keys)
375+
foreach (var dev in tensors_on_device.Keys)
377376
{
378377
var tensors = tensors_on_device[dev];
379378
ops._colocate_with_for_gradient(tensors[0].op, gradient_uid, ignore_existing: true);

test/TensorFlowNET.Examples/ImageProcessing/DigitRecognitionRNN.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ namespace TensorFlowNET.Examples.ImageProcess
2828
/// </summary>
2929
public class DigitRecognitionRNN : IExample
3030
{
31-
public bool Enabled { get; set; } = true;
31+
public bool Enabled { get; set; } = false;
3232
public bool IsImportingGraph { get; set; } = false;
3333

3434
public string Name => "MNIST RNN";

test/TensorFlowNET.UnitTest/GradientTest.cs

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
using NumSharp;
33
using System.Linq;
44
using Tensorflow;
5+
using static Tensorflow.Python;
56

67
namespace TensorFlowNET.UnitTest
78
{
@@ -28,6 +29,38 @@ public void Gradients()
2829
Assert.AreEqual(g[1].name, "gradients/Fill:0");
2930
}
3031

32+
[TestMethod]
33+
public void Gradient2x()
34+
{
35+
var graph = tf.Graph().as_default();
36+
with(tf.Session(graph), sess => {
37+
var x = tf.constant(7.0f);
38+
var y = x * x * tf.constant(0.1f);
39+
40+
var grad = tf.gradients(y, x);
41+
Assert.AreEqual(grad[0].name, "gradients/AddN:0");
42+
43+
float r = sess.run(grad[0]);
44+
Assert.AreEqual(r, 1.4f);
45+
});
46+
}
47+
48+
[TestMethod]
49+
public void Gradient3x()
50+
{
51+
var graph = tf.Graph().as_default();
52+
with(tf.Session(graph), sess => {
53+
var x = tf.constant(7.0f);
54+
var y = x * x * x * tf.constant(0.1f);
55+
56+
var grad = tf.gradients(y, x);
57+
Assert.AreEqual(grad[0].name, "gradients/AddN:0");
58+
59+
float r = sess.run(grad[0]);
60+
Assert.AreEqual(r, 14.700001f);
61+
});
62+
}
63+
3164
[TestMethod]
3265
public void StridedSlice()
3366
{

0 commit comments

Comments
 (0)