@@ -311,7 +311,7 @@ def flip(X, /, *, axis=None):
311
311
return X [indexer ]
312
312
313
313
314
- def roll (X , / , shift , * , axis = None ):
314
+ def roll (x , / , shift , * , axis = None ):
315
315
"""
316
316
roll(x, shift, axis)
317
317
@@ -343,44 +343,44 @@ def roll(X, /, shift, *, axis=None):
343
343
`device` attributes as `x` and whose elements are shifted relative
344
344
to `x`.
345
345
"""
346
- if not isinstance (X , dpt .usm_ndarray ):
347
- raise TypeError (f"Expected usm_ndarray type, got { type (X )} ." )
348
- exec_q = X .sycl_queue
346
+ if not isinstance (x , dpt .usm_ndarray ):
347
+ raise TypeError (f"Expected usm_ndarray type, got { type (x )} ." )
348
+ exec_q = x .sycl_queue
349
349
_manager = dputils .SequentialOrderManager [exec_q ]
350
350
if axis is None :
351
351
shift = operator .index (shift )
352
- dep_evs = _manager .submitted_events
353
352
res = dpt .empty (
354
- X .shape , dtype = X .dtype , usm_type = X .usm_type , sycl_queue = exec_q
353
+ x .shape , dtype = x .dtype , usm_type = x .usm_type , sycl_queue = exec_q
355
354
)
356
- sz = X .size
355
+ sz = x .size
357
356
shift = (shift % sz ) if sz > 0 else 0
357
+ dep_evs = _manager .submitted_events
358
358
hev , roll_ev = ti ._copy_usm_ndarray_for_roll_1d (
359
- src = X ,
359
+ src = x ,
360
360
dst = res ,
361
361
shift = shift ,
362
362
sycl_queue = exec_q ,
363
363
depends = dep_evs ,
364
364
)
365
365
_manager .add_event_pair (hev , roll_ev )
366
366
return res
367
- axis = normalize_axis_tuple (axis , X .ndim , allow_duplicate = True )
367
+ axis = normalize_axis_tuple (axis , x .ndim , allow_duplicate = True )
368
368
broadcasted = np .broadcast (shift , axis )
369
369
if broadcasted .ndim > 1 :
370
370
raise ValueError ("'shift' and 'axis' should be scalars or 1D sequences" )
371
371
shifts = [
372
372
0 ,
373
- ] * X .ndim
374
- shape = X .shape
373
+ ] * x .ndim
374
+ shape = x .shape
375
375
for sh , ax in broadcasted :
376
376
n_i = shape [ax ]
377
- shifts [ax ] = int (shifts [ax ] + sh ) % int (n_i ) if n_i > 0 else 0
377
+ shifts [ax ] = ( int (shifts [ax ] + sh ) % int (n_i ) ) if n_i > 0 else 0
378
378
res = dpt .empty (
379
- X .shape , dtype = X .dtype , usm_type = X .usm_type , sycl_queue = exec_q
379
+ x .shape , dtype = x .dtype , usm_type = x .usm_type , sycl_queue = exec_q
380
380
)
381
381
dep_evs = _manager .submitted_events
382
382
ht_e , roll_ev = ti ._copy_usm_ndarray_for_roll_nd (
383
- src = X , dst = res , shifts = shifts , sycl_queue = exec_q , depends = dep_evs
383
+ src = x , dst = res , shifts = shifts , sycl_queue = exec_q , depends = dep_evs
384
384
)
385
385
_manager .add_event_pair (ht_e , roll_ev )
386
386
return res
0 commit comments