Skip to content

Commit 7753183

Browse files
committed
Shape error for gradients/Sum_grad/Tile #193
1 parent 369d8ab commit 7753183

File tree

9 files changed

+27
-22
lines changed

9 files changed

+27
-22
lines changed

src/TensorFlowNET.Core/Framework/common_shapes.py.cs

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,5 +34,10 @@ public static Tensor _broadcast_shape_helper(Tensor shape_x, Tensor shape_y)
3434
{
3535
return tensor.rank;
3636
}
37+
38+
public static bool has_fully_defined_shape(Tensor tensor)
39+
{
40+
return tensor.getShape().is_fully_defined();
41+
}
3742
}
3843
}

src/TensorFlowNET.Core/Gradients/math_grad.cs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
1-
using System;
1+
//using Newtonsoft.Json;
2+
using System;
23
using System.Collections.Generic;
34
using System.Linq;
45
using System.Text;

src/TensorFlowNET.Core/Operations/Operation.Output.cs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
1-
using System;
1+
//using Newtonsoft.Json;
2+
using System;
23
using System.Collections.Generic;
34
using System.Linq;
45
using System.Runtime.InteropServices;

src/TensorFlowNET.Core/Operations/Operation.cs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
using Google.Protobuf.Collections;
2+
//using Newtonsoft.Json;
23
using System;
34
using System.Collections.Generic;
45
using System.Linq;

src/TensorFlowNET.Core/Operations/gen_math_ops.cs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -207,14 +207,14 @@ public static Tensor pow<Tx, Ty>(Tx x, Ty y, string name = null)
207207
return _op.outputs[0];
208208
}
209209

210-
public static Tensor sum(Tensor input, Tensor axis = null, bool keep_dims = false, string name = null)
210+
public static Tensor _sum(Tensor input, Tensor axis = null, bool keep_dims = false, string name = null)
211211
{
212212
var _op = _op_def_lib._apply_op_helper("Sum", name, args: new { input, reduction_indices = axis, keep_dims });
213213

214214
return _op.outputs[0];
215215
}
216216

217-
public static Tensor sum(Tensor input, int axis, bool keep_dims = false, string name = null)
217+
public static Tensor _sum(Tensor input, int axis, bool keep_dims = false, string name = null)
218218
{
219219
var _op = _op_def_lib._apply_op_helper("Sum", name, args: new { input, reduction_indices = axis, keep_dims });
220220

src/TensorFlowNET.Core/Operations/math_ops.py.cs

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -212,26 +212,29 @@ public static Tensor __case__(Tensor x, TF_DataType dtype, string name = null)
212212
throw new NotImplementedException();
213213
}
214214

215-
public static Tensor reduce_sum(Tensor input_tensor, Tensor axis = null, bool keepdims = false)
215+
public static Tensor reduce_sum(Tensor input_tensor, Tensor axis = null, bool keepdims = false, string name = null)
216216
{
217217
var r = _ReductionDims(input_tensor, axis);
218-
var m = gen_math_ops.sum(input_tensor, r);
219-
return _may_reduce_to_scalar(keepdims, m);
218+
var m = gen_math_ops._sum(input_tensor, r, keep_dims: keepdims, name: name);
219+
return _may_reduce_to_scalar(keepdims, axis, m);
220220
}
221221

222222
public static Tensor reduce_sum(Tensor input_tensor, int axis, bool keepdims = false)
223223
{
224-
var m = gen_math_ops.sum(input_tensor, axis);
225-
return _may_reduce_to_scalar(keepdims, m);
224+
var m = gen_math_ops._sum(input_tensor, axis);
225+
return _may_reduce_to_scalar(keepdims, new int[] { axis }, m);
226226
}
227227

228-
private static Tensor _may_reduce_to_scalar(bool keepdims, Tensor output)
228+
private static Tensor _may_reduce_to_scalar(bool keepdims, Tensor axis, Tensor output)
229229
{
230-
output.shape = new long[0];
230+
if (!common_shapes.has_fully_defined_shape(output) &&
231+
!keepdims &&
232+
axis == null)
233+
output.shape = new long[0];
231234
return output;
232235
}
233236

234-
private static Tensor _may_reduce_to_scalar(bool keepdims, int[] axos, Tensor output)
237+
private static Tensor _may_reduce_to_scalar(bool keepdims, int[] axis, Tensor output)
235238
{
236239
output.shape = new long[0];
237240
return output;

src/TensorFlowNET.Core/Tensors/Tensor.cs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
1-
using NumSharp.Core;
1+
//using Newtonsoft.Json;
2+
using NumSharp.Core;
23
using System;
34
using System.Collections.Generic;
45
using System.Linq;

src/TensorFlowNET.Core/Tensors/tensor_util.cs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -302,7 +302,7 @@ public static TensorShapeProto as_shape<T>(T[] dims)
302302
default:
303303
throw new NotImplementedException("as_shape Not Implemented");
304304
}
305-
dim.Name = $"dim_{i}";
305+
// dim.Name = $"dim_{i}";
306306

307307
shape.Dim.Add(dim);
308308
}
@@ -333,7 +333,7 @@ public static TensorShapeProto as_proto(this TensorShape tshape)
333333
{
334334
var dim = new TensorShapeProto.Types.Dim();
335335
dim.Size = tshape.Dimensions[i];
336-
dim.Name = $"dim_{i}";
336+
//dim.Name = $"dim_{i}";
337337

338338
shape.Dim.Add(dim);
339339
}

test/TensorFlowNET.Examples/LogisticRegression.cs

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -49,13 +49,6 @@ private void PrepareData()
4949
// Gradient Descent
5050
var optimizer = tf.train.GradientDescentOptimizer(learning_rate).minimize(cost);
5151

52-
//var new_saver = tf.train.import_meta_graph("logistic_regression.meta.bin");
53-
54-
/*var text = JsonConvert.SerializeObject(tf.get_default_graph(), new JsonSerializerSettings
55-
{
56-
Formatting = Formatting.Indented
57-
});*/
58-
5952
// Initialize the variables (i.e. assign their default value)
6053
var init = tf.global_variables_initializer();
6154

0 commit comments

Comments
 (0)