Skip to content

Commit 2c12312

Browse files
authored
Merge pull request #236 from honno/assert-array-elements-vectorisation
Short-circuit with vectorisation `ph.assert_array_elements()`
2 parents 1cf4a07 + 82125d1 commit 2c12312

File tree

4 files changed

+69
-21
lines changed

4 files changed

+69
-21
lines changed

array-api

Submodule array-api updated 116 files

array_api_tests/meta/test_pytest_helpers.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,3 +20,7 @@ def test_assert_array_elements():
2020
ph.assert_array_elements("mixed sign zeros", out=xp.asarray(0.0), expected=xp.asarray(-0.0))
2121
with raises(AssertionError):
2222
ph.assert_array_elements("mixed sign zeros", out=xp.asarray(-0.0), expected=xp.asarray(0.0))
23+
24+
ph.assert_array_elements("nans", out=xp.asarray(float("nan")), expected=xp.asarray(float("nan")))
25+
with raises(AssertionError):
26+
ph.assert_array_elements("nan and zero", out=xp.asarray(float("nan")), expected=xp.asarray(0.0))

array_api_tests/pytest_helpers.py

Lines changed: 56 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from . import dtype_helpers as dh
88
from . import shape_helpers as sh
99
from . import stubs
10+
from . import xp as _xp
1011
from .typing import Array, DataType, Scalar, ScalarType, Shape
1112

