@@ -437,19 +437,30 @@ sycl::event inclusive_scan_iter_1d(sycl::queue &exec_q,
437
437
dependent_event = exec_q.submit ([&](sycl::handler &cgh) {
438
438
cgh.depends_on (dependent_event);
439
439
440
+ constexpr nwiT updates_per_wi = n_wi;
441
+ size_t n_items = ceiling_quotient<size_t >(src_size, n_wi);
442
+
440
443
using UpdateKernelName =
441
444
class inclusive_scan_1d_iter_chunk_update_krn <
442
445
inputT, outputT, n_wi, IndexerT, TransformerT,
443
446
NoOpTransformerT, ScanOpT, include_initial>;
444
447
445
448
cgh.parallel_for <UpdateKernelName>(
446
- {src_size}, [chunk_size, src, local_scans, scan_op,
447
- identity](auto wiid) {
448
- auto gid = wiid[0 ];
449
- auto i = (gid / chunk_size);
450
- src[gid] = (i > 0 )
451
- ? scan_op (src[gid], local_scans[i - 1 ])
452
- : scan_op (src[gid], identity);
449
+ {n_items}, [chunk_size, src, src_size, local_scans, scan_op,
450
+ identity](auto wiid) {
451
+ auto gid = n_wi * wiid[0 ];
452
+ #pragma unroll
453
+ for (auto i = 0 ; i < updates_per_wi; ++i) {
454
+ auto src_id = gid + i;
455
+ if (src_id < src_size) {
456
+ auto scan_id = (src_id / chunk_size);
457
+ src[src_id] =
458
+ (scan_id > 0 )
459
+ ? scan_op (src[src_id],
460
+ local_scans[scan_id - 1 ])
461
+ : scan_op (src[src_id], identity);
462
+ }
463
+ }
453
464
});
454
465
});
455
466
}
0 commit comments