Skip to content

Commit 8249d8a

Browse files
mishunOceania2018
authored andcommitted
Add test case for tf.reduce_sum(..., axis = ...)
1 parent 5a8b1cb commit 8249d8a

File tree

1 file changed

+18
-0
lines changed

1 file changed

+18
-0
lines changed

test/TensorFlowNET.Graph.UnitTest/GradientTest/GradientTest.cs

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -175,6 +175,24 @@ void test(string name, Func<Tensor, Tensor> tfF, Func<double, (double, double)>
175175
new[] { -1.0, 1.0 });
176176
}
177177

178+
[TestMethod]
179+
public void testReduceSumGradients()
180+
{
181+
var x = tf.placeholder(tf.float64, shape: new Shape(1, 1));
182+
var m = tf.broadcast_to(x, new Shape(2, 3));
183+
var g0 = tf.gradients(tf.reduce_sum(m), x)[0];
184+
var g1 = tf.gradients(tf.reduce_sum(m, axis: 0), x)[0];
185+
var g2 = tf.gradients(tf.reduce_sum(m, axis: 1), x)[0];
186+
187+
using (var session = tf.Session())
188+
{
189+
var (r0, r1, r2) = session.run((g0, g1, g2), new FeedItem(x, 1.0));
190+
self.assertFloat64Equal(6.0, r0[0], $"tf.reduce_sum(...)");
191+
self.assertFloat64Equal(2.0, r1[0], $"tf.reduce_sum(..., axis = 0)");
192+
self.assertFloat64Equal(3.0, r2[0], $"tf.reduce_sum(..., axis = 1)");
193+
}
194+
}
195+
178196
[TestMethod]
179197
public void testTanhGradient()
180198
{

0 commit comments

Comments
 (0)