@@ -422,25 +422,30 @@ def assert_fill(
422
422
423
423
424
424
def _real_float_strict_equals (out : Array , expected : Array ) -> bool :
425
- assert hasattr (_xp , "signbit" ) # sanity check
426
-
427
425
nan_mask = xp .isnan (out )
428
426
if not xp .all (nan_mask == xp .isnan (expected )):
429
427
return False
428
+ ignore_mask = nan_mask
429
+
430
+ # Test sign of zeroes if xp.signbit() available, otherwise ignore as it's
431
+ # not that big of a deal for the perf costs.
432
+ if api_version >= "2023.12" and hasattr (_xp , "signbit" ):
433
+ out_zero_mask = out == 0
434
+ out_sign_mask = xp .signbit (out )
435
+ out_pos_zero_mask = out_zero_mask & out_sign_mask
436
+ out_neg_zero_mask = out_zero_mask & ~ out_sign_mask
437
+ expected_zero_mask = expected == 0
438
+ expected_sign_mask = xp .signbit (expected )
439
+ expected_pos_zero_mask = expected_zero_mask & expected_sign_mask
440
+ expected_neg_zero_mask = expected_zero_mask & ~ expected_sign_mask
441
+ pos_zero_match = out_pos_zero_mask == expected_pos_zero_mask
442
+ neg_zero_match = out_neg_zero_mask == expected_neg_zero_mask
443
+ if not (xp .all (pos_zero_match ) and xp .all (neg_zero_match )):
444
+ return False
445
+ ignore_mask |= out_zero_mask
430
446
431
- out_zero_mask = out == 0
432
- out_sign_mask = xp .signbit (out )
433
- out_pos_zero_mask = out_zero_mask & out_sign_mask
434
- out_neg_zero_mask = out_zero_mask & ~ out_sign_mask
435
- expected_zero_mask = expected == 0
436
- expected_sign_mask = xp .signbit (expected )
437
- expected_pos_zero_mask = expected_zero_mask & expected_sign_mask
438
- expected_neg_zero_mask = expected_zero_mask & ~ expected_sign_mask
439
- if not (xp .all (out_pos_zero_mask == expected_pos_zero_mask ) and xp .all (out_neg_zero_mask == expected_neg_zero_mask )):
440
- return False
441
-
442
- ignore_mask = nan_mask | out_zero_mask
443
447
replacement = xp .asarray (42 , dtype = out .dtype ) # i.e. an arbitrary non-zero value that equals itself
448
+ assert replacement == replacement # sanity check
444
449
match = xp .where (ignore_mask , replacement , out ) == xp .where (ignore_mask , replacement , expected )
445
450
return xp .all (match )
446
451
@@ -486,10 +491,10 @@ def assert_array_elements(
486
491
f_func = f"[{ func_name } ({ fmt_kw (kw )} )]"
487
492
488
493
# First we try short-circuit for a successful assertion by using vectorised checks.
489
- if out .dtype in dh .real_float_dtypes and api_version >= "2023.12" :
494
+ if out .dtype in dh .real_float_dtypes :
490
495
if _real_float_strict_equals (out , expected ):
491
496
return
492
- elif out .dtype in dh .complex_dtypes and api_version >= "2023.12" :
497
+ elif out .dtype in dh .complex_dtypes :
493
498
real_match = _real_float_strict_equals (out .real , expected .real )
494
499
imag_match = _real_float_strict_equals (out .imag , expected .imag )
495
500
if real_match and imag_match :
0 commit comments