31
31
":class:`dpctl.tensor.usm_ndarray`."
32
32
)
33
33
34
+ int32_t_max = 2147483648
35
+
34
36
35
37
def _copy_to_numpy (ary ):
36
38
if not isinstance (ary , dpt .usm_ndarray ):
@@ -482,7 +484,8 @@ def _extract_impl(ary, ary_mask, axis=0):
482
484
"Parameter p is inconsistent with input array dimensions"
483
485
)
484
486
mask_nelems = ary_mask .size
485
- cumsum = dpt .empty (mask_nelems , dtype = dpt .int64 , device = ary_mask .device )
487
+ cumsum_dt = dpt .int32 if mask_nelems < int32_t_max else dpt .int64
488
+ cumsum = dpt .empty (mask_nelems , dtype = cumsum_dt , device = ary_mask .device )
486
489
exec_q = cumsum .sycl_queue
487
490
mask_count = ti .mask_positions (ary_mask , cumsum , sycl_queue = exec_q )
488
491
dst_shape = ary .shape [:pp ] + (mask_count ,) + ary .shape [pp + mask_nd :]
@@ -509,8 +512,9 @@ def _nonzero_impl(ary):
509
512
exec_q = ary .sycl_queue
510
513
usm_type = ary .usm_type
511
514
mask_nelems = ary .size
515
+ cumsum_dt = dpt .int32 if mask_nelems < int32_t_max else dpt .int64
512
516
cumsum = dpt .empty (
513
- mask_nelems , dtype = dpt . int64 , sycl_queue = exec_q , order = "C"
517
+ mask_nelems , dtype = cumsum_dt , sycl_queue = exec_q , order = "C"
514
518
)
515
519
mask_count = ti .mask_positions (ary , cumsum , sycl_queue = exec_q )
516
520
indexes = dpt .empty (
@@ -604,7 +608,8 @@ def _place_impl(ary, ary_mask, vals, axis=0):
604
608
"Parameter p is inconsistent with input array dimensions"
605
609
)
606
610
mask_nelems = ary_mask .size
607
- cumsum = dpt .empty (mask_nelems , dtype = dpt .int64 , device = ary_mask .device )
611
+ cumsum_dt = dpt .int32 if mask_nelems < int32_t_max else dpt .int64
612
+ cumsum = dpt .empty (mask_nelems , dtype = cumsum_dt , device = ary_mask .device )
608
613
exec_q = cumsum .sycl_queue
609
614
mask_count = ti .mask_positions (ary_mask , cumsum , sycl_queue = exec_q )
610
615
expected_vals_shape = (
0 commit comments