Skip to content

Commit 5821275

Browse files
SuperpifferOceania2018
authored andcommitted
Reimplemented NDArray == and != operators, handling null values. Added unit tests.
1 parent ccda2c3 commit 5821275

File tree

2 files changed

+51
-4
lines changed

2 files changed

+51
-4
lines changed

src/TensorFlowNET.Core/NumPy/NDArray.Operators.cs

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -25,10 +25,24 @@ public partial class NDArray
2525
[AutoNumPy]
2626
public static NDArray operator -(NDArray lhs) => new NDArray(gen_math_ops.neg(lhs));
2727
[AutoNumPy]
28-
public static NDArray operator ==(NDArray lhs, NDArray rhs)
29-
=> rhs is null ? Scalar(false) : new NDArray(math_ops.equal(lhs, rhs));
28+
public static NDArray operator ==(NDArray lhs, NDArray rhs)
29+
{
30+
if(ReferenceEquals(lhs, rhs))
31+
return Scalar(true);
32+
if(lhs is null)
33+
return Scalar(false);
34+
if(rhs is null)
35+
return Scalar(false);
36+
return new NDArray(math_ops.equal(lhs, rhs));
37+
}
3038
[AutoNumPy]
31-
public static NDArray operator !=(NDArray lhs, NDArray rhs)
32-
=> new NDArray(math_ops.not_equal(lhs, rhs));
39+
public static NDArray operator !=(NDArray lhs, NDArray rhs)
40+
{
41+
if(ReferenceEquals(lhs, rhs))
42+
return Scalar(false);
43+
if(lhs is null || rhs is null)
44+
return Scalar(true);
45+
return new NDArray(math_ops.not_equal(lhs, rhs));
46+
}
3347
}
3448
}
Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
using Microsoft.VisualStudio.TestTools.UnitTesting;
2+
using Tensorflow.NumPy;
3+
4+
namespace TensorFlowNET.UnitTest.NumPy
5+
{
6+
[TestClass]
7+
public class OperatorsTest
8+
{
9+
[TestMethod]
10+
public void EqualToOperator()
11+
{
12+
NDArray n1 = null;
13+
NDArray n2 = new NDArray(1);
14+
15+
Assert.IsTrue(n1 == null);
16+
Assert.IsFalse(n2 == null);
17+
Assert.IsFalse(n1 == 1);
18+
Assert.IsTrue(n2 == 1);
19+
}
20+
21+
[TestMethod]
22+
public void NotEqualToOperator()
23+
{
24+
NDArray n1 = null;
25+
NDArray n2 = new NDArray(1);
26+
27+
Assert.IsFalse(n1 != null);
28+
Assert.IsTrue(n2 != null);
29+
Assert.IsTrue(n1 != 1);
30+
Assert.IsFalse(n2 != 1);
31+
}
32+
}
33+
}

0 commit comments

Comments
 (0)