|
36 | 36 | #include "kernels/dpctl_tensor_types.hpp"
|
37 | 37 | #include "merge_sort.hpp"
|
38 | 38 | #include "radix_sort.hpp"
|
| 39 | +#include "search_sorted_detail.hpp" |
39 | 40 | #include "utils/sycl_alloc_utils.hpp"
|
40 | 41 | #include <sycl/ext/oneapi/sub_group_mask.hpp>
|
41 | 42 |
|
@@ -247,14 +248,17 @@ sycl::event topk_merge_impl(
|
247 | 248 | // This assumption permits doing away with using a loop
|
248 | 249 | assert(sorted_block_size % lws == 0);
|
249 | 250 |
|
| 251 | + using search_sorted_detail::quotient_ceil; |
250 | 252 | const std::size_t n_segments =
|
251 |
| - merge_sort_detail::quotient_ceil<std::size_t>(axis_nelems, |
252 |
| - sorted_block_size); |
| 253 | + quotient_ceil<std::size_t>(axis_nelems, sorted_block_size); |
253 | 254 |
|
254 |
| - // round k up for the later merge kernel |
| 255 | + // round k up for the later merge kernel if necessary |
| 256 | + const std::size_t round_k_to = |
| 257 | + merge_sort_detail::get_merge_segment_size(dev); |
255 | 258 | std::size_t k_rounded =
|
256 |
| - merge_sort_detail::quotient_ceil<std::size_t>(k, elems_per_wi) * |
257 |
| - elems_per_wi; |
| 259 | + (k < round_k_to) |
| 260 | + ? k |
| 261 | + : quotient_ceil<std::size_t>(k, round_k_to) * round_k_to; |
258 | 262 |
|
259 | 263 | // get length of tail for alloc size
|
260 | 264 | auto rem = axis_nelems % sorted_block_size;
|
@@ -322,8 +326,7 @@ sycl::event topk_merge_impl(
|
322 | 326 | sycl::group_barrier(it.get_group());
|
323 | 327 |
|
324 | 328 | const std::size_t chunk =
|
325 |
| - merge_sort_detail::quotient_ceil<std::size_t>( |
326 |
| - sorted_block_size, lws); |
| 329 | + quotient_ceil<std::size_t>(sorted_block_size, lws); |
327 | 330 |
|
328 | 331 | const std::size_t chunk_start_idx = lid * chunk;
|
329 | 332 | const std::size_t chunk_end_idx =
|
|
0 commit comments