Skip to content

Commit 1478a2c

Browse files
committed
IEnumerable<NDArray>
1 parent 6adcfae commit 1478a2c

File tree

12 files changed

+49
-382
lines changed

12 files changed

+49
-382
lines changed

src/TensorFlowNET.Core/Binding.Util.cs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -299,8 +299,8 @@ public static float time()
299299
where T1 : unmanaged
300300
where T2 : unmanaged
301301
{
302-
var a = t1.AsIterator<T1>();
303-
var b = t2.AsIterator<T2>();
302+
//var a = t1.AsIterator<T1>();
303+
//var b = t2.AsIterator<T2>();
304304
//while (a.HasNext() && b.HasNext())
305305
//yield return (a.MoveNext(), b.MoveNext());
306306
throw new NotImplementedException("");

src/TensorFlowNET.Core/NumPy/NumPy.Logical.cs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@ namespace Tensorflow.NumPy
99
{
1010
public partial class np
1111
{
12+
[AutoNumPy]
13+
public static NDArray any(NDArray a, Axis axis = null) => throw new NotImplementedException("");
1214
[AutoNumPy]
1315
public static NDArray logical_or(NDArray x1, NDArray x2) => new NDArray(tf.logical_or(x1, x2));
1416

src/TensorFlowNET.Core/NumPy/Numpy.Manipulation.cs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,9 @@ public partial class np
1717
[AutoNumPy]
1818
public static NDArray squeeze(NDArray x1, Axis? axis = null) => new NDArray(array_ops.squeeze(x1, axis));
1919

20+
[AutoNumPy]
21+
public static NDArray stack(NDArray arrays, Axis axis = null) => new NDArray(array_ops.stack(arrays, axis ?? 0));
22+
2023
[AutoNumPy]
2124
public static NDArray dstack(params NDArray[] tup) => throw new NotImplementedException("");
2225
}

src/TensorFlowNET.Core/NumPy/Numpy.Math.cs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,9 @@ public partial class np
1212
[AutoNumPy]
1313
public static NDArray exp(NDArray x) => new NDArray(tf.exp(x));
1414

15+
[AutoNumPy]
16+
public static NDArray floor(NDArray x) => new NDArray(tf.floor(x));
17+
1518
[AutoNumPy]
1619
public static NDArray log(NDArray x) => new NDArray(tf.log(x));
1720

src/TensorFlowNET.Core/Numpy/NDArray.cs

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -15,29 +15,21 @@ limitations under the License.
1515
******************************************************************************/
1616

1717
using System;
18+
using System.Collections;
1819
using System.Collections.Generic;
1920
using System.Linq;
2021
using System.Text;
21-
using Tensorflow.Eager;
2222
using static Tensorflow.Binding;
2323

2424
namespace Tensorflow.NumPy
2525
{
26-
public partial class NDArray : Tensor
26+
public partial class NDArray : Tensor, IEnumerable<NDArray>
2727
{
2828
public IntPtr data => TensorDataPointer;
2929

30-
public NDArray[] GetNDArrays()
31-
=> throw new NotImplementedException("");
32-
3330
public ValueType GetValue(params int[] indices)
3431
=> throw new NotImplementedException("");
3532

36-
public NDIterator<T> AsIterator<T>(bool autoreset = false) where T : unmanaged
37-
=> throw new NotImplementedException("");
38-
39-
public bool HasNext() => throw new NotImplementedException("");
40-
public T MoveNext<T>() => throw new NotImplementedException("");
4133
[AutoNumPy]
4234
public NDArray reshape(Shape newshape) => new NDArray(tf.reshape(this, newshape));
4335
public NDArray astype(TF_DataType dtype) => new NDArray(math_ops.cast(this, dtype));
@@ -46,5 +38,14 @@ public NDIterator<T> AsIterator<T>(bool autoreset = false) where T : unmanaged
4638
public Array ToMuliDimArray<T>() => throw new NotImplementedException("");
4739
public byte[] ToByteArray() => BufferToArray();
4840
public override string ToString() => NDArrayRender.ToString(this);
41+
42+
public IEnumerator<NDArray> GetEnumerator()
43+
{
44+
for (int i = 0; i < dims[0]; i++)
45+
yield return this[i];
46+
}
47+
48+
IEnumerator IEnumerable.GetEnumerator()
49+
=> GetEnumerator();
4950
}
5051
}

src/TensorFlowNET.Core/Numpy/NDIterator.Generic.cs

Lines changed: 0 additions & 47 deletions
This file was deleted.

src/TensorFlowNET.Core/Numpy/NDIterator.cs

Lines changed: 0 additions & 24 deletions
This file was deleted.

src/TensorFlowNET.Core/Numpy/Shape.cs

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,12 @@ public static implicit operator Shape((int, int, int, int) dims)
9696
public static implicit operator Shape((long, long, long, long) dims)
9797
=> new Shape(dims.Item1, dims.Item2, dims.Item3, dims.Item4);
9898

99+
public static implicit operator Shape((int, int, int, int, int) dims)
100+
=> new Shape(dims.Item1, dims.Item2, dims.Item3, dims.Item4, dims.Item5);
101+
102+
public static implicit operator Shape((long, long, long, long, long) dims)
103+
=> new Shape(dims.Item1, dims.Item2, dims.Item3, dims.Item4, dims.Item5);
104+
99105
public static implicit operator int[](Shape shape)
100106
=> shape.dims.Select(x => (int)x).ToArray();
101107

src/TensorFlowNET.Core/Operations/array_ops.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1008,7 +1008,7 @@ public static Tensor pad(Tensor tensor, Tensor paddings, string mode = "CONSTANT
10081008
!(paddings_constant is null))
10091009
{
10101010
var new_shape = new List<int>();
1011-
foreach ((NDArray padding, int dim) in zip(paddings_constant.GetNDArrays(), np.array(input_shape.dims).GetNDArrays()))
1011+
foreach ((NDArray padding, int dim) in zip(paddings_constant, input_shape.as_int_list()))
10121012
{
10131013
if (padding is null || dim == -1 || padding.ToArray<int>().Contains(-1))
10141014
new_shape.Add(-1);

test/TensorFlowNET.UnitTest/Basics/SessionTest.cs

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,8 @@ public void Autocast_Case1()
8282
sess.run(tf.global_variables_initializer());
8383
var ret = sess.run(op, feed_dict: (input, np.array(1, 2, 3, 4, 5, 6)));
8484

85-
ret.Should().BeShaped(2, 3).And.BeOfValues(1, 2, 3, 4, 5, 6);
85+
Assert.AreEqual(ret.shape, (2, 3));
86+
Assert.AreEqual(ret, new[] { 1, 2, 3, 4, 5, 6 });
8687
print(ret.dtype);
8788
print(ret);
8889
}
@@ -110,7 +111,8 @@ public void Autocast_Case3()
110111
sess.run(tf.global_variables_initializer());
111112
var ret = sess.run(op, feed_dict: (input, np.array(1, 2, 3, 4, 5, 6).astype(np.float32) + 0.1f));
112113

113-
ret.Should().BeShaped(2, 3).And.BeOfValues(1, 2, 3, 4, 5, 6);
114+
Assert.AreEqual(ret.shape, (2, 3));
115+
Assert.AreEqual(ret, new[] { 1, 2, 3, 4, 5, 6 });
114116
print(ret.dtype);
115117
print(ret);
116118
}
@@ -124,7 +126,8 @@ public void Autocast_Case4()
124126
sess.run(tf.global_variables_initializer());
125127
var ret = sess.run(op, feed_dict: (input, np.array(1, 2, 3, 4, 5, 6).astype(np.float32) + 0.1f));
126128

127-
ret.Should().BeShaped(2, 3).And.BeOfValues(1, 2, 3, 4, 5, 6);
129+
Assert.AreEqual(ret.shape, (2, 3));
130+
Assert.AreEqual(ret, new[] { 1, 2, 3, 4, 5, 6 });
128131
print(ret.dtype);
129132
print(ret);
130133
}

test/TensorFlowNET.UnitTest/NumPy/Array.Indexing.Test.cs

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -101,5 +101,20 @@ public void shape_helper_get_shape_4dim()
101101
var shape3 = ShapeHelper.GetShape(x.shape, Slice.All, new Slice(0, isIndex: true));
102102
Assert.AreEqual(shape3, (4, 3, 2));
103103
}
104+
105+
[TestMethod]
106+
public void iterating()
107+
{
108+
var array = np.array(new[,] { { 0, 3 }, { 2, 2 }, { 3, 1 } });
109+
int i = 0;
110+
foreach(var x in array)
111+
{
112+
if (i == 0)
113+
Assert.AreEqual(x, new[] { 0, 3 });
114+
else
115+
Assert.AreEqual(x, array[i]);
116+
i++;
117+
}
118+
}
104119
}
105120
}

0 commit comments

Comments
 (0)