Skip to content

Commit 2cfaeb7

Browse files
committed
Ignore signbit testing if not available
1 parent a257446 commit 2cfaeb7

File tree

1 file changed

+19
-14
lines changed

1 file changed

+19
-14
lines changed

array_api_tests/pytest_helpers.py

Lines changed: 19 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -422,25 +422,30 @@ def assert_fill(
422422

423423

424424
def _real_float_strict_equals(out: Array, expected: Array) -> bool:
425-
assert hasattr(_xp, "signbit") # sanity check
426-
427425
nan_mask = xp.isnan(out)
428426
if not xp.all(nan_mask == xp.isnan(expected)):
429427
return False
428+
ignore_mask = nan_mask
429+
430+
# Test sign of zeroes if xp.signbit() available, otherwise ignore as it's
431+
# not that big of a deal for the perf costs.
432+
if hasattr(_xp, "signbit"):
433+
out_zero_mask = out == 0
434+
out_sign_mask = xp.signbit(out)
435+
out_pos_zero_mask = out_zero_mask & out_sign_mask
436+
out_neg_zero_mask = out_zero_mask & ~out_sign_mask
437+
expected_zero_mask = expected == 0
438+
expected_sign_mask = xp.signbit(expected)
439+
expected_pos_zero_mask = expected_zero_mask & expected_sign_mask
440+
expected_neg_zero_mask = expected_zero_mask & ~expected_sign_mask
441+
pos_zero_match = out_pos_zero_mask == expected_pos_zero_mask
442+
neg_zero_match = out_neg_zero_mask == expected_neg_zero_mask
443+
if not (xp.all(pos_zero_match) and xp.all(neg_zero_match)):
444+
return False
445+
ignore_mask |= out_zero_mask
430446

431-
out_zero_mask = out == 0
432-
out_sign_mask = xp.signbit(out)
433-
out_pos_zero_mask = out_zero_mask & out_sign_mask
434-
out_neg_zero_mask = out_zero_mask & ~out_sign_mask
435-
expected_zero_mask = expected == 0
436-
expected_sign_mask = xp.signbit(expected)
437-
expected_pos_zero_mask = expected_zero_mask & expected_sign_mask
438-
expected_neg_zero_mask = expected_zero_mask & ~expected_sign_mask
439-
if not (xp.all(out_pos_zero_mask == expected_pos_zero_mask) and xp.all(out_neg_zero_mask == expected_neg_zero_mask)):
440-
return False
441-
442-
ignore_mask = nan_mask | out_zero_mask
443447
replacement = xp.asarray(42, dtype=out.dtype) # i.e. an arbitrary non-zero value that equals itself
448+
assert replacement == replacement # sanity check
444449
match = xp.where(ignore_mask, replacement, out) == xp.where(ignore_mask, replacement, expected)
445450
return xp.all(match)
446451

0 commit comments

Comments
 (0)