Skip to content

Commit 0784d70

Browse files
committed
Partial reversion of advanced indexing changes
0D array indices are treated as Python scalars and moved to the host when not part of other, consecutive array indices This means that 0D arrays also can't start a series of consecutive advanced indices
1 parent d873f35 commit 0784d70

File tree

1 file changed

+14
-8
lines changed

1 file changed

+14
-8
lines changed

dpctl/tensor/_slicing.pxi

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,11 @@ cdef Py_ssize_t _slice_len(
4646
cdef bint _is_integral(object x) except *:
4747
"""Gives True if x is an integral slice spec"""
4848
if isinstance(x, usm_ndarray):
49-
return False
49+
if x.ndim > 0:
50+
return False
51+
if x.dtype.kind not in "ui":
52+
return False
53+
return True
5054
if isinstance(x, bool):
5155
return False
5256
if isinstance(x, int):
@@ -194,7 +198,7 @@ def _basic_slice_meta(ind, shape : tuple, strides : tuple, offset : int):
194198
dt_k = i.dtype.kind
195199
if dt_k == "b" and i.ndim > 0:
196200
axes_referenced += i.ndim
197-
elif dt_k in "ui":
201+
elif dt_k in "ui" and i.ndim > 0:
198202
axes_referenced += 1
199203
else:
200204
raise IndexError(
@@ -266,19 +270,21 @@ def _basic_slice_meta(ind, shape : tuple, strides : tuple, offset : int):
266270
if array_streak:
267271
array_streak = False
268272
elif _is_integral(ind_i):
269-
ind_i = ind_i.__index__()
270273
if array_streak:
271-
# integer will be converted to an array, still raise if OOB
272-
if not (0 <= ind_i < shape[k] or -shape[k] <= ind_i < 0):
273-
raise IndexError(
274-
("Index {0} is out of range for "
275-
"axes {1} with size {2}").format(ind_i, k, shape[k]))
274+
if not isinstance(ind_i, usm_ndarray):
275+
ind_i = ind_i.__index__()
276+
# integer will be converted to an array, still raise if OOB
277+
if not (0 <= ind_i < shape[k] or -shape[k] <= ind_i < 0):
278+
raise IndexError(
279+
("Index {0} is out of range for "
280+
"axes {1} with size {2}").format(ind_i, k, shape[k]))
276281
new_advanced_ind.append(ind_i)
277282
k_new = k + 1
278283
new_shape.extend(shape[k:k_new])
279284
new_strides.extend(strides[k:k_new])
280285
k = k_new
281286
else:
287+
ind_i = ind_i.__index__()
282288
if 0 <= ind_i < shape[k]:
283289
k_new = k + 1
284290
if not is_empty:

0 commit comments

Comments
 (0)