Skip to content

Commit 043e437

Browse files
committed
Test __setitem__
1 parent 1c3bb46 commit 043e437

File tree

1 file changed

+37
-8
lines changed

1 file changed

+37
-8
lines changed

xptests/test_array_object.py

Lines changed: 37 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import math
22
from itertools import product
3-
from typing import Sequence, Union
3+
from typing import Sequence, Union, get_args
44

55
import pytest
66
from hypothesis import assume, given, note
@@ -33,11 +33,9 @@ def test_getitem(shape, data):
3333
size = math.prod(shape)
3434
dtype = data.draw(xps.scalar_dtypes(), label="dtype")
3535
obj = data.draw(
36-
st.lists(
37-
xps.from_dtype(dtype),
38-
min_size=size,
39-
max_size=size,
40-
).map(lambda l: reshape(l, shape)),
36+
st.lists(xps.from_dtype(dtype), min_size=size, max_size=size).map(
37+
lambda l: reshape(l, shape)
38+
),
4139
label="obj",
4240
)
4341
x = xp.asarray(obj, dtype=dtype)
@@ -47,7 +45,6 @@ def test_getitem(shape, data):
4745
out = x[key]
4846

4947
ph.assert_dtype("__getitem__", x.dtype, out.dtype)
50-
5148
_key = tuple(key) if isinstance(key, tuple) else (key,)
5249
if Ellipsis in _key:
5350
start_a = _key.index(Ellipsis)
@@ -78,7 +75,39 @@ def test_getitem(shape, data):
7875
ph.assert_array("__getitem__", out, expected)
7976

8077

81-
# TODO: test_setitem
78+
@given(hh.shapes(min_side=1), st.data()) # TODO: test 0-sided arrays
79+
def test_setitem(shape, data):
80+
size = math.prod(shape)
81+
dtype = data.draw(xps.scalar_dtypes(), label="dtype")
82+
obj = data.draw(
83+
st.lists(xps.from_dtype(dtype), min_size=size, max_size=size).map(
84+
lambda l: reshape(l, shape)
85+
),
86+
label="obj",
87+
)
88+
x = xp.asarray(obj, dtype=dtype)
89+
note(f"{x=}")
90+
key = data.draw(xps.indices(shape=shape, max_dims=0), label="key")
91+
value = data.draw(
92+
xps.from_dtype(dtype) | xps.arrays(dtype=dtype, shape=()), label="value"
93+
)
94+
95+
res = xp.asarray(x, copy=True)
96+
res[key] = value
97+
98+
ph.assert_dtype("__setitem__", x.dtype, res.dtype, repr_name="x.dtype")
99+
ph.assert_shape("__setitem__", res.shape, x.shape, repr_name="x.shape")
100+
if isinstance(value, get_args(Scalar)):
101+
msg = f"x[{key}]={res[key]!r}, but should be {value=} [__setitem__()]"
102+
if math.isnan(value):
103+
assert xp.isnan(res[key]), msg
104+
else:
105+
assert res[key] == value, msg
106+
else:
107+
ph.assert_0d_equals("__setitem__", "value", value, f"x[{key}]", res[key])
108+
109+
110+
# TODO: test boolean indexing
82111

83112

84113
def make_param(method_name: str, dtype: DataType, stype: ScalarType) -> Param:

0 commit comments

Comments
 (0)