Skip to content

Commit 3bd3338

Browse files
Counters in one-workgroup kernel to use uint16_t from uint32_t
Counters can not exceed uint16_t max, because the kernel assumes that the number of elements to sort fits into uint16_t. The change reduces the kernel SLM footprint. Also, remove use of std::move, uint16_t->std::uint16_t, etc Replace size_t->std::size_t, uint32_t->std::uint32_t Use `if constexpr` in order-preservign-cast for better readability.
1 parent 1a64014 commit 3bd3338

File tree

1 file changed

+86
-73
lines changed

1 file changed

+86
-73
lines changed

dpctl/tensor/libtensor/include/kernels/sorting/radix_sort.hpp

Lines changed: 86 additions & 73 deletions
Original file line numberDiff line numberDiff line change
@@ -70,13 +70,14 @@ template <typename SizeT,
7070
int> = 0>
7171
std::uint32_t ceil_log2(SizeT n)
7272
{
73+
// if n > 2^b, n = q * 2^b + r for q > 0 and 0 <= r < 2^b
74+
// floor_log2(q * 2^b + r) == floor_log2(q * 2^b) == q + floor_log2(n1)
75+
// ceil_log2(n) == 1 + floor_log2(n-1)
7376
if (n <= 1)
7477
return std::uint32_t{1};
7578

7679
std::uint32_t exp{1};
7780
--n;
78-
// if n > 2^b, n = q * 2^b + r for q > 0 and 0 <= r < 2^b
79-
// ceil_log2(q * 2^b + r) == ceil_log2(q * 2^b) == q + ceil_log2(n1)
8081
if (n >= (SizeT{1} << 32)) {
8182
n >>= 32;
8283
exp += 32;
@@ -137,16 +138,20 @@ template <bool is_ascending,
137138
std::make_unsigned_t<IntT> order_preserving_cast(IntT val)
138139
{
139140
using UIntT = std::make_unsigned_t<IntT>;
140-
// ascending_mask: 100..0
141-
constexpr UIntT ascending_mask =
142-
(UIntT(1) << std::numeric_limits<IntT>::digits);
143-
// descending_mask: 011..1
144-
constexpr UIntT descending_mask = (std::numeric_limits<UIntT>::max() >> 1);
145-
146-
constexpr UIntT mask = (is_ascending) ? ascending_mask : descending_mask;
147141
const UIntT uint_val = sycl::bit_cast<UIntT>(val);
148142

149-
return (uint_val ^ mask);
143+
if constexpr (is_ascending) {
144+
// ascending_mask: 100..0
145+
constexpr UIntT ascending_mask =
146+
(UIntT(1) << std::numeric_limits<IntT>::digits);
147+
return (uint_val ^ ascending_mask);
148+
}
149+
else {
150+
// descending_mask: 011..1
151+
constexpr UIntT descending_mask =
152+
(std::numeric_limits<UIntT>::max() >> 1);
153+
return (uint_val ^ descending_mask);
154+
}
150155
}
151156

152157
template <bool is_ascending> std::uint16_t order_preserving_cast(sycl::half val)
@@ -1045,10 +1050,10 @@ template <typename Names, std::uint16_t... Constants>
10451050
class radix_sort_one_wg_krn;
10461051

10471052
template <typename KernelNameBase,
1048-
uint16_t wg_size = 256,
1049-
uint16_t block_size = 16,
1053+
std::uint16_t wg_size = 256,
1054+
std::uint16_t block_size = 16,
10501055
std::uint32_t radix = 4,
1051-
uint16_t req_sub_group_size = (block_size < 4 ? 32 : 16)>
1056+
std::uint16_t req_sub_group_size = (block_size < 4 ? 32 : 16)>
10521057
struct subgroup_radix_sort
10531058
{
10541059
private:
@@ -1062,8 +1067,8 @@ struct subgroup_radix_sort
10621067
public:
10631068
template <typename ValueT, typename OutputT, typename ProjT>
10641069
sycl::event operator()(sycl::queue &exec_q,
1065-
size_t n_iters,
1066-
size_t n_to_sort,
1070+
std::size_t n_iters,
1071+
std::size_t n_to_sort,
10671072
ValueT *input_ptr,
10681073
OutputT *output_ptr,
10691074
ProjT proj_op,
@@ -1160,8 +1165,8 @@ struct subgroup_radix_sort
11601165
};
11611166

11621167
static_assert(wg_size <= 1024);
1163-
static constexpr uint16_t bin_count = (1 << radix);
1164-
static constexpr uint16_t counter_buf_sz = wg_size * bin_count + 1;
1168+
static constexpr std::uint16_t bin_count = (1 << radix);
1169+
static constexpr std::uint16_t counter_buf_sz = wg_size * bin_count + 1;
11651170

11661171
enum class temp_allocations
11671172
{
@@ -1177,7 +1182,7 @@ struct subgroup_radix_sort
11771182
assert(n <= (SizeT(1) << 16));
11781183

11791184
constexpr auto req_slm_size_counters =
1180-
counter_buf_sz * sizeof(uint32_t);
1185+
counter_buf_sz * sizeof(std::uint16_t);
11811186

11821187
const auto &dev = exec_q.get_device();
11831188

@@ -1212,9 +1217,9 @@ struct subgroup_radix_sort
12121217
typename SLM_value_tag,
12131218
typename SLM_counter_tag>
12141219
sycl::event operator()(sycl::queue &exec_q,
1215-
size_t n_iters,
1216-
size_t n_batch_size,
1217-
size_t n_values,
1220+
std::size_t n_iters,
1221+
std::size_t n_batch_size,
1222+
std::size_t n_values,
12181223
InputT *input_arr,
12191224
OutputT *output_arr,
12201225
const ProjT &proj_op,
@@ -1228,7 +1233,7 @@ struct subgroup_radix_sort
12281233
assert(n_values <= static_cast<std::size_t>(block_size) *
12291234
static_cast<std::size_t>(wg_size));
12301235

1231-
uint16_t n = static_cast<uint16_t>(n_values);
1236+
const std::uint16_t n = static_cast<std::uint16_t>(n_values);
12321237
static_assert(std::is_same_v<std::remove_cv_t<InputT>, OutputT>);
12331238

12341239
using ValueT = OutputT;
@@ -1237,17 +1242,18 @@ struct subgroup_radix_sort
12371242

12381243
TempBuf<ValueT, SLM_value_tag> buf_val(
12391244
n_batch_size, static_cast<std::size_t>(block_size * wg_size));
1240-
TempBuf<std::uint32_t, SLM_counter_tag> buf_count(
1245+
TempBuf<std::uint16_t, SLM_counter_tag> buf_count(
12411246
n_batch_size, static_cast<std::size_t>(counter_buf_sz));
12421247

12431248
sycl::range<1> lRange{wg_size};
12441249

12451250
sycl::event sort_ev;
1246-
std::vector<sycl::event> deps = depends;
1251+
std::vector<sycl::event> deps{depends};
12471252

1248-
std::size_t n_batches = (n_iters + n_batch_size - 1) / n_batch_size;
1253+
const std::size_t n_batches =
1254+
(n_iters + n_batch_size - 1) / n_batch_size;
12491255

1250-
for (size_t batch_id = 0; batch_id < n_batches; ++batch_id) {
1256+
for (std::size_t batch_id = 0; batch_id < n_batches; ++batch_id) {
12511257

12521258
const std::size_t block_start = batch_id * n_batch_size;
12531259

@@ -1286,46 +1292,49 @@ struct subgroup_radix_sort
12861292
const std::size_t iter_exchange_offset =
12871293
iter_id * exchange_acc_iter_stride;
12881294

1289-
uint16_t wi = ndit.get_local_linear_id();
1290-
uint16_t begin_bit = 0;
1295+
std::uint16_t wi = ndit.get_local_linear_id();
1296+
std::uint16_t begin_bit = 0;
12911297

1292-
constexpr uint16_t end_bit =
1298+
constexpr std::uint16_t end_bit =
12931299
number_of_bits_in_type<KeyT>();
12941300

1295-
// copy from input array into values
1301+
// copy from input array into values
12961302
#pragma unroll
1297-
for (uint16_t i = 0; i < block_size; ++i) {
1298-
const uint16_t id = wi * block_size + i;
1299-
if (id < n)
1300-
values[i] = std::move(
1301-
this_input_arr[iter_val_offset + id]);
1303+
for (std::uint16_t i = 0; i < block_size; ++i) {
1304+
const std::uint16_t id = wi * block_size + i;
1305+
values[i] =
1306+
(id < n) ? this_input_arr[iter_val_offset + id]
1307+
: ValueT{};
13021308
}
13031309

13041310
while (true) {
13051311
// indices for indirect access in the "re-order"
13061312
// phase
1307-
uint16_t indices[block_size];
1313+
std::uint16_t indices[block_size];
13081314
{
13091315
// pointers to bucket's counters
1310-
uint32_t *counters[block_size];
1316+
std::uint16_t *counters[block_size];
13111317

13121318
// counting phase
13131319
auto pcounter =
13141320
get_accessor_pointer(counter_acc) +
13151321
(wi + iter_counter_offset);
13161322

1317-
// initialize counters
1323+
// initialize counters
13181324
#pragma unroll
1319-
for (uint16_t i = 0; i < bin_count; ++i)
1320-
pcounter[i * wg_size] = std::uint32_t{0};
1325+
for (std::uint16_t i = 0; i < bin_count; ++i)
1326+
pcounter[i * wg_size] = std::uint16_t{0};
13211327

13221328
sycl::group_barrier(ndit.get_group());
13231329

13241330
if (is_ascending) {
13251331
#pragma unroll
1326-
for (uint16_t i = 0; i < block_size; ++i) {
1327-
const uint16_t id = wi * block_size + i;
1328-
constexpr uint16_t bin_mask =
1332+
for (std::uint16_t i = 0; i < block_size;
1333+
++i)
1334+
{
1335+
const std::uint16_t id =
1336+
wi * block_size + i;
1337+
constexpr std::uint16_t bin_mask =
13291338
bin_count - 1;
13301339

13311340
// points to the padded element, i.e. id
@@ -1334,7 +1343,7 @@ struct subgroup_radix_sort
13341343
default_out_of_range_bin_id =
13351344
bin_mask;
13361345

1337-
const uint16_t bin =
1346+
const std::uint16_t bin =
13381347
(id < n)
13391348
? get_bucket_id<bin_mask>(
13401349
order_preserving_cast<
@@ -1352,9 +1361,12 @@ struct subgroup_radix_sort
13521361
}
13531362
else {
13541363
#pragma unroll
1355-
for (uint16_t i = 0; i < block_size; ++i) {
1356-
const uint16_t id = wi * block_size + i;
1357-
constexpr uint16_t bin_mask =
1364+
for (std::uint16_t i = 0; i < block_size;
1365+
++i)
1366+
{
1367+
const std::uint16_t id =
1368+
wi * block_size + i;
1369+
constexpr std::uint16_t bin_mask =
13581370
bin_count - 1;
13591371

13601372
// points to the padded element, i.e. id
@@ -1363,7 +1375,7 @@ struct subgroup_radix_sort
13631375
default_out_of_range_bin_id =
13641376
bin_mask;
13651377

1366-
const uint16_t bin =
1378+
const std::uint16_t bin =
13671379
(id < n)
13681380
? get_bucket_id<bin_mask>(
13691381
order_preserving_cast<
@@ -1386,29 +1398,31 @@ struct subgroup_radix_sort
13861398
{
13871399

13881400
// scan contiguous numbers
1389-
uint16_t bin_sum[bin_count];
1401+
std::uint16_t bin_sum[bin_count];
13901402
const std::size_t counter_offset0 =
13911403
iter_counter_offset + wi * bin_count;
13921404
bin_sum[0] = counter_acc[counter_offset0];
13931405

13941406
#pragma unroll
1395-
for (uint16_t i = 1; i < bin_count; ++i)
1407+
for (std::uint16_t i = 1; i < bin_count;
1408+
++i)
13961409
bin_sum[i] =
13971410
bin_sum[i - 1] +
13981411
counter_acc[counter_offset0 + i];
13991412

14001413
sycl::group_barrier(ndit.get_group());
14011414

14021415
// exclusive scan local sum
1403-
uint16_t sum_scan =
1416+
std::uint16_t sum_scan =
14041417
sycl::exclusive_scan_over_group(
14051418
ndit.get_group(),
14061419
bin_sum[bin_count - 1],
1407-
sycl::plus<uint16_t>());
1420+
sycl::plus<std::uint16_t>());
14081421

14091422
// add to local sum, generate exclusive scan result
14101423
#pragma unroll
1411-
for (uint16_t i = 0; i < bin_count; ++i)
1424+
for (std::uint16_t i = 0; i < bin_count;
1425+
++i)
14121426
counter_acc[counter_offset0 + i + 1] =
14131427
sum_scan + bin_sum[i];
14141428

@@ -1420,51 +1434,50 @@ struct subgroup_radix_sort
14201434
}
14211435

14221436
#pragma unroll
1423-
for (uint16_t i = 0; i < block_size; ++i) {
1437+
for (std::uint16_t i = 0; i < block_size; ++i) {
14241438
// a global index is a local offset plus a
14251439
// global base index
14261440
indices[i] += *counters[i];
14271441
}
1442+
1443+
sycl::group_barrier(ndit.get_group());
14281444
}
14291445

14301446
begin_bit += radix;
14311447

14321448
// "re-order" phase
14331449
sycl::group_barrier(ndit.get_group());
14341450
if (begin_bit >= end_bit) {
1435-
// the last iteration - writing out the result
1451+
// the last iteration - writing out the result
14361452
#pragma unroll
1437-
for (uint16_t i = 0; i < block_size; ++i) {
1438-
const uint16_t r = indices[i];
1453+
for (std::uint16_t i = 0; i < block_size; ++i) {
1454+
const std::uint16_t r = indices[i];
14391455
if (r < n) {
1440-
// move the values to source range and
1441-
// destroy the values
14421456
this_output_arr[iter_val_offset + r] =
1443-
std::move(values[i]);
1457+
values[i];
14441458
}
14451459
}
14461460

14471461
return;
14481462
}
14491463

1450-
// data exchange
1464+
// data exchange
14511465
#pragma unroll
1452-
for (uint16_t i = 0; i < block_size; ++i) {
1453-
const uint16_t r = indices[i];
1466+
for (std::uint16_t i = 0; i < block_size; ++i) {
1467+
const std::uint16_t r = indices[i];
14541468
if (r < n)
14551469
exchange_acc[iter_exchange_offset + r] =
1456-
std::move(values[i]);
1470+
values[i];
14571471
}
14581472

14591473
sycl::group_barrier(ndit.get_group());
14601474

14611475
#pragma unroll
1462-
for (uint16_t i = 0; i < block_size; ++i) {
1463-
const uint16_t id = wi * block_size + i;
1476+
for (std::uint16_t i = 0; i < block_size; ++i) {
1477+
const std::uint16_t id = wi * block_size + i;
14641478
if (id < n)
1465-
values[i] = std::move(
1466-
exchange_acc[iter_exchange_offset +
1467-
id]);
1479+
values[i] =
1480+
exchange_acc[iter_exchange_offset + id];
14681481
}
14691482

14701483
sycl::group_barrier(ndit.get_group());
@@ -1736,10 +1749,10 @@ radix_sort_axis1_contig_impl(sycl::queue &exec_q,
17361749
const bool sort_ascending,
17371750
// number of sub-arrays to sort (num. of rows in a
17381751
// matrix when sorting over rows)
1739-
size_t iter_nelems,
1752+
std::size_t iter_nelems,
17401753
// size of each array to sort (length of rows,
17411754
// i.e. number of columns)
1742-
size_t sort_nelems,
1755+
std::size_t sort_nelems,
17431756
const char *arg_cp,
17441757
char *res_cp,
17451758
ssize_t iter_arg_offset,
@@ -1775,10 +1788,10 @@ radix_argsort_axis1_contig_impl(sycl::queue &exec_q,
17751788
const bool sort_ascending,
17761789
// number of sub-arrays to sort (num. of
17771790
// rows in a matrix when sorting over rows)
1778-
size_t iter_nelems,
1791+
std::size_t iter_nelems,
17791792
// size of each array to sort (length of
17801793
// rows, i.e. number of columns)
1781-
size_t sort_nelems,
1794+
std::size_t sort_nelems,
17821795
const char *arg_cp,
17831796
char *res_cp,
17841797
ssize_t iter_arg_offset,

0 commit comments

Comments
 (0)