Skip to content

Commit 869faef

Browse files
Reduced number of created iota and map_back kernels
Avoid using comparator type to form kernel name types for iota and map_back kernels (as they do not depedent on comparator). This reduces the number of kernels generated during instantiation of template implementation functions.
1 parent 87052e2 commit 869faef

File tree

4 files changed

+29
-70
lines changed

4 files changed

+29
-70
lines changed

dpctl/tensor/libtensor/include/kernels/sorting/merge_sort.hpp

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -211,8 +211,6 @@ void merge_impl(const std::size_t offset,
211211
}
212212
}
213213

214-
namespace
215-
{
216214
template <typename Iter, typename Compare>
217215
void insertion_sort_impl(Iter first,
218216
const std::size_t begin,
@@ -259,7 +257,6 @@ void leaf_sort_impl(Iter first,
259257
return insertion_sort_impl<Iter, Compare>(
260258
std::move(first), std::move(begin), std::move(end), std::move(comp));
261259
}
262-
} // namespace
263260

264261
template <typename Iter> struct GetValueType
265262
{
@@ -768,9 +765,9 @@ sycl::event stable_sort_axis1_contig_impl(
768765
}
769766
}
770767

771-
template <typename T1, typename T2, typename T3> class populate_index_data_krn;
768+
template <typename T1, typename T2> class populate_index_data_krn;
772769

773-
template <typename T1, typename T2, typename T3> class index_map_to_rows_krn;
770+
template <typename T1, typename T2> class index_map_to_rows_krn;
774771

775772
template <typename IndexT, typename ValueT, typename ValueComp> struct IndexComp
776773
{
@@ -820,7 +817,7 @@ sycl::event stable_argsort_axis1_contig_impl(
820817

821818
using dpctl::tensor::kernels::sort_utils_detail::iota_impl;
822819

823-
using IotaKernelName = populate_index_data_krn<argTy, IndexTy, ValueComp>;
820+
using IotaKernelName = populate_index_data_krn<argTy, IndexTy>;
824821

825822
sycl::event populate_indexed_data_ev = iota_impl<IotaKernelName, IndexTy>(
826823
exec_q, res_tp, total_nelems, depends);
@@ -838,7 +835,7 @@ sycl::event stable_argsort_axis1_contig_impl(
838835
exec_q, iter_nelems, sort_nelems, res_tp, index_comp, sorted_block_size,
839836
{base_sort_ev});
840837

841-
using MapBackKernelName = index_map_to_rows_krn<argTy, IndexTy, ValueComp>;
838+
using MapBackKernelName = index_map_to_rows_krn<argTy, IndexTy>;
842839
using dpctl::tensor::kernels::sort_utils_detail::map_back_impl;
843840

844841
sycl::event write_out_ev = map_back_impl<MapBackKernelName, IndexTy>(

dpctl/tensor/libtensor/include/kernels/sorting/radix_sort.hpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1759,6 +1759,8 @@ template <typename ValueT, typename IndexT> struct ValueProj
17591759

17601760
template <typename IndexT, typename ValueT, typename ProjT> struct IndexedProj
17611761
{
1762+
IndexedProj(const ValueT *arg_ptr) : ptr(arg_ptr), value_projector{} {}
1763+
17621764
IndexedProj(const ValueT *arg_ptr, const ProjT &proj_op)
17631765
: ptr(arg_ptr), value_projector(proj_op)
17641766
{
@@ -1848,7 +1850,7 @@ radix_argsort_axis1_contig_impl(sycl::queue &exec_q,
18481850
using IdentityProjT = radix_sort_details::IdentityProj;
18491851
using IndexedProjT =
18501852
radix_sort_details::IndexedProj<IndexTy, argTy, IdentityProjT>;
1851-
const IndexedProjT proj_op{arg_tp, IdentityProjT{}};
1853+
const IndexedProjT proj_op{arg_tp};
18521854

18531855
using IotaKernelName = radix_argsort_iota_krn<argTy, IndexTy>;
18541856

dpctl/tensor/libtensor/include/kernels/sorting/topk.hpp

Lines changed: 14 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
#include <stdexcept>
3333
#include <vector>
3434

35+
#include <sycl/ext/oneapi/sub_group_mask.hpp>
3536
#include <sycl/sycl.hpp>
3637

3738
#include "kernels/dpctl_tensor_types.hpp"
@@ -40,7 +41,6 @@
4041
#include "kernels/sorting/search_sorted_detail.hpp"
4142
#include "kernels/sorting/sort_utils.hpp"
4243
#include "utils/sycl_alloc_utils.hpp"
43-
#include <sycl/ext/oneapi/sub_group_mask.hpp>
4444

4545
namespace dpctl
4646
{
@@ -134,11 +134,9 @@ sycl::event write_out_impl(sycl::queue &exec_q,
134134

135135
} // namespace topk_detail
136136

137-
template <typename T1, typename T2, typename T3>
138-
class topk_populate_index_data_krn;
137+
template <typename T1, typename T2> class topk_populate_index_data_krn;
139138

140-
template <typename T1, typename T2, typename T3>
141-
class topk_full_merge_map_back_krn;
139+
template <typename T1, typename T2> class topk_full_merge_map_back_krn;
142140

143141
template <typename argTy, typename IndexTy, typename CompT>
144142
sycl::event
@@ -158,7 +156,7 @@ topk_full_merge_sort_impl(sycl::queue &exec_q,
158156
// extract USM pointer
159157
IndexTy *index_data = index_data_owner.get();
160158

161-
using IotaKernelName = topk_populate_index_data_krn<argTy, IndexTy, CompT>;
159+
using IotaKernelName = topk_populate_index_data_krn<argTy, IndexTy>;
162160

163161
using dpctl::tensor::kernels::sort_utils_detail::iota_impl;
164162

@@ -179,8 +177,7 @@ topk_full_merge_sort_impl(sycl::queue &exec_q,
179177
exec_q, iter_nelems, axis_nelems, index_data, comp, sorted_block_size,
180178
{base_sort_ev});
181179

182-
using WriteOutKernelName =
183-
topk_full_merge_map_back_krn<argTy, IndexTy, CompT>;
180+
using WriteOutKernelName = topk_full_merge_map_back_krn<argTy, IndexTy>;
184181

185182
sycl::event write_out_ev =
186183
topk_detail::write_out_impl<WriteOutKernelName, argTy, IndexTy>(
@@ -194,8 +191,7 @@ topk_full_merge_sort_impl(sycl::queue &exec_q,
194191
return cleanup_host_task_event;
195192
};
196193

197-
template <typename T1, typename T2, typename T3>
198-
class topk_partial_merge_map_back_krn;
194+
template <typename T1, typename T2> class topk_partial_merge_map_back_krn;
199195

200196
template <typename T1, typename T2, typename Comp>
201197
class topk_over_work_group_krn;
@@ -213,24 +209,15 @@ sycl::event topk_merge_impl(
213209
const char *arg_cp,
214210
char *vals_cp,
215211
char *inds_cp,
216-
dpctl::tensor::ssize_t iter_arg_offset,
217-
dpctl::tensor::ssize_t iter_vals_offset,
218-
dpctl::tensor::ssize_t iter_inds_offset,
219-
dpctl::tensor::ssize_t axis_arg_offset,
220-
dpctl::tensor::ssize_t axis_vals_offset,
221-
dpctl::tensor::ssize_t axis_inds_offset,
222212
const std::vector<sycl::event> &depends)
223213
{
224214
if (axis_nelems < k) {
225215
throw std::runtime_error("Invalid sort axis size for value of k");
226216
}
227217

228-
const argTy *arg_tp = reinterpret_cast<const argTy *>(arg_cp) +
229-
iter_arg_offset + axis_arg_offset;
230-
argTy *vals_tp = reinterpret_cast<argTy *>(vals_cp) + iter_vals_offset +
231-
axis_vals_offset;
232-
IndexTy *inds_tp = reinterpret_cast<IndexTy *>(inds_cp) + iter_inds_offset +
233-
axis_inds_offset;
218+
const argTy *arg_tp = reinterpret_cast<const argTy *>(arg_cp);
219+
argTy *vals_tp = reinterpret_cast<argTy *>(vals_cp);
220+
IndexTy *inds_tp = reinterpret_cast<IndexTy *>(inds_cp);
234221

235222
using dpctl::tensor::kernels::IndexComp;
236223
const IndexComp<IndexTy, argTy, ValueComp> index_comp{arg_tp, ValueComp{}};
@@ -434,7 +421,7 @@ sycl::event topk_merge_impl(
434421

435422
// Write out top k of the merge-sorted memory
436423
using WriteOutKernelName =
437-
topk_partial_merge_map_back_krn<argTy, IndexTy, ValueComp>;
424+
topk_partial_merge_map_back_krn<argTy, IndexTy>;
438425

439426
sycl::event write_topk_ev =
440427
topk_detail::write_out_impl<WriteOutKernelName, argTy, IndexTy>(
@@ -462,24 +449,15 @@ sycl::event topk_radix_impl(sycl::queue &exec_q,
462449
const char *arg_cp,
463450
char *vals_cp,
464451
char *inds_cp,
465-
dpctl::tensor::ssize_t iter_arg_offset,
466-
dpctl::tensor::ssize_t iter_vals_offset,
467-
dpctl::tensor::ssize_t iter_inds_offset,
468-
dpctl::tensor::ssize_t axis_arg_offset,
469-
dpctl::tensor::ssize_t axis_vals_offset,
470-
dpctl::tensor::ssize_t axis_inds_offset,
471452
const std::vector<sycl::event> &depends)
472453
{
473454
if (axis_nelems < k) {
474455
throw std::runtime_error("Invalid sort axis size for value of k");
475456
}
476457

477-
const argTy *arg_tp = reinterpret_cast<const argTy *>(arg_cp) +
478-
iter_arg_offset + axis_arg_offset;
479-
argTy *vals_tp = reinterpret_cast<argTy *>(vals_cp) + iter_vals_offset +
480-
axis_vals_offset;
481-
IndexTy *inds_tp = reinterpret_cast<IndexTy *>(inds_cp) + iter_inds_offset +
482-
axis_inds_offset;
458+
const argTy *arg_tp = reinterpret_cast<const argTy *>(arg_cp);
459+
argTy *vals_tp = reinterpret_cast<argTy *>(vals_cp);
460+
IndexTy *inds_tp = reinterpret_cast<IndexTy *>(inds_cp);
483461

484462
const std::size_t total_nelems = iter_nelems * axis_nelems;
485463
const std::size_t padded_total_nelems = ((total_nelems + 63) / 64) * 64;
@@ -494,7 +472,7 @@ sycl::event topk_radix_impl(sycl::queue &exec_q,
494472
using IdentityProjT = radix_sort_details::IdentityProj;
495473
using IndexedProjT =
496474
radix_sort_details::IndexedProj<IndexTy, argTy, IdentityProjT>;
497-
const IndexedProjT proj_op{arg_tp, IdentityProjT{}};
475+
const IndexedProjT proj_op{arg_tp};
498476

499477
using IotaKernelName = topk_iota_krn<argTy, IndexTy>;
500478

dpctl/tensor/libtensor/source/sorting/topk.cpp

Lines changed: 8 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -46,15 +46,15 @@
4646
#include "rich_comparisons.hpp"
4747
#include "topk.hpp"
4848

49-
namespace td_ns = dpctl::tensor::type_dispatch;
50-
5149
namespace dpctl
5250
{
5351
namespace tensor
5452
{
5553
namespace py_internal
5654
{
5755

56+
namespace td_ns = dpctl::tensor::type_dispatch;
57+
5858
typedef sycl::event (*topk_impl_fn_ptr_t)(sycl::queue &,
5959
std::size_t,
6060
std::size_t,
@@ -63,12 +63,6 @@ typedef sycl::event (*topk_impl_fn_ptr_t)(sycl::queue &,
6363
const char *,
6464
char *,
6565
char *,
66-
py::ssize_t,
67-
py::ssize_t,
68-
py::ssize_t,
69-
py::ssize_t,
70-
py::ssize_t,
71-
py::ssize_t,
7266
const std::vector<sycl::event> &);
7367

7468
static topk_impl_fn_ptr_t topk_dispatch_vector[td_ns::num_types];
@@ -102,21 +96,14 @@ sycl::event topk_caller(sycl::queue &exec_q,
10296
const char *arg_cp,
10397
char *vals_cp,
10498
char *inds_cp,
105-
py::ssize_t iter_arg_offset,
106-
py::ssize_t iter_vals_offset,
107-
py::ssize_t iter_inds_offset,
108-
py::ssize_t axis_arg_offset,
109-
py::ssize_t axis_vals_offset,
110-
py::ssize_t axis_inds_offset,
11199
const std::vector<sycl::event> &depends)
112100
{
113101
if constexpr (use_radix_sort<argTy>::value) {
114102
using dpctl::tensor::kernels::topk_radix_impl;
115103
auto ascending = !largest;
116-
return topk_radix_impl<argTy, IndexTy>(
117-
exec_q, iter_nelems, axis_nelems, k, ascending, arg_cp, vals_cp,
118-
inds_cp, iter_arg_offset, iter_vals_offset, iter_inds_offset,
119-
axis_arg_offset, axis_vals_offset, axis_inds_offset, depends);
104+
return topk_radix_impl<argTy, IndexTy>(exec_q, iter_nelems, axis_nelems,
105+
k, ascending, arg_cp, vals_cp,
106+
inds_cp, depends);
120107
}
121108
else {
122109
using dpctl::tensor::kernels::topk_merge_impl;
@@ -126,16 +113,14 @@ sycl::event topk_caller(sycl::queue &exec_q,
126113
argTy>::type;
127114
return topk_merge_impl<argTy, IndexTy, CompTy>(
128115
exec_q, iter_nelems, axis_nelems, k, arg_cp, vals_cp, inds_cp,
129-
iter_arg_offset, iter_vals_offset, iter_inds_offset,
130-
axis_arg_offset, axis_vals_offset, axis_inds_offset, depends);
116+
depends);
131117
}
132118
else {
133119
using CompTy = typename dpctl::tensor::py_internal::AscendingSorter<
134120
argTy>::type;
135121
return topk_merge_impl<argTy, IndexTy, CompTy>(
136122
exec_q, iter_nelems, axis_nelems, k, arg_cp, vals_cp, inds_cp,
137-
iter_arg_offset, iter_vals_offset, iter_inds_offset,
138-
axis_arg_offset, axis_vals_offset, axis_inds_offset, depends);
123+
depends);
139124
}
140125
}
141126
}
@@ -268,14 +253,11 @@ py_topk(const dpctl::tensor::usm_ndarray &src,
268253
bool is_inds_c_contig = inds.is_c_contiguous();
269254

270255
if (is_src_c_contig && is_vals_c_contig && is_inds_c_contig) {
271-
static constexpr py::ssize_t zero_offset = py::ssize_t(0);
272-
273256
auto fn = topk_dispatch_vector[src_typeid];
274257

275258
sycl::event comp_ev =
276259
fn(exec_q, iter_nelems, axis_nelems, k, largest, src.get_data(),
277-
vals.get_data(), inds.get_data(), zero_offset, zero_offset,
278-
zero_offset, zero_offset, zero_offset, zero_offset, depends);
260+
vals.get_data(), inds.get_data(), depends);
279261

280262
sycl::event keep_args_alive_ev =
281263
dpctl::utils::keep_args_alive(exec_q, {src, vals, inds}, {comp_ev});

0 commit comments

Comments
 (0)