Skip to content

Commit f86d4f5

Browse files
committed
np.any
1 parent a129e61 commit f86d4f5

File tree

7 files changed

+65
-52
lines changed

7 files changed

+65
-52
lines changed

src/TensorFlowNET.Core/NumPy/NDArray.Index.cs

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -175,11 +175,25 @@ void SetData(IEnumerable<Slice> slices, NDArray array)
175175
void SetData(IEnumerable<Slice> slices, NDArray array, int currentNDim, int[] indices)
176176
{
177177
if (dtype != array.dtype)
178-
throw new ArrayTypeMismatchException($"Required dtype {dtype} but {array.dtype} is assigned.");
178+
array = array.astype(dtype);
179+
// throw new ArrayTypeMismatchException($"Required dtype {dtype} but {array.dtype} is assigned.");
179180

180181
if (!slices.Any())
181182
return;
182183

184+
var newshape = ShapeHelper.GetShape(shape, slices.ToArray());
185+
if(newshape.Equals(array.shape))
186+
{
187+
var offset = ShapeHelper.GetOffset(shape, slices.First().Start ?? 0);
188+
unsafe
189+
{
190+
var dst = (byte*)data + (ulong)offset * dtypesize;
191+
System.Buffer.MemoryCopy(array.data.ToPointer(), dst, array.bytesize, array.bytesize);
192+
}
193+
return;
194+
}
195+
196+
183197
var slice = slices.First();
184198

185199
if (slices.Count() == 1)
@@ -204,6 +218,9 @@ void SetData(IEnumerable<Slice> slices, NDArray array, int currentNDim, int[] in
204218
}
205219

206220
currentNDim++;
221+
if (slice.Stop == null)
222+
slice.Stop = (int)dims[currentNDim];
223+
207224
for (var i = slice.Start ?? 0; i < slice.Stop; i++)
208225
{
209226
indices[currentNDim] = i;

src/TensorFlowNET.Core/NumPy/NumPy.Logical.cs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,15 +2,15 @@
22
using System.Collections;
33
using System.Collections.Generic;
44
using System.Numerics;
5-
using System.Text;
5+
using System.Linq;
66
using static Tensorflow.Binding;
77

88
namespace Tensorflow.NumPy
99
{
1010
public partial class np
1111
{
1212
[AutoNumPy]
13-
public static NDArray any(NDArray a, Axis axis = null) => throw new NotImplementedException("");
13+
public static NDArray any(NDArray a, Axis axis = null) => new NDArray(a.ToArray<bool>().Any(x => x));
1414
[AutoNumPy]
1515
public static NDArray logical_or(NDArray x1, NDArray x2) => new NDArray(tf.logical_or(x1, x2));
1616

src/TensorFlowNET.Core/NumPy/Numpy.Manipulation.cs

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,12 @@ namespace Tensorflow.NumPy
88
{
99
public partial class np
1010
{
11+
[AutoNumPy]
12+
public static NDArray concatenate(NDArray[] arrays, int axis = 0) => new NDArray(array_ops.concat(arrays, axis));
13+
14+
[AutoNumPy]
15+
public static NDArray dstack(params NDArray[] tup) => throw new NotImplementedException("");
16+
1117
[AutoNumPy]
1218
public static NDArray expand_dims(NDArray a, Axis? axis = null) => throw new NotImplementedException("");
1319

@@ -19,8 +25,5 @@ public partial class np
1925

2026
[AutoNumPy]
2127
public static NDArray stack(params NDArray[] arrays) => new NDArray(array_ops.stack(arrays));
22-
23-
[AutoNumPy]
24-
public static NDArray dstack(params NDArray[] tup) => throw new NotImplementedException("");
2528
}
2629
}

src/TensorFlowNET.Core/Numpy/Numpy.Creation.cs

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,11 @@ namespace Tensorflow.NumPy
1111
public partial class np
1212
{
1313
[AutoNumPy]
14-
public static NDArray array(Array data) => new NDArray(data);
14+
public static NDArray array(Array data, TF_DataType? dtype = null)
15+
{
16+
var nd = new NDArray(data);
17+
return dtype == null ? nd : nd.astype(dtype.Value);
18+
}
1519

1620
[AutoNumPy]
1721
public static NDArray array<T>(params T[] data)

src/TensorFlowNET.Core/Numpy/Numpy.cs

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -69,9 +69,6 @@ public partial class np
6969
public static bool array_equal(NDArray a, NDArray b)
7070
=> a.Equals(b);
7171

72-
public static NDArray concatenate(NDArray[] arrays, int axis = 0)
73-
=> throw new NotImplementedException("");
74-
7572
public static bool allclose(NDArray a, NDArray b, double rtol = 1.0E-5, double atol = 1.0E-8,
7673
bool equal_nan = false) => throw new NotImplementedException("");
7774

src/TensorFlowNET.Core/Tensors/Tensor.Operators.cs

Lines changed: 26 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -278,53 +278,37 @@ private static string div_or_truediv<Tx, Ty>(string name, Tx x, Ty y)
278278

279279
protected static Tensor BinaryOpWrapper<Tx, Ty>(string name, Tx x, Ty y)
280280
{
281-
TF_DataType dtype = TF_DataType.DtInvalid;
282-
283-
if (x is Tensor tl)
284-
{
285-
dtype = tl.dtype.as_base_dtype();
286-
}
287-
288-
if (y is Tensor tr)
289-
{
290-
dtype = tr.dtype.as_base_dtype();
291-
}
292-
293281
return tf_with(ops.name_scope(null, name, new { x, y }), scope =>
294282
{
295-
Tensor result;
296-
var x1 = ops.convert_to_tensor(x, dtype: dtype, name: "x");
297-
var y1 = ops.convert_to_tensor(y, dtype: dtype, name: "y");
283+
var dtype = GetBestDType(x, y);
284+
var x1 = ops.convert_to_tensor(x, name: "x", dtype: dtype);
285+
var y1 = ops.convert_to_tensor(y, name: "y", dtype: dtype);
286+
string newname = scope;
298287

299-
switch (name.ToLowerInvariant())
288+
return name.ToLowerInvariant() switch
300289
{
301-
case "add":
302-
result = math_ops.add_v2(x1, y1, name: scope);
303-
break;
304-
case "div":
305-
result = math_ops.div(x1, y1, name: scope);
306-
break;
307-
case "floordiv":
308-
result = gen_math_ops.floor_div(x1, y1, name: scope);
309-
break;
310-
case "truediv":
311-
result = math_ops.truediv(x1, y1, name: scope);
312-
break;
313-
case "mul":
314-
result = math_ops.multiply(x1, y1, name: scope);
315-
break;
316-
case "sub":
317-
result = gen_math_ops.sub(x1, y1, name: scope);
318-
break;
319-
case "mod":
320-
result = gen_math_ops.floor_mod(x1, y1, name: scope);
321-
break;
322-
default:
323-
throw new NotImplementedException($"BinaryOpWrapper: {name} - {typeof(Tx).Name}, {typeof(Ty).Name}");
324-
}
325-
326-
return result;
290+
"add" => math_ops.add_v2(x1, y1, name: newname),
291+
"div" => math_ops.div(x1, y1, name: newname),
292+
"floordiv" => gen_math_ops.floor_div(x1, y1, name: newname),
293+
"truediv" => math_ops.truediv(x1, y1, name: newname),
294+
"mul" => math_ops.multiply(x1, y1, name: newname),
295+
"sub" => gen_math_ops.sub(x1, y1, name: newname),
296+
"mod" => gen_math_ops.floor_mod(x1, y1, name: newname),
297+
_ => throw new NotImplementedException($"BinaryOpWrapper: {name} - {typeof(Tx).Name}, {typeof(Ty).Name}")
298+
};
327299
});
328300
}
301+
302+
static TF_DataType GetBestDType<Tx, Ty>(Tx x, Ty y)
303+
{
304+
var dtype1 = x.GetDataType();
305+
var dtype2 = y.GetDataType();
306+
if (dtype1.is_integer() && dtype2.is_floating())
307+
return dtype2;
308+
else if (dtype1.is_floating() && dtype2.is_integer())
309+
return dtype1;
310+
else
311+
return dtype1;
312+
}
329313
}
330314
}

test/TensorFlowNET.UnitTest/Numpy/Math.Test.cs

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,5 +35,13 @@ public void astype()
3535
var x1 = x.astype(np.float32);
3636
Assert.AreEqual(x1[2], 200f);
3737
}
38+
39+
[TestMethod]
40+
public void divide()
41+
{
42+
var x = np.array(new float[] { 1, 100, 200 });
43+
var y = x / 2;
44+
Assert.AreEqual(y.dtype, np.float32);
45+
}
3846
}
3947
}

0 commit comments

Comments
 (0)