Skip to content

Commit fe9e122

Browse files
committed
Cover much more in test_setitem_mask
1 parent aea8097 commit fe9e122

File tree

2 files changed

+32
-7
lines changed

2 files changed

+32
-7
lines changed

array_api_tests/pytest_helpers.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -204,18 +204,19 @@ def assert_scalar_equals(
204204
out: Scalar,
205205
expected: Scalar,
206206
/,
207+
repr_name: str = "out",
207208
**kw,
208209
):
209-
out_repr = "out" if idx == () else f"out[{idx}]"
210+
repr_name = repr_name if idx == () else f"{repr_name}[{idx}]"
210211
f_func = f"{func_name}({fmt_kw(kw)})"
211212
if type_ is bool or type_ is int:
212-
msg = f"{out_repr}={out}, but should be {expected} [{f_func}]"
213+
msg = f"{repr_name}={out}, but should be {expected} [{f_func}]"
213214
assert out == expected, msg
214215
elif math.isnan(expected):
215-
msg = f"{out_repr}={out}, but should be {expected} [{f_func}]"
216+
msg = f"{repr_name}={out}, but should be {expected} [{f_func}]"
216217
assert math.isnan(out), msg
217218
else:
218-
msg = f"{out_repr}={out}, but should be roughly {expected} [{f_func}]"
219+
msg = f"{repr_name}={out}, but should be roughly {expected} [{f_func}]"
219220
assert math.isclose(out, expected, rel_tol=0.25, abs_tol=1), msg
220221

221222

array_api_tests/test_array_object.py

Lines changed: 27 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -156,16 +156,40 @@ def test_getitem_mask(shape, data):
156156
)
157157

158158

159-
@given(hh.shapes(min_side=1), st.data())
159+
@given(hh.shapes(), st.data())
160160
def test_setitem_mask(shape, data):
161161
x = data.draw(xps.arrays(xps.scalar_dtypes(), shape=shape), label="x")
162162
key = data.draw(xps.arrays(dtype=xp.bool, shape=shape), label="key")
163-
value = data.draw(xps.from_dtype(x.dtype), label="value") # TODO: more values
163+
value = data.draw(
164+
xps.from_dtype(x.dtype) | xps.arrays(dtype=x.dtype, shape=()), label="value"
165+
)
164166

165167
res = xp.asarray(x, copy=True)
166168
res[key] = value
167169

168-
# TODO
170+
ph.assert_dtype("__setitem__", x.dtype, res.dtype, repr_name="x.dtype")
171+
ph.assert_shape("__setitem__", res.shape, x.shape, repr_name="x.dtype")
172+
173+
scalar_type = dh.get_scalar_type(x.dtype)
174+
for idx in sh.ndindex(x.shape):
175+
if key[idx]:
176+
if isinstance(value, Scalar):
177+
ph.assert_scalar_equals(
178+
"__setitem__",
179+
scalar_type,
180+
idx,
181+
scalar_type(res[idx]),
182+
value,
183+
repr_name="modified x",
184+
)
185+
else:
186+
ph.assert_0d_equals(
187+
"__setitem__", "value", value, f"modified x[{idx}]", res[idx]
188+
)
189+
else:
190+
ph.assert_0d_equals(
191+
"__setitem__", f"old x[{idx}]", x[idx], f"modified x[{idx}]", res[idx]
192+
)
169193

170194

171195
def make_param(method_name: str, dtype: DataType, stype: ScalarType) -> Param:

0 commit comments

Comments
 (0)