@@ -105,14 +105,35 @@ sycl::event map_back_impl(sycl::queue &exec_q,
105
105
std::size_t row_size,
106
106
const std::vector<sycl::event> &dependent_events)
107
107
{
108
+ constexpr std::uint32_t lws = 64 ;
109
+ constexpr std::uint32_t n_wi = 4 ;
110
+ const std::size_t n_groups = (nelems + lws * n_wi - 1 ) / (n_wi * lws);
111
+
112
+ sycl::range<1 > lRange{lws};
113
+ sycl::range<1 > gRange {n_groups * lws};
114
+ sycl::nd_range<1 > ndRange{gRange , lRange};
115
+
108
116
sycl::event map_back_ev = exec_q.submit ([&](sycl::handler &cgh) {
109
117
cgh.depends_on (dependent_events);
110
118
111
- cgh.parallel_for <KernelName>(
112
- sycl::range<1 >(nelems), [=](sycl::id<1 > id) {
113
- const IndexTy linear_index = flat_index_data[id];
114
- reduced_index_data[id] = (linear_index % row_size);
115
- });
119
+ cgh.parallel_for <KernelName>(ndRange, [=](sycl::nd_item<1 > it) {
120
+ const std::size_t gid = it.get_global_linear_id ();
121
+ const auto &sg = it.get_sub_group ();
122
+ const std::uint32_t lane_id = sg.get_local_id ()[0 ];
123
+ const std::uint32_t sg_size = sg.get_max_local_range ()[0 ];
124
+
125
+ const std::size_t start_id = (gid - lane_id) * n_wi + lane_id;
126
+
127
+ #pragma unroll
128
+ for (std::uint32_t i = 0 ; i < n_wi; ++i) {
129
+ const std::size_t data_id = start_id + i * sg_size;
130
+
131
+ if (data_id < nelems) {
132
+ const IndexTy linear_index = flat_index_data[data_id];
133
+ reduced_index_data[data_id] = (linear_index % row_size);
134
+ }
135
+ }
136
+ });
116
137
});
117
138
118
139
return map_back_ev;
0 commit comments