@@ -70,13 +70,14 @@ template <typename SizeT,
70
70
int > = 0 >
71
71
std::uint32_t ceil_log2 (SizeT n)
72
72
{
73
+ // if n > 2^b, n = q * 2^b + r for q > 0 and 0 <= r < 2^b
74
+ // floor_log2(q * 2^b + r) == floor_log2(q * 2^b) == q + floor_log2(n1)
75
+ // ceil_log2(n) == 1 + floor_log2(n-1)
73
76
if (n <= 1 )
74
77
return std::uint32_t {1 };
75
78
76
79
std::uint32_t exp{1 };
77
80
--n;
78
- // if n > 2^b, n = q * 2^b + r for q > 0 and 0 <= r < 2^b
79
- // ceil_log2(q * 2^b + r) == ceil_log2(q * 2^b) == q + ceil_log2(n1)
80
81
if (n >= (SizeT{1 } << 32 )) {
81
82
n >>= 32 ;
82
83
exp += 32 ;
@@ -137,16 +138,20 @@ template <bool is_ascending,
137
138
std::make_unsigned_t <IntT> order_preserving_cast (IntT val)
138
139
{
139
140
using UIntT = std::make_unsigned_t <IntT>;
140
- // ascending_mask: 100..0
141
- constexpr UIntT ascending_mask =
142
- (UIntT (1 ) << std::numeric_limits<IntT>::digits);
143
- // descending_mask: 011..1
144
- constexpr UIntT descending_mask = (std::numeric_limits<UIntT>::max () >> 1 );
145
-
146
- constexpr UIntT mask = (is_ascending) ? ascending_mask : descending_mask;
147
141
const UIntT uint_val = sycl::bit_cast<UIntT>(val);
148
142
149
- return (uint_val ^ mask);
143
+ if constexpr (is_ascending) {
144
+ // ascending_mask: 100..0
145
+ constexpr UIntT ascending_mask =
146
+ (UIntT (1 ) << std::numeric_limits<IntT>::digits);
147
+ return (uint_val ^ ascending_mask);
148
+ }
149
+ else {
150
+ // descending_mask: 011..1
151
+ constexpr UIntT descending_mask =
152
+ (std::numeric_limits<UIntT>::max () >> 1 );
153
+ return (uint_val ^ descending_mask);
154
+ }
150
155
}
151
156
152
157
template <bool is_ascending> std::uint16_t order_preserving_cast (sycl::half val)
@@ -1045,10 +1050,10 @@ template <typename Names, std::uint16_t... Constants>
1045
1050
class radix_sort_one_wg_krn ;
1046
1051
1047
1052
template <typename KernelNameBase,
1048
- uint16_t wg_size = 256 ,
1049
- uint16_t block_size = 16 ,
1053
+ std:: uint16_t wg_size = 256 ,
1054
+ std:: uint16_t block_size = 16 ,
1050
1055
std::uint32_t radix = 4 ,
1051
- uint16_t req_sub_group_size = (block_size < 4 ? 32 : 16 )>
1056
+ std:: uint16_t req_sub_group_size = (block_size < 4 ? 32 : 16 )>
1052
1057
struct subgroup_radix_sort
1053
1058
{
1054
1059
private:
@@ -1062,8 +1067,8 @@ struct subgroup_radix_sort
1062
1067
public:
1063
1068
template <typename ValueT, typename OutputT, typename ProjT>
1064
1069
sycl::event operator ()(sycl::queue &exec_q,
1065
- size_t n_iters,
1066
- size_t n_to_sort,
1070
+ std:: size_t n_iters,
1071
+ std:: size_t n_to_sort,
1067
1072
ValueT *input_ptr,
1068
1073
OutputT *output_ptr,
1069
1074
ProjT proj_op,
@@ -1160,8 +1165,8 @@ struct subgroup_radix_sort
1160
1165
};
1161
1166
1162
1167
static_assert (wg_size <= 1024 );
1163
- static constexpr uint16_t bin_count = (1 << radix);
1164
- static constexpr uint16_t counter_buf_sz = wg_size * bin_count + 1 ;
1168
+ static constexpr std:: uint16_t bin_count = (1 << radix);
1169
+ static constexpr std:: uint16_t counter_buf_sz = wg_size * bin_count + 1 ;
1165
1170
1166
1171
enum class temp_allocations
1167
1172
{
@@ -1177,7 +1182,7 @@ struct subgroup_radix_sort
1177
1182
assert (n <= (SizeT (1 ) << 16 ));
1178
1183
1179
1184
constexpr auto req_slm_size_counters =
1180
- counter_buf_sz * sizeof (uint32_t );
1185
+ counter_buf_sz * sizeof (std:: uint16_t );
1181
1186
1182
1187
const auto &dev = exec_q.get_device ();
1183
1188
@@ -1212,9 +1217,9 @@ struct subgroup_radix_sort
1212
1217
typename SLM_value_tag,
1213
1218
typename SLM_counter_tag>
1214
1219
sycl::event operator ()(sycl::queue &exec_q,
1215
- size_t n_iters,
1216
- size_t n_batch_size,
1217
- size_t n_values,
1220
+ std:: size_t n_iters,
1221
+ std:: size_t n_batch_size,
1222
+ std:: size_t n_values,
1218
1223
InputT *input_arr,
1219
1224
OutputT *output_arr,
1220
1225
const ProjT &proj_op,
@@ -1228,7 +1233,7 @@ struct subgroup_radix_sort
1228
1233
assert (n_values <= static_cast <std::size_t >(block_size) *
1229
1234
static_cast <std::size_t >(wg_size));
1230
1235
1231
- uint16_t n = static_cast <uint16_t >(n_values);
1236
+ const std:: uint16_t n = static_cast <std:: uint16_t >(n_values);
1232
1237
static_assert (std::is_same_v<std::remove_cv_t <InputT>, OutputT>);
1233
1238
1234
1239
using ValueT = OutputT;
@@ -1237,17 +1242,18 @@ struct subgroup_radix_sort
1237
1242
1238
1243
TempBuf<ValueT, SLM_value_tag> buf_val (
1239
1244
n_batch_size, static_cast <std::size_t >(block_size * wg_size));
1240
- TempBuf<std::uint32_t , SLM_counter_tag> buf_count (
1245
+ TempBuf<std::uint16_t , SLM_counter_tag> buf_count (
1241
1246
n_batch_size, static_cast <std::size_t >(counter_buf_sz));
1242
1247
1243
1248
sycl::range<1 > lRange{wg_size};
1244
1249
1245
1250
sycl::event sort_ev;
1246
- std::vector<sycl::event> deps = depends;
1251
+ std::vector<sycl::event> deps{ depends} ;
1247
1252
1248
- std::size_t n_batches = (n_iters + n_batch_size - 1 ) / n_batch_size;
1253
+ const std::size_t n_batches =
1254
+ (n_iters + n_batch_size - 1 ) / n_batch_size;
1249
1255
1250
- for (size_t batch_id = 0 ; batch_id < n_batches; ++batch_id) {
1256
+ for (std:: size_t batch_id = 0 ; batch_id < n_batches; ++batch_id) {
1251
1257
1252
1258
const std::size_t block_start = batch_id * n_batch_size;
1253
1259
@@ -1286,46 +1292,49 @@ struct subgroup_radix_sort
1286
1292
const std::size_t iter_exchange_offset =
1287
1293
iter_id * exchange_acc_iter_stride;
1288
1294
1289
- uint16_t wi = ndit.get_local_linear_id ();
1290
- uint16_t begin_bit = 0 ;
1295
+ std:: uint16_t wi = ndit.get_local_linear_id ();
1296
+ std:: uint16_t begin_bit = 0 ;
1291
1297
1292
- constexpr uint16_t end_bit =
1298
+ constexpr std:: uint16_t end_bit =
1293
1299
number_of_bits_in_type<KeyT>();
1294
1300
1295
- // copy from input array into values
1301
+ // copy from input array into values
1296
1302
#pragma unroll
1297
- for (uint16_t i = 0 ; i < block_size; ++i) {
1298
- const uint16_t id = wi * block_size + i;
1299
- if (id < n)
1300
- values[i] = std::move (
1301
- this_input_arr[iter_val_offset + id]) ;
1303
+ for (std:: uint16_t i = 0 ; i < block_size; ++i) {
1304
+ const std:: uint16_t id = wi * block_size + i;
1305
+ values[i] =
1306
+ (id < n) ? this_input_arr[iter_val_offset + id]
1307
+ : ValueT{} ;
1302
1308
}
1303
1309
1304
1310
while (true ) {
1305
1311
// indices for indirect access in the "re-order"
1306
1312
// phase
1307
- uint16_t indices[block_size];
1313
+ std:: uint16_t indices[block_size];
1308
1314
{
1309
1315
// pointers to bucket's counters
1310
- uint32_t *counters[block_size];
1316
+ std:: uint16_t *counters[block_size];
1311
1317
1312
1318
// counting phase
1313
1319
auto pcounter =
1314
1320
get_accessor_pointer (counter_acc) +
1315
1321
(wi + iter_counter_offset);
1316
1322
1317
- // initialize counters
1323
+ // initialize counters
1318
1324
#pragma unroll
1319
- for (uint16_t i = 0 ; i < bin_count; ++i)
1320
- pcounter[i * wg_size] = std::uint32_t {0 };
1325
+ for (std:: uint16_t i = 0 ; i < bin_count; ++i)
1326
+ pcounter[i * wg_size] = std::uint16_t {0 };
1321
1327
1322
1328
sycl::group_barrier (ndit.get_group ());
1323
1329
1324
1330
if (is_ascending) {
1325
1331
#pragma unroll
1326
- for (uint16_t i = 0 ; i < block_size; ++i) {
1327
- const uint16_t id = wi * block_size + i;
1328
- constexpr uint16_t bin_mask =
1332
+ for (std::uint16_t i = 0 ; i < block_size;
1333
+ ++i)
1334
+ {
1335
+ const std::uint16_t id =
1336
+ wi * block_size + i;
1337
+ constexpr std::uint16_t bin_mask =
1329
1338
bin_count - 1 ;
1330
1339
1331
1340
// points to the padded element, i.e. id
@@ -1334,7 +1343,7 @@ struct subgroup_radix_sort
1334
1343
default_out_of_range_bin_id =
1335
1344
bin_mask;
1336
1345
1337
- const uint16_t bin =
1346
+ const std:: uint16_t bin =
1338
1347
(id < n)
1339
1348
? get_bucket_id<bin_mask>(
1340
1349
order_preserving_cast<
@@ -1352,9 +1361,12 @@ struct subgroup_radix_sort
1352
1361
}
1353
1362
else {
1354
1363
#pragma unroll
1355
- for (uint16_t i = 0 ; i < block_size; ++i) {
1356
- const uint16_t id = wi * block_size + i;
1357
- constexpr uint16_t bin_mask =
1364
+ for (std::uint16_t i = 0 ; i < block_size;
1365
+ ++i)
1366
+ {
1367
+ const std::uint16_t id =
1368
+ wi * block_size + i;
1369
+ constexpr std::uint16_t bin_mask =
1358
1370
bin_count - 1 ;
1359
1371
1360
1372
// points to the padded element, i.e. id
@@ -1363,7 +1375,7 @@ struct subgroup_radix_sort
1363
1375
default_out_of_range_bin_id =
1364
1376
bin_mask;
1365
1377
1366
- const uint16_t bin =
1378
+ const std:: uint16_t bin =
1367
1379
(id < n)
1368
1380
? get_bucket_id<bin_mask>(
1369
1381
order_preserving_cast<
@@ -1386,29 +1398,31 @@ struct subgroup_radix_sort
1386
1398
{
1387
1399
1388
1400
// scan contiguous numbers
1389
- uint16_t bin_sum[bin_count];
1401
+ std:: uint16_t bin_sum[bin_count];
1390
1402
const std::size_t counter_offset0 =
1391
1403
iter_counter_offset + wi * bin_count;
1392
1404
bin_sum[0 ] = counter_acc[counter_offset0];
1393
1405
1394
1406
#pragma unroll
1395
- for (uint16_t i = 1 ; i < bin_count; ++i)
1407
+ for (std::uint16_t i = 1 ; i < bin_count;
1408
+ ++i)
1396
1409
bin_sum[i] =
1397
1410
bin_sum[i - 1 ] +
1398
1411
counter_acc[counter_offset0 + i];
1399
1412
1400
1413
sycl::group_barrier (ndit.get_group ());
1401
1414
1402
1415
// exclusive scan local sum
1403
- uint16_t sum_scan =
1416
+ std:: uint16_t sum_scan =
1404
1417
sycl::exclusive_scan_over_group (
1405
1418
ndit.get_group (),
1406
1419
bin_sum[bin_count - 1 ],
1407
- sycl::plus<uint16_t >());
1420
+ sycl::plus<std:: uint16_t >());
1408
1421
1409
1422
// add to local sum, generate exclusive scan result
1410
1423
#pragma unroll
1411
- for (uint16_t i = 0 ; i < bin_count; ++i)
1424
+ for (std::uint16_t i = 0 ; i < bin_count;
1425
+ ++i)
1412
1426
counter_acc[counter_offset0 + i + 1 ] =
1413
1427
sum_scan + bin_sum[i];
1414
1428
@@ -1420,51 +1434,50 @@ struct subgroup_radix_sort
1420
1434
}
1421
1435
1422
1436
#pragma unroll
1423
- for (uint16_t i = 0 ; i < block_size; ++i) {
1437
+ for (std:: uint16_t i = 0 ; i < block_size; ++i) {
1424
1438
// a global index is a local offset plus a
1425
1439
// global base index
1426
1440
indices[i] += *counters[i];
1427
1441
}
1442
+
1443
+ sycl::group_barrier (ndit.get_group ());
1428
1444
}
1429
1445
1430
1446
begin_bit += radix;
1431
1447
1432
1448
// "re-order" phase
1433
1449
sycl::group_barrier (ndit.get_group ());
1434
1450
if (begin_bit >= end_bit) {
1435
- // the last iteration - writing out the result
1451
+ // the last iteration - writing out the result
1436
1452
#pragma unroll
1437
- for (uint16_t i = 0 ; i < block_size; ++i) {
1438
- const uint16_t r = indices[i];
1453
+ for (std:: uint16_t i = 0 ; i < block_size; ++i) {
1454
+ const std:: uint16_t r = indices[i];
1439
1455
if (r < n) {
1440
- // move the values to source range and
1441
- // destroy the values
1442
1456
this_output_arr[iter_val_offset + r] =
1443
- std::move ( values[i]) ;
1457
+ values[i];
1444
1458
}
1445
1459
}
1446
1460
1447
1461
return ;
1448
1462
}
1449
1463
1450
- // data exchange
1464
+ // data exchange
1451
1465
#pragma unroll
1452
- for (uint16_t i = 0 ; i < block_size; ++i) {
1453
- const uint16_t r = indices[i];
1466
+ for (std:: uint16_t i = 0 ; i < block_size; ++i) {
1467
+ const std:: uint16_t r = indices[i];
1454
1468
if (r < n)
1455
1469
exchange_acc[iter_exchange_offset + r] =
1456
- std::move ( values[i]) ;
1470
+ values[i];
1457
1471
}
1458
1472
1459
1473
sycl::group_barrier (ndit.get_group ());
1460
1474
1461
1475
#pragma unroll
1462
- for (uint16_t i = 0 ; i < block_size; ++i) {
1463
- const uint16_t id = wi * block_size + i;
1476
+ for (std:: uint16_t i = 0 ; i < block_size; ++i) {
1477
+ const std:: uint16_t id = wi * block_size + i;
1464
1478
if (id < n)
1465
- values[i] = std::move (
1466
- exchange_acc[iter_exchange_offset +
1467
- id]);
1479
+ values[i] =
1480
+ exchange_acc[iter_exchange_offset + id];
1468
1481
}
1469
1482
1470
1483
sycl::group_barrier (ndit.get_group ());
@@ -1736,10 +1749,10 @@ radix_sort_axis1_contig_impl(sycl::queue &exec_q,
1736
1749
const bool sort_ascending,
1737
1750
// number of sub-arrays to sort (num. of rows in a
1738
1751
// matrix when sorting over rows)
1739
- size_t iter_nelems,
1752
+ std:: size_t iter_nelems,
1740
1753
// size of each array to sort (length of rows,
1741
1754
// i.e. number of columns)
1742
- size_t sort_nelems,
1755
+ std:: size_t sort_nelems,
1743
1756
const char *arg_cp,
1744
1757
char *res_cp,
1745
1758
ssize_t iter_arg_offset,
@@ -1775,10 +1788,10 @@ radix_argsort_axis1_contig_impl(sycl::queue &exec_q,
1775
1788
const bool sort_ascending,
1776
1789
// number of sub-arrays to sort (num. of
1777
1790
// rows in a matrix when sorting over rows)
1778
- size_t iter_nelems,
1791
+ std:: size_t iter_nelems,
1779
1792
// size of each array to sort (length of
1780
1793
// rows, i.e. number of columns)
1781
- size_t sort_nelems,
1794
+ std:: size_t sort_nelems,
1782
1795
const char *arg_cp,
1783
1796
char *res_cp,
1784
1797
ssize_t iter_arg_offset,
0 commit comments