@@ -212,29 +212,85 @@ T custom_reduce_over_group(const GroupT &wg,
212
212
return sycl::group_broadcast (wg, red_val_over_wg, 0 );
213
213
}
214
214
215
- template <typename T, typename GroupT, typename LocAccT, typename OpT>
216
- T custom_inclusive_scan_over_group (const GroupT &wg,
217
- LocAccT local_mem_acc,
218
- const T local_val,
219
- const OpT &op)
215
+ template <typename GroupT,
216
+ typename SubGroupT,
217
+ typename LocAccT,
218
+ typename T,
219
+ typename OpT>
220
+ T custom_inclusive_scan_over_group (GroupT &&wg,
221
+ SubGroupT &&sg,
222
+ LocAccT &&local_mem_acc,
223
+ const T &local_val,
224
+ const T &identity,
225
+ OpT &&op)
220
226
{
221
227
const std::uint32_t local_id = wg.get_local_id (0 );
222
228
const std::uint32_t wgs = wg.get_local_range (0 );
223
- local_mem_acc[local_id] = local_val;
224
229
230
+ const std::uint32_t lane_id = sg.get_local_id ()[0 ];
231
+ const std::uint32_t sgSize = sg.get_local_range ()[0 ];
232
+
233
+ T scan_val = local_val;
234
+ for (std::uint32_t step = 1 ; step < sgSize; step *= 2 ) {
235
+ const bool advanced_lane = (lane_id >= step);
236
+ const std::uint32_t src_lane_id =
237
+ (advanced_lane ? lane_id - step : lane_id);
238
+ const T modifier = sycl::select_from_group (sg, scan_val, src_lane_id);
239
+ if (advanced_lane) {
240
+ scan_val = op (scan_val, modifier);
241
+ }
242
+ }
243
+
244
+ local_mem_acc[local_id] = scan_val;
225
245
sycl::group_barrier (wg, sycl::memory_scope::work_group);
226
246
227
- if (wg.leader ()) {
228
- T scan_val = local_mem_acc[0 ];
229
- for (std::uint32_t i = 1 ; i < wgs; ++i) {
230
- scan_val = op (local_mem_acc[i], scan_val);
231
- local_mem_acc[i] = scan_val;
247
+ const std::uint32_t max_sgSize = sg.get_max_local_range ()[0 ];
248
+ const std::uint32_t sgr_id = sg.get_group_id ()[0 ];
249
+
250
+ // now scan
251
+ const std::uint32_t n_aggregates = 1 + ((wgs - 1 ) / max_sgSize);
252
+ const bool large_wg = (n_aggregates > max_sgSize);
253
+ if (large_wg) {
254
+ if (wg.leader ()) {
255
+ T _scan_val = identity;
256
+ for (std::uint32_t i = 1 ; i <= n_aggregates - max_sgSize; ++i) {
257
+ _scan_val = op (local_mem_acc[i * max_sgSize - 1 ], _scan_val);
258
+ local_mem_acc[i * max_sgSize - 1 ] = _scan_val;
259
+ }
260
+ }
261
+ sycl::group_barrier (wg, sycl::memory_scope::work_group);
262
+ }
263
+
264
+ if (sgr_id == 0 && lane_id < n_aggregates) {
265
+ const std::uint32_t offset =
266
+ (large_wg) ? n_aggregates - max_sgSize : 0u ;
267
+ T __scan_val = (offset + lane_id > 0 )
268
+ ? local_mem_acc[(offset + lane_id) * max_sgSize - 1 ]
269
+ : identity;
270
+ for (std::uint32_t step = 1 ; step < sgSize; step *= 2 ) {
271
+ const bool advanced_lane = (lane_id >= step);
272
+ const std::uint32_t src_lane_id =
273
+ (advanced_lane ? lane_id - step : lane_id);
274
+ const T modifier =
275
+ sycl::select_from_group (sg, __scan_val, src_lane_id);
276
+ if (advanced_lane) {
277
+ __scan_val = op (__scan_val, modifier);
278
+ }
232
279
}
280
+ sycl::group_barrier (sg);
281
+ local_mem_acc[(offset + lane_id) * max_sgSize - 1 ] = __scan_val;
233
282
}
283
+ sycl::group_barrier (wg, sycl::memory_scope::work_group);
234
284
235
- // ensure all work-items see the same SLM that leader updated
285
+ if (sgr_id > 0 ) {
286
+ const T modifier = local_mem_acc[sgr_id * max_sgSize - 1 ];
287
+ scan_val = op (scan_val, modifier);
288
+ }
289
+
290
+ // ensure all work-items finished reading from SLM
236
291
sycl::group_barrier (wg, sycl::memory_scope::work_group);
237
- return local_mem_acc[local_id];
292
+
293
+ return scan_val;
238
294
}
239
295
240
296
// Reduction functors
0 commit comments