Skip to content

Commit bd0c9b2

Browse files
Use operator.index to normalize shift
1 parent 23fd986 commit bd0c9b2

File tree

1 file changed

+4
-3
lines changed

1 file changed

+4
-3
lines changed

dpctl/tensor/_manipulation_functions.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -352,7 +352,7 @@ def roll(x, /, shift, *, axis=None):
352352
res = dpt.empty(
353353
x.shape, dtype=x.dtype, usm_type=x.usm_type, sycl_queue=exec_q
354354
)
355-
sz = x.size
355+
sz = operator.index(x.size)
356356
shift = (shift % sz) if sz > 0 else 0
357357
dep_evs = _manager.submitted_events
358358
hev, roll_ev = ti._copy_usm_ndarray_for_roll_1d(
@@ -373,8 +373,9 @@ def roll(x, /, shift, *, axis=None):
373373
] * x.ndim
374374
shape = x.shape
375375
for sh, ax in broadcasted:
376-
n_i = shape[ax]
377-
shifts[ax] = (int(shifts[ax] + sh) % int(n_i)) if n_i > 0 else 0
376+
n_i = operator.index(shape[ax])
377+
shifted = shifts[ax] + operator.index(sh)
378+
shifts[ax] = (shifted % n_i) if n_i > 0 else 0
378379
res = dpt.empty(
379380
x.shape, dtype=x.dtype, usm_type=x.usm_type, sycl_queue=exec_q
380381
)

0 commit comments

Comments
 (0)