Skip to content

Commit 8f38b80

Browse files
committed
Fix bug in top_k partial merge sort implementation
rounded value of k must be divisible by the merge sort chunk size
1 parent 5b0b80f commit 8f38b80

File tree

1 file changed

+9
-7
lines changed
  • dpctl/tensor/libtensor/include/kernels/sorting

1 file changed

+9
-7
lines changed

dpctl/tensor/libtensor/include/kernels/sorting/topk.hpp

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636
#include "kernels/dpctl_tensor_types.hpp"
3737
#include "merge_sort.hpp"
3838
#include "radix_sort.hpp"
39+
#include "search_sorted_detail.hpp"
3940
#include "utils/sycl_alloc_utils.hpp"
4041
#include <sycl/ext/oneapi/sub_group_mask.hpp>
4142

@@ -247,14 +248,16 @@ sycl::event topk_merge_impl(
247248
// This assumption permits doing away with using a loop
248249
assert(sorted_block_size % lws == 0);
249250

251+
using search_sorted_detail::quotient_ceil;
250252
const std::size_t n_segments =
251-
merge_sort_detail::quotient_ceil<std::size_t>(axis_nelems,
252-
sorted_block_size);
253+
quotient_ceil<std::size_t>(axis_nelems, sorted_block_size);
253254

254-
// round k up for the later merge kernel
255+
// round k up for the later merge kernel if necessary
256+
const std::size_t round_k_to = dev.has(sycl::aspect::cpu) ? 32 : 4;
255257
std::size_t k_rounded =
256-
merge_sort_detail::quotient_ceil<std::size_t>(k, elems_per_wi) *
257-
elems_per_wi;
258+
(k < round_k_to)
259+
? k
260+
: quotient_ceil<std::size_t>(k, round_k_to) * round_k_to;
258261

259262
// get length of tail for alloc size
260263
auto rem = axis_nelems % sorted_block_size;
@@ -322,8 +325,7 @@ sycl::event topk_merge_impl(
322325
sycl::group_barrier(it.get_group());
323326

324327
const std::size_t chunk =
325-
merge_sort_detail::quotient_ceil<std::size_t>(
326-
sorted_block_size, lws);
328+
quotient_ceil<std::size_t>(sorted_block_size, lws);
327329

328330
const std::size_t chunk_start_idx = lid * chunk;
329331
const std::size_t chunk_end_idx =

0 commit comments

Comments
 (0)