Skip to content

Commit a0fde7b

Browse files
committed
Ignore signbit testing if not available
1 parent 2cb2e5d commit a0fde7b

File tree

1 file changed

+21
-16
lines changed

1 file changed

+21
-16
lines changed

array_api_tests/pytest_helpers.py

Lines changed: 21 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -422,25 +422,30 @@ def assert_fill(
422422

423423

424424
def _real_float_strict_equals(out: Array, expected: Array) -> bool:
425-
assert hasattr(_xp, "signbit") # sanity check
426-
427425
nan_mask = xp.isnan(out)
428426
if not xp.all(nan_mask == xp.isnan(expected)):
429427
return False
428+
ignore_mask = nan_mask
429+
430+
# Test sign of zeroes if xp.signbit() available, otherwise ignore as it's
431+
# not that big of a deal for the perf costs.
432+
if api_version >= "2023.12" and hasattr(_xp, "signbit"):
433+
out_zero_mask = out == 0
434+
out_sign_mask = xp.signbit(out)
435+
out_pos_zero_mask = out_zero_mask & out_sign_mask
436+
out_neg_zero_mask = out_zero_mask & ~out_sign_mask
437+
expected_zero_mask = expected == 0
438+
expected_sign_mask = xp.signbit(expected)
439+
expected_pos_zero_mask = expected_zero_mask & expected_sign_mask
440+
expected_neg_zero_mask = expected_zero_mask & ~expected_sign_mask
441+
pos_zero_match = out_pos_zero_mask == expected_pos_zero_mask
442+
neg_zero_match = out_neg_zero_mask == expected_neg_zero_mask
443+
if not (xp.all(pos_zero_match) and xp.all(neg_zero_match)):
444+
return False
445+
ignore_mask |= out_zero_mask
430446

431-
out_zero_mask = out == 0
432-
out_sign_mask = xp.signbit(out)
433-
out_pos_zero_mask = out_zero_mask & out_sign_mask
434-
out_neg_zero_mask = out_zero_mask & ~out_sign_mask
435-
expected_zero_mask = expected == 0
436-
expected_sign_mask = xp.signbit(expected)
437-
expected_pos_zero_mask = expected_zero_mask & expected_sign_mask
438-
expected_neg_zero_mask = expected_zero_mask & ~expected_sign_mask
439-
if not (xp.all(out_pos_zero_mask == expected_pos_zero_mask) and xp.all(out_neg_zero_mask == expected_neg_zero_mask)):
440-
return False
441-
442-
ignore_mask = nan_mask | out_zero_mask
443447
replacement = xp.asarray(42, dtype=out.dtype) # i.e. an arbitrary non-zero value that equals itself
448+
assert replacement == replacement # sanity check
444449
match = xp.where(ignore_mask, replacement, out) == xp.where(ignore_mask, replacement, expected)
445450
return xp.all(match)
446451

@@ -486,10 +491,10 @@ def assert_array_elements(
486491
f_func = f"[{func_name}({fmt_kw(kw)})]"
487492

488493
# First we try short-circuit for a successful assertion by using vectorised checks.
489-
if out.dtype in dh.real_float_dtypes and api_version >= "2023.12":
494+
if out.dtype in dh.real_float_dtypes:
490495
if _real_float_strict_equals(out, expected):
491496
return
492-
elif out.dtype in dh.complex_dtypes and api_version >= "2023.12":
497+
elif out.dtype in dh.complex_dtypes:
493498
real_match = _real_float_strict_equals(out.real, expected.real)
494499
imag_match = _real_float_strict_equals(out.imag, expected.imag)
495500
if real_match and imag_match:

0 commit comments

Comments
 (0)