Skip to content

Commit 9eb8f03

Browse files
Merge pull request #1869 from IntelPython/bugfix/gh-1857-roll-with-large-shift
Roll must reduce shift steps by size along axis
2 parents 85e4121 + bd0c9b2 commit 9eb8f03

File tree

3 files changed

+42
-14
lines changed

3 files changed

+42
-14
lines changed

dpctl/tensor/_manipulation_functions.py

Lines changed: 17 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -311,7 +311,7 @@ def flip(X, /, *, axis=None):
311311
return X[indexer]
312312

313313

314-
def roll(X, /, shift, *, axis=None):
314+
def roll(x, /, shift, *, axis=None):
315315
"""
316316
roll(x, shift, axis)
317317
@@ -343,41 +343,45 @@ def roll(X, /, shift, *, axis=None):
343343
`device` attributes as `x` and whose elements are shifted relative
344344
to `x`.
345345
"""
346-
if not isinstance(X, dpt.usm_ndarray):
347-
raise TypeError(f"Expected usm_ndarray type, got {type(X)}.")
348-
exec_q = X.sycl_queue
346+
if not isinstance(x, dpt.usm_ndarray):
347+
raise TypeError(f"Expected usm_ndarray type, got {type(x)}.")
348+
exec_q = x.sycl_queue
349349
_manager = dputils.SequentialOrderManager[exec_q]
350350
if axis is None:
351351
shift = operator.index(shift)
352-
dep_evs = _manager.submitted_events
353352
res = dpt.empty(
354-
X.shape, dtype=X.dtype, usm_type=X.usm_type, sycl_queue=exec_q
353+
x.shape, dtype=x.dtype, usm_type=x.usm_type, sycl_queue=exec_q
355354
)
355+
sz = operator.index(x.size)
356+
shift = (shift % sz) if sz > 0 else 0
357+
dep_evs = _manager.submitted_events
356358
hev, roll_ev = ti._copy_usm_ndarray_for_roll_1d(
357-
src=X,
359+
src=x,
358360
dst=res,
359361
shift=shift,
360362
sycl_queue=exec_q,
361363
depends=dep_evs,
362364
)
363365
_manager.add_event_pair(hev, roll_ev)
364366
return res
365-
axis = normalize_axis_tuple(axis, X.ndim, allow_duplicate=True)
367+
axis = normalize_axis_tuple(axis, x.ndim, allow_duplicate=True)
366368
broadcasted = np.broadcast(shift, axis)
367369
if broadcasted.ndim > 1:
368370
raise ValueError("'shift' and 'axis' should be scalars or 1D sequences")
369371
shifts = [
370372
0,
371-
] * X.ndim
373+
] * x.ndim
374+
shape = x.shape
372375
for sh, ax in broadcasted:
373-
shifts[ax] += sh
374-
376+
n_i = operator.index(shape[ax])
377+
shifted = shifts[ax] + operator.index(sh)
378+
shifts[ax] = (shifted % n_i) if n_i > 0 else 0
375379
res = dpt.empty(
376-
X.shape, dtype=X.dtype, usm_type=X.usm_type, sycl_queue=exec_q
380+
x.shape, dtype=x.dtype, usm_type=x.usm_type, sycl_queue=exec_q
377381
)
378382
dep_evs = _manager.submitted_events
379383
ht_e, roll_ev = ti._copy_usm_ndarray_for_roll_nd(
380-
src=X, dst=res, shifts=shifts, sycl_queue=exec_q, depends=dep_evs
384+
src=x, dst=res, shifts=shifts, sycl_queue=exec_q, depends=dep_evs
381385
)
382386
_manager.add_event_pair(ht_e, roll_ev)
383387
return res

dpctl/tensor/libtensor/source/copy_for_roll.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -326,7 +326,7 @@ copy_usm_ndarray_for_roll_nd(const dpctl::tensor::usm_ndarray &src,
326326
// normalize shift parameter to be 0 <= offset < dim
327327
py::ssize_t dim = src_shape_ptr[i];
328328
size_t offset =
329-
(shifts[i] > 0) ? (shifts[i] % dim) : dim + (shifts[i] % dim);
329+
(shifts[i] >= 0) ? (shifts[i] % dim) : dim + (shifts[i] % dim);
330330

331331
normalized_shifts.push_back(offset);
332332
}

dpctl/tests/test_usm_ndarray_manipulation.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -657,6 +657,30 @@ def test_roll_2d(data):
657657
assert_array_equal(Ynp, dpt.asnumpy(Y))
658658

659659

660+
def test_roll_out_bounds_shifts():
661+
"See gh-1857"
662+
get_queue_or_skip()
663+
664+
x = dpt.arange(4)
665+
y = dpt.roll(x, np.uint64(2**63 + 2))
666+
expected = dpt.roll(x, 2)
667+
assert dpt.all(y == expected)
668+
669+
x_empty = x[1:1]
670+
y = dpt.roll(x_empty, 11)
671+
assert y.size == 0
672+
673+
x_2d = dpt.reshape(x, (2, 2))
674+
y = dpt.roll(x_2d, np.uint64(2**63 + 1), axis=1)
675+
expected = dpt.roll(x_2d, 1, axis=1)
676+
assert dpt.all(y == expected)
677+
678+
x_2d_empty = x_2d[:, 1:1]
679+
y = dpt.roll(x_2d_empty, 3, axis=1)
680+
expected = dpt.empty_like(x_2d_empty)
681+
assert dpt.all(y == expected)
682+
683+
660684
def test_roll_validation():
661685
get_queue_or_skip()
662686

0 commit comments

Comments
 (0)