Skip to content

Update integer advanced indexing for array API 2024.12 spec #2032

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 13 commits into from
Apr 9, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
133 changes: 85 additions & 48 deletions dpctl/tensor/_copy_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
# limitations under the License.
import builtins
import operator
from numbers import Integral

import numpy as np

Expand Down Expand Up @@ -799,6 +800,79 @@ def _nonzero_impl(ary):
return res


def _validate_indices(inds, queue_list, usm_type_list):
"""
Utility for validating indices are usm_ndarray of integral dtype or Python
integers. At least one must be an array.

For each array, the queue and usm type are appended to `queue_list` and
`usm_type_list`, respectively.
"""
any_usmarray = False
for ind in inds:
if isinstance(ind, dpt.usm_ndarray):
any_usmarray = True
if ind.dtype.kind not in "ui":
raise IndexError(
"arrays used as indices must be of integer (or boolean) "
"type"
)
queue_list.append(ind.sycl_queue)
usm_type_list.append(ind.usm_type)
elif not isinstance(ind, Integral):
raise TypeError(
"all elements of `ind` expected to be usm_ndarrays "
f"or integers, found {type(ind)}"
)
if not any_usmarray:
raise TypeError(
"at least one element of `inds` expected to be a usm_ndarray"
)
return inds


def _prepare_indices_arrays(inds, q, usm_type):
"""
Utility taking a mix of usm_ndarray and possibly Python int scalar indices,
a queue (assumed to be common to arrays in inds), and a usm type.

Python scalar integers are promoted to arrays on the provided queue and
with the provided usm type. All arrays are then promoted to a common
integral type (if possible) before being broadcast to a common shape.
"""
# scalar integers -> arrays
inds = tuple(
map(
lambda ind: (
ind
if isinstance(ind, dpt.usm_ndarray)
else dpt.asarray(ind, usm_type=usm_type, sycl_queue=q)
),
inds,
)
)

# promote to a common integral type if possible
ind_dt = dpt.result_type(*inds)
if ind_dt.kind not in "ui":
raise ValueError(
"cannot safely promote indices to an integer data type"
)
inds = tuple(
map(
lambda ind: (
ind if ind.dtype == ind_dt else dpt.astype(ind, ind_dt)
),
inds,
)
)

# broadcast
inds = dpt.broadcast_arrays(*inds)

return inds


def _take_multi_index(ary, inds, p, mode=0):
if not isinstance(ary, dpt.usm_ndarray):
raise TypeError(
Expand All @@ -819,15 +893,8 @@ def _take_multi_index(ary, inds, p, mode=0):
]
if not isinstance(inds, (list, tuple)):
inds = (inds,)
for ind in inds:
if not isinstance(ind, dpt.usm_ndarray):
raise TypeError("all elements of `ind` expected to be usm_ndarrays")
queues_.append(ind.sycl_queue)
usm_types_.append(ind.usm_type)
if ind.dtype.kind not in "ui":
raise IndexError(
"arrays used as indices must be of integer (or boolean) type"
)

_validate_indices(inds, queues_, usm_types_)
res_usm_type = dpctl.utils.get_coerced_usm_type(usm_types_)
exec_q = dpctl.utils.get_execution_queue(queues_)
if exec_q is None:
Expand All @@ -837,22 +904,10 @@ def _take_multi_index(ary, inds, p, mode=0):
"Use `usm_ndarray.to_device` method to migrate data to "
"be associated with the same queue."
)

if len(inds) > 1:
ind_dt = dpt.result_type(*inds)
# ind arrays have been checked to be of integer dtype
if ind_dt.kind not in "ui":
raise ValueError(
"cannot safely promote indices to an integer data type"
)
inds = tuple(
map(
lambda ind: (
ind if ind.dtype == ind_dt else dpt.astype(ind, ind_dt)
),
inds,
)
)
inds = dpt.broadcast_arrays(*inds)
inds = _prepare_indices_arrays(inds, exec_q, res_usm_type)

