Skip to content

Commit 5125e11

Browse files
oleksandr-pavlykndgrigorian
authored andcommitted
Apply work-around for failing tests with CPU device and short sub-groups
The team developing OpenCL:CPU device runtime and compiler was notified. See CMPLRLLVM-64592 Once fixed, the work-around should be removed.
1 parent 0493485 commit 5125e11

File tree

1 file changed

+29
-0
lines changed

1 file changed

+29
-0
lines changed

dpctl/tensor/libtensor/include/kernels/sorting/radix_sort.hpp

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1253,6 +1253,24 @@ struct subgroup_radix_sort
12531253
const std::size_t n_batches =
12541254
(n_iters + n_batch_size - 1) / n_batch_size;
12551255

1256+
const auto &kernel_id = sycl::get_kernel_id<KernelName>();
1257+
1258+
auto const &ctx = exec_q.get_context();
1259+
auto const &dev = exec_q.get_device();
1260+
auto kb = sycl::get_kernel_bundle<sycl::bundle_state::executable>(
1261+
ctx, {dev}, {kernel_id});
1262+
1263+
const auto &krn = kb.get_kernel(kernel_id);
1264+
1265+
const std::uint32_t krn_sg_size = krn.template get_info<
1266+
sycl::info::kernel_device_specific::max_sub_group_size>(dev);
1267+
1268+
// due to a bug in CPU device implementation, an additional
1269+
// synchronization is necessary for short sub-group sizes
1270+
const bool work_around_needed =
1271+
exec_q.get_device().has(sycl::aspect::cpu) &&
1272+
(krn_sg_size < 16);
1273+
12561274
for (std::size_t batch_id = 0; batch_id < n_batches; ++batch_id) {
12571275

12581276
const std::size_t block_start = batch_id * n_batch_size;
@@ -1269,6 +1287,7 @@ struct subgroup_radix_sort
12691287

12701288
sort_ev = exec_q.submit([&](sycl::handler &cgh) {
12711289
cgh.depends_on(deps);
1290+
cgh.use_kernel_bundle(kb);
12721291

12731292
// allocation to use for value exchanges
12741293
auto exchange_acc = buf_val.get_acc(cgh);
@@ -1357,6 +1376,11 @@ struct subgroup_radix_sort
13571376
counters[i] = &pcounter[bin * wg_size];
13581377
indices[i] = *counters[i];
13591378
*counters[i] = indices[i] + 1;
1379+
1380+
if (work_around_needed) {
1381+
sycl::group_barrier(
1382+
ndit.get_group());
1383+
}
13601384
}
13611385
}
13621386
else {
@@ -1389,6 +1413,11 @@ struct subgroup_radix_sort
13891413
counters[i] = &pcounter[bin * wg_size];
13901414
indices[i] = *counters[i];
13911415
*counters[i] = indices[i] + 1;
1416+
1417+
if (work_around_needed) {
1418+
sycl::group_barrier(
1419+
ndit.get_group());
1420+
}
13921421
}
13931422
}
13941423

0 commit comments

Comments
 (0)