35
35
36
36
#include " kernels/dpctl_tensor_types.hpp"
37
37
#include " merge_sort.hpp"
38
+ #include " radix_sort.hpp"
38
39
#include " utils/sycl_alloc_utils.hpp"
39
40
#include < sycl/ext/oneapi/sub_group_mask.hpp>
40
41
@@ -70,31 +71,25 @@ void scale_topk_params(const std::uint64_t nelems_per_slm,
70
71
} // namespace topk_detail
71
72
72
73
template <typename T1, typename T2, typename T3>
73
- class populate_index_data_full_sort_krn ;
74
+ class topk_populate_index_data_krn ;
74
75
75
76
template <typename T1, typename T2, typename T3>
76
- class topk_map_to_rows_full_sort_krn ;
77
-
78
- template <typename T1, typename T2, typename T3> class populate_index_data_krn ;
79
-
80
- template <typename T1, typename T2, typename T3> class topk_map_to_rows_krn ;
77
+ class topk_full_merge_map_back_krn ;
81
78
82
79
template <typename argTy, typename IndexTy, typename CompT>
83
- sycl::event topk_full_sort_impl (
84
- sycl::queue &exec_q,
85
- std::size_t iter_nelems, // number of sub-arrays to sort (num. of rows in a
86
- // matrix when sorting over rows)
87
- std::size_t sort_nelems, // size of each array to sort (length of rows,
88
- // i.e. number of columns)
89
- std::size_t k,
90
- const argTy *arg_tp,
91
- argTy *vals_tp,
92
- IndexTy *inds_tp,
93
- const CompT &comp,
94
- const std::vector<sycl::event> &depends)
80
+ sycl::event
81
+ topk_full_merge_sort_impl (sycl::queue &exec_q,
82
+ std::size_t iter_nelems, // number of sub-arrays
83
+ std::size_t axis_nelems, // size of each sub-array
84
+ std::size_t k,
85
+ const argTy *arg_tp,
86
+ argTy *vals_tp,
87
+ IndexTy *inds_tp,
88
+ const CompT &comp,
89
+ const std::vector<sycl::event> &depends)
95
90
{
96
91
IndexTy *index_data =
97
- sycl::malloc_device<IndexTy>(iter_nelems * sort_nelems , exec_q);
92
+ sycl::malloc_device<IndexTy>(iter_nelems * axis_nelems , exec_q);
98
93
if (index_data == nullptr ) {
99
94
throw std::runtime_error (" Unable to allocate device_memory" );
100
95
}
@@ -103,10 +98,10 @@ sycl::event topk_full_sort_impl(
103
98
exec_q.submit ([&](sycl::handler &cgh) {
104
99
cgh.depends_on (depends);
105
100
106
- auto const &range = sycl::range<1 >(iter_nelems * sort_nelems );
101
+ auto const &range = sycl::range<1 >(iter_nelems * axis_nelems );
107
102
108
103
using KernelName =
109
- populate_index_data_full_sort_krn <argTy, IndexTy, CompT>;
104
+ topk_populate_index_data_krn <argTy, IndexTy, CompT>;
110
105
111
106
cgh.parallel_for <KernelName>(range, [=](sycl::id<1 > id) {
112
107
std::size_t i = id[0 ];
@@ -118,34 +113,33 @@ sycl::event topk_full_sort_impl(
118
113
// Sort segments of the array
119
114
sycl::event base_sort_ev =
120
115
merge_sort_detail::sort_over_work_group_contig_impl (
121
- exec_q, iter_nelems, sort_nelems , index_data, index_data, comp,
116
+ exec_q, iter_nelems, axis_nelems , index_data, index_data, comp,
122
117
sorted_block_size, // modified in place with size of sorted block
123
118
// size
124
119
{populate_indexed_data_ev});
125
120
126
121
// Merge segments in parallel until all elements are sorted
127
122
sycl::event merges_ev = merge_sort_detail::merge_sorted_block_contig_impl (
128
- exec_q, iter_nelems, sort_nelems , index_data, comp, sorted_block_size,
123
+ exec_q, iter_nelems, axis_nelems , index_data, comp, sorted_block_size,
129
124
{base_sort_ev});
130
125
131
126
sycl::event write_out_ev = exec_q.submit ([&](sycl::handler &cgh) {
132
127
cgh.depends_on (merges_ev);
133
128
134
- using KernelName =
135
- topk_map_to_rows_full_sort_krn<argTy, IndexTy, CompT>;
129
+ using KernelName = topk_full_merge_map_back_krn<argTy, IndexTy, CompT>;
136
130
137
131
cgh.parallel_for <KernelName>(iter_nelems * k, [=](sycl::id<1 > id) {
138
132
std::size_t gid = id[0 ];
139
133
140
134
std::size_t iter_gid = gid / k;
141
135
std::size_t axis_gid = gid - (iter_gid * k);
142
136
143
- std::size_t src_idx = iter_gid * sort_nelems + axis_gid;
137
+ std::size_t src_idx = iter_gid * axis_nelems + axis_gid;
144
138
std::size_t dst_idx = iter_gid * k + axis_gid;
145
139
146
140
auto res_ind = index_data[src_idx];
147
141
vals_tp[dst_idx] = arg_tp[res_ind];
148
- inds_tp[dst_idx] = res_ind % sort_nelems ;
142
+ inds_tp[dst_idx] = res_ind % axis_nelems ;
149
143
});
150
144
});
151
145
@@ -162,29 +156,32 @@ sycl::event topk_full_sort_impl(
162
156
return cleanup_host_task_event;
163
157
};
164
158
159
+ template <typename T1, typename T2, typename T3>
160
+ class topk_partial_merge_map_back_krn ;
161
+
165
162
template <typename T1, typename T2, typename Comp>
166
163
class topk_over_work_group_krn ;
167
164
168
165
template <typename argTy,
169
166
typename IndexTy,
170
167
typename ValueComp = std::less<argTy>>
171
- sycl::event
172
- topk_impl ( sycl::queue &exec_q,
173
- std::size_t iter_nelems, // number of sub-arrays to sort (num. of rows
174
- // in a matrix when sorting over rows)
175
- std::size_t axis_nelems, // size of each array to sort (length of
176
- // rows, i.e. number of columns)
177
- std::size_t k,
178
- const char *arg_cp,
179
- char *vals_cp,
180
- char *inds_cp,
181
- dpctl::tensor::ssize_t iter_arg_offset,
182
- dpctl::tensor::ssize_t iter_vals_offset,
183
- dpctl::tensor::ssize_t iter_inds_offset,
184
- dpctl::tensor::ssize_t axis_arg_offset,
185
- dpctl::tensor::ssize_t axis_vals_offset,
186
- dpctl::tensor::ssize_t axis_inds_offset,
187
- const std::vector<sycl::event> &depends)
168
+ sycl::event topk_merge_impl (
169
+ sycl::queue &exec_q,
170
+ std::size_t iter_nelems, // number of sub-arrays to sort (num. of rows
171
+ // in a matrix when sorting over rows)
172
+ std::size_t axis_nelems, // size of each array to sort (length of
173
+ // rows, i.e. number of columns)
174
+ std::size_t k,
175
+ const char *arg_cp,
176
+ char *vals_cp,
177
+ char *inds_cp,
178
+ dpctl::tensor::ssize_t iter_arg_offset,
179
+ dpctl::tensor::ssize_t iter_vals_offset,
180
+ dpctl::tensor::ssize_t iter_inds_offset,
181
+ dpctl::tensor::ssize_t axis_arg_offset,
182
+ dpctl::tensor::ssize_t axis_vals_offset,
183
+ dpctl::tensor::ssize_t axis_inds_offset,
184
+ const std::vector<sycl::event> &depends)
188
185
{
189
186
if (axis_nelems < k) {
190
187
throw std::runtime_error (" Invalid sort axis size for value of k" );
@@ -201,8 +198,9 @@ topk_impl(sycl::queue &exec_q,
201
198
const IndexComp<IndexTy, argTy, ValueComp> index_comp{arg_tp, ValueComp{}};
202
199
203
200
if (axis_nelems <= 512 || k >= 1024 || k > axis_nelems / 2 ) {
204
- return topk_full_sort_impl (exec_q, iter_nelems, axis_nelems, k, arg_tp,
205
- vals_tp, inds_tp, index_comp, depends);
201
+ return topk_full_merge_sort_impl (exec_q, iter_nelems, axis_nelems, k,
202
+ arg_tp, vals_tp, inds_tp, index_comp,
203
+ depends);
206
204
}
207
205
else {
208
206
using PartialKernelName =
@@ -269,9 +267,9 @@ topk_impl(sycl::queue &exec_q,
269
267
if (k_rounded >= axis_nelems || k_rounded >= sorted_block_size ||
270
268
alloc_len >= axis_nelems / 2 )
271
269
{
272
- return topk_full_sort_impl (exec_q, iter_nelems, axis_nelems, k ,
273
- arg_tp, vals_tp, inds_tp, index_comp ,
274
- depends);
270
+ return topk_full_merge_sort_impl (exec_q, iter_nelems, axis_nelems,
271
+ k, arg_tp, vals_tp, inds_tp,
272
+ index_comp, depends);
275
273
}
276
274
277
275
IndexTy *index_data =
@@ -399,7 +397,8 @@ topk_impl(sycl::queue &exec_q,
399
397
sycl::event write_topk_ev = exec_q.submit ([&](sycl::handler &cgh) {
400
398
cgh.depends_on (merges_ev);
401
399
402
- using KernelName = topk_map_to_rows_krn<argTy, IndexTy, ValueComp>;
400
+ using KernelName =
401
+ topk_partial_merge_map_back_krn<argTy, IndexTy, ValueComp>;
403
402
404
403
cgh.parallel_for <KernelName>(iter_nelems * k, [=](sycl::id<1 > id) {
405
404
std::size_t gid = id[0 ];
@@ -430,6 +429,109 @@ topk_impl(sycl::queue &exec_q,
430
429
}
431
430
}
432
431
432
+ template <typename T1, typename T2> class topk_iota_krn ;
433
+
434
+ template <typename T1, typename T2> class topk_radix_map_back_krn ;
435
+
436
+ template <typename argTy, typename IndexTy>
437
+ sycl::event topk_radix_impl (sycl::queue &exec_q,
438
+ std::size_t iter_nelems, // number of sub-arrays
439
+ std::size_t axis_nelems, // size of each sub-array
440
+ std::size_t k,
441
+ bool ascending,
442
+ const char *arg_cp,
443
+ char *vals_cp,
444
+ char *inds_cp,
445
+ dpctl::tensor::ssize_t iter_arg_offset,
446
+ dpctl::tensor::ssize_t iter_vals_offset,
447
+ dpctl::tensor::ssize_t iter_inds_offset,
448
+ dpctl::tensor::ssize_t axis_arg_offset,
449
+ dpctl::tensor::ssize_t axis_vals_offset,
450
+ dpctl::tensor::ssize_t axis_inds_offset,
451
+ const std::vector<sycl::event> &depends)
452
+ {
453
+ if (axis_nelems < k) {
454
+ throw std::runtime_error (" Invalid sort axis size for value of k" );
455
+ }
456
+
457
+ const argTy *arg_tp = reinterpret_cast <const argTy *>(arg_cp) +
458
+ iter_arg_offset + axis_arg_offset;
459
+ argTy *vals_tp = reinterpret_cast <argTy *>(vals_cp) + iter_vals_offset +
460
+ axis_vals_offset;
461
+ IndexTy *inds_tp = reinterpret_cast <IndexTy *>(inds_cp) + iter_inds_offset +
462
+ axis_inds_offset;
463
+
464
+ const std::size_t total_nelems = iter_nelems * axis_nelems;
465
+ const std::size_t padded_total_nelems = ((total_nelems + 63 ) / 64 ) * 64 ;
466
+ IndexTy *workspace = sycl::malloc_device<IndexTy>(
467
+ padded_total_nelems + total_nelems, exec_q);
468
+
469
+ IndexTy *tmp_tp = sycl::malloc_device<IndexTy>(total_nelems, exec_q);
470
+
471
+ if (nullptr == workspace || nullptr == tmp_tp) {
472
+ throw std::runtime_error (
473
+ " Not enough device memory for radix sort topk" );
474
+ }
475
+
476
+ using IdentityProjT = radix_sort_details::IdentityProj;
477
+ using IndexedProjT =
478
+ radix_sort_details::IndexedProj<IndexTy, argTy, IdentityProjT>;
479
+ const IndexedProjT proj_op{arg_tp, IdentityProjT{}};
480
+
481
+ sycl::event iota_ev = exec_q.submit ([&](sycl::handler &cgh) {
482
+ cgh.depends_on (depends);
483
+
484
+ using KernelName = topk_iota_krn<argTy, IndexTy>;
485
+
486
+ cgh.parallel_for <KernelName>(
487
+ sycl::range<1 >(total_nelems), [=](sycl::id<1 > id) {
488
+ size_t i = id[0 ];
489
+ IndexTy sort_id = static_cast <IndexTy>(i);
490
+ workspace[i] = sort_id;
491
+ });
492
+ });
493
+
494
+ sycl::event radix_sort_ev =
495
+ radix_sort_details::parallel_radix_sort_impl<IndexTy, IndexedProjT>(
496
+ exec_q, iter_nelems, axis_nelems, workspace, tmp_tp, proj_op,
497
+ ascending, {iota_ev});
498
+
499
+ // Write out top k of the temporary
500
+ sycl::event write_topk_ev = exec_q.submit ([&](sycl::handler &cgh) {
501
+ cgh.depends_on (radix_sort_ev);
502
+
503
+ using KernelName = topk_radix_map_back_krn<argTy, IndexTy>;
504
+
505
+ cgh.parallel_for <KernelName>(iter_nelems * k, [=](sycl::id<1 > id) {
506
+ std::size_t gid = id[0 ];
507
+
508
+ std::size_t iter_gid = gid / k;
509
+ std::size_t axis_gid = gid - (iter_gid * k);
510
+
511
+ std::size_t src_idx = iter_gid * axis_nelems + axis_gid;
512
+ std::size_t dst_idx = iter_gid * k + axis_gid;
513
+
514
+ IndexTy res_ind = tmp_tp[src_idx];
515
+ vals_tp[dst_idx] = arg_tp[res_ind];
516
+ inds_tp[dst_idx] = res_ind % axis_nelems;
517
+ });
518
+ });
519
+
520
+ sycl::event cleanup_ev = exec_q.submit ([&](sycl::handler &cgh) {
521
+ cgh.depends_on (write_topk_ev);
522
+
523
+ const sycl::context &ctx = exec_q.get_context ();
524
+
525
+ using dpctl::tensor::alloc_utils::sycl_free_noexcept;
526
+ cgh.host_task ([ctx, workspace, tmp_tp] {
527
+ sycl_free_noexcept (workspace, ctx);
528
+ sycl_free_noexcept (tmp_tp, ctx);
529
+ });
530
+ });
531
+
532
+ return cleanup_ev;
533
+ }
534
+
433
535
} // end of namespace kernels
434
536
} // end of namespace tensor
435
537
} // end of namespace dpctl
0 commit comments