@@ -273,8 +273,7 @@ inclusive_scan_base_step_blocked(sycl::queue &exec_q,
273
273
274
274
outputT wg_iscan_val;
275
275
if constexpr (can_use_inclusive_scan_over_group<ScanOpT,
276
- outputT>::value)
277
- {
276
+ outputT>::value) {
278
277
wg_iscan_val = sycl::inclusive_scan_over_group (
279
278
it.get_group (), local_iscan.back (), scan_op, identity);
280
279
}
@@ -447,8 +446,7 @@ inclusive_scan_base_step_striped(sycl::queue &exec_q,
447
446
448
447
outputT wg_iscan_val;
449
448
if constexpr (can_use_inclusive_scan_over_group<ScanOpT,
450
- outputT>::value)
451
- {
449
+ outputT>::value) {
452
450
wg_iscan_val = sycl::inclusive_scan_over_group (
453
451
it.get_group (), local_iscan.back (), scan_op, identity);
454
452
}
@@ -472,35 +470,32 @@ inclusive_scan_base_step_striped(sycl::queue &exec_q,
472
470
it.barrier (sycl::access::fence_space::local_space);
473
471
474
472
// convert back to blocked layout
475
- {
476
- {
477
- const std::uint32_t local_offset0 = lid * n_wi;
473
+ {{const std::uint32_t local_offset0 = lid * n_wi;
478
474
#pragma unroll
479
- for (nwiT m_wi = 0 ; m_wi < n_wi; ++m_wi) {
480
- slm_iscan_tmp[local_offset0 + m_wi] = local_iscan[m_wi];
481
- }
475
+ for (nwiT m_wi = 0 ; m_wi < n_wi; ++m_wi) {
476
+ slm_iscan_tmp[local_offset0 + m_wi] = local_iscan[m_wi];
477
+ }
482
478
483
- it.barrier (sycl::access::fence_space::local_space);
479
+ it.barrier (sycl::access::fence_space::local_space);
484
480
}
485
481
}
486
482
487
483
{
488
- const std::uint32_t block_offset =
489
- sgroup_id * sgSize * n_wi + lane_id;
484
+ const std::uint32_t block_offset = sgroup_id * sgSize * n_wi + lane_id;
490
485
#pragma unroll
491
- for (nwiT m_wi = 0 ; m_wi < n_wi; ++m_wi) {
492
- const std::uint32_t m_wi_scaled = m_wi * sgSize;
493
- const std::size_t out_id = inp_id0 + m_wi_scaled;
494
- if (out_id < acc_nelems) {
495
- output[out_iter_offset + out_indexer (out_id)] =
496
- slm_iscan_tmp[block_offset + m_wi_scaled];
497
- }
498
- }
486
+ for (nwiT m_wi = 0 ; m_wi < n_wi; ++m_wi) {
487
+ const std::uint32_t m_wi_scaled = m_wi * sgSize;
488
+ const std::size_t out_id = inp_id0 + m_wi_scaled;
489
+ if (out_id < acc_nelems) {
490
+ output[out_iter_offset + out_indexer (out_id)] =
491
+ slm_iscan_tmp[block_offset + m_wi_scaled];
499
492
}
500
- });
501
- });
493
+ }
494
+ }
495
+ });
496
+ });
502
497
503
- return inc_scan_phase1_ev;
498
+ return inc_scan_phase1_ev;
504
499
}
505
500
506
501
template <typename inputT,
@@ -530,6 +525,8 @@ inclusive_scan_base_step(sycl::queue &exec_q,
530
525
std::size_t &acc_groups,
531
526
const std::vector<sycl::event> &depends = {})
532
527
{
528
+ // For small stride use striped load/store.
529
+ // Threshold value chosen experimentally.
533
530
if (s1 <= 16 ) {
534
531
return inclusive_scan_base_step_striped<
535
532
inputT, outputT, n_wi, IterIndexerT, InpIndexerT, OutIndexerT,
0 commit comments