Skip to content

Commit e83b101

Browse files
committed
Test unaffected indices of test_setitem
1 parent fe9e122 commit e83b101

File tree

1 file changed

+13
-3
lines changed

1 file changed

+13
-3
lines changed

array_api_tests/test_array_object.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,7 @@ def test_setitem(shape, data):
8888
)
8989
x = xp.asarray(obj, dtype=dtype)
9090
note(f"{x=}")
91+
# TODO: test setting non-0d arrays
9192
key = data.draw(xps.indices(shape=shape, max_dims=0), label="key")
9293
value = data.draw(
9394
xps.from_dtype(dtype) | xps.arrays(dtype=dtype, shape=()), label="value"
@@ -105,7 +106,18 @@ def test_setitem(shape, data):
105106
else:
106107
assert res[key] == value, msg
107108
else:
108-
ph.assert_0d_equals("__setitem__", "value", value, f"x[{key}]", res[key])
109+
ph.assert_0d_equals(
110+
"__setitem__", "value", value, f"modified x[{key}]", res[key]
111+
)
112+
_key = key if isinstance(key, tuple) else (key,)
113+
assume(all(isinstance(i, int) for i in _key)) # TODO: normalise slices and ellipsis
114+
_key = tuple(i if i >= 0 else s + i for i, s in zip(_key, x.shape))
115+
unaffected_indices = list(sh.ndindex(res.shape))
116+
unaffected_indices.remove(_key)
117+
for idx in unaffected_indices:
118+
ph.assert_0d_equals(
119+
"__setitem__", f"old x[{idx}]", x[idx], f"modified x[{idx}]", res[idx]
120+
)
109121

110122

111123
# TODO: make mask tests optional
@@ -140,7 +152,6 @@ def test_getitem_mask(shape, data):
140152
size = int(xp.sum(xp.astype(key, xp.uint8)))
141153
out_shape = (size,) + x.shape[key.ndim :]
142154
ph.assert_shape("__getitem__", out.shape, out_shape)
143-
144155
if not any(s == 0 for s in key.shape):
145156
assume(key.ndim == x.ndim) # TODO: test key.ndim < x.ndim scenarios
146157
out_indices = sh.ndindex(out.shape)
@@ -169,7 +180,6 @@ def test_setitem_mask(shape, data):
169180

170181
ph.assert_dtype("__setitem__", x.dtype, res.dtype, repr_name="x.dtype")
171182
ph.assert_shape("__setitem__", res.shape, x.shape, repr_name="x.dtype")
172-
173183
scalar_type = dh.get_scalar_type(x.dtype)
174184
for idx in sh.ndindex(x.shape):
175185
if key[idx]:

0 commit comments

Comments
 (0)