Skip to content

Commit 7ade829

Browse files
If size of mask allows, using int32 type for cumsum to improve performance
1 parent f7ac1f1 commit 7ade829

File tree

1 file changed

+8
-3
lines changed

1 file changed

+8
-3
lines changed

dpctl/tensor/_copy_utils.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,8 @@
3131
":class:`dpctl.tensor.usm_ndarray`."
3232
)
3333

34+
int32_t_max = 2147483648
35+
3436

3537
def _copy_to_numpy(ary):
3638
if not isinstance(ary, dpt.usm_ndarray):
@@ -482,7 +484,8 @@ def _extract_impl(ary, ary_mask, axis=0):
482484
"Parameter p is inconsistent with input array dimensions"
483485
)
484486
mask_nelems = ary_mask.size
485-
cumsum = dpt.empty(mask_nelems, dtype=dpt.int64, device=ary_mask.device)
487+
cumsum_dt = dpt.int32 if mask_nelems < int32_t_max else dpt.int64
488+
cumsum = dpt.empty(mask_nelems, dtype=cumsum_dt, device=ary_mask.device)
486489
exec_q = cumsum.sycl_queue
487490
mask_count = ti.mask_positions(ary_mask, cumsum, sycl_queue=exec_q)
488491
dst_shape = ary.shape[:pp] + (mask_count,) + ary.shape[pp + mask_nd :]
@@ -509,8 +512,9 @@ def _nonzero_impl(ary):
509512
exec_q = ary.sycl_queue
510513
usm_type = ary.usm_type
511514
mask_nelems = ary.size
515+
cumsum_dt = dpt.int32 if mask_nelems < int32_t_max else dpt.int64
512516
cumsum = dpt.empty(
513-
mask_nelems, dtype=dpt.int64, sycl_queue=exec_q, order="C"
517+
mask_nelems, dtype=cumsum_dt, sycl_queue=exec_q, order="C"
514518
)
515519
mask_count = ti.mask_positions(ary, cumsum, sycl_queue=exec_q)
516520
indexes = dpt.empty(
@@ -604,7 +608,8 @@ def _place_impl(ary, ary_mask, vals, axis=0):
604608
"Parameter p is inconsistent with input array dimensions"
605609
)
606610
mask_nelems = ary_mask.size
607-
cumsum = dpt.empty(mask_nelems, dtype=dpt.int64, device=ary_mask.device)
611+
cumsum_dt = dpt.int32 if mask_nelems < int32_t_max else dpt.int64
612+
cumsum = dpt.empty(mask_nelems, dtype=cumsum_dt, device=ary_mask.device)
608613
exec_q = cumsum.sycl_queue
609614
mask_count = ti.mask_positions(ary_mask, cumsum, sycl_queue=exec_q)
610615
expected_vals_shape = (

0 commit comments

Comments
 (0)