|
6 | 6 | from . import _array_module as xp
|
7 | 7 | from . import dtype_helpers as dh
|
8 | 8 | from . import shape_helpers as sh
|
9 |
| -from . import stubs |
| 9 | +from . import stubs, api_version |
| 10 | +from . import xp as _xp |
10 | 11 | from .typing import Array, DataType, Scalar, ScalarType, Shape
|
11 | 12 |
|
12 | 13 | __all__ = [
|
@@ -420,6 +421,30 @@ def assert_fill(
|
420 | 421 | assert xp.all(xp.equal(out, xp.asarray(fill_value, dtype=dtype))), msg
|
421 | 422 |
|
422 | 423 |
|
| 424 | +def _real_float_strict_equals(out: Array, expected: Array) -> bool: |
| 425 | + assert hasattr(_xp, "signbit") # sanity check |
| 426 | + |
| 427 | + nan_mask = xp.isnan(out) |
| 428 | + if not xp.all(nan_mask == xp.isnan(expected)): |
| 429 | + return False |
| 430 | + |
| 431 | + out_zero_mask = out == 0 |
| 432 | + out_sign_mask = xp.signbit(out) |
| 433 | + out_pos_zero_mask = out_zero_mask & out_sign_mask |
| 434 | + out_neg_zero_mask = out_zero_mask & ~out_sign_mask |
| 435 | + expected_zero_mask = expected == 0 |
| 436 | + expected_sign_mask = xp.signbit(expected) |
| 437 | + expected_pos_zero_mask = expected_zero_mask & expected_sign_mask |
| 438 | + expected_neg_zero_mask = expected_zero_mask & ~expected_sign_mask |
| 439 | + if not (xp.all(out_pos_zero_mask == expected_pos_zero_mask) and xp.all(out_neg_zero_mask == expected_neg_zero_mask)): |
| 440 | + return False |
| 441 | + |
| 442 | + ignore_mask = nan_mask | out_zero_mask |
| 443 | + replacement = xp.asarray(42, dtype=out.dtype) # i.e. an arbitrary non-zero value that equals itself |
| 444 | + match = xp.where(ignore_mask, replacement, out) == xp.where(ignore_mask, replacement, expected) |
| 445 | + return xp.all(match) |
| 446 | + |
| 447 | + |
423 | 448 | def _assert_float_element(at_out: Array, at_expected: Array, msg: str):
|
424 | 449 | if xp.isnan(at_expected):
|
425 | 450 | assert xp.isnan(at_out), msg
|
@@ -460,31 +485,40 @@ def assert_array_elements(
|
460 | 485 | assert_shape(func_name, out_shape=out.shape, expected=expected.shape, kw=kw) # sanity check
|
461 | 486 | f_func = f"[{func_name}({fmt_kw(kw)})]"
|
462 | 487 |
|
463 |
| - match = (out == expected) |
464 |
| - if xp.all(match): |
465 |
| - return |
| 488 | + # First we try short-circuit for a successful assertion by using vectorised checks. |
| 489 | + if out.dtype in dh.real_float_dtypes and api_version >= "2023.12": |
| 490 | + if _real_float_strict_equals(out, expected): |
| 491 | + return |
| 492 | + elif out.dtype in dh.complex_dtypes and api_version >= "2023.12": |
| 493 | + real_match = _real_float_strict_equals(out.real, expected.real) |
| 494 | + imag_match = _real_float_strict_equals(out.imag, expected.imag) |
| 495 | + if real_match and imag_match: |
| 496 | + return |
| 497 | + else: |
| 498 | + match = out == expected |
| 499 | + if xp.all(match): |
| 500 | + return |
466 | 501 |
|
467 | 502 | # In case of mismatch, generate a more helpful error. Cycling through all indices is
|
468 | 503 | # costly in some array api implementations, so we only do this in the case of a failure.
|
| 504 | + msg_template = "{}={}, but should be {} " + f_func |
469 | 505 | if out.dtype in dh.real_float_dtypes:
|
470 | 506 | for idx in sh.ndindex(out.shape):
|
471 | 507 | at_out = out[idx]
|
472 | 508 | at_expected = expected[idx]
|
473 |
| - msg = ( |
474 |
| - f"{sh.fmt_idx(out_repr, idx)}={at_out}, should be {at_expected} " |
475 |
| - f"{f_func}" |
476 |
| - ) |
| 509 | + msg = msg_template.format(sh.fmt_idx(out_repr, idx), at_out, at_expected) |
477 | 510 | _assert_float_element(at_out, at_expected, msg)
|
478 | 511 | elif out.dtype in dh.complex_dtypes:
|
479 | 512 | assert (out.dtype in dh.complex_dtypes) == (expected.dtype in dh.complex_dtypes)
|
480 | 513 | for idx in sh.ndindex(out.shape):
|
481 | 514 | at_out = out[idx]
|
482 | 515 | at_expected = expected[idx]
|
483 |
| - msg = ( |
484 |
| - f"{sh.fmt_idx(out_repr, idx)}={at_out}, should be {at_expected} " |
485 |
| - f"{f_func}" |
486 |
| - ) |
| 516 | + msg = msg_template.format(sh.fmt_idx(out_repr, idx), at_out, at_expected) |
487 | 517 | _assert_float_element(xp.real(at_out), xp.real(at_expected), msg)
|
488 | 518 | _assert_float_element(xp.imag(at_out), xp.imag(at_expected), msg)
|
489 | 519 | else:
|
490 |
| - assert xp.all(match), f"{out_repr} not as expected {f_func}\n{out_repr}={out!r}\n{expected=}" |
| 520 | + for idx in sh.ndindex(out.shape): |
| 521 | + at_out = out[idx] |
| 522 | + at_expected = expected[idx] |
| 523 | + msg = msg_template.format(sh.fmt_idx(out_repr, idx), at_out, at_expected) |
| 524 | + assert at_out == at_expected, msg |
0 commit comments