Skip to content

Commit a56e21c

Browse files
committed
Add implementation of top_k using radix sort
1 parent a338383 commit a56e21c

File tree

2 files changed

+193
-63
lines changed

2 files changed

+193
-63
lines changed

dpctl/tensor/libtensor/include/kernels/sorting/topk.hpp

Lines changed: 152 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535

3636
#include "kernels/dpctl_tensor_types.hpp"
3737
#include "merge_sort.hpp"
38+
#include "radix_sort.hpp"
3839
#include "utils/sycl_alloc_utils.hpp"
3940
#include <sycl/ext/oneapi/sub_group_mask.hpp>
4041

@@ -70,31 +71,25 @@ void scale_topk_params(const std::uint64_t nelems_per_slm,
7071
} // namespace topk_detail
7172

7273
template <typename T1, typename T2, typename T3>
73-
class populate_index_data_full_sort_krn;
74+
class topk_populate_index_data_krn;
7475

7576
template <typename T1, typename T2, typename T3>
76-
class topk_map_to_rows_full_sort_krn;
77-
78-
template <typename T1, typename T2, typename T3> class populate_index_data_krn;
79-
80-
template <typename T1, typename T2, typename T3> class topk_map_to_rows_krn;
77+
class topk_full_merge_map_back_krn;
8178

