From 62bbd3f4d74a724a6ba54b0bdeac23616cb1e206 Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Mon, 13 Nov 2023 10:17:32 -0800 Subject: [PATCH 1/2] Make assert_array_elements more efficient in the non-error case --- array_api_tests/pytest_helpers.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/array_api_tests/pytest_helpers.py b/array_api_tests/pytest_helpers.py index e6ede7b2..6d3899a9 100644 --- a/array_api_tests/pytest_helpers.py +++ b/array_api_tests/pytest_helpers.py @@ -459,6 +459,13 @@ def assert_array_elements( dh.result_type(out.dtype, expected.dtype) # sanity check assert_shape(func_name, out_shape=out.shape, expected=expected.shape, kw=kw) # sanity check f_func = f"[{func_name}({fmt_kw(kw)})]" + + match = (out == expected) + if xp.all(match): + return + + # In case of mismatch, generate a more helpful error. Cycling through all indices is + # costly in some array api implementations, so we only do this in the case of a failure. if out.dtype in dh.real_float_dtypes: for idx in sh.ndindex(out.shape): at_out = out[idx] @@ -480,6 +487,4 @@ def assert_array_elements( _assert_float_element(xp.real(at_out), xp.real(at_expected), msg) _assert_float_element(xp.imag(at_out), xp.imag(at_expected), msg) else: - assert xp.all( - out == expected - ), f"{out_repr} not as expected {f_func}\n{out_repr}={out!r}\n{expected=}" + assert xp.all(match), f"{out_repr} not as expected {f_func}\n{out_repr}={out!r}\n{expected=}" From 89b2112900e9a75c64e111cb984d45754d1486ce Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Mon, 13 Nov 2023 10:18:01 -0800 Subject: [PATCH 2/2] test_eye: use assert_array_elements utility --- array_api_tests/test_creation_functions.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/array_api_tests/test_creation_functions.py b/array_api_tests/test_creation_functions.py index 94b6b0ec..64959841 100644 --- a/array_api_tests/test_creation_functions.py +++ b/array_api_tests/test_creation_functions.py @@ -354,14 +354,14 @@ def test_eye(n_rows, n_cols, kw): ph.assert_kw_dtype("eye", kw_dtype=kw["dtype"], out_dtype=out.dtype) _n_cols = n_rows if n_cols is None else n_cols ph.assert_shape("eye", out_shape=out.shape, expected=(n_rows, _n_cols), kw=dict(n_rows=n_rows, n_cols=n_cols)) - f_func = f"[eye({n_rows=}, {n_cols=})]" - for i in range(n_rows): - for j in range(_n_cols): - f_indexed_out = f"out[{i}, {j}]={out[i, j]}" - if j - i == kw.get("k", 0): - assert out[i, j] == 1, f"{f_indexed_out}, should be 1 {f_func}" - else: - assert out[i, j] == 0, f"{f_indexed_out}, should be 0 {f_func}" + k = kw.get("k", 0) + expected = xp.asarray( + [[1 if j - i == k else 0 for j in range(_n_cols)] for i in range(n_rows)], + dtype=out.dtype # Note: dtype already checked above. + ) + if expected.size == 0: + expected = xp.reshape(expected, (n_rows, _n_cols)) + ph.assert_array_elements("eye", out=out, expected=expected, kw=kw) default_unsafe_dtypes = [xp.uint64]