Skip to content

Commit 60f60c6

Browse files
Add a comment about constant choice
1 parent dd9a873 commit 60f60c6

File tree

1 file changed

+21
-24
lines changed

1 file changed

+21
-24
lines changed

dpctl/tensor/libtensor/include/kernels/accumulators.hpp

Lines changed: 21 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -273,8 +273,7 @@ inclusive_scan_base_step_blocked(sycl::queue &exec_q,
273273

274274
outputT wg_iscan_val;
275275
if constexpr (can_use_inclusive_scan_over_group<ScanOpT,
276-
outputT>::value)
277-
{
276+
outputT>::value) {
278277
wg_iscan_val = sycl::inclusive_scan_over_group(
279278
it.get_group(), local_iscan.back(), scan_op, identity);
280279
}
@@ -447,8 +446,7 @@ inclusive_scan_base_step_striped(sycl::queue &exec_q,
447446

448447
outputT wg_iscan_val;
449448
if constexpr (can_use_inclusive_scan_over_group<ScanOpT,
450-
outputT>::value)
451-
{
449+
outputT>::value) {
452450
wg_iscan_val = sycl::inclusive_scan_over_group(
453451
it.get_group(), local_iscan.back(), scan_op, identity);
454452
}
@@ -472,35 +470,32 @@ inclusive_scan_base_step_striped(sycl::queue &exec_q,
472470
it.barrier(sycl::access::fence_space::local_space);
473471

474472
// convert back to blocked layout
475-
{
476-
{
477-
const std::uint32_t local_offset0 = lid * n_wi;
473+
{{const std::uint32_t local_offset0 = lid * n_wi;
478474
#pragma unroll
479-
for (nwiT m_wi = 0; m_wi < n_wi; ++m_wi) {
480-
slm_iscan_tmp[local_offset0 + m_wi] = local_iscan[m_wi];
481-
}
475+
for (nwiT m_wi = 0; m_wi < n_wi; ++m_wi) {
476+
slm_iscan_tmp[local_offset0 + m_wi] = local_iscan[m_wi];
477+
}
482478

483-
it.barrier(sycl::access::fence_space::local_space);
479+
it.barrier(sycl::access::fence_space::local_space);
484480
}
485481
}
486482

487483
{
488-
const std::uint32_t block_offset =
489-
sgroup_id * sgSize * n_wi + lane_id;
484+
const std::uint32_t block_offset = sgroup_id * sgSize * n_wi + lane_id;
490485
#pragma unroll
491-
for (nwiT m_wi = 0; m_wi < n_wi; ++m_wi) {
492-
const std::uint32_t m_wi_scaled = m_wi * sgSize;
493-
const std::size_t out_id = inp_id0 + m_wi_scaled;
494-
if (out_id < acc_nelems) {
495-
output[out_iter_offset + out_indexer(out_id)] =
496-
slm_iscan_tmp[block_offset + m_wi_scaled];
497-
}
498-
}
486+
for (nwiT m_wi = 0; m_wi < n_wi; ++m_wi) {
487+
const std::uint32_t m_wi_scaled = m_wi * sgSize;
488+
const std::size_t out_id = inp_id0 + m_wi_scaled;
489+
if (out_id < acc_nelems) {
490+
output[out_iter_offset + out_indexer(out_id)] =
491+
slm_iscan_tmp[block_offset + m_wi_scaled];
499492
}
500-
});
501-
});
493+
}
494+
}
495+
});
496+
});
502497

503-
return inc_scan_phase1_ev;
498+
return inc_scan_phase1_ev;
504499
}
505500

506501
template <typename inputT,
@@ -530,6 +525,8 @@ inclusive_scan_base_step(sycl::queue &exec_q,
530525
std::size_t &acc_groups,
531526
const std::vector<sycl::event> &depends = {})
532527
{
528+
// For small stride use striped load/store.
529+
// Threshold value chosen experimentally.
533530
if (s1 <= 16) {
534531
return inclusive_scan_base_step_striped<
535532
inputT, outputT, n_wi, IterIndexerT, InpIndexerT, OutIndexerT,

0 commit comments

Comments
 (0)