Skip to content

Commit 427a2f9

Browse files
committed
tf.reduce_sum #917
1 parent 8249d8a commit 427a2f9

File tree

3 files changed

+30
-7
lines changed

3 files changed

+30
-7
lines changed

src/TensorFlowNET.Core/Gradients/math_grad.cs

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -529,7 +529,12 @@ public static Tensor[] _SumGrad(Operation op, Tensor[] grads)
529529
}
530530
else if (!input_0_shape.Contains(-1) && !tf.Context.executing_eagerly())
531531
{
532-
throw new NotImplementedException("");
532+
axes = axes.reshape(new Shape(-1));
533+
var shape_tensor = tf.constant(op.inputs[0].shape.as_int_list());
534+
var output_shape_kept_dims = math_ops.reduced_shape(shape_tensor, axes);
535+
var tile_scaling = _safe_shape_div(shape_tensor, output_shape_kept_dims);
536+
grad = array_ops.reshape(grad, output_shape_kept_dims);
537+
return new Tensor[] { array_ops.tile(grad, tile_scaling), null };
533538
}
534539
}
535540
}

src/TensorFlowNET.Core/Operations/array_ops.cs

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -585,9 +585,14 @@ private static Tensor size_internal<T>(T input, string name = null, bool optimiz
585585
}
586586

587587
public static Tensor tile(Tensor input, Tensor multiples, string name = null)
588-
{
589-
throw new NotImplementedException("tile");
590-
}
588+
=> tf.Context.ExecuteOp("Tile", name, new ExecuteOpArgs(input, multiples)
589+
{
590+
GetGradientAttrs = (op) => new
591+
{
592+
T = op.get_attr<TF_DataType>("T"),
593+
Tmultiples = op.get_attr<TF_DataType>("Tmultiples")
594+
}
595+
});
591596

592597
public static Tensor zeros_like(Tensor tensor, TF_DataType dtype = TF_DataType.DtInvalid, string name = null, bool optimize = true)
593598
{

test/TensorFlowNET.Graph.UnitTest/GradientTest/GradientTest.cs

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -178,6 +178,19 @@ void test(string name, Func<Tensor, Tensor> tfF, Func<double, (double, double)>
178178
[TestMethod]
179179
public void testReduceSumGradients()
180180
{
181+
/* python code
182+
import tensorflow.compat.v1 as tf
183+
tf.disable_v2_behavior()
184+
185+
x = tf.placeholder(tf.float64, shape = (1, 1))
186+
m = tf.broadcast_to(x, (2, 3))
187+
g0 = tf.gradients(tf.reduce_sum(m), x)[0]
188+
g1 = tf.gradients(tf.reduce_sum(m, axis = 0), x)[0]
189+
g2 = tf.gradients(tf.reduce_sum(m, axis = 1), x)[0]
190+
with tf.compat.v1.Session() as sess:
191+
(r0, r1, r2) = sess.run((g0, g1, g2), {x: [[1.0]]})
192+
*/
193+
181194
var x = tf.placeholder(tf.float64, shape: new Shape(1, 1));
182195
var m = tf.broadcast_to(x, new Shape(2, 3));
183196
var g0 = tf.gradients(tf.reduce_sum(m), x)[0];
@@ -186,10 +199,10 @@ public void testReduceSumGradients()
186199

187200
using (var session = tf.Session())
188201
{
189-
var (r0, r1, r2) = session.run((g0, g1, g2), new FeedItem(x, 1.0));
202+
var (r0, r1, r2) = session.run((g0, g1, g2), new FeedItem(x, new[,] { { 1.0 } }));
190203
self.assertFloat64Equal(6.0, r0[0], $"tf.reduce_sum(...)");
191-
self.assertFloat64Equal(2.0, r1[0], $"tf.reduce_sum(..., axis = 0)");
192-
self.assertFloat64Equal(3.0, r2[0], $"tf.reduce_sum(..., axis = 1)");
204+
self.assertFloat64Equal(6.0, r1[0], $"tf.reduce_sum(..., axis = 0)");
205+
self.assertFloat64Equal(6.0, r2[0], $"tf.reduce_sum(..., axis = 1)");
193206
}
194207
}
195208

0 commit comments

Comments
 (0)