27
27
28
28
#pragma once
29
29
30
+ #include < array>
30
31
#include < cstdint>
31
32
#include < limits>
32
33
#include < stdexcept>
@@ -477,10 +478,10 @@ struct peer_prefix_helper<OffsetT, peer_prefix_algo::atomic_fetch_or>
477
478
sycl::access::address_space::local_space>;
478
479
using TempStorageT = sycl::local_accessor<std::uint32_t , 1 >;
479
480
480
- sycl::sub_group sgroup;
481
- std::uint32_t lid;
482
- std::uint32_t item_mask;
483
- AtomicT atomic_peer_mask;
481
+ const sycl::sub_group sgroup;
482
+ const std::uint32_t lid;
483
+ const std::uint32_t item_mask;
484
+ const AtomicT atomic_peer_mask;
484
485
485
486
peer_prefix_helper (sycl::nd_item<1 > ndit, TempStorageT lacc)
486
487
: sgroup(ndit.get_sub_group()), lid(ndit.get_local_linear_id()),
@@ -490,7 +491,7 @@ struct peer_prefix_helper<OffsetT, peer_prefix_algo::atomic_fetch_or>
490
491
491
492
std::uint32_t peer_contribution (OffsetT &new_offset_id,
492
493
OffsetT offset_prefix,
493
- bool wi_bit_set)
494
+ bool wi_bit_set) const
494
495
{
495
496
// reset mask for each radix state
496
497
if (lid == 0 )
@@ -523,8 +524,8 @@ struct peer_prefix_helper<OffsetT, peer_prefix_algo::scan_then_broadcast>
523
524
using ItemType = sycl::nd_item<1 >;
524
525
using SubGroupType = sycl::sub_group;
525
526
526
- SubGroupType sgroup;
527
- std::uint32_t sg_size;
527
+ const SubGroupType sgroup;
528
+ const std::uint32_t sg_size;
528
529
529
530
peer_prefix_helper (sycl::nd_item<1 > ndit, TempStorageT)
530
531
: sgroup(ndit.get_sub_group()), sg_size(sgroup.get_local_range()[0 ])
@@ -533,7 +534,7 @@ struct peer_prefix_helper<OffsetT, peer_prefix_algo::scan_then_broadcast>
533
534
534
535
std::uint32_t peer_contribution (OffsetT &new_offset_id,
535
536
OffsetT offset_prefix,
536
- bool wi_bit_set)
537
+ bool wi_bit_set) const
537
538
{
538
539
const std::uint32_t contrib{wi_bit_set ? std::uint32_t {1 }
539
540
: std::uint32_t {0 }};
@@ -567,9 +568,9 @@ struct peer_prefix_helper<OffsetT, peer_prefix_algo::subgroup_ballot>
567
568
public:
568
569
using TempStorageT = empty_storage;
569
570
570
- sycl::sub_group sgroup;
571
- std::uint32_t lid;
572
- sycl::ext::oneapi::sub_group_mask item_sg_mask;
571
+ const sycl::sub_group sgroup;
572
+ const std::uint32_t lid;
573
+ const sycl::ext::oneapi::sub_group_mask item_sg_mask;
573
574
574
575
peer_prefix_helper (sycl::nd_item<1 > ndit, TempStorageT)
575
576
: sgroup(ndit.get_sub_group()), lid(ndit.get_local_linear_id()),
@@ -580,7 +581,7 @@ struct peer_prefix_helper<OffsetT, peer_prefix_algo::subgroup_ballot>
580
581
581
582
std::uint32_t peer_contribution (OffsetT &new_offset_id,
582
583
OffsetT offset_prefix,
583
- bool wi_bit_set)
584
+ bool wi_bit_set) const
584
585
{
585
586
// set local id's bit to 1 if the bucket value matches the radix state
586
587
auto peer_mask = sycl::ext::oneapi::group_ballot (sgroup, wi_bit_set);
@@ -750,7 +751,7 @@ radix_sort_reorder_submit(sycl::queue &exec_q,
750
751
const std::uint32_t tail_size = (seg_end - seg_start) % sg_size;
751
752
seg_end -= tail_size;
752
753
753
- PeerHelper peer_prefix_hlp (ndit, peer_temp);
754
+ const PeerHelper peer_prefix_hlp (ndit, peer_temp);
754
755
755
756
// find offsets for the same values within a segment and fill the
756
757
// resulting buffer
@@ -967,8 +968,13 @@ struct parallel_radix_sort_iteration_step
967
968
968
969
// 3. Reorder Phase
969
970
sycl::event reorder_ev{};
970
- if (reorder_sg_size == 8 || reorder_sg_size == 16 ||
971
- reorder_sg_size == 32 )
971
+ // subgroup_ballot-based peer algo uses extract_bits to populate
972
+ // uint32_t mask and hence relies on sub-group to be 32 or narrower
973
+ constexpr std::size_t sg32_v = 32u ;
974
+ constexpr std::size_t sg16_v = 16u ;
975
+ constexpr std::size_t sg08_v = 8u ;
976
+ if (sg32_v == reorder_sg_size || sg16_v == reorder_sg_size ||
977
+ sg08_v == reorder_sg_size)
972
978
{
973
979
constexpr auto peer_algorithm = peer_prefix_algo::subgroup_ballot;
974
980
0 commit comments