Skip to content

Commit f50813e

Browse files
committed
Allow uint64 indices in take_along_axis and put_along_axis
1 parent 0bcd635 commit f50813e

File tree

2 files changed

+43
-2
lines changed

2 files changed

+43
-2
lines changed

dpctl/tensor/_indexing_functions.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -491,7 +491,11 @@ def take_along_axis(x, indices, /, *, axis=-1, mode="wrap"):
491491
"from input arguments. "
492492
)
493493
mode_i = _get_indexing_mode(mode)
494-
indexes_dt = ti.default_device_index_type(exec_q.sycl_device)
494+
indexes_dt = (
495+
dpt.uint64
496+
if indices.dtype == dpt.uint64
497+
else ti.default_device_index_type(exec_q.sycl_device)
498+
)
495499
_ind = tuple(
496500
(
497501
indices
@@ -567,7 +571,11 @@ def put_along_axis(x, indices, vals, /, *, axis=-1, mode="wrap"):
567571
)
568572
out_usm_type = dpctl.utils.get_coerced_usm_type(usm_types_)
569573
mode_i = _get_indexing_mode(mode)
570-
indexes_dt = ti.default_device_index_type(exec_q.sycl_device)
574+
indexes_dt = (
575+
dpt.uint64
576+
if indices.dtype == dpt.uint64
577+
else ti.default_device_index_type(exec_q.sycl_device)
578+
)
571579
_ind = tuple(
572580
(
573581
indices

dpctl/tests/test_usm_ndarray_indexing.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1858,3 +1858,36 @@ def test_put_indices_oob_py_ssize_t(mode):
18581858

18591859
assert dpt.all(x[:-1] == -1)
18601860
assert x[-1] == i
1861+
1862+
1863+
def test_take_along_axis_uint64_indices():
1864+
get_queue_or_skip()
1865+
1866+
inds = dpt.arange(1, 10, 2, dtype="u8")
1867+
1868+
x = dpt.tile(dpt.asarray([0, -1], dtype="i4"), 5)
1869+
res = dpt.take_along_axis(x, inds)
1870+
assert dpt.all(res == -1)
1871+
1872+
x = dpt.tile(dpt.asarray([0, -1], dtype="i4"), (2, 5))
1873+
inds = dpt.arange(1, 10, 2, dtype="u8")
1874+
inds = dpt.broadcast_to(inds, (2, 5))
1875+
res = dpt.take_along_axis(x, inds, axis=1)
1876+
assert dpt.all(res == -1)
1877+
1878+
1879+
def test_put_along_axis_uint64_indices():
1880+
get_queue_or_skip()
1881+
1882+
inds = dpt.arange(1, 10, 2, dtype="u8")
1883+
1884+
x = dpt.zeros(10, dtype="i4")
1885+
dpt.put_along_axis(x, inds, dpt.asarray(2, dtype=x.dtype))
1886+
expected = dpt.tile(dpt.asarray([0, 2], dtype="i4"), 5)
1887+
assert dpt.all(x == expected)
1888+
1889+
x = dpt.zeros((2, 10), dtype="i4")
1890+
inds = dpt.broadcast_to(inds, (2, 5))
1891+
dpt.put_along_axis(x, inds, dpt.asarray(2, dtype=x.dtype), axis=1)
1892+
expected = dpt.tile(dpt.asarray([0, 2], dtype="i4"), (2, 5))
1893+
assert dpt.all(expected == x)

0 commit comments

Comments
 (0)