Skip to content

Commit 399cdd1

Browse files
Replace map_back_impl in sort_utils
Change kernel to process few data elements in the work-item.
1 parent addb341 commit 399cdd1

File tree

1 file changed

+26
-5
lines changed

1 file changed

+26
-5
lines changed

dpctl/tensor/libtensor/include/kernels/sorting/sort_utils.hpp

Lines changed: 26 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -105,14 +105,35 @@ sycl::event map_back_impl(sycl::queue &exec_q,
105105
std::size_t row_size,
106106
const std::vector<sycl::event> &dependent_events)
107107
{
108+
constexpr std::uint32_t lws = 64;
109+
constexpr std::uint32_t n_wi = 4;
110+
const std::size_t n_groups = (nelems + lws * n_wi - 1) / (n_wi * lws);
111+
112+
sycl::range<1> lRange{lws};
113+
sycl::range<1> gRange{n_groups * lws};
114+
sycl::nd_range<1> ndRange{gRange, lRange};
115+
108116
sycl::event map_back_ev = exec_q.submit([&](sycl::handler &cgh) {
109117
cgh.depends_on(dependent_events);
110118

111-
cgh.parallel_for<KernelName>(
112-
sycl::range<1>(nelems), [=](sycl::id<1> id) {
113-
const IndexTy linear_index = flat_index_data[id];
114-
reduced_index_data[id] = (linear_index % row_size);
115-
});
119+
cgh.parallel_for<KernelName>(ndRange, [=](sycl::nd_item<1> it) {
120+
const std::size_t gid = it.get_global_linear_id();
121+
const auto &sg = it.get_sub_group();
122+
const std::uint32_t lane_id = sg.get_local_id()[0];
123+
const std::uint32_t sg_size = sg.get_max_local_range()[0];
124+
125+
const std::size_t start_id = (gid - lane_id) * n_wi + lane_id;
126+
127+
#pragma unroll
128+
for (std::uint32_t i = 0; i < n_wi; ++i) {
129+
const std::size_t data_id = start_id + i * sg_size;
130+
131+
if (data_id < nelems) {
132+
const IndexTy linear_index = flat_index_data[data_id];
133+
reduced_index_data[data_id] = (linear_index % row_size);
134+
}
135+
}
136+
});
116137
});
117138

118139
return map_back_ev;

0 commit comments

Comments
 (0)