Skip to content

Commit a642929

Browse files
committed
Add elementwise tests for logical_and, logical_not, logical_or, and logical_xor
1 parent 7c862c0 commit a642929

File tree

1 file changed

+35
-6
lines changed

1 file changed

+35
-6
lines changed

array_api_tests/test_elementwise_functions.py

Lines changed: 35 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -680,30 +680,59 @@ def test_log10(x):
680680
def test_logaddexp(args):
681681
x1, x2 = args
682682
sanity_check(x1, x2)
683-
# a = _array_module.logaddexp(x1, x2)
683+
_array_module.logaddexp(x1, x2)
684+
# The spec doesn't require any behavior for this function. We could test
685+
# that this is indeed an approximation of log(exp(x1) + exp(x2)), but we
686+
# don't have tests for this sort of thing for any functions yet.
684687

685688
@given(two_boolean_dtypes.flatmap(lambda i: two_array_scalars(*i)))
686689
def test_logical_and(args):
687690
x1, x2 = args
688691
sanity_check(x1, x2)
689-
# a = _array_module.logical_and(x1, x2)
692+
a = _array_module.logical_and(x1, x2)
693+
694+
# See the comments in test_equal
695+
shape = broadcast_shapes(x1.shape, x2.shape)
696+
_x1 = _array_module.broadcast_to(x1, shape)
697+
_x2 = _array_module.broadcast_to(x2, shape)
698+
699+
for idx in ndindex(shape):
700+
assert a[idx] == (bool(_x1[idx]) and bool(_x2[idx]))
690701

691702
@given(boolean_scalars)
692703
def test_logical_not(x):
693-
# a = _array_module.logical_not(x)
694-
pass
704+
a = _array_module.logical_not(x)
705+
706+
for idx in ndindex(x.shape):
707+
assert a[idx] == (not bool(x[idx]))
695708

696709
@given(two_boolean_dtypes.flatmap(lambda i: two_array_scalars(*i)))
697710
def test_logical_or(args):
698711
x1, x2 = args
699712
sanity_check(x1, x2)
700-
# a = _array_module.logical_or(x1, x2)
713+
a = _array_module.logical_or(x1, x2)
714+
715+
# See the comments in test_equal
716+
shape = broadcast_shapes(x1.shape, x2.shape)
717+
_x1 = _array_module.broadcast_to(x1, shape)
718+
_x2 = _array_module.broadcast_to(x2, shape)
719+
720+
for idx in ndindex(shape):
721+
assert a[idx] == (bool(_x1[idx]) or bool(_x2[idx]))
701722

702723
@given(two_boolean_dtypes.flatmap(lambda i: two_array_scalars(*i)))
703724
def test_logical_xor(args):
704725
x1, x2 = args
705726
sanity_check(x1, x2)
706-
# a = _array_module.logical_xor(x1, x2)
727+
a = _array_module.logical_xor(x1, x2)
728+
729+
# See the comments in test_equal
730+
shape = broadcast_shapes(x1.shape, x2.shape)
731+
_x1 = _array_module.broadcast_to(x1, shape)
732+
_x2 = _array_module.broadcast_to(x2, shape)
733+
734+
for idx in ndindex(shape):
735+
assert a[idx] == (bool(_x1[idx]) ^ bool(_x2[idx]))
707736

708737
@given(two_numeric_dtypes.flatmap(lambda i: two_array_scalars(*i)))
709738
def test_multiply(args):

0 commit comments

Comments
 (0)