Skip to content

Commit 3c8c7a8

Browse files
committed
Remove sanity_check() in elementwise
We now test that `two_mutual_arrays()` generates mutually promotable dtypes
1 parent 2876e61 commit 3c8c7a8

File tree

2 files changed

+1
-32
lines changed

2 files changed

+1
-32
lines changed

array_api_tests/meta_tests/test_hypothesis_helpers.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99
from .. import dtype_helpers as dh
1010
from .. import hypothesis_helpers as hh
1111
from ..test_broadcasting import broadcast_shapes
12-
from ..test_elementwise_functions import sanity_check
1312

1413
UNDEFINED_DTYPES = any(isinstance(d, _UndefinedStub) for d in dh.all_dtypes)
1514
pytestmark = [pytest.mark.skipif(UNDEFINED_DTYPES, reason="undefined dtypes")]
@@ -52,7 +51,7 @@ def test_two_broadcastable_shapes(pair):
5251

5352
@given(*hh.two_mutual_arrays())
5453
def test_two_mutual_arrays(x1, x2):
55-
sanity_check(x1, x2)
54+
assert (x1.dtype, x2.dtype) in dh.promotion_table.keys()
5655
assert broadcast_shapes(x1.shape, x2.shape) in (x1.shape, x2.shape)
5756

5857

array_api_tests/test_elementwise_functions.py

Lines changed: 0 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -42,13 +42,6 @@ def two_array_scalars(draw, dtype1, dtype2):
4242
# hh.mutually_promotable_dtypes())
4343
return draw(hh.array_scalars(st.just(dtype1))), draw(hh.array_scalars(st.just(dtype2)))
4444

45-
# TODO: refactor this into dtype_helpers.py, see https://github.com/data-apis/array-api-tests/pull/26
46-
def sanity_check(x1, x2):
47-
try:
48-
dh.promotion_table[x1.dtype, x2.dtype]
49-
except ValueError:
50-
raise RuntimeError("Error in test generation (probably a bug in the test suite")
51-
5245
@given(xps.arrays(dtype=xps.numeric_dtypes(), shape=hh.shapes()))
5346
def test_abs(x):
5447
if dh.is_int_dtype(x.dtype):
@@ -94,7 +87,6 @@ def test_acosh(x):
9487

9588
@given(*hh.two_mutual_arrays(dh.numeric_dtypes))
9689
def test_add(x1, x2):
97-
sanity_check(x1, x2)
9890
a = xp.add(x1, x2)
9991

10092
b = xp.add(x2, x1)
@@ -136,7 +128,6 @@ def test_atan(x):
136128

137129
@given(*hh.two_mutual_arrays(dh.float_dtypes))
138130
def test_atan2(x1, x2):
139-
sanity_check(x1, x2)
140131
a = xp.atan2(x1, x2)
141132
INFINITY1 = ah.infinity(x1.shape, x1.dtype)
142133
INFINITY2 = ah.infinity(x2.shape, x2.dtype)
@@ -183,7 +174,6 @@ def test_atanh(x):
183174

184175
@given(*hh.two_mutual_arrays(dh.bool_and_all_int_dtypes))
185176
def test_bitwise_and(x1, x2):
186-
sanity_check(x1, x2)
187177
out = xp.bitwise_and(x1, x2)
188178

189179
# TODO: generate indices without broadcasting arrays (see test_equal comment)
@@ -210,7 +200,6 @@ def test_bitwise_and(x1, x2):
210200

211201
@given(*hh.two_mutual_arrays(dh.all_int_dtypes))
212202
def test_bitwise_left_shift(x1, x2):
213-
sanity_check(x1, x2)
214203
assume(not ah.any(ah.isnegative(x2)))
215204
out = xp.bitwise_left_shift(x1, x2)
216205

@@ -249,7 +238,6 @@ def test_bitwise_invert(x):
249238

250239
@given(*hh.two_mutual_arrays(dh.bool_and_all_int_dtypes))
251240
def test_bitwise_or(x1, x2):
252-
sanity_check(x1, x2)
253241
out = xp.bitwise_or(x1, x2)
254242

255243
# TODO: generate indices without broadcasting arrays (see test_equal comment)
@@ -276,7 +264,6 @@ def test_bitwise_or(x1, x2):
276264

277265
@given(*hh.two_mutual_arrays(dh.all_int_dtypes))
278266
def test_bitwise_right_shift(x1, x2):
279-
sanity_check(x1, x2)
280267
assume(not ah.any(ah.isnegative(x2)))
281268
out = xp.bitwise_right_shift(x1, x2)
282269

@@ -297,7 +284,6 @@ def test_bitwise_right_shift(x1, x2):
297284

298285
@given(*hh.two_mutual_arrays(dh.bool_and_all_int_dtypes))
299286
def test_bitwise_xor(x1, x2):
300-
sanity_check(x1, x2)
301287
out = xp.bitwise_xor(x1, x2)
302288

303289
# TODO: generate indices without broadcasting arrays (see test_equal comment)
@@ -356,7 +342,6 @@ def test_cosh(x):
356342

357343
@given(*hh.two_mutual_arrays(dh.float_dtypes))
358344
def test_divide(x1, x2):
359-
sanity_check(x1, x2)
360345
xp.divide(x1, x2)
361346
# There isn't much we can test here. The spec doesn't require any behavior
362347
# beyond the special cases, and indeed, there aren't many mathematical
@@ -367,7 +352,6 @@ def test_divide(x1, x2):
367352

