Skip to content

Commit a338383

Browse files
committed
Use std::size_t for k instead of py::ssize_t
Reduces amount of casting. `k` will need to fit in `py::ssize_t` regardless.
1 parent 882c70d commit a338383

File tree

2 files changed

+19
-25
lines changed

2 files changed

+19
-25
lines changed

dpctl/tensor/libtensor/include/kernels/sorting/topk.hpp

Lines changed: 10 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,7 @@ sycl::event topk_full_sort_impl(
8686
// matrix when sorting over rows)
8787
std::size_t sort_nelems, // size of each array to sort (length of rows,
8888
// i.e. number of columns)
89-
dpctl::tensor::ssize_t k,
89+
std::size_t k,
9090
const argTy *arg_tp,
9191
argTy *vals_tp,
9292
IndexTy *inds_tp,
@@ -174,7 +174,7 @@ topk_impl(sycl::queue &exec_q,
174174
// in a matrix when sorting over rows)
175175
std::size_t axis_nelems, // size of each array to sort (length of
176176
// rows, i.e. number of columns)
177-
dpctl::tensor::ssize_t k,
177+
std::size_t k,
178178
const char *arg_cp,
179179
char *vals_cp,
180180
char *inds_cp,
@@ -186,7 +186,7 @@ topk_impl(sycl::queue &exec_q,
186186
dpctl::tensor::ssize_t axis_inds_offset,
187187
const std::vector<sycl::event> &depends)
188188
{
189-
if (axis_nelems < static_cast<std::size_t>(k)) {
189+
if (axis_nelems < k) {
190190
throw std::runtime_error("Invalid sort axis size for value of k");
191191
}
192192

@@ -200,9 +200,7 @@ topk_impl(sycl::queue &exec_q,
200200
using dpctl::tensor::kernels::IndexComp;
201201
const IndexComp<IndexTy, argTy, ValueComp> index_comp{arg_tp, ValueComp{}};
202202

203-
if (axis_nelems <= 512 || k >= 1024 ||
204-
static_cast<std::size_t>(k) > axis_nelems / 2)
205-
{
203+
if (axis_nelems <= 512 || k >= 1024 || k > axis_nelems / 2) {
206204
return topk_full_sort_impl(exec_q, iter_nelems, axis_nelems, k, arg_tp,
207205
vals_tp, inds_tp, index_comp, depends);
208206
}
@@ -256,22 +254,19 @@ topk_impl(sycl::queue &exec_q,
256254
sorted_block_size);
257255

258256
// round k up for the later merge kernel
259-
const dpctl::tensor::ssize_t round_k_to = elems_per_wi;
260-
dpctl::tensor::ssize_t k_rounded =
261-
merge_sort_detail::quotient_ceil<dpctl::tensor::ssize_t>(
262-
k, round_k_to) *
263-
round_k_to;
257+
std::size_t k_rounded =
258+
merge_sort_detail::quotient_ceil<std::size_t>(k, elems_per_wi) *
259+
elems_per_wi;
264260

265261
// get length of tail for alloc size
266262
auto rem = axis_nelems % sorted_block_size;
267-
auto alloc_len = (rem && rem < static_cast<std::size_t>(k_rounded))
263+
auto alloc_len = (rem && rem < k_rounded)
268264
? rem + k_rounded * (n_segments - 1)
269265
: k_rounded * n_segments;
270266

271267
// if allocation would be sufficiently large or k is larger than
272268
// elements processed, use full sort
273-
if (static_cast<std::size_t>(k_rounded) >= axis_nelems ||
274-
static_cast<std::size_t>(k_rounded) >= sorted_block_size ||
269+
if (k_rounded >= axis_nelems || k_rounded >= sorted_block_size ||
275270
alloc_len >= axis_nelems / 2)
276271
{
277272
return topk_full_sort_impl(exec_q, iter_nelems, axis_nelems, k,
@@ -386,7 +381,7 @@ topk_impl(sycl::queue &exec_q,
386381
for (std::size_t array_id = k_segment_start_idx + lid;
387382
array_id < k_segment_end_idx; array_id += lws)
388383
{
389-
if (lid < static_cast<std::size_t>(k_rounded)) {
384+
if (lid < k_rounded) {
390385
index_data[iter_id * alloc_len + array_id] =
391386
out_src[array_id - k_segment_start_idx];
392387
}

dpctl/tensor/libtensor/source/sorting/topk.cpp

Lines changed: 9 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ namespace py_internal
5858
typedef sycl::event (*topk_impl_fn_ptr_t)(sycl::queue &,
5959
std::size_t,
6060
std::size_t,
61-
py::ssize_t,
61+
std::size_t,
6262
bool,
6363
const char *,
6464
char *,
@@ -83,7 +83,7 @@ topk_caller(sycl::queue &exec_q,
8383
// rows in a matrix when sorting over rows)
8484
std::size_t axis_nelems, // size of each array to sort (length of
8585
// rows, i.e. number of columns)
86-
py::ssize_t k,
86+
std::size_t k,
8787
bool largest,
8888
const char *arg_cp,
8989
char *vals_cp,
@@ -120,7 +120,7 @@ topk_caller(sycl::queue &exec_q,
120120
std::pair<sycl::event, sycl::event>
121121
py_topk(const dpctl::tensor::usm_ndarray &src,
122122
const int trailing_dims_to_search,
123-
const py::ssize_t k,
123+
const std::size_t k,
124124
const bool largest,
125125
const dpctl::tensor::usm_ndarray &vals,
126126
const dpctl::tensor::usm_ndarray &inds,
@@ -168,9 +168,7 @@ py_topk(const dpctl::tensor::usm_ndarray &src,
168168
inds_k *= static_cast<std::size_t>(inds_shape_ptr[i]);
169169
}
170170

171-
bool valid_k = (vals_k == static_cast<std::size_t>(k) &&
172-
inds_k == static_cast<std::size_t>(k) &&
173-
axis_nelems >= static_cast<std::size_t>(k));
171+
bool valid_k = (vals_k == k && inds_k == k && axis_nelems >= k);
174172
if (!valid_k) {
175173
throw py::value_error(
176174
"The value of k is invalid for the input and destination arrays");
@@ -243,7 +241,7 @@ py_topk(const dpctl::tensor::usm_ndarray &src,
243241

244242
std::pair<sycl::event, sycl::event>
245243
py_topk(const dpctl::tensor::usm_ndarray &src,
246-
const py::ssize_t k,
244+
const std::size_t k,
247245
const bool largest,
248246
const dpctl::tensor::usm_ndarray &vals,
249247
const dpctl::tensor::usm_ndarray &inds,
@@ -266,8 +264,9 @@ py_topk(const dpctl::tensor::usm_ndarray &src,
266264
axis_nelems *= static_cast<std::size_t>(src_shape_ptr[i]);
267265
}
268266

269-
bool valid_k = (axis_nelems >= static_cast<std::size_t>(k) &&
270-
vals_shape_ptr[0] == k && inds_shape_ptr[0] == k);
267+
bool valid_k =
268+
(axis_nelems >= k && static_cast<std::size_t>(vals_shape_ptr[0]) == k &&
269+
static_cast<std::size_t>(inds_shape_ptr[0]) == k);
271270
if (!valid_k) {
272271
throw py::value_error(
273272
"The value of k is invalid for the input and destination arrays");
@@ -360,7 +359,7 @@ void init_topk_functions(py::module_ m)
360359

361360
auto py_topk = [](const dpctl::tensor::usm_ndarray &src,
362361
std::optional<const int> trailing_dims_to_search,
363-
const py::ssize_t k, const bool largest,
362+
const std::size_t k, const bool largest,
364363
const dpctl::tensor::usm_ndarray &vals,
365364
const dpctl::tensor::usm_ndarray &inds,
366365
sycl::queue &exec_q,

0 commit comments

Comments
 (0)