Skip to content

Commit 810c2bb

Browse files
committed
Rudimentary test_getitem
Does not yet supercede `test_indexing.py`
1 parent c5b3d62 commit 810c2bb

File tree

3 files changed

+63
-48
lines changed

3 files changed

+63
-48
lines changed

requirements.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
11
pytest
2-
hypothesis>=6.30.0
2+
hypothesis>=6.31.1
33
regex
44
removestar

xptests/test_array2scalar.py

Lines changed: 0 additions & 47 deletions
This file was deleted.

xptests/test_array_object.py

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
import pytest
2+
from hypothesis import given
3+
from hypothesis import strategies as st
4+
5+
from . import _array_module as xp
6+
from . import dtype_helpers as dh
7+
from . import hypothesis_helpers as hh
8+
from . import pytest_helpers as ph
9+
from . import xps
10+
from .typing import DataType, Param, ScalarType
11+
12+
13+
@given(hh.shapes(), st.data())
14+
def test_getitem(shape, data):
15+
x = data.draw(xps.arrays(dtype=xps.scalar_dtypes(), shape=shape), label="x")
16+
key = data.draw(xps.indices(shape=shape), label="key")
17+
18+
out = x[key]
19+
20+
ph.assert_dtype("__getitem__", x.dtype, out.dtype)
21+
22+
_key = tuple(key) if isinstance(key, tuple) else (key,)
23+
if Ellipsis in _key:
24+
start_a = _key.index(Ellipsis)
25+
stop_a = start_a + (len(shape) - (len(_key) - 1))
26+
slices = tuple(slice(None, None) for _ in range(start_a, stop_a))
27+
_key = _key[:start_a] + slices + _key[start_a + 1 :]
28+
expected = []
29+
for a, i in enumerate(_key):
30+
if isinstance(i, slice):
31+
r = range(shape[a])[i]
32+
expected.append(len(r))
33+
expected = tuple(expected)
34+
ph.assert_shape("__getitem__", out.shape, expected)
35+
36+
# TODO: fold in all remaining concepts from test_indexing.py
37+
38+
39+
# TODO: test_setitem
40+
41+
42+
def make_param(method_name: str, dtype: DataType, stype: ScalarType) -> Param:
43+
return pytest.param(
44+
method_name, dtype, stype, id=f"{method_name}({dh.dtype_to_name[dtype]})"
45+
)
46+
47+
48+
@pytest.mark.parametrize(
49+
"method_name, dtype, stype",
50+
[make_param("__bool__", xp.bool, bool)]
51+
+ [make_param("__int__", d, int) for d in dh.all_int_dtypes]
52+
+ [make_param("__index__", d, int) for d in dh.all_int_dtypes]
53+
+ [make_param("__float__", d, float) for d in dh.float_dtypes],
54+
)
55+
@given(data=st.data())
56+
def test_duck_typing(method_name, dtype, stype, data):
57+
x = data.draw(xps.arrays(dtype, shape=()), label="x")
58+
method = getattr(x, method_name)
59+
out = method()
60+
assert isinstance(
61+
out, stype
62+
), f"{method_name}({x})={out}, which is not a {stype.__name__} scalar"

0 commit comments

Comments
 (0)