@@ -86,7 +86,7 @@ sycl::event topk_full_sort_impl(
86
86
// matrix when sorting over rows)
87
87
std::size_t sort_nelems, // size of each array to sort (length of rows,
88
88
// i.e. number of columns)
89
- dpctl::tensor:: ssize_t k,
89
+ std:: size_t k,
90
90
const argTy *arg_tp,
91
91
argTy *vals_tp,
92
92
IndexTy *inds_tp,
@@ -174,7 +174,7 @@ topk_impl(sycl::queue &exec_q,
174
174
// in a matrix when sorting over rows)
175
175
std::size_t axis_nelems, // size of each array to sort (length of
176
176
// rows, i.e. number of columns)
177
- dpctl::tensor:: ssize_t k,
177
+ std:: size_t k,
178
178
const char *arg_cp,
179
179
char *vals_cp,
180
180
char *inds_cp,
@@ -186,7 +186,7 @@ topk_impl(sycl::queue &exec_q,
186
186
dpctl::tensor::ssize_t axis_inds_offset,
187
187
const std::vector<sycl::event> &depends)
188
188
{
189
- if (axis_nelems < static_cast <std:: size_t >(k) ) {
189
+ if (axis_nelems < k ) {
190
190
throw std::runtime_error (" Invalid sort axis size for value of k" );
191
191
}
192
192
@@ -200,9 +200,7 @@ topk_impl(sycl::queue &exec_q,
200
200
using dpctl::tensor::kernels::IndexComp;
201
201
const IndexComp<IndexTy, argTy, ValueComp> index_comp{arg_tp, ValueComp{}};
202
202
203
- if (axis_nelems <= 512 || k >= 1024 ||
204
- static_cast <std::size_t >(k) > axis_nelems / 2 )
205
- {
203
+ if (axis_nelems <= 512 || k >= 1024 || k > axis_nelems / 2 ) {
206
204
return topk_full_sort_impl (exec_q, iter_nelems, axis_nelems, k, arg_tp,
207
205
vals_tp, inds_tp, index_comp, depends);
208
206
}
@@ -256,22 +254,19 @@ topk_impl(sycl::queue &exec_q,
256
254
sorted_block_size);
257
255
258
256
// round k up for the later merge kernel
259
- const dpctl::tensor::ssize_t round_k_to = elems_per_wi;
260
- dpctl::tensor::ssize_t k_rounded =
261
- merge_sort_detail::quotient_ceil<dpctl::tensor::ssize_t >(
262
- k, round_k_to) *
263
- round_k_to;
257
+ std::size_t k_rounded =
258
+ merge_sort_detail::quotient_ceil<std::size_t >(k, elems_per_wi) *
259
+ elems_per_wi;
264
260
265
261
// get length of tail for alloc size
266
262
auto rem = axis_nelems % sorted_block_size;
267
- auto alloc_len = (rem && rem < static_cast <std:: size_t >( k_rounded) )
263
+ auto alloc_len = (rem && rem < k_rounded)
268
264
? rem + k_rounded * (n_segments - 1 )
269
265
: k_rounded * n_segments;
270
266
271
267
// if allocation would be sufficiently large or k is larger than
272
268
// elements processed, use full sort
273
- if (static_cast <std::size_t >(k_rounded) >= axis_nelems ||
274
- static_cast <std::size_t >(k_rounded) >= sorted_block_size ||
269
+ if (k_rounded >= axis_nelems || k_rounded >= sorted_block_size ||
275
270
alloc_len >= axis_nelems / 2 )
276
271
{
277
272
return topk_full_sort_impl (exec_q, iter_nelems, axis_nelems, k,
@@ -386,7 +381,7 @@ topk_impl(sycl::queue &exec_q,
386
381
for (std::size_t array_id = k_segment_start_idx + lid;
387
382
array_id < k_segment_end_idx; array_id += lws)
388
383
{
389
- if (lid < static_cast <std:: size_t >( k_rounded) ) {
384
+ if (lid < k_rounded) {
390
385
index_data[iter_id * alloc_len + array_id] =
391
386
out_src[array_id - k_segment_start_idx];
392
387
}
0 commit comments