Skip to content

Commit a129e61

Browse files
committed
fix scalar ndarray to tensor proto.
1 parent 42ca73c commit a129e61

File tree

4 files changed

+36
-6
lines changed

4 files changed

+36
-6
lines changed

src/TensorFlowNET.Core/NumPy/Numpy.Manipulation.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ public partial class np
1818
public static NDArray squeeze(NDArray x1, Axis? axis = null) => new NDArray(array_ops.squeeze(x1, axis));
1919

2020
[AutoNumPy]
21-
public static NDArray stack(NDArray arrays, Axis axis = null) => new NDArray(array_ops.stack(arrays, axis ?? 0));
21+
public static NDArray stack(params NDArray[] arrays) => new NDArray(array_ops.stack(arrays));
2222

2323
[AutoNumPy]
2424
public static NDArray dstack(params NDArray[] tup) => throw new NotImplementedException("");

src/TensorFlowNET.Core/NumPy/ShapeHelper.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,7 @@ public static string ToString(Shape shape)
9494
{
9595
-1 => "<unknown>",
9696
0 => "()",
97-
1 => $"({shape.dims[0]},)",
97+
1 => $"({shape.dims[0].ToString().Replace("-1", "None")},)",
9898
_ => $"({string.Join(", ", shape.dims).Replace("-1", "None")})"
9999
};
100100
}

src/TensorFlowNET.Core/Tensors/tensor_util.cs

Lines changed: 33 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -134,12 +134,41 @@ public static TensorProto make_tensor_proto(object values, TF_DataType dtype = T
134134
TensorShape = shape.as_shape_proto()
135135
};
136136

137-
// scalar
138137
if (values is NDArray nd)
139138
{
140-
var len = nd.dtypesize * nd.size;
141-
byte[] bytes = nd.ToByteArray();
142-
tensor_proto.TensorContent = Google.Protobuf.ByteString.CopyFrom(bytes);
139+
// scalar
140+
if (nd.shape.IsScalar)
141+
{
142+
switch (nd.dtype)
143+
{
144+
case TF_DataType.TF_BOOL:
145+
tensor_proto.BoolVal.AddRange(nd.ToArray<bool>());
146+
break;
147+
case TF_DataType.TF_UINT8:
148+
tensor_proto.IntVal.AddRange(nd.ToArray<byte>().Select(x => (int)x).ToArray());
149+
break;
150+
case TF_DataType.TF_INT32:
151+
tensor_proto.IntVal.AddRange(nd.ToArray<int>());
152+
break;
153+
case TF_DataType.TF_INT64:
154+
tensor_proto.Int64Val.AddRange(nd.ToArray<long>());
155+
break;
156+
case TF_DataType.TF_FLOAT:
157+
tensor_proto.FloatVal.AddRange(nd.ToArray<float>());
158+
break;
159+
case TF_DataType.TF_DOUBLE:
160+
tensor_proto.DoubleVal.AddRange(nd.ToArray<double>());
161+
break;
162+
default:
163+
throw new Exception("make_tensor_proto Not Implemented");
164+
}
165+
}
166+
else
167+
{
168+
var len = nd.dtypesize * nd.size;
169+
byte[] bytes = nd.ToByteArray();
170+
tensor_proto.TensorContent = Google.Protobuf.ByteString.CopyFrom(bytes);
171+
}
143172
}
144173
else if (dtype == TF_DataType.TF_STRING && !(values is NDArray))
145174
{

test/TensorFlowNET.UnitTest/Numpy/Array.Creation.Test.cs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ public void empty_zeros_ones_full()
2121
var zeros = np.zeros((2, 2));
2222
var ones = np.ones((2, 2));
2323
var full = np.full((2, 2), 0.1f);
24+
Assert.AreEqual(np.float32, full.dtype);
2425
}
2526

2627
[TestMethod]

0 commit comments

Comments
 (0)