Skip to content

Commit dcb22c0

Browse files
committed
ENH: test take_along_axis
1 parent 31eec9d commit dcb22c0

File tree

1 file changed

+52
-0
lines changed

1 file changed

+52
-0
lines changed

array_api_tests/test_indexing_functions.py

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,3 +60,55 @@ def test_take(x, data):
6060
# sanity check
6161
with pytest.raises(StopIteration):
6262
next(out_indices)
63+
64+
65+
66+
@pytest.mark.unvectorized
67+
@pytest.mark.min_version("2024.12")
68+
@given(
69+
x=hh.arrays(hh.all_dtypes, hh.shapes(min_dims=1, min_side=1)),
70+
data=st.data(),
71+
)
72+
def test_take_along_axis(x, data):
73+
# TODO
74+
# 1. negative axis
75+
# 2. negative indices
76+
# 3. different dtypes for indices
77+
axis = data.draw(st.integers(0, max(x.ndim - 1, 0)), label="axis")
78+
len_axis = data.draw(st.integers(0, 2*x.shape[axis]), label="len_axis")
79+
80+
idx_shape = x.shape[:axis] + (len_axis,) + x.shape[axis+1:]
81+
indices = data.draw(
82+
hh.arrays(
83+
shape=idx_shape,
84+
dtype=dh.default_int,
85+
elements={"min_value": 0, "max_value": x.shape[axis]-1}
86+
),
87+
label="indices"
88+
)
89+
note(f"{indices=} {idx_shape=}")
90+
91+
out = xp.take_along_axis(x, indices, axis=axis)
92+
93+
ph.assert_dtype("take_along_axis", in_dtype=x.dtype, out_dtype=out.dtype)
94+
ph.assert_shape(
95+
"take_along_axis",
96+
out_shape=out.shape,
97+
expected=x.shape[:axis] + (len_axis,) + x.shape[axis+1:],
98+
kw=dict(
99+
x=x,
100+
indices=indices,
101+
axis=axis,
102+
),
103+
)
104+
105+
# value test: notation is from `np.take_along_axis` docstring
106+
Ni, Nk = x.shape[:axis], x.shape[axis+1:]
107+
for ii in sh.ndindex(Ni):
108+
for kk in sh.ndindex(Nk):
109+
a_1d = x[ii + (slice(None),) + kk]
110+
i_1d = indices[ii + (slice(None),) + kk]
111+
o_1d = out[ii + (slice(None),) + kk]
112+
for j in range(len_axis):
113+
assert o_1d[j] == a_1d[i_1d[j]], f'{ii=}, {kk=}, {j=}'
114+

0 commit comments

Comments
 (0)