Skip to content

Commit 5352dd5

Browse files
authored
Merge pull request #1934 from IntelPython/permit-uint64-indices-along-axis-fns
Allow uint64 indices in `dpt.take_along_axis` and `dpt.put_along_axis`
2 parents c5cb665 + 7602b4f commit 5352dd5

File tree

2 files changed

+42
-2
lines changed

2 files changed

+42
-2
lines changed

dpctl/tensor/_indexing_functions.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -491,7 +491,11 @@ def take_along_axis(x, indices, /, *, axis=-1, mode="wrap"):
491491
"from input arguments. "
492492
)
493493
mode_i = _get_indexing_mode(mode)
494-
indexes_dt = ti.default_device_index_type(exec_q.sycl_device)
494+
indexes_dt = (
495+
dpt.uint64
496+
if indices.dtype == dpt.uint64
497+
else ti.default_device_index_type(exec_q.sycl_device)
498+
)
495499
_ind = tuple(
496500
(
497501
indices
@@ -567,7 +571,11 @@ def put_along_axis(x, indices, vals, /, *, axis=-1, mode="wrap"):
567571
)
568572
out_usm_type = dpctl.utils.get_coerced_usm_type(usm_types_)
569573
mode_i = _get_indexing_mode(mode)
570-
indexes_dt = ti.default_device_index_type(exec_q.sycl_device)
574+
indexes_dt = (
575+
dpt.uint64
576+
if indices.dtype == dpt.uint64
577+
else ti.default_device_index_type(exec_q.sycl_device)
578+
)
571579
_ind = tuple(
572580
(
573581
indices

dpctl/tests/test_usm_ndarray_indexing.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1858,3 +1858,35 @@ def test_put_indices_oob_py_ssize_t(mode):
18581858

18591859
assert dpt.all(x[:-1] == -1)
18601860
assert x[-1] == i
1861+
1862+
1863+
def test_take_along_axis_uint64_indices():
1864+
get_queue_or_skip()
1865+
1866+
inds = dpt.arange(1, 10, 2, dtype="u8")
1867+
x = dpt.tile(dpt.asarray([0, -1], dtype="i4"), 5)
1868+
res = dpt.take_along_axis(x, inds)
1869+
assert dpt.all(res == -1)
1870+
1871+
sh0 = 2
1872+
inds = dpt.broadcast_to(inds, (sh0,) + inds.shape)
1873+
x = dpt.broadcast_to(x, (sh0,) + x.shape)
1874+
res = dpt.take_along_axis(x, inds, axis=1)
1875+
assert dpt.all(res == -1)
1876+
1877+
1878+
def test_put_along_axis_uint64_indices():
1879+
get_queue_or_skip()
1880+
1881+
inds = dpt.arange(1, 10, 2, dtype="u8")
1882+
x = dpt.zeros(10, dtype="i4")
1883+
dpt.put_along_axis(x, inds, dpt.asarray(2, dtype=x.dtype))
1884+
expected = dpt.tile(dpt.asarray([0, 2], dtype="i4"), 5)
1885+
assert dpt.all(x == expected)
1886+
1887+
sh0 = 2
1888+
inds = dpt.broadcast_to(inds, (sh0,) + inds.shape)
1889+
x = dpt.zeros((sh0,) + x.shape, dtype="i4")
1890+
dpt.put_along_axis(x, inds, dpt.asarray(2, dtype=x.dtype), axis=1)
1891+
expected = dpt.tile(dpt.asarray([0, 2], dtype="i4"), (2, 5))
1892+
assert dpt.all(expected == x)

0 commit comments

Comments
 (0)