Skip to content

Commit 96717db

Browse files
committed
test_take
1 parent a533680 commit 96717db

File tree

2 files changed

+63
-0
lines changed

2 files changed

+63
-0
lines changed

array_api_tests/_array_module.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,7 @@ def __repr__(self):
6262
]
6363
_constants = ["e", "inf", "nan", "pi"]
6464
_funcs = [f.__name__ for funcs in stubs.category_to_funcs.values() for f in funcs]
65+
_funcs += ["take"] # TODO: bump spec and update array-api-tests to new spec layout
6566
_top_level_attrs = _dtypes + _constants + _funcs + stubs.EXTENSIONS
6667

6768
for attr in _top_level_attrs:
Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
import pytest
2+
from hypothesis import given, note
3+
from hypothesis import strategies as st
4+
5+
from . import _array_module as xp
6+
from . import dtype_helpers as dh
7+
from . import hypothesis_helpers as hh
8+
from . import pytest_helpers as ph
9+
from . import shape_helpers as sh
10+
from . import xps
11+
12+
pytestmark = pytest.mark.ci
13+
14+
15+
@pytest.mark.min_version("2022.12")
16+
@given(
17+
x=xps.arrays(xps.scalar_dtypes(), hh.shapes(min_dims=1, min_side=1)),
18+
data=st.data(),
19+
)
20+
def test_take(x, data):
21+
# TODO:
22+
# * negative axis
23+
# * negative indices
24+
# * different dtypes for indices
25+
axis = data.draw(st.integers(0, max(x.ndim - 1, 0)), label="axis")
26+
_indices = data.draw(
27+
st.lists(st.integers(0, x.shape[axis] - 1), min_size=1, unique=True),
28+
label="_indices",
29+
)
30+
indices = xp.asarray(_indices, dtype=dh.default_int)
31+
note(f"{indices=}")
32+
33+
out = xp.take(x, indices, axis=axis)
34+
35+
ph.assert_dtype("take", x.dtype, out.dtype)
36+
ph.assert_shape(
37+
"take",
38+
out.shape,
39+
x.shape[:axis] + (len(_indices),) + x.shape[axis + 1 :],
40+
x=x,
41+
indices=indices,
42+
axis=axis,
43+
)
44+
out_indices = sh.ndindex(out.shape)
45+
axis_indices = list(sh.axis_ndindex(x.shape, axis))
46+
for axis_idx in axis_indices:
47+
f_axis_idx = sh.fmt_idx("x", axis_idx)
48+
for i in _indices:
49+
f_take_idx = sh.fmt_idx(f_axis_idx, i)
50+
indexed_x = x[axis_idx][i]
51+
for at_idx in sh.ndindex(indexed_x.shape):
52+
out_idx = next(out_indices)
53+
ph.assert_0d_equals(
54+
"take",
55+
sh.fmt_idx(f_take_idx, at_idx),
56+
indexed_x[at_idx],
57+
sh.fmt_idx("out", out_idx),
58+
out[out_idx],
59+
)
60+
# sanity check
61+
with pytest.raises(StopIteration):
62+
next(out_indices)

0 commit comments

Comments
 (0)