1213
__all__ = [
@@ -420,6 +421,35 @@ def assert_fill(
420421
assert xp.all(xp.equal(out, xp.asarray(fill_value, dtype=dtype))), msg
421422

422423

424+
def _real_float_strict_equals(out: Array, expected: Array) -> bool:
425+
nan_mask = xp.isnan(out)
426+
if not xp.all(nan_mask == xp.isnan(expected)):
427+
return False
428+
ignore_mask = nan_mask
429+
430+
# Test sign of zeroes if xp.signbit() available, otherwise ignore as it's
431+
# not that big of a deal for the perf costs.
432+
if hasattr(_xp, "signbit"):
433+
out_zero_mask = out == 0
434+
out_sign_mask = _xp.signbit(out)
435+
out_pos_zero_mask = out_zero_mask & out_sign_mask
436+
out_neg_zero_mask = out_zero_mask & ~out_sign_mask
437+
expected_zero_mask = expected == 0
438+
expected_sign_mask = _xp.signbit(expected)
439+
expected_pos_zero_mask = expected_zero_mask & expected_sign_mask
440+
expected_neg_zero_mask = expected_zero_mask & ~expected_sign_mask
441+
pos_zero_match = out_pos_zero_mask == expected_pos_zero_mask
442+
neg_zero_match = out_neg_zero_mask == expected_neg_zero_mask
443+
if not (xp.all(pos_zero_match) and xp.all(neg_zero_match)):
444+
return False
445+
ignore_mask |= out_zero_mask
446+
447+
replacement = xp.asarray(42, dtype=out.dtype) # i.e. an arbitrary non-zero value that equals itself
448+
assert replacement == replacement # sanity check
449+
match = xp.where(ignore_mask, replacement, out) == xp.where(ignore_mask, replacement, expected)
450+
return xp.all(match)
451+
452+
423453
def _assert_float_element(at_out: Array, at_expected: Array, msg: str):
424454
if xp.isnan(at_expected):
425455
assert xp.isnan(at_out), msg
@@ -455,31 +485,45 @@ def assert_array_elements(
455485
>>> assert xp.all(out == x)
456486
457487
"""
458-
__tracebackhide__ = True
488+
# __tracebackhide__ = True
459489
dh.result_type(out.dtype, expected.dtype) # sanity check
460490
assert_shape(func_name, out_shape=out.shape, expected=expected.shape, kw=kw) # sanity check
461491
f_func = f"[{func_name}({fmt_kw(kw)})]"
492+
493+
# First we try short-circuit for a successful assertion by using vectorised checks.
494+
if out.dtype in dh.real_float_dtypes:
495+
if _real_float_strict_equals(out, expected):
496+
return
497+
elif out.dtype in dh.complex_dtypes:
498+
real_match = _real_float_strict_equals(xp.real(out), xp.real(expected))
499+
imag_match = _real_float_strict_equals(xp.imag(out), xp.imag(expected))
500+
if real_match and imag_match:
501+
return
502+
else:
503+
match = out == expected
504+
if xp.all(match):
505+
return
506+
507+
# In case of mismatch, generate a more helpful error. Cycling through all indices is
508+
# costly in some array api implementations, so we only do this in the case of a failure.
509+
msg_template = "{}={}, but should be {} " + f_func
462510
if out.dtype in dh.real_float_dtypes:
463511
for idx in sh.ndindex(out.shape):
464512
at_out = out[idx]
465513
at_expected = expected[idx]
466-
msg = (
467-
f"{sh.fmt_idx(out_repr, idx)}={at_out}, should be {at_expected} "
468-
f"{f_func}"
469-
)
514+
msg = msg_template.format(sh.fmt_idx(out_repr, idx), at_out, at_expected)
470515
_assert_float_element(at_out, at_expected, msg)
471516
elif out.dtype in dh.complex_dtypes:
472517
assert (out.dtype in dh.complex_dtypes) == (expected.dtype in dh.complex_dtypes)
473518
for idx in sh.ndindex(out.shape):
474519
at_out = out[idx]
475520
at_expected = expected[idx]
476-
msg = (
477-
f"{sh.fmt_idx(out_repr, idx)}={at_out}, should be {at_expected} "
478-
f"{f_func}"
479-
)
521+
msg = msg_template.format(sh.fmt_idx(out_repr, idx), at_out, at_expected)
480522
_assert_float_element(xp.real(at_out), xp.real(at_expected), msg)
481523
_assert_float_element(xp.imag(at_out), xp.imag(at_expected), msg)
482524
else:
483-
assert xp.all(
484-
out == expected
485-
), f"{out_repr} not as expected {f_func}\n{out_repr}={out!r}\n{expected=}"
525+
for idx in sh.ndindex(out.shape):
526+
at_out = out[idx]
527+
at_expected = expected[idx]
528+
msg = msg_template.format(sh.fmt_idx(out_repr, idx), at_out, at_expected)
529+
assert at_out == at_expected, msg

array_api_tests/test_creation_functions.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -354,14 +354,14 @@ def test_eye(n_rows, n_cols, kw):
354354
ph.assert_kw_dtype("eye", kw_dtype=kw["dtype"], out_dtype=out.dtype)
355355
_n_cols = n_rows if n_cols is None else n_cols
356356
ph.assert_shape("eye", out_shape=out.shape, expected=(n_rows, _n_cols), kw=dict(n_rows=n_rows, n_cols=n_cols))
357-
f_func = f"[eye({n_rows=}, {n_cols=})]"
358-
for i in range(n_rows):
359-
for j in range(_n_cols):
360-
f_indexed_out = f"out[{i}, {j}]={out[i, j]}"
361-
if j - i == kw.get("k", 0):
362-
assert out[i, j] == 1, f"{f_indexed_out}, should be 1 {f_func}"
363-
else:
364-
assert out[i, j] == 0, f"{f_indexed_out}, should be 0 {f_func}"
357+
k = kw.get("k", 0)
358+
expected = xp.asarray(
359+
[[1 if j - i == k else 0 for j in range(_n_cols)] for i in range(n_rows)],
360+
dtype=out.dtype # Note: dtype already checked above.
361+
)
362+
if expected.size == 0:
363+
expected = xp.reshape(expected, (n_rows, _n_cols))
364+
ph.assert_array_elements("eye", out=out, expected=expected, kw=kw)
365365

366366

367367
default_unsafe_dtypes = [xp.uint64]

0 commit comments

Comments
 (0)