Skip to content

Commit a077f42

Browse files
Merge pull request #1792 from IntelPython/fix-for-gh-1785
Fix for gh 1785
2 parents dea8b83 + 7f9ab84 commit a077f42

File tree

3 files changed

+38
-6
lines changed

3 files changed

+38
-6
lines changed

dpctl/tensor/_slicing.pxi

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,7 @@ cdef bint _is_boolean(object x) except *:
8181
return True
8282
if isinstance(x, bool):
8383
return True
84-
if isinstance(x, int):
84+
if isinstance(x, (int, float, complex)):
8585
return False
8686
if _is_buffer(x):
8787
mbuf = memoryview(x)
@@ -204,7 +204,11 @@ def _basic_slice_meta(ind, shape : tuple, strides : tuple, offset : int):
204204
)
205205
array_count += 1
206206
else:
207-
raise TypeError
207+
raise IndexError(
208+
"Only integers, slices (`:`), ellipsis (`...`), "
209+
"dpctl.tensor.newaxis (`None`) and integer and "
210+
"boolean arrays are valid indices."
211+
)
208212
if ellipses_count > 1:
209213
raise IndexError(
210214
"an index can only have a single ellipsis ('...')")
@@ -283,6 +287,8 @@ def _basic_slice_meta(ind, shape : tuple, strides : tuple, offset : int):
283287
new_shape.extend(shape[k:k_new])
284288
new_strides.extend(strides[k:k_new])
285289
k = k_new
290+
else:
291+
raise IndexError
286292
new_shape.extend(shape[k:])
287293
new_strides.extend(strides[k:])
288294
new_shape_len += len(shape) - k
@@ -291,4 +297,8 @@ def _basic_slice_meta(ind, shape : tuple, strides : tuple, offset : int):
291297
# assert len(new_advanced_ind) == array_count
292298
return (tuple(new_shape), tuple(new_strides), new_offset, tuple(new_advanced_ind), new_advanced_start_pos)
293299
else:
294-
raise TypeError
300+
raise IndexError(
301+
"Only integers, slices (`:`), ellipsis (`...`), "
302+
"dpctl.tensor.newaxis (`None`) and integer and "
303+
"boolean arrays are valid indices."
304+
)

dpctl/tests/test_usm_ndarray_ctor.py

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -441,7 +441,7 @@ def test_slicing_basic():
441441
Xusm[:, -4]
442442
with pytest.raises(IndexError):
443443
Xusm[:, -128]
444-
with pytest.raises(TypeError):
444+
with pytest.raises(IndexError):
445445
Xusm[{1, 2, 3, 4, 5, 6, 7}]
446446
X = dpt.usm_ndarray(10, "u1")
447447
X.usm_data.copy_from_host(b"\x00\x01\x02\x03\x04\x05\x06\x07\x08\x09")
@@ -2416,3 +2416,25 @@ def test_asarray_writable_flag(ro_flag):
24162416

24172417
assert b.flags["W"] == (not ro_flag)
24182418
assert b._pointer == a._pointer
2419+
2420+
2421+
def test_getitem_validation():
2422+
"""Test based on gh-1785"""
2423+
try:
2424+
a = dpt.empty((2, 2, 2))
2425+
except dpctl.SyclDeviceCreationError:
2426+
pytest.skip("No SYCL devices available")
2427+
with pytest.raises(IndexError):
2428+
a[0.0]
2429+
with pytest.raises(IndexError):
2430+
a[1, 0.0, ...]
2431+
with pytest.raises(IndexError):
2432+
a[1, 0.0, dpt.newaxis, 1]
2433+
with pytest.raises(IndexError):
2434+
a[dpt.newaxis, ..., 0.0]
2435+
with pytest.raises(IndexError):
2436+
a[dpt.newaxis, ..., 0.0, dpt.newaxis]
2437+
with pytest.raises(IndexError):
2438+
a[..., 0.0, dpt.newaxis]
2439+
with pytest.raises(IndexError):
2440+
a[:, 0.0, dpt.newaxis]

dpctl/tests/test_usm_ndarray_indexing.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -983,7 +983,7 @@ def test_take_arg_validation():
983983
dpt.take(dict(), ind0, axis=0)
984984
with pytest.raises(TypeError):
985985
dpt.take(x, dict(), axis=0)
986-
with pytest.raises(TypeError):
986+
with pytest.raises(IndexError):
987987
x[[]]
988988
with pytest.raises(IndexError):
989989
dpt.take(x, ind1, axis=0)
@@ -1016,7 +1016,7 @@ def test_put_arg_validation():
10161016
dpt.put(dict(), ind0, val, axis=0)
10171017
with pytest.raises(TypeError):
10181018
dpt.put(x, dict(), val, axis=0)
1019-
with pytest.raises(TypeError):
1019+
with pytest.raises(IndexError):
10201020
x[[]] = val
10211021
with pytest.raises(IndexError):
10221022
dpt.put(x, ind1, val, axis=0)

0 commit comments

Comments
 (0)