Skip to content

Commit 10fda1f

Browse files
mishunOceania2018
authored andcommitted
Fix reduce_sum test case
1 parent 427a2f9 commit 10fda1f

File tree

1 file changed

+6
-6
lines changed

1 file changed

+6
-6
lines changed

test/TensorFlowNET.Graph.UnitTest/GradientTest/GradientTest.cs

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -185,24 +185,24 @@ import tensorflow.compat.v1 as tf
185185
x = tf.placeholder(tf.float64, shape = (1, 1))
186186
m = tf.broadcast_to(x, (2, 3))
187187
g0 = tf.gradients(tf.reduce_sum(m), x)[0]
188-
g1 = tf.gradients(tf.reduce_sum(m, axis = 0), x)[0]
189-
g2 = tf.gradients(tf.reduce_sum(m, axis = 1), x)[0]
188+
g1 = tf.gradients(tf.reduce_sum(m, axis = 0)[0], x)[0]
189+
g2 = tf.gradients(tf.reduce_sum(m, axis = 1)[0], x)[0]
190190
with tf.compat.v1.Session() as sess:
191191
(r0, r1, r2) = sess.run((g0, g1, g2), {x: [[1.0]]})
192192
*/
193193

194194
var x = tf.placeholder(tf.float64, shape: new Shape(1, 1));
195195
var m = tf.broadcast_to(x, new Shape(2, 3));
196196
var g0 = tf.gradients(tf.reduce_sum(m), x)[0];
197-
var g1 = tf.gradients(tf.reduce_sum(m, axis: 0), x)[0];
198-
var g2 = tf.gradients(tf.reduce_sum(m, axis: 1), x)[0];
197+
var g1 = tf.gradients(tf.reduce_sum(m, axis: 0)[0], x)[0];
198+
var g2 = tf.gradients(tf.reduce_sum(m, axis: 1)[0], x)[0];
199199

200200
using (var session = tf.Session())
201201
{
202202
var (r0, r1, r2) = session.run((g0, g1, g2), new FeedItem(x, new[,] { { 1.0 } }));
203203
self.assertFloat64Equal(6.0, r0[0], $"tf.reduce_sum(...)");
204-
self.assertFloat64Equal(6.0, r1[0], $"tf.reduce_sum(..., axis = 0)");
205-
self.assertFloat64Equal(6.0, r2[0], $"tf.reduce_sum(..., axis = 1)");
204+
self.assertFloat64Equal(2.0, r1[0], $"tf.reduce_sum(..., axis = 0)");
205+
self.assertFloat64Equal(3.0, r2[0], $"tf.reduce_sum(..., axis = 1)");
206206
}
207207
}
208208

0 commit comments

Comments
 (0)