32
32
#include < stdexcept>
33
33
#include < vector>
34
34
35
+ #include < sycl/ext/oneapi/sub_group_mask.hpp>
35
36
#include < sycl/sycl.hpp>
36
37
37
38
#include " kernels/dpctl_tensor_types.hpp"
40
41
#include " kernels/sorting/search_sorted_detail.hpp"
41
42
#include " kernels/sorting/sort_utils.hpp"
42
43
#include " utils/sycl_alloc_utils.hpp"
43
- #include < sycl/ext/oneapi/sub_group_mask.hpp>
44
44
45
45
namespace dpctl
46
46
{
@@ -134,11 +134,9 @@ sycl::event write_out_impl(sycl::queue &exec_q,
134
134
135
135
} // namespace topk_detail
136
136
137
- template <typename T1, typename T2, typename T3>
138
- class topk_populate_index_data_krn ;
137
+ template <typename T1, typename T2> class topk_populate_index_data_krn ;
139
138
140
- template <typename T1, typename T2, typename T3>
141
- class topk_full_merge_map_back_krn ;
139
+ template <typename T1, typename T2> class topk_full_merge_map_back_krn ;
142
140
143
141
template <typename argTy, typename IndexTy, typename CompT>
144
142
sycl::event
@@ -158,7 +156,7 @@ topk_full_merge_sort_impl(sycl::queue &exec_q,
158
156
// extract USM pointer
159
157
IndexTy *index_data = index_data_owner.get ();
160
158
161
- using IotaKernelName = topk_populate_index_data_krn<argTy, IndexTy, CompT >;
159
+ using IotaKernelName = topk_populate_index_data_krn<argTy, IndexTy>;
162
160
163
161
using dpctl::tensor::kernels::sort_utils_detail::iota_impl;
164
162
@@ -179,8 +177,7 @@ topk_full_merge_sort_impl(sycl::queue &exec_q,
179
177
exec_q, iter_nelems, axis_nelems, index_data, comp, sorted_block_size,
180
178
{base_sort_ev});
181
179
182
- using WriteOutKernelName =
183
- topk_full_merge_map_back_krn<argTy, IndexTy, CompT>;
180
+ using WriteOutKernelName = topk_full_merge_map_back_krn<argTy, IndexTy>;
184
181
185
182
sycl::event write_out_ev =
186
183
topk_detail::write_out_impl<WriteOutKernelName, argTy, IndexTy>(
@@ -194,8 +191,7 @@ topk_full_merge_sort_impl(sycl::queue &exec_q,
194
191
return cleanup_host_task_event;
195
192
};
196
193
197
- template <typename T1, typename T2, typename T3>
198
- class topk_partial_merge_map_back_krn ;
194
+ template <typename T1, typename T2> class topk_partial_merge_map_back_krn ;
199
195
200
196
template <typename T1, typename T2, typename Comp>
201
197
class topk_over_work_group_krn ;
@@ -213,24 +209,15 @@ sycl::event topk_merge_impl(
213
209
const char *arg_cp,
214
210
char *vals_cp,
215
211
char *inds_cp,
216
- dpctl::tensor::ssize_t iter_arg_offset,
217
- dpctl::tensor::ssize_t iter_vals_offset,
218
- dpctl::tensor::ssize_t iter_inds_offset,
219
- dpctl::tensor::ssize_t axis_arg_offset,
220
- dpctl::tensor::ssize_t axis_vals_offset,
221
- dpctl::tensor::ssize_t axis_inds_offset,
222
212
const std::vector<sycl::event> &depends)
223
213
{
224
214
if (axis_nelems < k) {
225
215
throw std::runtime_error (" Invalid sort axis size for value of k" );
226
216
}
227
217
228
- const argTy *arg_tp = reinterpret_cast <const argTy *>(arg_cp) +
229
- iter_arg_offset + axis_arg_offset;
230
- argTy *vals_tp = reinterpret_cast <argTy *>(vals_cp) + iter_vals_offset +
231
- axis_vals_offset;
232
- IndexTy *inds_tp = reinterpret_cast <IndexTy *>(inds_cp) + iter_inds_offset +
233
- axis_inds_offset;
218
+ const argTy *arg_tp = reinterpret_cast <const argTy *>(arg_cp);
219
+ argTy *vals_tp = reinterpret_cast <argTy *>(vals_cp);
220
+ IndexTy *inds_tp = reinterpret_cast <IndexTy *>(inds_cp);
234
221
235
222
using dpctl::tensor::kernels::IndexComp;
236
223
const IndexComp<IndexTy, argTy, ValueComp> index_comp{arg_tp, ValueComp{}};
@@ -434,7 +421,7 @@ sycl::event topk_merge_impl(
434
421
435
422
// Write out top k of the merge-sorted memory
436
423
using WriteOutKernelName =
437
- topk_partial_merge_map_back_krn<argTy, IndexTy, ValueComp >;
424
+ topk_partial_merge_map_back_krn<argTy, IndexTy>;
438
425
439
426
sycl::event write_topk_ev =
440
427
topk_detail::write_out_impl<WriteOutKernelName, argTy, IndexTy>(
@@ -462,24 +449,15 @@ sycl::event topk_radix_impl(sycl::queue &exec_q,
462
449
const char *arg_cp,
463
450
char *vals_cp,
464
451
char *inds_cp,
465
- dpctl::tensor::ssize_t iter_arg_offset,
466
- dpctl::tensor::ssize_t iter_vals_offset,
467
- dpctl::tensor::ssize_t iter_inds_offset,
468
- dpctl::tensor::ssize_t axis_arg_offset,
469
- dpctl::tensor::ssize_t axis_vals_offset,
470
- dpctl::tensor::ssize_t axis_inds_offset,
471
452
const std::vector<sycl::event> &depends)
472
453
{
473
454
if (axis_nelems < k) {
474
455
throw std::runtime_error (" Invalid sort axis size for value of k" );
475
456
}
476
457
477
- const argTy *arg_tp = reinterpret_cast <const argTy *>(arg_cp) +
478
- iter_arg_offset + axis_arg_offset;
479
- argTy *vals_tp = reinterpret_cast <argTy *>(vals_cp) + iter_vals_offset +
480
- axis_vals_offset;
481
- IndexTy *inds_tp = reinterpret_cast <IndexTy *>(inds_cp) + iter_inds_offset +
482
- axis_inds_offset;
458
+ const argTy *arg_tp = reinterpret_cast <const argTy *>(arg_cp);
459
+ argTy *vals_tp = reinterpret_cast <argTy *>(vals_cp);
460
+ IndexTy *inds_tp = reinterpret_cast <IndexTy *>(inds_cp);
483
461
484
462
const std::size_t total_nelems = iter_nelems * axis_nelems;
485
463
const std::size_t padded_total_nelems = ((total_nelems + 63 ) / 64 ) * 64 ;
@@ -494,7 +472,7 @@ sycl::event topk_radix_impl(sycl::queue &exec_q,
494
472
using IdentityProjT = radix_sort_details::IdentityProj;
495
473
using IndexedProjT =
496
474
radix_sort_details::IndexedProj<IndexTy, argTy, IdentityProjT>;
497
- const IndexedProjT proj_op{arg_tp, IdentityProjT{} };
475
+ const IndexedProjT proj_op{arg_tp};
498
476
499
477
using IotaKernelName = topk_iota_krn<argTy, IndexTy>;
500
478
0 commit comments