File tree Expand file tree Collapse file tree 1 file changed +9
-5
lines changed
dpctl/tensor/libtensor/include/utils Expand file tree Collapse file tree 1 file changed +9
-5
lines changed Original file line number Diff line number Diff line change @@ -261,10 +261,13 @@ T custom_inclusive_scan_over_group(GroupT &&wg,
261
261
sycl::group_barrier (wg, sycl::memory_scope::work_group);
262
262
}
263
263
264
- if (sgr_id == 0 && lane_id < n_aggregates ) {
264
+ if (sgr_id == 0 ) {
265
265
const std::uint32_t offset =
266
266
(large_wg) ? n_aggregates - max_sgSize : 0u ;
267
- T __scan_val = (offset + lane_id > 0 )
267
+ const bool in_range = (lane_id < n_aggregates);
268
+ const bool in_bounds = in_range && (lane_id > 0 || large_wg);
269
+
270
+ T __scan_val = (in_bounds)
268
271
? local_mem_acc[(offset + lane_id) * max_sgSize - 1 ]
269
272
: identity;
270
273
for (std::uint32_t step = 1 ; step < sgSize; step *= 2 ) {
@@ -273,12 +276,13 @@ T custom_inclusive_scan_over_group(GroupT &&wg,
273
276
(advanced_lane ? lane_id - step : lane_id);
274
277
const T modifier =
275
278
sycl::select_from_group (sg, __scan_val, src_lane_id);
276
- if (advanced_lane) {
279
+ if (advanced_lane && in_range ) {
277
280
__scan_val = op (__scan_val, modifier);
278
281
}
279
282
}
280
- sycl::group_barrier (sg);
281
- local_mem_acc[(offset + lane_id) * max_sgSize - 1 ] = __scan_val;
283
+ if (in_bounds) {
284
+ local_mem_acc[(offset + lane_id) * max_sgSize - 1 ] = __scan_val;
285
+ }
282
286
}
283
287
sycl::group_barrier (wg, sycl::memory_scope::work_group);
284
288
You can’t perform that action at this time.
0 commit comments