368353
@given(*hh.two_mutual_arrays())
369354
def test_equal(x1, x2):
370-
sanity_check(x1, x2)
371355
a = ah.equal(x1, x2)
372356
# NOTE: ah.assert_exactly_equal() itself uses ah.equal(), so we must be careful
373357
# not to use it here. Otherwise, the test would be circular and
@@ -449,7 +433,6 @@ def test_floor(x):
449433

450434
@given(*hh.two_mutual_arrays(dh.numeric_dtypes))
451435
def test_floor_divide(x1, x2):
452-
sanity_check(x1, x2)
453436
if dh.is_int_dtype(x1.dtype):
454437
# The spec does not specify the behavior for division by 0 for integer
455438
# dtypes. A library may choose to raise an exception in this case, so
@@ -473,7 +456,6 @@ def test_floor_divide(x1, x2):
473456

474457
@given(*hh.two_mutual_arrays(dh.numeric_dtypes))
475458
def test_greater(x1, x2):
476-
sanity_check(x1, x2)
477459
a = xp.greater(x1, x2)
478460

479461
# See the comments in test_equal() for a description of how this test
@@ -502,7 +484,6 @@ def test_greater(x1, x2):
502484

503485
@given(*hh.two_mutual_arrays(dh.numeric_dtypes))
504486
def test_greater_equal(x1, x2):
505-
sanity_check(x1, x2)
506487
a = xp.greater_equal(x1, x2)
507488

508489
# See the comments in test_equal() for a description of how this test
@@ -577,7 +558,6 @@ def test_isnan(x):
577558

578559
@given(*hh.two_mutual_arrays(dh.numeric_dtypes))
579560
def test_less(x1, x2):
580-
sanity_check(x1, x2)
581561
a = ah.less(x1, x2)
582562

583563
# See the comments in test_equal() for a description of how this test
@@ -606,7 +586,6 @@ def test_less(x1, x2):
606586

607587
@given(*hh.two_mutual_arrays(dh.numeric_dtypes))
608588
def test_less_equal(x1, x2):
609-
sanity_check(x1, x2)
610589
a = ah.less_equal(x1, x2)
611590

612591
# See the comments in test_equal() for a description of how this test
@@ -679,15 +658,13 @@ def test_log10(x):
679658

680659
@given(*hh.two_mutual_arrays(dh.float_dtypes))
681660
def test_logaddexp(x1, x2):
682-
sanity_check(x1, x2)
683661
xp.logaddexp(x1, x2)
684662
# The spec doesn't require any behavior for this function. We could test
685663
# that this is indeed an approximation of log(exp(x1) + exp(x2)), but we
686664
# don't have tests for this sort of thing for any functions yet.
687665

688666
@given(*hh.two_mutual_arrays([xp.bool]))
689667
def test_logical_and(x1, x2):
690-
sanity_check(x1, x2)
691668
a = ah.logical_and(x1, x2)
692669

693670
# See the comments in test_equal
@@ -707,7 +684,6 @@ def test_logical_not(x):
707684

708685
@given(*hh.two_mutual_arrays([xp.bool]))
709686
def test_logical_or(x1, x2):
710-
sanity_check(x1, x2)
711687
a = ah.logical_or(x1, x2)
712688

713689
# See the comments in test_equal
@@ -720,7 +696,6 @@ def test_logical_or(x1, x2):
720696

721697
@given(*hh.two_mutual_arrays([xp.bool]))
722698
def test_logical_xor(x1, x2):
723-
sanity_check(x1, x2)
724699
a = xp.logical_xor(x1, x2)
725700

726701
# See the comments in test_equal
@@ -733,7 +708,6 @@ def test_logical_xor(x1, x2):
733708

734709
@given(*hh.two_mutual_arrays(dh.numeric_dtypes))
735710
def test_multiply(x1, x2):
736-
sanity_check(x1, x2)
737711
a = xp.multiply(x1, x2)
738712

739713
b = xp.multiply(x2, x1)
@@ -762,7 +736,6 @@ def test_negative(x):
762736

763737
@given(*hh.two_mutual_arrays())
764738
def test_not_equal(x1, x2):
765-
sanity_check(x1, x2)
766739
a = xp.not_equal(x1, x2)
767740

768741
# See the comments in test_equal() for a description of how this test
@@ -798,7 +771,6 @@ def test_positive(x):
798771

799772
@given(*hh.two_mutual_arrays(dh.float_dtypes))
800773
def test_pow(x1, x2):
801-
sanity_check(x1, x2)
802774
xp.pow(x1, x2)
803775
# There isn't much we can test here. The spec doesn't require any behavior
804776
# beyond the special cases, and indeed, there aren't many mathematical
@@ -809,7 +781,6 @@ def test_pow(x1, x2):
809781
@given(*hh.two_mutual_arrays(dh.numeric_dtypes))
810782
def test_remainder(x1, x2):
811783
assume(len(x1.shape) <= len(x2.shape)) # TODO: rework same sign testing below to remove this
812-
sanity_check(x1, x2)
813784
out = xp.remainder(x1, x2)
814785

815786
# out and x2 should have the same sign.
@@ -868,7 +839,6 @@ def test_sqrt(x):
868839
@given(two_numeric_dtypes.flatmap(lambda i: two_array_scalars(*i)))
869840
def test_subtract(args):
870841
x1, x2 = args
871-
sanity_check(x1, x2)
872842
# a = xp.subtract(x1, x2)
873843

874844
@given(floating_scalars)

0 commit comments

Comments
 (0)