Skip to content

Commit ce4cf5d

Browse files
Fixed hang on CUDA GPUs in custom scanning code
1 parent 101514f commit ce4cf5d

File tree

1 file changed

+9
-5
lines changed

1 file changed

+9
-5
lines changed

dpctl/tensor/libtensor/include/utils/sycl_utils.hpp

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -261,10 +261,13 @@ T custom_inclusive_scan_over_group(GroupT &&wg,
261261
sycl::group_barrier(wg, sycl::memory_scope::work_group);
262262
}
263263

264-
if (sgr_id == 0 && lane_id < n_aggregates) {
264+
if (sgr_id == 0) {
265265
const std::uint32_t offset =
266266
(large_wg) ? n_aggregates - max_sgSize : 0u;
267-
T __scan_val = (offset + lane_id > 0)
267+
const bool in_range = (lane_id < n_aggregates);
268+
const bool in_bounds = in_range && (lane_id > 0 || large_wg);
269+
270+
T __scan_val = (in_bounds)
268271
? local_mem_acc[(offset + lane_id) * max_sgSize - 1]
269272
: identity;
270273
for (std::uint32_t step = 1; step < sgSize; step *= 2) {
@@ -273,12 +276,13 @@ T custom_inclusive_scan_over_group(GroupT &&wg,
273276
(advanced_lane ? lane_id - step : lane_id);
274277
const T modifier =
275278
sycl::select_from_group(sg, __scan_val, src_lane_id);
276-
if (advanced_lane) {
279+
if (advanced_lane && in_range) {
277280
__scan_val = op(__scan_val, modifier);
278281
}
279282
}
280-
sycl::group_barrier(sg);
281-
local_mem_acc[(offset + lane_id) * max_sgSize - 1] = __scan_val;
283+
if (in_bounds) {
284+
local_mem_acc[(offset + lane_id) * max_sgSize - 1] = __scan_val;
285+
}
282286
}
283287
sycl::group_barrier(wg, sycl::memory_scope::work_group);
284288

0 commit comments

Comments
 (0)