From f50813e269f8d0b84140dd37825d1fc77f444863 Mon Sep 17 00:00:00 2001 From: Nikita Grigorian Date: Thu, 12 Dec 2024 20:38:23 -0800 Subject: [PATCH 1/2] Allow uint64 indices in take_along_axis and put_along_axis --- dpctl/tensor/_indexing_functions.py | 12 +++++++-- dpctl/tests/test_usm_ndarray_indexing.py | 33 ++++++++++++++++++++++++ 2 files changed, 43 insertions(+), 2 deletions(-) diff --git a/dpctl/tensor/_indexing_functions.py b/dpctl/tensor/_indexing_functions.py index 74e03dc95e..43162814fc 100644 --- a/dpctl/tensor/_indexing_functions.py +++ b/dpctl/tensor/_indexing_functions.py @@ -491,7 +491,11 @@ def take_along_axis(x, indices, /, *, axis=-1, mode="wrap"): "from input arguments. " ) mode_i = _get_indexing_mode(mode) - indexes_dt = ti.default_device_index_type(exec_q.sycl_device) + indexes_dt = ( + dpt.uint64 + if indices.dtype == dpt.uint64 + else ti.default_device_index_type(exec_q.sycl_device) + ) _ind = tuple( ( indices @@ -567,7 +571,11 @@ def put_along_axis(x, indices, vals, /, *, axis=-1, mode="wrap"): ) out_usm_type = dpctl.utils.get_coerced_usm_type(usm_types_) mode_i = _get_indexing_mode(mode) - indexes_dt = ti.default_device_index_type(exec_q.sycl_device) + indexes_dt = ( + dpt.uint64 + if indices.dtype == dpt.uint64 + else ti.default_device_index_type(exec_q.sycl_device) + ) _ind = tuple( ( indices diff --git a/dpctl/tests/test_usm_ndarray_indexing.py b/dpctl/tests/test_usm_ndarray_indexing.py index 05b0b278fc..3301a40c2c 100644 --- a/dpctl/tests/test_usm_ndarray_indexing.py +++ b/dpctl/tests/test_usm_ndarray_indexing.py @@ -1858,3 +1858,36 @@ def test_put_indices_oob_py_ssize_t(mode): assert dpt.all(x[:-1] == -1) assert x[-1] == i + + +def test_take_along_axis_uint64_indices(): + get_queue_or_skip() + + inds = dpt.arange(1, 10, 2, dtype="u8") + + x = dpt.tile(dpt.asarray([0, -1], dtype="i4"), 5) + res = dpt.take_along_axis(x, inds) + assert dpt.all(res == -1) + + x = dpt.tile(dpt.asarray([0, -1], dtype="i4"), (2, 5)) + inds = dpt.arange(1, 10, 2, dtype="u8") + inds = dpt.broadcast_to(inds, (2, 5)) + res = dpt.take_along_axis(x, inds, axis=1) + assert dpt.all(res == -1) + + +def test_put_along_axis_uint64_indices(): + get_queue_or_skip() + + inds = dpt.arange(1, 10, 2, dtype="u8") + + x = dpt.zeros(10, dtype="i4") + dpt.put_along_axis(x, inds, dpt.asarray(2, dtype=x.dtype)) + expected = dpt.tile(dpt.asarray([0, 2], dtype="i4"), 5) + assert dpt.all(x == expected) + + x = dpt.zeros((2, 10), dtype="i4") + inds = dpt.broadcast_to(inds, (2, 5)) + dpt.put_along_axis(x, inds, dpt.asarray(2, dtype=x.dtype), axis=1) + expected = dpt.tile(dpt.asarray([0, 2], dtype="i4"), (2, 5)) + assert dpt.all(expected == x) From 7602b4fec6dc5ed91e8bf64996099832c107d7e8 Mon Sep 17 00:00:00 2001 From: Nikita Grigorian Date: Fri, 13 Dec 2024 09:20:24 -0800 Subject: [PATCH 2/2] Changes per PR review Make tests more maintenance-friendly --- dpctl/tests/test_usm_ndarray_indexing.py | 13 ++++++------- 1 file changed, 6 insertions(+), 7 deletions(-) diff --git a/dpctl/tests/test_usm_ndarray_indexing.py b/dpctl/tests/test_usm_ndarray_indexing.py index 3301a40c2c..d1226b9bb7 100644 --- a/dpctl/tests/test_usm_ndarray_indexing.py +++ b/dpctl/tests/test_usm_ndarray_indexing.py @@ -1864,14 +1864,13 @@ def test_take_along_axis_uint64_indices(): get_queue_or_skip() inds = dpt.arange(1, 10, 2, dtype="u8") - x = dpt.tile(dpt.asarray([0, -1], dtype="i4"), 5) res = dpt.take_along_axis(x, inds) assert dpt.all(res == -1) - x = dpt.tile(dpt.asarray([0, -1], dtype="i4"), (2, 5)) - inds = dpt.arange(1, 10, 2, dtype="u8") - inds = dpt.broadcast_to(inds, (2, 5)) + sh0 = 2 + inds = dpt.broadcast_to(inds, (sh0,) + inds.shape) + x = dpt.broadcast_to(x, (sh0,) + x.shape) res = dpt.take_along_axis(x, inds, axis=1) assert dpt.all(res == -1) @@ -1880,14 +1879,14 @@ def test_put_along_axis_uint64_indices(): get_queue_or_skip() inds = dpt.arange(1, 10, 2, dtype="u8") - x = dpt.zeros(10, dtype="i4") dpt.put_along_axis(x, inds, dpt.asarray(2, dtype=x.dtype)) expected = dpt.tile(dpt.asarray([0, 2], dtype="i4"), 5) assert dpt.all(x == expected) - x = dpt.zeros((2, 10), dtype="i4") - inds = dpt.broadcast_to(inds, (2, 5)) + sh0 = 2 + inds = dpt.broadcast_to(inds, (sh0,) + inds.shape) + x = dpt.zeros((sh0,) + x.shape, dtype="i4") dpt.put_along_axis(x, inds, dpt.asarray(2, dtype=x.dtype), axis=1) expected = dpt.tile(dpt.asarray([0, 2], dtype="i4"), (2, 5)) assert dpt.all(expected == x)