Skip to content

Commit 7d0d271

Browse files
committed
fix reduce_sum when axis is not null
1 parent 6b154dd commit 7d0d271

File tree

4 files changed

+26
-6
lines changed

4 files changed

+26
-6
lines changed

src/TensorFlowNET.Core/APIs/tf.math.cs

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,8 +36,12 @@ public static Tensor pow<T1, T2>(T1 x, T2 y)
3636
/// <param name="input"></param>
3737
/// <param name="axis"></param>
3838
/// <returns></returns>
39-
public static Tensor reduce_sum(Tensor input, int[] axis = null, int? reduction_indices = null)
40-
=> math_ops.reduce_sum(input);
39+
public static Tensor reduce_sum(Tensor input, int? axis = null, int? reduction_indices = null)
40+
{
41+
if(!axis.HasValue && reduction_indices.HasValue)
42+
return math_ops.reduce_sum(input, reduction_indices.Value);
43+
return math_ops.reduce_sum(input);
44+
}
4145

4246
public static Tensor reduce_mean(Tensor input_tensor, int[] axis = null, bool keepdims = false, string name = null, int? reduction_indices = null)
4347
=> math_ops.reduce_mean(input_tensor, axis: axis, keepdims: keepdims, name: name, reduction_indices: reduction_indices);

src/TensorFlowNET.Core/Operations/gen_math_ops.cs

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -207,6 +207,13 @@ public static Tensor sum(Tensor input, Tensor axis = null, bool keep_dims = fals
207207
return _op.outputs[0];
208208
}
209209

210+
public static Tensor sum(Tensor input, int axis, bool keep_dims = false, string name = null)
211+
{
212+
var _op = _op_def_lib._apply_op_helper("Sum", name, args: new { input, reduction_indices = axis, keep_dims });
213+
214+
return _op.outputs[0];
215+
}
216+
210217
/// <summary>
211218
/// Creates a sequence of numbers.
212219
/// </summary>

src/TensorFlowNET.Core/Operations/math_ops.py.cs

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -209,6 +209,12 @@ public static Tensor reduce_sum(Tensor input_tensor, Tensor axis = null, bool ke
209209
return _may_reduce_to_scalar(keepdims, m);
210210
}
211211

212+
public static Tensor reduce_sum(Tensor input_tensor, int axis, bool keepdims = false)
213+
{
214+
var m = gen_math_ops.sum(input_tensor, axis);
215+
return _may_reduce_to_scalar(keepdims, m);
216+
}
217+
212218
private static Tensor _may_reduce_to_scalar(bool keepdims, Tensor output)
213219
{
214220
output.shape = new long[0];
@@ -233,7 +239,7 @@ private static Tensor _ReductionDims(Tensor x, Tensor axis)
233239
return range(0, rank, 1);
234240
}
235241
}
236-
242+
237243
private static Tensor _ReductionDims(Tensor x, int[] axis)
238244
{
239245
if (axis != null)

test/TensorFlowNET.Examples/LogisticRegression.cs

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ public void Run()
2626

2727
private void PrepareData()
2828
{
29-
var mnist = MnistDataSet.read_data_sets("logistic_regression", one_hot: true);
29+
//var mnist = MnistDataSet.read_data_sets("logistic_regression", one_hot: true);
3030

3131
// tf Graph Input
3232
var x = tf.placeholder(tf.float32, new TensorShape(-1, 784)); // mnist data image of shape 28*28=784
@@ -40,8 +40,11 @@ private void PrepareData()
4040
var pred = tf.nn.softmax(tf.matmul(x, W) + b); // Softmax
4141

4242
// Minimize error using cross entropy
43-
var sum = -tf.reduce_sum(y * tf.log(pred), reduction_indices: 1);
44-
var cost = tf.reduce_mean(sum);
43+
var log = tf.log(pred);
44+
var mul = y * log;
45+
var sum = tf.reduce_sum(mul, reduction_indices: 1);
46+
var neg = -sum;
47+
var cost = tf.reduce_mean(neg);
4548

4649
// Gradient Descent
4750
var optimizer = tf.train.GradientDescentOptimizer(learning_rate).minimize(cost);

0 commit comments

Comments
 (0)