Skip to content

Commit ca3ef30

Browse files
authored
Merge pull request #77 from honno/mask-tests
Mask tests
2 parents 081d700 + be70f26 commit ca3ef30

File tree

2 files changed

+99
-6
lines changed

2 files changed

+99
-6
lines changed

array_api_tests/pytest_helpers.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -204,18 +204,19 @@ def assert_scalar_equals(
204204
out: Scalar,
205205
expected: Scalar,
206206
/,
207+
repr_name: str = "out",
207208
**kw,
208209
):
209-
out_repr = "out" if idx == () else f"out[{idx}]"
210+
repr_name = repr_name if idx == () else f"{repr_name}[{idx}]"
210211
f_func = f"{func_name}({fmt_kw(kw)})"
211212
if type_ is bool or type_ is int:
212-
msg = f"{out_repr}={out}, but should be {expected} [{f_func}]"
213+
msg = f"{repr_name}={out}, but should be {expected} [{f_func}]"
213214
assert out == expected, msg
214215
elif math.isnan(expected):
215-
msg = f"{out_repr}={out}, but should be {expected} [{f_func}]"
216+
msg = f"{repr_name}={out}, but should be {expected} [{f_func}]"
216217
assert math.isnan(out), msg
217218
else:
218-
msg = f"{out_repr}={out}, but should be roughly {expected} [{f_func}]"
219+
msg = f"{repr_name}={out}, but should be roughly {expected} [{f_func}]"
219220
assert math.isclose(out, expected, rel_tol=0.25, abs_tol=1), msg
220221

221222

array_api_tests/test_array_object.py

Lines changed: 94 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from . import dtype_helpers as dh
1111
from . import hypothesis_helpers as hh
1212
from . import pytest_helpers as ph
13+
from . import shape_helpers as sh
1314
from . import xps
1415
from .typing import DataType, Param, Scalar, ScalarType, Shape
1516

@@ -87,6 +88,7 @@ def test_setitem(shape, data):
8788
)
8889
x = xp.asarray(obj, dtype=dtype)
8990
note(f"{x=}")
91+
# TODO: test setting non-0d arrays
9092
key = data.draw(xps.indices(shape=shape, max_dims=0), label="key")
9193
value = data.draw(
9294
xps.from_dtype(dtype) | xps.arrays(dtype=dtype, shape=()), label="value"
@@ -104,10 +106,100 @@ def test_setitem(shape, data):
104106
else:
105107
assert res[key] == value, msg
106108
else:
107-
ph.assert_0d_equals("__setitem__", "value", value, f"x[{key}]", res[key])
109+
ph.assert_0d_equals(
110+
"__setitem__", "value", value, f"modified x[{key}]", res[key]
111+
)
112+
_key = key if isinstance(key, tuple) else (key,)
113+
assume(all(isinstance(i, int) for i in _key)) # TODO: normalise slices and ellipsis
114+
_key = tuple(i if i >= 0 else s + i for i, s in zip(_key, x.shape))
115+
unaffected_indices = list(sh.ndindex(res.shape))
116+
unaffected_indices.remove(_key)
117+
for idx in unaffected_indices:
118+
ph.assert_0d_equals(
119+
"__setitem__", f"old x[{idx}]", x[idx], f"modified x[{idx}]", res[idx]
120+
)
121+
122+
123+
# TODO: make mask tests optional
124+
125+
126+
@given(hh.shapes(), st.data())
127+
def test_getitem_masking(shape, data):
128+
x = data.draw(xps.arrays(xps.scalar_dtypes(), shape=shape), label="x")
129+
mask_shapes = st.one_of(
130+
st.sampled_from([x.shape, ()]),
131+
st.lists(st.booleans(), min_size=x.ndim, max_size=x.ndim).map(
132+
lambda l: tuple(s if b else 0 for s, b in zip(x.shape, l))
133+
),
134+
hh.shapes(),
135+
)
136+
key = data.draw(xps.arrays(dtype=xp.bool, shape=mask_shapes), label="key")
108137

138+
if key.ndim > x.ndim or not all(
139+
ks in (xs, 0) for xs, ks in zip(x.shape, key.shape)
140+
):
141+
with pytest.raises(IndexError):
142+
x[key]
143+
return
144+
145+
out = x[key]
109146

110-
# TODO: test boolean indexing
147+
ph.assert_dtype("__getitem__", x.dtype, out.dtype)
148+
if key.ndim == 0:
149+
out_shape = (1,) if key else (0,)
150+
out_shape += x.shape
151+
else:
152+
size = int(xp.sum(xp.astype(key, xp.uint8)))
153+
out_shape = (size,) + x.shape[key.ndim :]
154+
ph.assert_shape("__getitem__", out.shape, out_shape)
155+
if not any(s == 0 for s in key.shape):
156+
assume(key.ndim == x.ndim) # TODO: test key.ndim < x.ndim scenarios
157+
out_indices = sh.ndindex(out.shape)
158+
for x_idx in sh.ndindex(x.shape):
159+
if key[x_idx]:
160+
out_idx = next(out_indices)
161+
ph.assert_0d_equals(
162+
"__getitem__",
163+
f"x[{x_idx}]",
164+
x[x_idx],
165+
f"out[{out_idx}]",
166+
out[out_idx],
167+
)
168+
169+
170+
@given(hh.shapes(), st.data())
171+
def test_setitem_masking(shape, data):
172+
x = data.draw(xps.arrays(xps.scalar_dtypes(), shape=shape), label="x")
173+
key = data.draw(xps.arrays(dtype=xp.bool, shape=shape), label="key")
174+
value = data.draw(
175+
xps.from_dtype(x.dtype) | xps.arrays(dtype=x.dtype, shape=()), label="value"
176+
)
177+
178+
res = xp.asarray(x, copy=True)
179+
res[key] = value
180+
181+
ph.assert_dtype("__setitem__", x.dtype, res.dtype, repr_name="x.dtype")
182+
ph.assert_shape("__setitem__", res.shape, x.shape, repr_name="x.dtype")
183+
scalar_type = dh.get_scalar_type(x.dtype)
184+
for idx in sh.ndindex(x.shape):
185+
if key[idx]:
186+
if isinstance(value, Scalar):
187+
ph.assert_scalar_equals(
188+
"__setitem__",
189+
scalar_type,
190+
idx,
191+
scalar_type(res[idx]),
192+
value,
193+
repr_name="modified x",
194+
)
195+
else:
196+
ph.assert_0d_equals(
197+
"__setitem__", "value", value, f"modified x[{idx}]", res[idx]
198+
)
199+
else:
200+
ph.assert_0d_equals(
201+
"__setitem__", f"old x[{idx}]", x[idx], f"modified x[{idx}]", res[idx]
202+
)
111203

112204

113205
def make_param(method_name: str, dtype: DataType, stype: ScalarType) -> Param:

0 commit comments

Comments
 (0)