Skip to content

Commit 6abc0d6

Browse files
committed
add support for xp.take
Original NumPy Commit: f07d55b27671a4575e3b9b2fc7ca9ec897d4db9e
1 parent 619fbb5 commit 6abc0d6

File tree

3 files changed

+46
-0
lines changed

3 files changed

+46
-0
lines changed

array_api_strict/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -333,6 +333,10 @@
333333
"trunc",
334334
]
335335

336+
from ._indexing_functions import take
337+
338+
__all__ += ["take"]
339+
336340
# linalg is an extension in the array API spec, which is a sub-namespace. Only
337341
# a subset of functions in it are imported into the top-level namespace.
338342
from . import linalg
Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
from __future__ import annotations
2+
3+
from ._array_object import Array
4+
from ._dtypes import _integer_dtypes
5+
6+
import numpy as np
7+
8+
def take(x: Array, indices: Array, /, *, axis: int) -> Array:
9+
"""
10+
Array API compatible wrapper for :py:func:`np.take <numpy.take>`.
11+
12+
See its docstring for more information.
13+
"""
14+
if indices.dtype not in _integer_dtypes:
15+
raise TypeError("Only integer dtypes are allowed in indexing")
16+
if indices.ndim != 1:
17+
raise ValueError("Only 1-dim indices array is supported")
18+
return Array._new(np.take(x._array, indices._array, axis=axis))
Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
import pytest
2+
3+
from numpy import array_api as xp
4+
5+
6+
@pytest.mark.parametrize(
7+
"x, indices, axis, expected",
8+
[
9+
([2, 3], [1, 1, 0], 0, [3, 3, 2]),
10+
([2, 3], [1, 1, 0], -1, [3, 3, 2]),
11+
([[2, 3]], [1], -1, [[3]]),
12+
([[2, 3]], [0, 0], 0, [[2, 3], [2, 3]]),
13+
],
14+
)
15+
def test_stable_desc_argsort(x, indices, axis, expected):
16+
"""
17+
Indices respect relative order of a descending stable-sort
18+
19+
See https://github.com/numpy/numpy/issues/20778
20+
"""
21+
x = xp.asarray(x)
22+
indices = xp.asarray(indices)
23+
out = xp.take(x, indices, axis=axis)
24+
assert xp.all(out == xp.asarray(expected))

0 commit comments

Comments
 (0)