Skip to content

Commit a257446

Browse files
committed
Use vectorised checks in assert_array_elements
1 parent 10d43c5 commit a257446

File tree

2 files changed

+51
-13
lines changed

2 files changed

+51
-13
lines changed

array_api_tests/meta/test_pytest_helpers.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,3 +20,7 @@ def test_assert_array_elements():
2020
ph.assert_array_elements("mixed sign zeros", out=xp.asarray(0.0), expected=xp.asarray(-0.0))
2121
with raises(AssertionError):
2222
ph.assert_array_elements("mixed sign zeros", out=xp.asarray(-0.0), expected=xp.asarray(0.0))
23+
24+
ph.assert_array_elements("nans", out=xp.asarray(float("nan")), expected=xp.asarray(float("nan")))
25+
with raises(AssertionError):
26+
ph.assert_array_elements("nan and zero", out=xp.asarray(float("nan")), expected=xp.asarray(0.0))

array_api_tests/pytest_helpers.py

Lines changed: 47 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,8 @@
66
from . import _array_module as xp
77
from . import dtype_helpers as dh
88
from . import shape_helpers as sh
9-
from . import stubs
9+
from . import stubs, api_version
10+
from . import xp as _xp
1011
from .typing import Array, DataType, Scalar, ScalarType, Shape
1112

1213
__all__ = [
@@ -420,6 +421,30 @@ def assert_fill(
420421
assert xp.all(xp.equal(out, xp.asarray(fill_value, dtype=dtype))), msg
421422

422423

424+
def _real_float_strict_equals(out: Array, expected: Array) -> bool:
425+
assert hasattr(_xp, "signbit") # sanity check
426+
427+
nan_mask = xp.isnan(out)
428+
if not xp.all(nan_mask == xp.isnan(expected)):
429+
return False
430+
431+
out_zero_mask = out == 0
432+
out_sign_mask = xp.signbit(out)
433+
out_pos_zero_mask = out_zero_mask & out_sign_mask
434+
out_neg_zero_mask = out_zero_mask & ~out_sign_mask
435+
expected_zero_mask = expected == 0
436+
expected_sign_mask = xp.signbit(expected)
437+
expected_pos_zero_mask = expected_zero_mask & expected_sign_mask
438+
expected_neg_zero_mask = expected_zero_mask & ~expected_sign_mask
439+
if not (xp.all(out_pos_zero_mask == expected_pos_zero_mask) and xp.all(out_neg_zero_mask == expected_neg_zero_mask)):
440+
return False
441+
442+
ignore_mask = nan_mask | out_zero_mask
443+
replacement = xp.asarray(42, dtype=out.dtype) # i.e. an arbitrary non-zero value that equals itself
444+
match = xp.where(ignore_mask, replacement, out) == xp.where(ignore_mask, replacement, expected)
445+
return xp.all(match)
446+
447+
423448
def _assert_float_element(at_out: Array, at_expected: Array, msg: str):
424449
if xp.isnan(at_expected):
425450
assert xp.isnan(at_out), msg
@@ -460,31 +485,40 @@ def assert_array_elements(
460485
assert_shape(func_name, out_shape=out.shape, expected=expected.shape, kw=kw) # sanity check
461486
f_func = f"[{func_name}({fmt_kw(kw)})]"
462487

463-
match = (out == expected)
464-
if xp.all(match):
465-
return
488+
# First we try short-circuit for a successful assertion by using vectorised checks.
489+
if out.dtype in dh.real_float_dtypes and api_version >= "2023.12":
490+
if _real_float_strict_equals(out, expected):
491+
return
492+
elif out.dtype in dh.complex_dtypes and api_version >= "2023.12":
493+
real_match = _real_float_strict_equals(out.real, expected.real)
494+
imag_match = _real_float_strict_equals(out.imag, expected.imag)
495+
if real_match and imag_match:
496+
return
497+
else:
498+
match = out == expected
499+
if xp.all(match):
500+
return
466501

467502
# In case of mismatch, generate a more helpful error. Cycling through all indices is
468503
# costly in some array api implementations, so we only do this in the case of a failure.
504+
msg_template = "{}={}, but should be {} " + f_func
469505
if out.dtype in dh.real_float_dtypes:
470506
for idx in sh.ndindex(out.shape):
471507
at_out = out[idx]
472508
at_expected = expected[idx]
473-
msg = (
474-
f"{sh.fmt_idx(out_repr, idx)}={at_out}, should be {at_expected} "
475-
f"{f_func}"
476-
)
509+
msg = msg_template.format(sh.fmt_idx(out_repr, idx), at_out, at_expected)
477510
_assert_float_element(at_out, at_expected, msg)
478511
elif out.dtype in dh.complex_dtypes:
479512
assert (out.dtype in dh.complex_dtypes) == (expected.dtype in dh.complex_dtypes)
480513
for idx in sh.ndindex(out.shape):
481514
at_out = out[idx]
482515
at_expected = expected[idx]
483-
msg = (
484-
f"{sh.fmt_idx(out_repr, idx)}={at_out}, should be {at_expected} "
485-
f"{f_func}"
486-
)
516+
msg = msg_template.format(sh.fmt_idx(out_repr, idx), at_out, at_expected)
487517
_assert_float_element(xp.real(at_out), xp.real(at_expected), msg)
488518
_assert_float_element(xp.imag(at_out), xp.imag(at_expected), msg)
489519
else:
490-
assert xp.all(match), f"{out_repr} not as expected {f_func}\n{out_repr}={out!r}\n{expected=}"
520+
for idx in sh.ndindex(out.shape):
521+
at_out = out[idx]
522+
at_expected = expected[idx]
523+
msg = msg_template.format(sh.fmt_idx(out_repr, idx), at_out, at_expected)
524+
assert at_out == at_expected, msg

0 commit comments

Comments
 (0)