ind0 = inds[0]
ary_sh = ary.shape
p_end = p + len(inds)
Expand Down Expand Up @@ -968,15 +1023,9 @@ def _put_multi_index(ary, inds, p, vals, mode=0):
]
if not isinstance(inds, (list, tuple)):
inds = (inds,)
for ind in inds:
if not isinstance(ind, dpt.usm_ndarray):
raise TypeError("all elements of `ind` expected to be usm_ndarrays")
queues_.append(ind.sycl_queue)
usm_types_.append(ind.usm_type)
if ind.dtype.kind not in "ui":
raise IndexError(
"arrays used as indices must be of integer (or boolean) type"
)

_validate_indices(inds, queues_, usm_types_)

vals_usm_type = dpctl.utils.get_coerced_usm_type(usm_types_)
exec_q = dpctl.utils.get_execution_queue(queues_)
if exec_q is not None:
Expand All @@ -993,22 +1042,10 @@ def _put_multi_index(ary, inds, p, vals, mode=0):
"Use `usm_ndarray.to_device` method to migrate data to "
"be associated with the same queue."
)

if len(inds) > 1:
ind_dt = dpt.result_type(*inds)
# ind arrays have been checked to be of integer dtype
if ind_dt.kind not in "ui":
raise ValueError(
"cannot safely promote indices to an integer data type"
)
inds = tuple(
map(
lambda ind: (
ind if ind.dtype == ind_dt else dpt.astype(ind, ind_dt)
),
inds,
)
)
inds = dpt.broadcast_arrays(*inds)
inds = _prepare_indices_arrays(inds, exec_q, vals_usm_type)

ind0 = inds[0]
ary_sh = ary.shape
p_end = p + len(inds)
Expand Down
64 changes: 45 additions & 19 deletions dpctl/tensor/_slicing.pxi
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
# limitations under the License.

import numbers
from operator import index
from cpython.buffer cimport PyObject_CheckBuffer


Expand Down Expand Up @@ -64,7 +65,7 @@ cdef bint _is_integral(object x) except *:
return False
if callable(getattr(x, "__index__", None)):
try:
x.__index__()
index(x)
except (TypeError, ValueError):
return False
return True
Expand Down Expand Up @@ -136,7 +137,7 @@ def _basic_slice_meta(ind, shape : tuple, strides : tuple, offset : int):
else:
return ((0,) + shape, (0,) + strides, offset, _no_advanced_ind, _no_advanced_pos)
elif _is_integral(ind):
ind = ind.__index__()
ind = index(ind)
new_shape = shape[1:]
new_strides = strides[1:]
is_empty = any(sh_i == 0 for sh_i in new_shape)
Expand Down Expand Up @@ -179,10 +180,12 @@ def _basic_slice_meta(ind, shape : tuple, strides : tuple, offset : int):
if array_streak_started:
array_streak_interrupted = True
elif _is_integral(i):
explicit_index += 1
axes_referenced += 1
if array_streak_started:
array_streak_interrupted = True
if array_streak_started and not array_streak_interrupted:
# integers converted to arrays in this case
array_count += 1
else:
explicit_index += 1
elif isinstance(i, usm_ndarray):
if not seen_arrays_yet:
seen_arrays_yet = True
Expand Down Expand Up @@ -229,6 +232,7 @@ def _basic_slice_meta(ind, shape : tuple, strides : tuple, offset : int):
advanced_start_pos_set = False
new_offset = offset
is_empty = False
array_streak = False
for i in range(len(ind)):
ind_i = ind[i]
if (ind_i is Ellipsis):
Expand All @@ -239,9 +243,13 @@ def _basic_slice_meta(ind, shape : tuple, strides : tuple, offset : int):
is_empty = True
new_offset = offset
k = k_new
if array_streak:
array_streak = False
elif ind_i is None:
new_shape.append(1)
new_strides.append(0)
if array_streak:
array_streak = False
elif isinstance(ind_i, slice):
k_new = k + 1
sl_start, sl_stop, sl_step = ind_i.indices(shape[k])
Expand All @@ -255,26 +263,46 @@ def _basic_slice_meta(ind, shape : tuple, strides : tuple, offset : int):
is_empty = True
new_offset = offset
k = k_new
if array_streak:
array_streak = False
elif _is_boolean(ind_i):
new_shape.append(1 if ind_i else 0)
new_strides.append(0)
if array_streak:
array_streak = False
elif _is_integral(ind_i):
ind_i = ind_i.__index__()
if 0 <= ind_i < shape[k]:
if array_streak:
if not isinstance(ind_i, usm_ndarray):
ind_i = index(ind_i)
# integer will be converted to an array, still raise if OOB
if not (0 <= ind_i < shape[k] or -shape[k] <= ind_i < 0):
raise IndexError(
("Index {0} is out of range for "
"axes {1} with size {2}").format(ind_i, k, shape[k]))
new_advanced_ind.append(ind_i)
k_new = k + 1
if not is_empty:
new_offset = new_offset + ind_i * strides[k]
k = k_new
elif -shape[k] <= ind_i < 0:
k_new = k + 1
if not is_empty:
new_offset = new_offset + (shape[k] + ind_i) * strides[k]
new_shape.extend(shape[k:k_new])
new_strides.extend(strides[k:k_new])
k = k_new
else:
raise IndexError(
("Index {0} is out of range for "
"axes {1} with size {2}").format(ind_i, k, shape[k]))
ind_i = index(ind_i)
if 0 <= ind_i < shape[k]:
k_new = k + 1
if not is_empty:
new_offset = new_offset + ind_i * strides[k]
k = k_new
elif -shape[k] <= ind_i < 0:
k_new = k + 1
if not is_empty:
new_offset = new_offset + (shape[k] + ind_i) * strides[k]
k = k_new
else:
raise IndexError(
("Index {0} is out of range for "
"axes {1} with size {2}").format(ind_i, k, shape[k]))
elif isinstance(ind_i, usm_ndarray):
if not array_streak:
array_streak = True
if not advanced_start_pos_set:
new_advanced_start_pos = len(new_shape)
advanced_start_pos_set = True
Expand All @@ -287,8 +315,6 @@ def _basic_slice_meta(ind, shape : tuple, strides : tuple, offset : int):
new_shape.extend(shape[k:k_new])
new_strides.extend(strides[k:k_new])
k = k_new
else:
raise IndexError
new_shape.extend(shape[k:])
new_strides.extend(strides[k:])
new_shape_len += len(shape) - k
Expand Down
11 changes: 6 additions & 5 deletions dpctl/tensor/_usmarray.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,6 @@ cdef void _validate_and_use_stream(object stream, c_dpctl.SyclQueue self_queue)
ev = self_queue.submit_barrier()
stream.submit_barrier(dependent_events=[ev])


