diff --git a/array_api_tests/pytest_helpers.py b/array_api_tests/pytest_helpers.py index 2cb78cb0..64e39aa4 100644 --- a/array_api_tests/pytest_helpers.py +++ b/array_api_tests/pytest_helpers.py @@ -204,18 +204,19 @@ def assert_scalar_equals( out: Scalar, expected: Scalar, /, + repr_name: str = "out", **kw, ): - out_repr = "out" if idx == () else f"out[{idx}]" + repr_name = repr_name if idx == () else f"{repr_name}[{idx}]" f_func = f"{func_name}({fmt_kw(kw)})" if type_ is bool or type_ is int: - msg = f"{out_repr}={out}, but should be {expected} [{f_func}]" + msg = f"{repr_name}={out}, but should be {expected} [{f_func}]" assert out == expected, msg elif math.isnan(expected): - msg = f"{out_repr}={out}, but should be {expected} [{f_func}]" + msg = f"{repr_name}={out}, but should be {expected} [{f_func}]" assert math.isnan(out), msg else: - msg = f"{out_repr}={out}, but should be roughly {expected} [{f_func}]" + msg = f"{repr_name}={out}, but should be roughly {expected} [{f_func}]" assert math.isclose(out, expected, rel_tol=0.25, abs_tol=1), msg diff --git a/array_api_tests/test_array_object.py b/array_api_tests/test_array_object.py index 78385cbe..397d823d 100644 --- a/array_api_tests/test_array_object.py +++ b/array_api_tests/test_array_object.py @@ -10,6 +10,7 @@ from . import dtype_helpers as dh from . import hypothesis_helpers as hh from . import pytest_helpers as ph +from . import shape_helpers as sh from . import xps from .typing import DataType, Param, Scalar, ScalarType, Shape @@ -87,6 +88,7 @@ def test_setitem(shape, data): ) x = xp.asarray(obj, dtype=dtype) note(f"{x=}") + # TODO: test setting non-0d arrays key = data.draw(xps.indices(shape=shape, max_dims=0), label="key") value = data.draw( xps.from_dtype(dtype) | xps.arrays(dtype=dtype, shape=()), label="value" @@ -104,10 +106,100 @@ def test_setitem(shape, data): else: assert res[key] == value, msg else: - ph.assert_0d_equals("__setitem__", "value", value, f"x[{key}]", res[key]) + ph.assert_0d_equals( + "__setitem__", "value", value, f"modified x[{key}]", res[key] + ) + _key = key if isinstance(key, tuple) else (key,) + assume(all(isinstance(i, int) for i in _key)) # TODO: normalise slices and ellipsis + _key = tuple(i if i >= 0 else s + i for i, s in zip(_key, x.shape)) + unaffected_indices = list(sh.ndindex(res.shape)) + unaffected_indices.remove(_key) + for idx in unaffected_indices: + ph.assert_0d_equals( + "__setitem__", f"old x[{idx}]", x[idx], f"modified x[{idx}]", res[idx] + ) + + +# TODO: make mask tests optional + + +@given(hh.shapes(), st.data()) +def test_getitem_masking(shape, data): + x = data.draw(xps.arrays(xps.scalar_dtypes(), shape=shape), label="x") + mask_shapes = st.one_of( + st.sampled_from([x.shape, ()]), + st.lists(st.booleans(), min_size=x.ndim, max_size=x.ndim).map( + lambda l: tuple(s if b else 0 for s, b in zip(x.shape, l)) + ), + hh.shapes(), + ) + key = data.draw(xps.arrays(dtype=xp.bool, shape=mask_shapes), label="key") + if key.ndim > x.ndim or not all( + ks in (xs, 0) for xs, ks in zip(x.shape, key.shape) + ): + with pytest.raises(IndexError): + x[key] + return + + out = x[key] -# TODO: test boolean indexing + ph.assert_dtype("__getitem__", x.dtype, out.dtype) + if key.ndim == 0: + out_shape = (1,) if key else (0,) + out_shape += x.shape + else: + size = int(xp.sum(xp.astype(key, xp.uint8))) + out_shape = (size,) + x.shape[key.ndim :] + ph.assert_shape("__getitem__", out.shape, out_shape) + if not any(s == 0 for s in key.shape): + assume(key.ndim == x.ndim) # TODO: test key.ndim < x.ndim scenarios + out_indices = sh.ndindex(out.shape) + for x_idx in sh.ndindex(x.shape): + if key[x_idx]: + out_idx = next(out_indices) + ph.assert_0d_equals( + "__getitem__", + f"x[{x_idx}]", + x[x_idx], + f"out[{out_idx}]", + out[out_idx], + ) + + +@given(hh.shapes(), st.data()) +def test_setitem_masking(shape, data): + x = data.draw(xps.arrays(xps.scalar_dtypes(), shape=shape), label="x") + key = data.draw(xps.arrays(dtype=xp.bool, shape=shape), label="key") + value = data.draw( + xps.from_dtype(x.dtype) | xps.arrays(dtype=x.dtype, shape=()), label="value" + ) + + res = xp.asarray(x, copy=True) + res[key] = value + + ph.assert_dtype("__setitem__", x.dtype, res.dtype, repr_name="x.dtype") + ph.assert_shape("__setitem__", res.shape, x.shape, repr_name="x.dtype") + scalar_type = dh.get_scalar_type(x.dtype) + for idx in sh.ndindex(x.shape): + if key[idx]: + if isinstance(value, Scalar): + ph.assert_scalar_equals( + "__setitem__", + scalar_type, + idx, + scalar_type(res[idx]), + value, + repr_name="modified x", + ) + else: + ph.assert_0d_equals( + "__setitem__", "value", value, f"modified x[{idx}]", res[idx] + ) + else: + ph.assert_0d_equals( + "__setitem__", f"old x[{idx}]", x[idx], f"modified x[{idx}]", res[idx] + ) def make_param(method_name: str, dtype: DataType, stype: ScalarType) -> Param: