Skip to content

Commit 441e081

Browse files
committed
Increase work per work item in inclusive_scan_iter_1d update step
1 parent 859c747 commit 441e081

File tree

1 file changed

+18
-7
lines changed

1 file changed

+18
-7
lines changed

dpctl/tensor/libtensor/include/kernels/accumulators.hpp

Lines changed: 18 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -437,19 +437,30 @@ sycl::event inclusive_scan_iter_1d(sycl::queue &exec_q,
437437
dependent_event = exec_q.submit([&](sycl::handler &cgh) {
438438
cgh.depends_on(dependent_event);
439439

440+
constexpr nwiT updates_per_wi = n_wi;
441+
size_t n_items = ceiling_quotient<size_t>(src_size, n_wi);
442+
440443
using UpdateKernelName =
441444
class inclusive_scan_1d_iter_chunk_update_krn<
442445
inputT, outputT, n_wi, IndexerT, TransformerT,
443446
NoOpTransformerT, ScanOpT, include_initial>;
444447

445448
cgh.parallel_for<UpdateKernelName>(
446-
{src_size}, [chunk_size, src, local_scans, scan_op,
447-
identity](auto wiid) {
448-
auto gid = wiid[0];
449-
auto i = (gid / chunk_size);
450-
src[gid] = (i > 0)
451-
? scan_op(src[gid], local_scans[i - 1])
452-
: scan_op(src[gid], identity);
449+
{n_items}, [chunk_size, src, src_size, local_scans, scan_op,
450+
identity](auto wiid) {
451+
auto gid = n_wi * wiid[0];
452+
#pragma unroll
453+
for (auto i = 0; i < updates_per_wi; ++i) {
454+
auto src_id = gid + i;
455+
if (src_id < src_size) {
456+
auto scan_id = (src_id / chunk_size);
457+
src[src_id] =
458+
(scan_id > 0)
459+
? scan_op(src[src_id],
460+
local_scans[scan_id - 1])
461+
: scan_op(src[src_id], identity);
462+
}
463+
}
453464
});
454465
});
455466
}

0 commit comments

Comments
 (0)