8279
template <typename argTy, typename IndexTy, typename CompT>
83-
sycl::event topk_full_sort_impl(
84-
sycl::queue &exec_q,
85-
std::size_t iter_nelems, // number of sub-arrays to sort (num. of rows in a
86-
// matrix when sorting over rows)
87-
std::size_t sort_nelems, // size of each array to sort (length of rows,
88-
// i.e. number of columns)
89-
std::size_t k,
90-
const argTy *arg_tp,
91-
argTy *vals_tp,
92-
IndexTy *inds_tp,
93-
const CompT &comp,
94-
const std::vector<sycl::event> &depends)
80+
sycl::event
81+
topk_full_merge_sort_impl(sycl::queue &exec_q,
82+
std::size_t iter_nelems, // number of sub-arrays
83+
std::size_t axis_nelems, // size of each sub-array
84+
std::size_t k,
85+
const argTy *arg_tp,
86+
argTy *vals_tp,
87+
IndexTy *inds_tp,
88+
const CompT &comp,
89+
const std::vector<sycl::event> &depends)
9590
{
9691
IndexTy *index_data =
97-
sycl::malloc_device<IndexTy>(iter_nelems * sort_nelems, exec_q);
92+
sycl::malloc_device<IndexTy>(iter_nelems * axis_nelems, exec_q);
9893
if (index_data == nullptr) {
9994
throw std::runtime_error("Unable to allocate device_memory");
10095
}
@@ -103,10 +98,10 @@ sycl::event topk_full_sort_impl(
10398
exec_q.submit([&](sycl::handler &cgh) {
10499
cgh.depends_on(depends);
105100

106-
auto const &range = sycl::range<1>(iter_nelems * sort_nelems);
101+
auto const &range = sycl::range<1>(iter_nelems * axis_nelems);
107102

108103
using KernelName =
109-
populate_index_data_full_sort_krn<argTy, IndexTy, CompT>;
104+
topk_populate_index_data_krn<argTy, IndexTy, CompT>;
110105

111106
cgh.parallel_for<KernelName>(range, [=](sycl::id<1> id) {
112107
std::size_t i = id[0];
@@ -118,34 +113,33 @@ sycl::event topk_full_sort_impl(
118113
// Sort segments of the array
119114
sycl::event base_sort_ev =
120115
merge_sort_detail::sort_over_work_group_contig_impl(
121-
exec_q, iter_nelems, sort_nelems, index_data, index_data, comp,
116+
exec_q, iter_nelems, axis_nelems, index_data, index_data, comp,
122117
sorted_block_size, // modified in place with size of sorted block
123118
// size
124119
{populate_indexed_data_ev});
125120

126121
// Merge segments in parallel until all elements are sorted
127122
sycl::event merges_ev = merge_sort_detail::merge_sorted_block_contig_impl(
128-
exec_q, iter_nelems, sort_nelems, index_data, comp, sorted_block_size,
123+
exec_q, iter_nelems, axis_nelems, index_data, comp, sorted_block_size,
129124
{base_sort_ev});
130125

131126
sycl::event write_out_ev = exec_q.submit([&](sycl::handler &cgh) {
132127
cgh.depends_on(merges_ev);
133128

134-
using KernelName =
135-
topk_map_to_rows_full_sort_krn<argTy, IndexTy, CompT>;
129+
using KernelName = topk_full_merge_map_back_krn<argTy, IndexTy, CompT>;
136130

137131
cgh.parallel_for<KernelName>(iter_nelems * k, [=](sycl::id<1> id) {
138132
std::size_t gid = id[0];
139133

140134
std::size_t iter_gid = gid / k;
141135
std::size_t axis_gid = gid - (iter_gid * k);
142136

143-
std::size_t src_idx = iter_gid * sort_nelems + axis_gid;
137+
std::size_t src_idx = iter_gid * axis_nelems + axis_gid;
144138
std::size_t dst_idx = iter_gid * k + axis_gid;
145139

146140
auto res_ind = index_data[src_idx];
147141
vals_tp[dst_idx] = arg_tp[res_ind];
148-
inds_tp[dst_idx] = res_ind % sort_nelems;
142+
inds_tp[dst_idx] = res_ind % axis_nelems;
149143
});
150144
});
151145

@@ -162,29 +156,32 @@ sycl::event topk_full_sort_impl(
162156
return cleanup_host_task_event;
163157
};
164158

159+
template <typename T1, typename T2, typename T3>
160+
class topk_partial_merge_map_back_krn;
161+
165162
template <typename T1, typename T2, typename Comp>
166163
class topk_over_work_group_krn;
167164

168165
template <typename argTy,
169166
typename IndexTy,
170167
typename ValueComp = std::less<argTy>>
171-
sycl::event
172-
topk_impl(sycl::queue &exec_q,
173-
std::size_t iter_nelems, // number of sub-arrays to sort (num. of rows
174-
// in a matrix when sorting over rows)
175-
std::size_t axis_nelems, // size of each array to sort (length of
176-
// rows, i.e. number of columns)
177-
std::size_t k,
178-
const char *arg_cp,
179-
char *vals_cp,
180-
char *inds_cp,
181-
dpctl::tensor::ssize_t iter_arg_offset,
182-
dpctl::tensor::ssize_t iter_vals_offset,
183-
dpctl::tensor::ssize_t iter_inds_offset,
184-
dpctl::tensor::ssize_t axis_arg_offset,
185-
dpctl::tensor::ssize_t axis_vals_offset,
186-
dpctl::tensor::ssize_t axis_inds_offset,
187-
const std::vector<sycl::event> &depends)
168+
sycl::event topk_merge_impl(
169+
sycl::queue &exec_q,
170+
std::size_t iter_nelems, // number of sub-arrays to sort (num. of rows
171+
// in a matrix when sorting over rows)
172+
std::size_t axis_nelems, // size of each array to sort (length of
173+
// rows, i.e. number of columns)
174+
std::size_t k,
175+
const char *arg_cp,
176+
char *vals_cp,
177+
char *inds_cp,
178+
dpctl::tensor::ssize_t iter_arg_offset,
179+
dpctl::tensor::ssize_t iter_vals_offset,
180+
dpctl::tensor::ssize_t iter_inds_offset,
181+
dpctl::tensor::ssize_t axis_arg_offset,
182+
dpctl::tensor::ssize_t axis_vals_offset,
183+
dpctl::tensor::ssize_t axis_inds_offset,
184+
const std::vector<sycl::event> &depends)
188185
{
189186
if (axis_nelems < k) {
190187
throw std::runtime_error("Invalid sort axis size for value of k");
@@ -201,8 +198,9 @@ topk_impl(sycl::queue &exec_q,
201198
const IndexComp<IndexTy, argTy, ValueComp> index_comp{arg_tp, ValueComp{}};
202199

203200
if (axis_nelems <= 512 || k >= 1024 || k > axis_nelems / 2) {
204-
return topk_full_sort_impl(exec_q, iter_nelems, axis_nelems, k, arg_tp,
205-
vals_tp, inds_tp, index_comp, depends);
201+
return topk_full_merge_sort_impl(exec_q, iter_nelems, axis_nelems, k,
202+
arg_tp, vals_tp, inds_tp, index_comp,
203+
depends);
206204
}
207205
else {
208206
using PartialKernelName =
@@ -269,9 +267,9 @@ topk_impl(sycl::queue &exec_q,
269267
if (k_rounded >= axis_nelems || k_rounded >= sorted_block_size ||
270268
alloc_len >= axis_nelems / 2)
271269
{
272-
return topk_full_sort_impl(exec_q, iter_nelems, axis_nelems, k,
273-
arg_tp, vals_tp, inds_tp, index_comp,
274-
depends);
270+
return topk_full_merge_sort_impl(exec_q, iter_nelems, axis_nelems,
271+
k, arg_tp, vals_tp, inds_tp,
272+
index_comp, depends);
275273
}
276274

277275
IndexTy *index_data =
@@ -399,7 +397,8 @@ topk_impl(sycl::queue &exec_q,
399397
sycl::event write_topk_ev = exec_q.submit([&](sycl::handler &cgh) {
400398
cgh.depends_on(merges_ev);
401399

402-
using KernelName = topk_map_to_rows_krn<argTy, IndexTy, ValueComp>;
400+
using KernelName =
401+
topk_partial_merge_map_back_krn<argTy, IndexTy, ValueComp>;
403402

404403
cgh.parallel_for<KernelName>(iter_nelems * k, [=](sycl::id<1> id) {
405404
std::size_t gid = id[0];
@@ -430,6 +429,109 @@ topk_impl(sycl::queue &exec_q,
430429
}
431430
}
432431

432+
template <typename T1, typename T2> class topk_iota_krn;
433+
434+
template <typename T1, typename T2> class topk_radix_map_back_krn;
435+
436+
template <typename argTy, typename IndexTy>
437+
sycl::event topk_radix_impl(sycl::queue &exec_q,
438+
std::size_t iter_nelems, // number of sub-arrays
439+
std::size_t axis_nelems, // size of each sub-array
440+
std::size_t k,
441+
bool ascending,
442+
const char *arg_cp,
443+
char *vals_cp,
444+
char *inds_cp,
445+
dpctl::tensor::ssize_t iter_arg_offset,
446+
dpctl::tensor::ssize_t iter_vals_offset,
447+
dpctl::tensor::ssize_t iter_inds_offset,
448+
dpctl::tensor::ssize_t axis_arg_offset,
449+
dpctl::tensor::ssize_t axis_vals_offset,
450+
dpctl::tensor::ssize_t axis_inds_offset,
451+
const std::vector<sycl::event> &depends)
452+
{
453+
if (axis_nelems < k) {
454+
throw std::runtime_error("Invalid sort axis size for value of k");
455+
}
456+
457+
const argTy *arg_tp = reinterpret_cast<const argTy *>(arg_cp) +
458+
iter_arg_offset + axis_arg_offset;
459+
argTy *vals_tp = reinterpret_cast<argTy *>(vals_cp) + iter_vals_offset +
460+
axis_vals_offset;
461+
IndexTy *inds_tp = reinterpret_cast<IndexTy *>(inds_cp) + iter_inds_offset +
462+
axis_inds_offset;
463+
464+
const std::size_t total_nelems = iter_nelems * axis_nelems;
465+
const std::size_t padded_total_nelems = ((total_nelems + 63) / 64) * 64;
466+
IndexTy *workspace = sycl::malloc_device<IndexTy>(
467+
padded_total_nelems + total_nelems, exec_q);
468+
469+
IndexTy *tmp_tp = sycl::malloc_device<IndexTy>(total_nelems, exec_q);
470+
471+
if (nullptr == workspace || nullptr == tmp_tp) {
472+
throw std::runtime_error(
473+
"Not enough device memory for radix sort topk");
474+
}
475+
476+
using IdentityProjT = radix_sort_details::IdentityProj;
477+
using IndexedProjT =
478+
radix_sort_details::IndexedProj<IndexTy, argTy, IdentityProjT>;
479+
const IndexedProjT proj_op{arg_tp, IdentityProjT{}};
480+
481+
sycl::event iota_ev = exec_q.submit([&](sycl::handler &cgh) {
482+
cgh.depends_on(depends);
483+
484+
using KernelName = topk_iota_krn<argTy, IndexTy>;
485+
486+
cgh.parallel_for<KernelName>(
487+
sycl::range<1>(total_nelems), [=](sycl::id<1> id) {
488+
size_t i = id[0];
489+
IndexTy sort_id = static_cast<IndexTy>(i);
490+
workspace[i] = sort_id;
491+
});
492+
});
493+
494+
sycl::event radix_sort_ev =
495+
radix_sort_details::parallel_radix_sort_impl<IndexTy, IndexedProjT>(
496+
exec_q, iter_nelems, axis_nelems, workspace, tmp_tp, proj_op,
497+
ascending, {iota_ev});
498+
499+
// Write out top k of the temporary
500+
sycl::event write_topk_ev = exec_q.submit([&](sycl::handler &cgh) {
501+
cgh.depends_on(radix_sort_ev);
502+
503+
using KernelName = topk_radix_map_back_krn<argTy, IndexTy>;
504+
505+
cgh.parallel_for<KernelName>(iter_nelems * k, [=](sycl::id<1> id) {
506+
std::size_t gid = id[0];
507+
508+
std::size_t iter_gid = gid / k;
509+
std::size_t axis_gid = gid - (iter_gid * k);
510+
511+
std::size_t src_idx = iter_gid * axis_nelems + axis_gid;
512+
std::size_t dst_idx = iter_gid * k + axis_gid;
513+
514+
IndexTy res_ind = tmp_tp[src_idx];
515+
vals_tp[dst_idx] = arg_tp[res_ind];
516+
inds_tp[dst_idx] = res_ind % axis_nelems;
517+
});
518+
});
519+
520+
sycl::event cleanup_ev = exec_q.submit([&](sycl::handler &cgh) {
521+
cgh.depends_on(write_topk_ev);
522+
523+
const sycl::context &ctx = exec_q.get_context();
524+
525+
using dpctl::tensor::alloc_utils::sycl_free_noexcept;
526+
cgh.host_task([ctx, workspace, tmp_tp] {
527+
sycl_free_noexcept(workspace, ctx);
528+
sycl_free_noexcept(tmp_tp, ctx);
529+
});
530+
});
531+
532+
return cleanup_ev;
533+
}
534+
433535
} // end of namespace kernels
434536
} // end of namespace tensor
435537
} // end of namespace dpctl

dpctl/tensor/libtensor/source/sorting/topk.cpp

Lines changed: 41 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,23 @@ static topk_impl_fn_ptr_t topk_dispatch_vector[td_ns::num_types];
7676
namespace
7777
{
7878

79+
template <typename T, typename = void>
80+
struct use_radix_sort : public std::false_type
81+
{
82+
};
83+
84+
template <typename T>
85+
struct use_radix_sort<
86+
T,
87+
std::enable_if_t<std::disjunction<std::is_same<T, bool>,
88+
std::is_same<T, std::uint8_t>,
89+
std::is_same<T, std::int8_t>,
90+
std::is_same<T, std::uint16_t>,
91+
std::is_same<T, std::int16_t>>::value>>
92+
: public std::true_type
93+
{
94+
};
95+
7996
template <typename argTy, typename IndexTy>
8097
sycl::event
8198
topk_caller(sycl::queue &exec_q,
@@ -96,22 +113,33 @@ topk_caller(sycl::queue &exec_q,
96113
py::ssize_t axis_inds_offset,
97114
const std::vector<sycl::event> &depends)
98115
{
99-
using dpctl::tensor::kernels::topk_impl;
100-
if (largest) {
101-
using CompTy =
102-
typename dpctl::tensor::py_internal::DescendingSorter<argTy>::type;
103-
return topk_impl<argTy, IndexTy, CompTy>(
104-
exec_q, iter_nelems, axis_nelems, k, arg_cp, vals_cp, inds_cp,
105-
iter_arg_offset, iter_vals_offset, iter_inds_offset,
116+
if constexpr (use_radix_sort<argTy>::value) {
117+
using dpctl::tensor::kernels::topk_radix_impl;
118+
auto ascending = !largest;
119+
return topk_radix_impl<argTy, IndexTy>(
120+
exec_q, iter_nelems, axis_nelems, k, ascending, arg_cp, vals_cp,
121+
inds_cp, iter_arg_offset, iter_vals_offset, iter_inds_offset,
106122
axis_arg_offset, axis_vals_offset, axis_inds_offset, depends);
107123
}
108124
else {
109-
using CompTy =
110-
typename dpctl::tensor::py_internal::AscendingSorter<argTy>::type;
111-
return topk_impl<argTy, IndexTy, CompTy>(
112-
exec_q, iter_nelems, axis_nelems, k, arg_cp, vals_cp, inds_cp,
113-
iter_arg_offset, iter_vals_offset, iter_inds_offset,
114-
axis_arg_offset, axis_vals_offset, axis_inds_offset, depends);
125+
using dpctl::tensor::kernels::topk_merge_impl;
126+
if (largest) {
127+
using CompTy =
128+
typename dpctl::tensor::py_internal::DescendingSorter<
129+
argTy>::type;
130+
return topk_merge_impl<argTy, IndexTy, CompTy>(
131+
exec_q, iter_nelems, axis_nelems, k, arg_cp, vals_cp, inds_cp,
132+
iter_arg_offset, iter_vals_offset, iter_inds_offset,
133+
axis_arg_offset, axis_vals_offset, axis_inds_offset, depends);
134+
}
135+
else {
136+
using CompTy = typename dpctl::tensor::py_internal::AscendingSorter<
137+
argTy>::type;
138+
return topk_merge_impl<argTy, IndexTy, CompTy>(
139+
exec_q, iter_nelems, axis_nelems, k, arg_cp, vals_cp, inds_cp,
140+
iter_arg_offset, iter_vals_offset, iter_inds_offset,
141+
axis_arg_offset, axis_vals_offset, axis_inds_offset, depends);
142+
}
115143
}
116144
}
117145

0 commit comments

Comments
 (0)