From 9ee83813a83c4ab0a9862a04297fb70b79f9aedf Mon Sep 17 00:00:00 2001 From: Matthew Barber Date: Tue, 18 Jan 2022 12:01:33 +0000 Subject: [PATCH 1/5] Basic mask tests --- array_api_tests/test_array_object.py | 45 +++++++++++++++++++++++++++- 1 file changed, 44 insertions(+), 1 deletion(-) diff --git a/array_api_tests/test_array_object.py b/array_api_tests/test_array_object.py index 78385cbe..24a13203 100644 --- a/array_api_tests/test_array_object.py +++ b/array_api_tests/test_array_object.py @@ -107,7 +107,50 @@ def test_setitem(shape, data): ph.assert_0d_equals("__setitem__", "value", value, f"x[{key}]", res[key]) -# TODO: test boolean indexing +# TODO: make mask tests optional + + +@given(hh.shapes(), st.data()) +def test_getitem_mask(shape, data): + x = data.draw(xps.arrays(xps.scalar_dtypes(), shape=shape), label="x") + mask_shapes = st.one_of( + st.sampled_from([x.shape, ()]), + st.lists(st.booleans(), min_size=x.ndim, max_size=x.ndim).map( + lambda l: tuple(s if b else 0 for s, b in zip(x.shape, l)) + ), + hh.shapes(), + ) + key = data.draw(xps.arrays(dtype=xp.bool, shape=mask_shapes), label="key") + + if key.ndim > x.ndim or not all( + ks in (xs, 0) for xs, ks in zip(x.shape, key.shape) + ): + with pytest.raises(IndexError): + x[key] + return + + out = x[key] + + ph.assert_dtype("__getitem__", x.dtype, out.dtype) + if key.ndim == 0: + out_shape = (1,) if key else (0,) + out_shape += x.shape + else: + size = int(xp.sum(xp.astype(key, xp.uint8))) + out_shape = (size,) + x.shape[key.ndim :] + ph.assert_shape("__getitem__", out.shape, out_shape) + + +@given(hh.shapes(min_side=1), st.data()) +def test_setitem_mask(shape, data): + x = data.draw(xps.arrays(xps.scalar_dtypes(), shape=shape), label="x") + key = data.draw(xps.arrays(dtype=xp.bool, shape=shape), label="key") + value = data.draw(xps.from_dtype(x.dtype), label="value") + + res = xp.asarray(x, copy=True) + res[key] = value + + # TODO def make_param(method_name: str, dtype: DataType, stype: ScalarType) -> Param: From aea80971616eba1b837e8a8163d9cb444791d0b0 Mon Sep 17 00:00:00 2001 From: Matthew Barber Date: Tue, 18 Jan 2022 12:28:39 +0000 Subject: [PATCH 2/5] Rudimentary out elements assertions for getitem masking --- array_api_tests/test_array_object.py | 17 ++++++++++++++++- 1 file changed, 16 insertions(+), 1 deletion(-) diff --git a/array_api_tests/test_array_object.py b/array_api_tests/test_array_object.py index 24a13203..de800364 100644 --- a/array_api_tests/test_array_object.py +++ b/array_api_tests/test_array_object.py @@ -10,6 +10,7 @@ from . import dtype_helpers as dh from . import hypothesis_helpers as hh from . import pytest_helpers as ph +from . import shape_helpers as sh from . import xps from .typing import DataType, Param, Scalar, ScalarType, Shape @@ -140,12 +141,26 @@ def test_getitem_mask(shape, data): out_shape = (size,) + x.shape[key.ndim :] ph.assert_shape("__getitem__", out.shape, out_shape) + if not any(s == 0 for s in key.shape): + assume(key.ndim == x.ndim) # TODO: test key.ndim < x.ndim scenarios + out_indices = sh.ndindex(out.shape) + for x_idx in sh.ndindex(x.shape): + if key[x_idx]: + out_idx = next(out_indices) + ph.assert_0d_equals( + "__getitem__", + f"x[{x_idx}]", + x[x_idx], + f"out[{out_idx}]", + out[out_idx], + ) + @given(hh.shapes(min_side=1), st.data()) def test_setitem_mask(shape, data): x = data.draw(xps.arrays(xps.scalar_dtypes(), shape=shape), label="x") key = data.draw(xps.arrays(dtype=xp.bool, shape=shape), label="key") - value = data.draw(xps.from_dtype(x.dtype), label="value") + value = data.draw(xps.from_dtype(x.dtype), label="value") # TODO: more values res = xp.asarray(x, copy=True) res[key] = value From fe9e122f1456212857e9f9f5f1007915e9865395 Mon Sep 17 00:00:00 2001 From: Matthew Barber Date: Tue, 18 Jan 2022 13:10:14 +0000 Subject: [PATCH 3/5] Cover much more in `test_setitem_mask` --- array_api_tests/pytest_helpers.py | 9 +++++---- array_api_tests/test_array_object.py | 30 +++++++++++++++++++++++++--- 2 files changed, 32 insertions(+), 7 deletions(-) diff --git a/array_api_tests/pytest_helpers.py b/array_api_tests/pytest_helpers.py index 2cb78cb0..64e39aa4 100644 --- a/array_api_tests/pytest_helpers.py +++ b/array_api_tests/pytest_helpers.py @@ -204,18 +204,19 @@ def assert_scalar_equals( out: Scalar, expected: Scalar, /, + repr_name: str = "out", **kw, ): - out_repr = "out" if idx == () else f"out[{idx}]" + repr_name = repr_name if idx == () else f"{repr_name}[{idx}]" f_func = f"{func_name}({fmt_kw(kw)})" if type_ is bool or type_ is int: - msg = f"{out_repr}={out}, but should be {expected} [{f_func}]" + msg = f"{repr_name}={out}, but should be {expected} [{f_func}]" assert out == expected, msg elif math.isnan(expected): - msg = f"{out_repr}={out}, but should be {expected} [{f_func}]" + msg = f"{repr_name}={out}, but should be {expected} [{f_func}]" assert math.isnan(out), msg else: - msg = f"{out_repr}={out}, but should be roughly {expected} [{f_func}]" + msg = f"{repr_name}={out}, but should be roughly {expected} [{f_func}]" assert math.isclose(out, expected, rel_tol=0.25, abs_tol=1), msg diff --git a/array_api_tests/test_array_object.py b/array_api_tests/test_array_object.py index de800364..6dbd4a9a 100644 --- a/array_api_tests/test_array_object.py +++ b/array_api_tests/test_array_object.py @@ -156,16 +156,40 @@ def test_getitem_mask(shape, data): ) -@given(hh.shapes(min_side=1), st.data()) +@given(hh.shapes(), st.data()) def test_setitem_mask(shape, data): x = data.draw(xps.arrays(xps.scalar_dtypes(), shape=shape), label="x") key = data.draw(xps.arrays(dtype=xp.bool, shape=shape), label="key") - value = data.draw(xps.from_dtype(x.dtype), label="value") # TODO: more values + value = data.draw( + xps.from_dtype(x.dtype) | xps.arrays(dtype=x.dtype, shape=()), label="value" + ) res = xp.asarray(x, copy=True) res[key] = value - # TODO + ph.assert_dtype("__setitem__", x.dtype, res.dtype, repr_name="x.dtype") + ph.assert_shape("__setitem__", res.shape, x.shape, repr_name="x.dtype") + + scalar_type = dh.get_scalar_type(x.dtype) + for idx in sh.ndindex(x.shape): + if key[idx]: + if isinstance(value, Scalar): + ph.assert_scalar_equals( + "__setitem__", + scalar_type, + idx, + scalar_type(res[idx]), + value, + repr_name="modified x", + ) + else: + ph.assert_0d_equals( + "__setitem__", "value", value, f"modified x[{idx}]", res[idx] + ) + else: + ph.assert_0d_equals( + "__setitem__", f"old x[{idx}]", x[idx], f"modified x[{idx}]", res[idx] + ) def make_param(method_name: str, dtype: DataType, stype: ScalarType) -> Param: From e83b101cac05486fecba175781f0aa245fd3e246 Mon Sep 17 00:00:00 2001 From: Matthew Barber Date: Tue, 18 Jan 2022 18:05:56 +0000 Subject: [PATCH 4/5] Test unaffected indices of `test_setitem` --- array_api_tests/test_array_object.py | 16 +++++++++++++--- 1 file changed, 13 insertions(+), 3 deletions(-) diff --git a/array_api_tests/test_array_object.py b/array_api_tests/test_array_object.py index 6dbd4a9a..8efaa1a6 100644 --- a/array_api_tests/test_array_object.py +++ b/array_api_tests/test_array_object.py @@ -88,6 +88,7 @@ def test_setitem(shape, data): ) x = xp.asarray(obj, dtype=dtype) note(f"{x=}") + # TODO: test setting non-0d arrays key = data.draw(xps.indices(shape=shape, max_dims=0), label="key") value = data.draw( xps.from_dtype(dtype) | xps.arrays(dtype=dtype, shape=()), label="value" @@ -105,7 +106,18 @@ def test_setitem(shape, data): else: assert res[key] == value, msg else: - ph.assert_0d_equals("__setitem__", "value", value, f"x[{key}]", res[key]) + ph.assert_0d_equals( + "__setitem__", "value", value, f"modified x[{key}]", res[key] + ) + _key = key if isinstance(key, tuple) else (key,) + assume(all(isinstance(i, int) for i in _key)) # TODO: normalise slices and ellipsis + _key = tuple(i if i >= 0 else s + i for i, s in zip(_key, x.shape)) + unaffected_indices = list(sh.ndindex(res.shape)) + unaffected_indices.remove(_key) + for idx in unaffected_indices: + ph.assert_0d_equals( + "__setitem__", f"old x[{idx}]", x[idx], f"modified x[{idx}]", res[idx] + ) # TODO: make mask tests optional @@ -140,7 +152,6 @@ def test_getitem_mask(shape, data): size = int(xp.sum(xp.astype(key, xp.uint8))) out_shape = (size,) + x.shape[key.ndim :] ph.assert_shape("__getitem__", out.shape, out_shape) - if not any(s == 0 for s in key.shape): assume(key.ndim == x.ndim) # TODO: test key.ndim < x.ndim scenarios out_indices = sh.ndindex(out.shape) @@ -169,7 +180,6 @@ def test_setitem_mask(shape, data): ph.assert_dtype("__setitem__", x.dtype, res.dtype, repr_name="x.dtype") ph.assert_shape("__setitem__", res.shape, x.shape, repr_name="x.dtype") - scalar_type = dh.get_scalar_type(x.dtype) for idx in sh.ndindex(x.shape): if key[idx]: From be70f2615db5e67309a1db56454eacaa2033263b Mon Sep 17 00:00:00 2001 From: Matthew Barber Date: Wed, 19 Jan 2022 09:27:10 +0000 Subject: [PATCH 5/5] `mask` -> `masking` for indexing test names --- array_api_tests/test_array_object.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/array_api_tests/test_array_object.py b/array_api_tests/test_array_object.py index 8efaa1a6..397d823d 100644 --- a/array_api_tests/test_array_object.py +++ b/array_api_tests/test_array_object.py @@ -124,7 +124,7 @@ def test_setitem(shape, data): @given(hh.shapes(), st.data()) -def test_getitem_mask(shape, data): +def test_getitem_masking(shape, data): x = data.draw(xps.arrays(xps.scalar_dtypes(), shape=shape), label="x") mask_shapes = st.one_of( st.sampled_from([x.shape, ()]), @@ -168,7 +168,7 @@ def test_getitem_mask(shape, data): @given(hh.shapes(), st.data()) -def test_setitem_mask(shape, data): +def test_setitem_masking(shape, data): x = data.draw(xps.arrays(xps.scalar_dtypes(), shape=shape), label="x") key = data.draw(xps.arrays(dtype=xp.bool, shape=shape), label="key") value = data.draw(