Skip to content

Mask tests #77

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Jan 19, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 5 additions & 4 deletions array_api_tests/pytest_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,18 +204,19 @@ def assert_scalar_equals(
out: Scalar,
expected: Scalar,
/,
repr_name: str = "out",
**kw,
):
out_repr = "out" if idx == () else f"out[{idx}]"
repr_name = repr_name if idx == () else f"{repr_name}[{idx}]"
f_func = f"{func_name}({fmt_kw(kw)})"
if type_ is bool or type_ is int:
msg = f"{out_repr}={out}, but should be {expected} [{f_func}]"
msg = f"{repr_name}={out}, but should be {expected} [{f_func}]"
assert out == expected, msg
elif math.isnan(expected):
msg = f"{out_repr}={out}, but should be {expected} [{f_func}]"
msg = f"{repr_name}={out}, but should be {expected} [{f_func}]"
assert math.isnan(out), msg
else:
msg = f"{out_repr}={out}, but should be roughly {expected} [{f_func}]"
msg = f"{repr_name}={out}, but should be roughly {expected} [{f_func}]"
assert math.isclose(out, expected, rel_tol=0.25, abs_tol=1), msg


Expand Down
96 changes: 94 additions & 2 deletions array_api_tests/test_array_object.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from . import dtype_helpers as dh
from . import hypothesis_helpers as hh
from . import pytest_helpers as ph
from . import shape_helpers as sh
from . import xps
from .typing import DataType, Param, Scalar, ScalarType, Shape

Expand Down Expand Up @@ -87,6 +88,7 @@ def test_setitem(shape, data):
)
x = xp.asarray(obj, dtype=dtype)
note(f"{x=}")
# TODO: test setting non-0d arrays
key = data.draw(xps.indices(shape=shape, max_dims=0), label="key")
value = data.draw(
xps.from_dtype(dtype) | xps.arrays(dtype=dtype, shape=()), label="value"
Expand All @@ -104,10 +106,100 @@ def test_setitem(shape, data):
else:
assert res[key] == value, msg
else:
ph.assert_0d_equals("__setitem__", "value", value, f"x[{key}]", res[key])
ph.assert_0d_equals(
"__setitem__", "value", value, f"modified x[{key}]", res[key]
)
_key = key if isinstance(key, tuple) else (key,)
assume(all(isinstance(i, int) for i in _key)) # TODO: normalise slices and ellipsis
_key = tuple(i if i >= 0 else s + i for i, s in zip(_key, x.shape))
unaffected_indices = list(sh.ndindex(res.shape))
unaffected_indices.remove(_key)
for idx in unaffected_indices:
ph.assert_0d_equals(
"__setitem__", f"old x[{idx}]", x[idx], f"modified x[{idx}]", res[idx]
)


# TODO: make mask tests optional


@given(hh.shapes(), st.data())
def test_getitem_masking(shape, data):
x = data.draw(xps.arrays(xps.scalar_dtypes(), shape=shape), label="x")
mask_shapes = st.one_of(
st.sampled_from([x.shape, ()]),
st.lists(st.booleans(), min_size=x.ndim, max_size=x.ndim).map(
lambda l: tuple(s if b else 0 for s, b in zip(x.shape, l))
),
hh.shapes(),
)
key = data.draw(xps.arrays(dtype=xp.bool, shape=mask_shapes), label="key")

if key.ndim > x.ndim or not all(
ks in (xs, 0) for xs, ks in zip(x.shape, key.shape)
):
with pytest.raises(IndexError):
x[key]
return

out = x[key]

# TODO: test boolean indexing
ph.assert_dtype("__getitem__", x.dtype, out.dtype)
if key.ndim == 0:
out_shape = (1,) if key else (0,)
out_shape += x.shape
else:
size = int(xp.sum(xp.astype(key, xp.uint8)))
out_shape = (size,) + x.shape[key.ndim :]
ph.assert_shape("__getitem__", out.shape, out_shape)
if not any(s == 0 for s in key.shape):
assume(key.ndim == x.ndim) # TODO: test key.ndim < x.ndim scenarios
out_indices = sh.ndindex(out.shape)
for x_idx in sh.ndindex(x.shape):
if key[x_idx]:
out_idx = next(out_indices)
ph.assert_0d_equals(
"__getitem__",
f"x[{x_idx}]",
x[x_idx],
f"out[{out_idx}]",
out[out_idx],
)


@given(hh.shapes(), st.data())
def test_setitem_masking(shape, data):
x = data.draw(xps.arrays(xps.scalar_dtypes(), shape=shape), label="x")
key = data.draw(xps.arrays(dtype=xp.bool, shape=shape), label="key")
value = data.draw(
xps.from_dtype(x.dtype) | xps.arrays(dtype=x.dtype, shape=()), label="value"
)

res = xp.asarray(x, copy=True)
res[key] = value

ph.assert_dtype("__setitem__", x.dtype, res.dtype, repr_name="x.dtype")
ph.assert_shape("__setitem__", res.shape, x.shape, repr_name="x.dtype")
scalar_type = dh.get_scalar_type(x.dtype)
for idx in sh.ndindex(x.shape):
if key[idx]:
if isinstance(value, Scalar):
ph.assert_scalar_equals(
"__setitem__",
scalar_type,
idx,
scalar_type(res[idx]),
value,
repr_name="modified x",
)
else:
ph.assert_0d_equals(
"__setitem__", "value", value, f"modified x[{idx}]", res[idx]
)
else:
ph.assert_0d_equals(
"__setitem__", f"old x[{idx}]", x[idx], f"modified x[{idx}]", res[idx]
)


def make_param(method_name: str, dtype: DataType, stype: ScalarType) -> Param:
Expand Down