Skip to content

Commit aea8097

Browse files
committed
Rudimentary out elements assertions for getitem masking
1 parent 9ee8381 commit aea8097

File tree

1 file changed

+16
-1
lines changed

1 file changed

+16
-1
lines changed

array_api_tests/test_array_object.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from . import dtype_helpers as dh
1111
from . import hypothesis_helpers as hh
1212
from . import pytest_helpers as ph
13+
from . import shape_helpers as sh
1314
from . import xps
1415
from .typing import DataType, Param, Scalar, ScalarType, Shape
1516

@@ -140,12 +141,26 @@ def test_getitem_mask(shape, data):
140141
out_shape = (size,) + x.shape[key.ndim :]
141142
ph.assert_shape("__getitem__", out.shape, out_shape)
142143

144+
if not any(s == 0 for s in key.shape):
145+
assume(key.ndim == x.ndim) # TODO: test key.ndim < x.ndim scenarios
146+
out_indices = sh.ndindex(out.shape)
147+
for x_idx in sh.ndindex(x.shape):
148+
if key[x_idx]:
149+
out_idx = next(out_indices)
150+
ph.assert_0d_equals(
151+
"__getitem__",
152+
f"x[{x_idx}]",
153+
x[x_idx],
154+
f"out[{out_idx}]",
155+
out[out_idx],
156+
)
157+
143158

144159
@given(hh.shapes(min_side=1), st.data())
145160
def test_setitem_mask(shape, data):
146161
x = data.draw(xps.arrays(xps.scalar_dtypes(), shape=shape), label="x")
147162
key = data.draw(xps.arrays(dtype=xp.bool, shape=shape), label="key")
148-
value = data.draw(xps.from_dtype(x.dtype), label="value")
163+
value = data.draw(xps.from_dtype(x.dtype), label="value") # TODO: more values
149164

150165
res = xp.asarray(x, copy=True)
151166
res[key] = value

0 commit comments

Comments
 (0)