cdef class usm_ndarray:
""" usm_ndarray(shape, dtype=None, strides=None, buffer="device", \
offset=0, order="C", buffer_ctor_kwargs=dict(), \
Expand Down Expand Up @@ -962,6 +961,8 @@ cdef class usm_ndarray:
return res

from ._copy_utils import _extract_impl, _nonzero_impl, _take_multi_index

# if len(adv_ind == 1), the (only) element is always an array
if len(adv_ind) == 1 and adv_ind[0].dtype == dpt_bool:
key_ = adv_ind[0]
adv_ind_end_p = key_.ndim + adv_ind_start_p
Expand All @@ -979,10 +980,10 @@ cdef class usm_ndarray:
res.flags_ = _copy_writable(res.flags_, self.flags_)
return res

if any(ind.dtype == dpt_bool for ind in adv_ind):
if any((isinstance(ind, usm_ndarray) and ind.dtype == dpt_bool) for ind in adv_ind):
adv_ind_int = list()
for ind in adv_ind:
if ind.dtype == dpt_bool:
if isinstance(ind, usm_ndarray) and ind.dtype == dpt_bool:
adv_ind_int.extend(_nonzero_impl(ind))
else:
adv_ind_int.append(ind)
Expand Down Expand Up @@ -1433,10 +1434,10 @@ cdef class usm_ndarray:
_place_impl(Xv, adv_ind[0], rhs, axis=adv_ind_start_p)
return

if any(ind.dtype == dpt_bool for ind in adv_ind):
if any((isinstance(ind, usm_ndarray) and ind.dtype == dpt_bool) for ind in adv_ind):
adv_ind_int = list()
for ind in adv_ind:
if ind.dtype == dpt_bool:
if isinstance(ind, usm_ndarray) and ind.dtype == dpt_bool:
adv_ind_int.extend(_nonzero_impl(ind))
else:
adv_ind_int.append(ind)
Expand Down
Loading
Loading