Skip to content

Commit 8bcb100

Browse files
committed
Fix bug in top_k partial merge sort implementation
rounded value of k must be divisible by the merge sort chunk size
1 parent 5b0b80f commit 8bcb100

File tree

2 files changed

+16
-8
lines changed

2 files changed

+16
-8
lines changed

dpctl/tensor/libtensor/include/kernels/sorting/merge_sort.hpp

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,11 @@ namespace merge_sort_detail
4646

4747
using namespace dpctl::tensor::kernels::search_sorted_detail;
4848

49+
size_t get_merge_segment_size(const sycl::device &dev)
50+
{
51+
return dev.has(sycl::aspect::cpu) ? 32 : 4;
52+
}
53+
4954
/*! @brief Merge two contiguous sorted segments */
5055
template <typename InAcc, typename OutAcc, typename Compare>
5156
void merge_impl(const std::size_t offset,
@@ -580,7 +585,7 @@ merge_sorted_block_contig_impl(sycl::queue &q,
580585
// experimentally determined value
581586
// size of segments worked upon by each work-item during merging
582587
const sycl::device &dev = q.get_device();
583-
const size_t segment_size = (dev.has(sycl::aspect::cpu)) ? 32 : 4;
588+
const size_t segment_size = get_merge_segment_size(dev);
584589

585590
const size_t chunk_size =
586591
(sorted_block_size < segment_size) ? sorted_block_size : segment_size;

dpctl/tensor/libtensor/include/kernels/sorting/topk.hpp

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636
#include "kernels/dpctl_tensor_types.hpp"
3737
#include "merge_sort.hpp"
3838
#include "radix_sort.hpp"
39+
#include "search_sorted_detail.hpp"
3940
#include "utils/sycl_alloc_utils.hpp"
4041
#include <sycl/ext/oneapi/sub_group_mask.hpp>
4142

@@ -247,14 +248,17 @@ sycl::event topk_merge_impl(
247248
// This assumption permits doing away with using a loop
248249
assert(sorted_block_size % lws == 0);
249250

251+
using search_sorted_detail::quotient_ceil;
250252
const std::size_t n_segments =
251-
merge_sort_detail::quotient_ceil<std::size_t>(axis_nelems,
252-
sorted_block_size);
253+
quotient_ceil<std::size_t>(axis_nelems, sorted_block_size);
253254

254-
// round k up for the later merge kernel
255+
// round k up for the later merge kernel if necessary
256+
const std::size_t round_k_to =
257+
merge_sort_detail::get_merge_segment_size(dev);
255258
std::size_t k_rounded =
256-
merge_sort_detail::quotient_ceil<std::size_t>(k, elems_per_wi) *
257-
elems_per_wi;
259+
(k < round_k_to)
260+
? k
261+
: quotient_ceil<std::size_t>(k, round_k_to) * round_k_to;
258262

259263
// get length of tail for alloc size
260264
auto rem = axis_nelems % sorted_block_size;
@@ -322,8 +326,7 @@ sycl::event topk_merge_impl(
322326
sycl::group_barrier(it.get_group());
323327

324328
const std::size_t chunk =
325-
merge_sort_detail::quotient_ceil<std::size_t>(
326-
sorted_block_size, lws);
329+
quotient_ceil<std::size_t>(sorted_block_size, lws);
327330

328331
const std::size_t chunk_start_idx = lid * chunk;
329332
const std::size_t chunk_end_idx =

0 commit comments

Comments
 (0)