Skip to content

Commit 295c532

Browse files
Merge pull request #1918 from IntelPython/radix-sort-technical-debt
Technical debt changes in radix_sort.hpp
2 parents 631d4c3 + 19a7e75 commit 295c532

File tree

1 file changed

+21
-15
lines changed

1 file changed

+21
-15
lines changed

dpctl/tensor/libtensor/include/kernels/sorting/radix_sort.hpp

Lines changed: 21 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727

2828
#pragma once
2929

30+
#include <array>
3031
#include <cstdint>
3132
#include <limits>
3233
#include <stdexcept>
@@ -477,10 +478,10 @@ struct peer_prefix_helper<OffsetT, peer_prefix_algo::atomic_fetch_or>
477478
sycl::access::address_space::local_space>;
478479
using TempStorageT = sycl::local_accessor<std::uint32_t, 1>;
479480

480-
sycl::sub_group sgroup;
481-
std::uint32_t lid;
482-
std::uint32_t item_mask;
483-
AtomicT atomic_peer_mask;
481+
const sycl::sub_group sgroup;
482+
const std::uint32_t lid;
483+
const std::uint32_t item_mask;
484+
const AtomicT atomic_peer_mask;
484485

485486
peer_prefix_helper(sycl::nd_item<1> ndit, TempStorageT lacc)
486487
: sgroup(ndit.get_sub_group()), lid(ndit.get_local_linear_id()),
@@ -490,7 +491,7 @@ struct peer_prefix_helper<OffsetT, peer_prefix_algo::atomic_fetch_or>
490491

491492
std::uint32_t peer_contribution(OffsetT &new_offset_id,
492493
OffsetT offset_prefix,
493-
bool wi_bit_set)
494+
bool wi_bit_set) const
494495
{
495496
// reset mask for each radix state
496497
if (lid == 0)
@@ -523,8 +524,8 @@ struct peer_prefix_helper<OffsetT, peer_prefix_algo::scan_then_broadcast>
523524
using ItemType = sycl::nd_item<1>;
524525
using SubGroupType = sycl::sub_group;
525526

526-
SubGroupType sgroup;
527-
std::uint32_t sg_size;
527+
const SubGroupType sgroup;
528+
const std::uint32_t sg_size;
528529

529530
peer_prefix_helper(sycl::nd_item<1> ndit, TempStorageT)
530531
: sgroup(ndit.get_sub_group()), sg_size(sgroup.get_local_range()[0])
@@ -533,7 +534,7 @@ struct peer_prefix_helper<OffsetT, peer_prefix_algo::scan_then_broadcast>
533534

534535
std::uint32_t peer_contribution(OffsetT &new_offset_id,
535536
OffsetT offset_prefix,
536-
bool wi_bit_set)
537+
bool wi_bit_set) const
537538
{
538539
const std::uint32_t contrib{wi_bit_set ? std::uint32_t{1}
539540
: std::uint32_t{0}};
@@ -567,9 +568,9 @@ struct peer_prefix_helper<OffsetT, peer_prefix_algo::subgroup_ballot>
567568
public:
568569
using TempStorageT = empty_storage;
569570

570-
sycl::sub_group sgroup;
571-
std::uint32_t lid;
572-
sycl::ext::oneapi::sub_group_mask item_sg_mask;
571+
const sycl::sub_group sgroup;
572+
const std::uint32_t lid;
573+
const sycl::ext::oneapi::sub_group_mask item_sg_mask;
573574

574575
peer_prefix_helper(sycl::nd_item<1> ndit, TempStorageT)
575576
: sgroup(ndit.get_sub_group()), lid(ndit.get_local_linear_id()),
@@ -580,7 +581,7 @@ struct peer_prefix_helper<OffsetT, peer_prefix_algo::subgroup_ballot>
580581

581582
std::uint32_t peer_contribution(OffsetT &new_offset_id,
582583
OffsetT offset_prefix,
583-
bool wi_bit_set)
584+
bool wi_bit_set) const
584585
{
585586
// set local id's bit to 1 if the bucket value matches the radix state
586587
auto peer_mask = sycl::ext::oneapi::group_ballot(sgroup, wi_bit_set);
@@ -750,7 +751,7 @@ radix_sort_reorder_submit(sycl::queue &exec_q,
750751
const std::uint32_t tail_size = (seg_end - seg_start) % sg_size;
751752
seg_end -= tail_size;
752753

753-
PeerHelper peer_prefix_hlp(ndit, peer_temp);
754+
const PeerHelper peer_prefix_hlp(ndit, peer_temp);
754755

755756
// find offsets for the same values within a segment and fill the
756757
// resulting buffer
@@ -967,8 +968,13 @@ struct parallel_radix_sort_iteration_step
967968

968969
// 3. Reorder Phase
969970
sycl::event reorder_ev{};
970-
if (reorder_sg_size == 8 || reorder_sg_size == 16 ||
971-
reorder_sg_size == 32)
971+
// subgroup_ballot-based peer algo uses extract_bits to populate
972+
// uint32_t mask and hence relies on sub-group to be 32 or narrower
973+
constexpr std::size_t sg32_v = 32u;
974+
constexpr std::size_t sg16_v = 16u;
975+
constexpr std::size_t sg08_v = 8u;
976+
if (sg32_v == reorder_sg_size || sg16_v == reorder_sg_size ||
977+
sg08_v == reorder_sg_size)
972978
{
973979
constexpr auto peer_algorithm = peer_prefix_algo::subgroup_ballot;
974980

0 commit comments

Comments
 (0)