Skip to content

Commit 62bbd3f

Browse files
jakevdphonno
authored andcommitted
Make assert_array_elements more efficient in the non-error case
1 parent 9afe8c7 commit 62bbd3f

File tree

1 file changed

+8
-3
lines changed

1 file changed

+8
-3
lines changed

array_api_tests/pytest_helpers.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -459,6 +459,13 @@ def assert_array_elements(
459459
dh.result_type(out.dtype, expected.dtype) # sanity check
460460
assert_shape(func_name, out_shape=out.shape, expected=expected.shape, kw=kw) # sanity check
461461
f_func = f"[{func_name}({fmt_kw(kw)})]"
462+
463+
match = (out == expected)
464+
if xp.all(match):
465+
return
466+
467+
# In case of mismatch, generate a more helpful error. Cycling through all indices is
468+
# costly in some array api implementations, so we only do this in the case of a failure.
462469
if out.dtype in dh.real_float_dtypes:
463470
for idx in sh.ndindex(out.shape):
464471
at_out = out[idx]
@@ -480,6 +487,4 @@ def assert_array_elements(
480487
_assert_float_element(xp.real(at_out), xp.real(at_expected), msg)
481488
_assert_float_element(xp.imag(at_out), xp.imag(at_expected), msg)
482489
else:
483-
assert xp.all(
484-
out == expected
485-
), f"{out_repr} not as expected {f_func}\n{out_repr}={out!r}\n{expected=}"
490+
assert xp.all(match), f"{out_repr} not as expected {f_func}\n{out_repr}={out!r}\n{expected=}"

0 commit comments

Comments
 (0)