Skip to content

Commit b38478b

Browse files
committed
Add another test for advanced indexing
Test for the case where an basic integral index appears between two integral arrays, followed by a basic index, and then followed by `:`
1 parent 72f8938 commit b38478b

File tree

1 file changed

+18
-0
lines changed

1 file changed

+18
-0
lines changed

dpctl/tests/test_usm_ndarray_indexing.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -402,6 +402,24 @@ def test_advanced_slice13():
402402
assert (dpt.asnumpy(y) == dpt.asnumpy(expected)).all()
403403

404404

405+
def test_advanced_slice14():
406+
q = get_queue_or_skip()
407+
ii = dpt.asarray([1, 2], sycl_queue=q)
408+
x = dpt.reshape(dpt.arange(3**5, dtype="i4", sycl_queue=q), (3,) * 5)
409+
y = x[ii, 0, ii, 1, :]
410+
assert isinstance(y, dpt.usm_ndarray)
411+
# integers broadcast to ii.shape per array API
412+
assert y.shape == ii.shape + x.shape[-1:]
413+
assert _all_equal(
414+
(
415+
x[ii[i], 0, ii[i], 1, k]
416+
for i in range(ii.shape[0])
417+
for k in range(x.shape[-1])
418+
),
419+
(y[i, k] for i in range(ii.shape[0]) for k in range(x.shape[-1])),
420+
)
421+
422+
405423
def test_boolean_indexing_validation():
406424
get_queue_or_skip()
407425
x = dpt.zeros(10, dtype="i4")

0 commit comments

Comments
 (0)