@@ -1253,6 +1253,24 @@ struct subgroup_radix_sort
1253
1253
const std::size_t n_batches =
1254
1254
(n_iters + n_batch_size - 1 ) / n_batch_size;
1255
1255
1256
+ const auto &kernel_id = sycl::get_kernel_id<KernelName>();
1257
+
1258
+ auto const &ctx = exec_q.get_context ();
1259
+ auto const &dev = exec_q.get_device ();
1260
+ auto kb = sycl::get_kernel_bundle<sycl::bundle_state::executable>(
1261
+ ctx, {dev}, {kernel_id});
1262
+
1263
+ const auto &krn = kb.get_kernel (kernel_id);
1264
+
1265
+ const std::uint32_t krn_sg_size = krn.template get_info <
1266
+ sycl::info::kernel_device_specific::max_sub_group_size>(dev);
1267
+
1268
+ // due to a bug in CPU device implementation, an additional
1269
+ // synchronization is necessary for short sub-group sizes
1270
+ const bool work_around_needed =
1271
+ exec_q.get_device ().has (sycl::aspect::cpu) &&
1272
+ (krn_sg_size < 16 );
1273
+
1256
1274
for (std::size_t batch_id = 0 ; batch_id < n_batches; ++batch_id) {
1257
1275
1258
1276
const std::size_t block_start = batch_id * n_batch_size;
@@ -1269,6 +1287,7 @@ struct subgroup_radix_sort
1269
1287
1270
1288
sort_ev = exec_q.submit ([&](sycl::handler &cgh) {
1271
1289
cgh.depends_on (deps);
1290
+ cgh.use_kernel_bundle (kb);
1272
1291
1273
1292
// allocation to use for value exchanges
1274
1293
auto exchange_acc = buf_val.get_acc (cgh);
@@ -1357,6 +1376,11 @@ struct subgroup_radix_sort
1357
1376
counters[i] = &pcounter[bin * wg_size];
1358
1377
indices[i] = *counters[i];
1359
1378
*counters[i] = indices[i] + 1 ;
1379
+
1380
+ if (work_around_needed) {
1381
+ sycl::group_barrier (
1382
+ ndit.get_group ());
1383
+ }
1360
1384
}
1361
1385
}
1362
1386
else {
@@ -1389,6 +1413,11 @@ struct subgroup_radix_sort
1389
1413
counters[i] = &pcounter[bin * wg_size];
1390
1414
indices[i] = *counters[i];
1391
1415
*counters[i] = indices[i] + 1 ;
1416
+
1417
+ if (work_around_needed) {
1418
+ sycl::group_barrier (
1419
+ ndit.get_group ());
1420
+ }
1392
1421
}
1393
1422
}
1394
1423
0 commit comments