Skip to content

Technical debt changes in radix_sort.hpp #1918

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Dec 3, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 21 additions & 15 deletions dpctl/tensor/libtensor/include/kernels/sorting/radix_sort.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@

#pragma once

#include <array>
#include <cstdint>
#include <limits>
#include <stdexcept>
Expand Down Expand Up @@ -477,10 +478,10 @@ struct peer_prefix_helper<OffsetT, peer_prefix_algo::atomic_fetch_or>
sycl::access::address_space::local_space>;
using TempStorageT = sycl::local_accessor<std::uint32_t, 1>;

sycl::sub_group sgroup;
std::uint32_t lid;
std::uint32_t item_mask;
AtomicT atomic_peer_mask;
const sycl::sub_group sgroup;
const std::uint32_t lid;
const std::uint32_t item_mask;
const AtomicT atomic_peer_mask;

peer_prefix_helper(sycl::nd_item<1> ndit, TempStorageT lacc)
: sgroup(ndit.get_sub_group()), lid(ndit.get_local_linear_id()),
Expand All @@ -490,7 +491,7 @@ struct peer_prefix_helper<OffsetT, peer_prefix_algo::atomic_fetch_or>

std::uint32_t peer_contribution(OffsetT &new_offset_id,
OffsetT offset_prefix,
bool wi_bit_set)
bool wi_bit_set) const
{
// reset mask for each radix state
if (lid == 0)
Expand Down Expand Up @@ -523,8 +524,8 @@ struct peer_prefix_helper<OffsetT, peer_prefix_algo::scan_then_broadcast>
using ItemType = sycl::nd_item<1>;
using SubGroupType = sycl::sub_group;

SubGroupType sgroup;
std::uint32_t sg_size;
const SubGroupType sgroup;
const std::uint32_t sg_size;

peer_prefix_helper(sycl::nd_item<1> ndit, TempStorageT)
: sgroup(ndit.get_sub_group()), sg_size(sgroup.get_local_range()[0])
Expand All @@ -533,7 +534,7 @@ struct peer_prefix_helper<OffsetT, peer_prefix_algo::scan_then_broadcast>

std::uint32_t peer_contribution(OffsetT &new_offset_id,
OffsetT offset_prefix,
bool wi_bit_set)
bool wi_bit_set) const
{
const std::uint32_t contrib{wi_bit_set ? std::uint32_t{1}
: std::uint32_t{0}};
Expand Down Expand Up @@ -567,9 +568,9 @@ struct peer_prefix_helper<OffsetT, peer_prefix_algo::subgroup_ballot>
public:
using TempStorageT = empty_storage;

sycl::sub_group sgroup;
std::uint32_t lid;
sycl::ext::oneapi::sub_group_mask item_sg_mask;
const sycl::sub_group sgroup;
const std::uint32_t lid;
const sycl::ext::oneapi::sub_group_mask item_sg_mask;

peer_prefix_helper(sycl::nd_item<1> ndit, TempStorageT)
: sgroup(ndit.get_sub_group()), lid(ndit.get_local_linear_id()),
Expand All @@ -580,7 +581,7 @@ struct peer_prefix_helper<OffsetT, peer_prefix_algo::subgroup_ballot>

std::uint32_t peer_contribution(OffsetT &new_offset_id,
OffsetT offset_prefix,
bool wi_bit_set)
bool wi_bit_set) const
{
// set local id's bit to 1 if the bucket value matches the radix state
auto peer_mask = sycl::ext::oneapi::group_ballot(sgroup, wi_bit_set);
Expand Down Expand Up @@ -750,7 +751,7 @@ radix_sort_reorder_submit(sycl::queue &exec_q,
const std::uint32_t tail_size = (seg_end - seg_start) % sg_size;
seg_end -= tail_size;

PeerHelper peer_prefix_hlp(ndit, peer_temp);
const PeerHelper peer_prefix_hlp(ndit, peer_temp);

// find offsets for the same values within a segment and fill the
// resulting buffer
Expand Down Expand Up @@ -967,8 +968,13 @@ struct parallel_radix_sort_iteration_step

// 3. Reorder Phase
sycl::event reorder_ev{};
if (reorder_sg_size == 8 || reorder_sg_size == 16 ||
reorder_sg_size == 32)
// subgroup_ballot-based peer algo uses extract_bits to populate
// uint32_t mask and hence relies on sub-group to be 32 or narrower
constexpr std::size_t sg32_v = 32u;
constexpr std::size_t sg16_v = 16u;
constexpr std::size_t sg08_v = 8u;
if (sg32_v == reorder_sg_size || sg16_v == reorder_sg_size ||
sg08_v == reorder_sg_size)
{
constexpr auto peer_algorithm = peer_prefix_algo::subgroup_ballot;

Expand Down
Loading