diff --git a/dpctl/tensor/_indexing_functions.py b/dpctl/tensor/_indexing_functions.py index 74e03dc95e..43162814fc 100644 --- a/dpctl/tensor/_indexing_functions.py +++ b/dpctl/tensor/_indexing_functions.py @@ -491,7 +491,11 @@ def take_along_axis(x, indices, /, *, axis=-1, mode="wrap"): "from input arguments. " ) mode_i = _get_indexing_mode(mode) - indexes_dt = ti.default_device_index_type(exec_q.sycl_device) + indexes_dt = ( + dpt.uint64 + if indices.dtype == dpt.uint64 + else ti.default_device_index_type(exec_q.sycl_device) + ) _ind = tuple( ( indices @@ -567,7 +571,11 @@ def put_along_axis(x, indices, vals, /, *, axis=-1, mode="wrap"): ) out_usm_type = dpctl.utils.get_coerced_usm_type(usm_types_) mode_i = _get_indexing_mode(mode) - indexes_dt = ti.default_device_index_type(exec_q.sycl_device) + indexes_dt = ( + dpt.uint64 + if indices.dtype == dpt.uint64 + else ti.default_device_index_type(exec_q.sycl_device) + ) _ind = tuple( ( indices diff --git a/dpctl/tests/test_usm_ndarray_indexing.py b/dpctl/tests/test_usm_ndarray_indexing.py index 05b0b278fc..d1226b9bb7 100644 --- a/dpctl/tests/test_usm_ndarray_indexing.py +++ b/dpctl/tests/test_usm_ndarray_indexing.py @@ -1858,3 +1858,35 @@ def test_put_indices_oob_py_ssize_t(mode): assert dpt.all(x[:-1] == -1) assert x[-1] == i + + +def test_take_along_axis_uint64_indices(): + get_queue_or_skip() + + inds = dpt.arange(1, 10, 2, dtype="u8") + x = dpt.tile(dpt.asarray([0, -1], dtype="i4"), 5) + res = dpt.take_along_axis(x, inds) + assert dpt.all(res == -1) + + sh0 = 2 + inds = dpt.broadcast_to(inds, (sh0,) + inds.shape) + x = dpt.broadcast_to(x, (sh0,) + x.shape) + res = dpt.take_along_axis(x, inds, axis=1) + assert dpt.all(res == -1) + + +def test_put_along_axis_uint64_indices(): + get_queue_or_skip() + + inds = dpt.arange(1, 10, 2, dtype="u8") + x = dpt.zeros(10, dtype="i4") + dpt.put_along_axis(x, inds, dpt.asarray(2, dtype=x.dtype)) + expected = dpt.tile(dpt.asarray([0, 2], dtype="i4"), 5) + assert dpt.all(x == expected) + + sh0 = 2 + inds = dpt.broadcast_to(inds, (sh0,) + inds.shape) + x = dpt.zeros((sh0,) + x.shape, dtype="i4") + dpt.put_along_axis(x, inds, dpt.asarray(2, dtype=x.dtype), axis=1) + expected = dpt.tile(dpt.asarray([0, 2], dtype="i4"), (2, 5)) + assert dpt.all(expected == x)