@@ -311,7 +311,7 @@ def flip(X, /, *, axis=None):
311
311
return X [indexer ]
312
312
313
313
314
- def roll (X , / , shift , * , axis = None ):
314
+ def roll (x , / , shift , * , axis = None ):
315
315
"""
316
316
roll(x, shift, axis)
317
317
@@ -343,41 +343,45 @@ def roll(X, /, shift, *, axis=None):
343
343
`device` attributes as `x` and whose elements are shifted relative
344
344
to `x`.
345
345
"""
346
- if not isinstance (X , dpt .usm_ndarray ):
347
- raise TypeError (f"Expected usm_ndarray type, got { type (X )} ." )
348
- exec_q = X .sycl_queue
346
+ if not isinstance (x , dpt .usm_ndarray ):
347
+ raise TypeError (f"Expected usm_ndarray type, got { type (x )} ." )
348
+ exec_q = x .sycl_queue
349
349
_manager = dputils .SequentialOrderManager [exec_q ]
350
350
if axis is None :
351
351
shift = operator .index (shift )
352
- dep_evs = _manager .submitted_events
353
352
res = dpt .empty (
354
- X .shape , dtype = X .dtype , usm_type = X .usm_type , sycl_queue = exec_q
353
+ x .shape , dtype = x .dtype , usm_type = x .usm_type , sycl_queue = exec_q
355
354
)
355
+ sz = operator .index (x .size )
356
+ shift = (shift % sz ) if sz > 0 else 0
357
+ dep_evs = _manager .submitted_events
356
358
hev , roll_ev = ti ._copy_usm_ndarray_for_roll_1d (
357
- src = X ,
359
+ src = x ,
358
360
dst = res ,
359
361
shift = shift ,
360
362
sycl_queue = exec_q ,
361
363
depends = dep_evs ,
362
364
)
363
365
_manager .add_event_pair (hev , roll_ev )
364
366
return res
365
- axis = normalize_axis_tuple (axis , X .ndim , allow_duplicate = True )
367
+ axis = normalize_axis_tuple (axis , x .ndim , allow_duplicate = True )
366
368
broadcasted = np .broadcast (shift , axis )
367
369
if broadcasted .ndim > 1 :
368
370
raise ValueError ("'shift' and 'axis' should be scalars or 1D sequences" )
369
371
shifts = [
370
372
0 ,
371
- ] * X .ndim
373
+ ] * x .ndim
374
+ shape = x .shape
372
375
for sh , ax in broadcasted :
373
- shifts [ax ] += sh
374
-
376
+ n_i = operator .index (shape [ax ])
377
+ shifted = shifts [ax ] + operator .index (sh )
378
+ shifts [ax ] = (shifted % n_i ) if n_i > 0 else 0
375
379
res = dpt .empty (
376
- X .shape , dtype = X .dtype , usm_type = X .usm_type , sycl_queue = exec_q
380
+ x .shape , dtype = x .dtype , usm_type = x .usm_type , sycl_queue = exec_q
377
381
)
378
382
dep_evs = _manager .submitted_events
379
383
ht_e , roll_ev = ti ._copy_usm_ndarray_for_roll_nd (
380
- src = X , dst = res , shifts = shifts , sycl_queue = exec_q , depends = dep_evs
384
+ src = x , dst = res , shifts = shifts , sycl_queue = exec_q , depends = dep_evs
381
385
)
382
386
_manager .add_event_pair (ht_e , roll_ev )
383
387
return res
0 commit comments