Skip to content

Commit 3226e0f

Browse files
committed
Test resulting elements from __getitem__
1 parent 810c2bb commit 3226e0f

File tree

3 files changed

+68
-150
lines changed

3 files changed

+68
-150
lines changed

xptests/pytest_helpers.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
"assert_result_shape",
2424
"assert_keepdimable_shape",
2525
"assert_fill",
26+
"assert_array",
2627
]
2728

2829

@@ -226,3 +227,20 @@ def assert_fill(
226227
assert ah.all(ah.isnan(out)), msg
227228
else:
228229
assert ah.all(ah.equal(out, ah.asarray(fill_value, dtype=dtype))), msg
230+
231+
232+
def assert_array(func_name: str, out: Array, expected: Array, /, **kw):
233+
assert_dtype(func_name, out.dtype, expected.dtype, **kw)
234+
assert_shape(func_name, out.shape, expected.shape, **kw)
235+
msg = f"out not as expected [{func_name}({fmt_kw(kw)})]\n{out=}\n{expected=}"
236+
if dh.is_float_dtype(out.dtype):
237+
neg_zeros = expected == -0.0
238+
assert xp.all((out == -0.0) == neg_zeros), msg
239+
pos_zeros = expected == +0.0
240+
assert xp.all((out == +0.0) == pos_zeros), msg
241+
nans = xp.isnan(expected)
242+
assert xp.all(xp.isnan(out) == nans), msg
243+
mask = ~(neg_zeros | pos_zeros | nans)
244+
assert xp.all(out[mask] == expected[mask]), msg
245+
else:
246+
assert xp.all(out == expected), msg

xptests/test_array_object.py

Lines changed: 50 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,39 +1,75 @@
1+
import math
2+
from itertools import product
3+
from typing import Sequence, Union
4+
15
import pytest
2-
from hypothesis import given
6+
from hypothesis import assume, given, note
37
from hypothesis import strategies as st
48

59
from . import _array_module as xp
610
from . import dtype_helpers as dh
711
from . import hypothesis_helpers as hh
812
from . import pytest_helpers as ph
913
from . import xps
10-
from .typing import DataType, Param, ScalarType
14+
from .typing import DataType, Param, Scalar, ScalarType, Shape
15+
1116

17+
def reshape(
18+
flat_seq: Sequence[Scalar], shape: Shape
19+
) -> Union[Scalar, Sequence[Scalar]]:
20+
"""Reshape a flat sequence"""
21+
if len(shape) == 0:
22+
assert len(flat_seq) == 1 # sanity check
23+
return flat_seq[0]
24+
elif len(shape) == 1:
25+
return flat_seq
26+
size = len(flat_seq)
27+
n = math.prod(shape[1:])
28+
return [reshape(flat_seq[i * n : (i + 1) * n], shape[1:]) for i in range(size // n)]
1229

13-
@given(hh.shapes(), st.data())
30+
31+
@given(hh.shapes(min_side=1), st.data()) # TODO: test 0-sided arrays
1432
def test_getitem(shape, data):
15-
x = data.draw(xps.arrays(dtype=xps.scalar_dtypes(), shape=shape), label="x")
33+
size = math.prod(shape)
34+
dtype = data.draw(xps.scalar_dtypes(), label="dtype")
35+
obj = data.draw(
36+
st.lists(
37+
xps.from_dtype(dtype),
38+
min_size=size,
39+
max_size=size,
40+
).map(lambda l: reshape(l, shape)),
41+
label="obj",
42+
)
43+
x = xp.asarray(obj, dtype=dtype)
44+
note(f"{x=}")
1645
key = data.draw(xps.indices(shape=shape), label="key")
1746

1847
out = x[key]
1948

20-
ph.assert_dtype("__getitem__", x.dtype, out.dtype)
21-
2249
_key = tuple(key) if isinstance(key, tuple) else (key,)
2350
if Ellipsis in _key:
2451
start_a = _key.index(Ellipsis)
2552
stop_a = start_a + (len(shape) - (len(_key) - 1))
2653
slices = tuple(slice(None, None) for _ in range(start_a, stop_a))
2754
_key = _key[:start_a] + slices + _key[start_a + 1 :]
28-
expected = []
55+
axes_indices = []
2956
for a, i in enumerate(_key):
30-
if isinstance(i, slice):
31-
r = range(shape[a])[i]
32-
expected.append(len(r))
33-
expected = tuple(expected)
34-
ph.assert_shape("__getitem__", out.shape, expected)
35-
36-
# TODO: fold in all remaining concepts from test_indexing.py
57+
if isinstance(i, int):
58+
axes_indices.append([i])
59+
else:
60+
side = shape[a]
61+
indices = range(side)[i]
62+
assume(len(indices) > 0) # TODO: test 0-sided arrays
63+
axes_indices.append(indices)
64+
expected = []
65+
for idx in product(*axes_indices):
66+
val = obj
67+
for i in idx:
68+
val = val[i]
69+
expected.append(val)
70+
expected = reshape(expected, out.shape)
71+
expected = xp.asarray(expected, dtype=dtype)
72+
ph.assert_array("__getitem__", out, expected)
3773

3874

3975
# TODO: test_setitem

xptests/test_indexing.py

Lines changed: 0 additions & 136 deletions
This file was deleted.

0 commit comments

Comments
 (0)