Skip to content

Commit 10f5706

Browse files
Renamed X to x as per docstring, add parens to inline if
1 parent 275fdba commit 10f5706

File tree

1 file changed

+14
-14
lines changed

1 file changed

+14
-14
lines changed

dpctl/tensor/_manipulation_functions.py

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -311,7 +311,7 @@ def flip(X, /, *, axis=None):
311311
return X[indexer]
312312

313313

314-
def roll(X, /, shift, *, axis=None):
314+
def roll(x, /, shift, *, axis=None):
315315
"""
316316
roll(x, shift, axis)
317317
@@ -343,44 +343,44 @@ def roll(X, /, shift, *, axis=None):
343343
`device` attributes as `x` and whose elements are shifted relative
344344
to `x`.
345345
"""
346-
if not isinstance(X, dpt.usm_ndarray):
347-
raise TypeError(f"Expected usm_ndarray type, got {type(X)}.")
348-
exec_q = X.sycl_queue
346+
if not isinstance(x, dpt.usm_ndarray):
347+
raise TypeError(f"Expected usm_ndarray type, got {type(x)}.")
348+
exec_q = x.sycl_queue
349349
_manager = dputils.SequentialOrderManager[exec_q]
350350
if axis is None:
351351
shift = operator.index(shift)
352-
dep_evs = _manager.submitted_events
353352
res = dpt.empty(
354-
X.shape, dtype=X.dtype, usm_type=X.usm_type, sycl_queue=exec_q
353+
x.shape, dtype=x.dtype, usm_type=x.usm_type, sycl_queue=exec_q
355354
)
356-
sz = X.size
355+
sz = x.size
357356
shift = (shift % sz) if sz > 0 else 0
357+
dep_evs = _manager.submitted_events
358358
hev, roll_ev = ti._copy_usm_ndarray_for_roll_1d(
359-
src=X,
359+
src=x,
360360
dst=res,
361361
shift=shift,
362362
sycl_queue=exec_q,
363363
depends=dep_evs,
364364
)
365365
_manager.add_event_pair(hev, roll_ev)
366366
return res
367-
axis = normalize_axis_tuple(axis, X.ndim, allow_duplicate=True)
367+
axis = normalize_axis_tuple(axis, x.ndim, allow_duplicate=True)
368368
broadcasted = np.broadcast(shift, axis)
369369
if broadcasted.ndim > 1:
370370
raise ValueError("'shift' and 'axis' should be scalars or 1D sequences")
371371
shifts = [
372372
0,
373-
] * X.ndim
374-
shape = X.shape
373+
] * x.ndim
374+
shape = x.shape
375375
for sh, ax in broadcasted:
376376
n_i = shape[ax]
377-
shifts[ax] = int(shifts[ax] + sh) % int(n_i) if n_i > 0 else 0
377+
shifts[ax] = (int(shifts[ax] + sh) % int(n_i)) if n_i > 0 else 0
378378
res = dpt.empty(
379-
X.shape, dtype=X.dtype, usm_type=X.usm_type, sycl_queue=exec_q
379+
x.shape, dtype=x.dtype, usm_type=x.usm_type, sycl_queue=exec_q
380380
)
381381
dep_evs = _manager.submitted_events
382382
ht_e, roll_ev = ti._copy_usm_ndarray_for_roll_nd(
383-
src=X, dst=res, shifts=shifts, sycl_queue=exec_q, depends=dep_evs
383+
src=x, dst=res, shifts=shifts, sycl_queue=exec_q, depends=dep_evs
384384
)
385385
_manager.add_event_pair(ht_e, roll_ev)
386386
return res

0 commit comments

Comments
 (0)