|
7 | 7 | from . import dtype_helpers as dh
|
8 | 8 | from . import shape_helpers as sh
|
9 | 9 | from . import stubs
|
| 10 | +from . import xp as _xp |
10 | 11 | from .typing import Array, DataType, Scalar, ScalarType, Shape
|
11 | 12 |
|
12 | 13 | __all__ = [
|
@@ -420,6 +421,35 @@ def assert_fill(
|
420 | 421 | assert xp.all(xp.equal(out, xp.asarray(fill_value, dtype=dtype))), msg
|
421 | 422 |
|
422 | 423 |
|
| 424 | +def _real_float_strict_equals(out: Array, expected: Array) -> bool: |
| 425 | + nan_mask = xp.isnan(out) |
| 426 | + if not xp.all(nan_mask == xp.isnan(expected)): |
| 427 | + return False |
| 428 | + ignore_mask = nan_mask |
| 429 | + |
| 430 | + # Test sign of zeroes if xp.signbit() available, otherwise ignore as it's |
| 431 | + # not that big of a deal for the perf costs. |
| 432 | + if hasattr(_xp, "signbit"): |
| 433 | + out_zero_mask = out == 0 |
| 434 | + out_sign_mask = _xp.signbit(out) |
| 435 | + out_pos_zero_mask = out_zero_mask & out_sign_mask |
| 436 | + out_neg_zero_mask = out_zero_mask & ~out_sign_mask |
| 437 | + expected_zero_mask = expected == 0 |
| 438 | + expected_sign_mask = _xp.signbit(expected) |
| 439 | + expected_pos_zero_mask = expected_zero_mask & expected_sign_mask |
| 440 | + expected_neg_zero_mask = expected_zero_mask & ~expected_sign_mask |
| 441 | + pos_zero_match = out_pos_zero_mask == expected_pos_zero_mask |
| 442 | + neg_zero_match = out_neg_zero_mask == expected_neg_zero_mask |
| 443 | + if not (xp.all(pos_zero_match) and xp.all(neg_zero_match)): |
| 444 | + return False |
| 445 | + ignore_mask |= out_zero_mask |
| 446 | + |
| 447 | + replacement = xp.asarray(42, dtype=out.dtype) # i.e. an arbitrary non-zero value that equals itself |
| 448 | + assert replacement == replacement # sanity check |
| 449 | + match = xp.where(ignore_mask, replacement, out) == xp.where(ignore_mask, replacement, expected) |
| 450 | + return xp.all(match) |
| 451 | + |
| 452 | + |
423 | 453 | def _assert_float_element(at_out: Array, at_expected: Array, msg: str):
|
424 | 454 | if xp.isnan(at_expected):
|
425 | 455 | assert xp.isnan(at_out), msg
|
@@ -455,31 +485,45 @@ def assert_array_elements(
|
455 | 485 | >>> assert xp.all(out == x)
|
456 | 486 |
|
457 | 487 | """
|
458 |
| - __tracebackhide__ = True |
| 488 | + # __tracebackhide__ = True |
459 | 489 | dh.result_type(out.dtype, expected.dtype) # sanity check
|
460 | 490 | assert_shape(func_name, out_shape=out.shape, expected=expected.shape, kw=kw) # sanity check
|
461 | 491 | f_func = f"[{func_name}({fmt_kw(kw)})]"
|
| 492 | + |
| 493 | + # First we try short-circuit for a successful assertion by using vectorised checks. |
| 494 | + if out.dtype in dh.real_float_dtypes: |
| 495 | + if _real_float_strict_equals(out, expected): |
| 496 | + return |
| 497 | + elif out.dtype in dh.complex_dtypes: |
| 498 | + real_match = _real_float_strict_equals(xp.real(out), xp.real(expected)) |
| 499 | + imag_match = _real_float_strict_equals(xp.imag(out), xp.imag(expected)) |
| 500 | + if real_match and imag_match: |
| 501 | + return |
| 502 | + else: |
| 503 | + match = out == expected |
| 504 | + if xp.all(match): |
| 505 | + return |
| 506 | + |
| 507 | + # In case of mismatch, generate a more helpful error. Cycling through all indices is |
| 508 | + # costly in some array api implementations, so we only do this in the case of a failure. |
| 509 | + msg_template = "{}={}, but should be {} " + f_func |
462 | 510 | if out.dtype in dh.real_float_dtypes:
|
463 | 511 | for idx in sh.ndindex(out.shape):
|
464 | 512 | at_out = out[idx]
|
465 | 513 | at_expected = expected[idx]
|
466 |
| - msg = ( |
467 |
| - f"{sh.fmt_idx(out_repr, idx)}={at_out}, should be {at_expected} " |
468 |
| - f"{f_func}" |
469 |
| - ) |
| 514 | + msg = msg_template.format(sh.fmt_idx(out_repr, idx), at_out, at_expected) |
470 | 515 | _assert_float_element(at_out, at_expected, msg)
|
471 | 516 | elif out.dtype in dh.complex_dtypes:
|
472 | 517 | assert (out.dtype in dh.complex_dtypes) == (expected.dtype in dh.complex_dtypes)
|
473 | 518 | for idx in sh.ndindex(out.shape):
|
474 | 519 | at_out = out[idx]
|
475 | 520 | at_expected = expected[idx]
|
476 |
| - msg = ( |
477 |
| - f"{sh.fmt_idx(out_repr, idx)}={at_out}, should be {at_expected} " |
478 |
| - f"{f_func}" |
479 |
| - ) |
| 521 | + msg = msg_template.format(sh.fmt_idx(out_repr, idx), at_out, at_expected) |
480 | 522 | _assert_float_element(xp.real(at_out), xp.real(at_expected), msg)
|
481 | 523 | _assert_float_element(xp.imag(at_out), xp.imag(at_expected), msg)
|
482 | 524 | else:
|
483 |
| - assert xp.all( |
484 |
| - out == expected |
485 |
| - ), f"{out_repr} not as expected {f_func}\n{out_repr}={out!r}\n{expected=}" |
| 525 | + for idx in sh.ndindex(out.shape): |
| 526 | + at_out = out[idx] |
| 527 | + at_expected = expected[idx] |
| 528 | + msg = msg_template.format(sh.fmt_idx(out_repr, idx), at_out, at_expected) |
| 529 | + assert at_out == at_expected, msg |
0 commit comments