@@ -145,7 +145,7 @@ template <typename T> class stack_strided_t
145
145
146
146
namespace su_ns = dpctl::tensor::sycl_utils;
147
147
148
- using nwiT = std::uint16_t ;
148
+ using nwiT = std::uint32_t ;
149
149
150
150
template <typename inputT,
151
151
typename outputT,
@@ -156,7 +156,18 @@ template <typename inputT,
156
156
typename TransformerT,
157
157
typename ScanOpT,
158
158
bool include_initial>
159
- class inclusive_scan_iter_local_scan_krn ;
159
+ class inclusive_scan_iter_local_scan_blocked_krn ;
160
+
161
+ template <typename inputT,
162
+ typename outputT,
163
+ nwiT n_wi,
164
+ typename IterIndexerT,
165
+ typename InpIndexerT,
166
+ typename OutIndexerT,
167
+ typename TransformerT,
168
+ typename ScanOpT,
169
+ bool include_initial>
170
+ class inclusive_scan_iter_local_scan_striped_krn ;
160
171
161
172
template <typename inputT,
162
173
typename outputT,
@@ -177,22 +188,22 @@ template <typename inputT,
177
188
typename ScanOpT,
178
189
bool include_initial = false >
179
190
sycl::event
180
- inclusive_scan_base_step (sycl::queue &exec_q,
181
- const std::size_t wg_size,
182
- const std::size_t iter_nelems,
183
- const std::size_t acc_nelems,
184
- const inputT *input,
185
- outputT *output,
186
- const std::size_t s0,
187
- const std::size_t s1,
188
- const IterIndexerT &iter_indexer,
189
- const InpIndexerT &inp_indexer,
190
- const OutIndexerT &out_indexer,
191
- TransformerT transformer,
192
- const ScanOpT &scan_op,
193
- outputT identity,
194
- std::size_t &acc_groups,
195
- const std::vector<sycl::event> &depends = {})
191
+ inclusive_scan_base_step_blocked (sycl::queue &exec_q,
192
+ const std::uint32_t wg_size,
193
+ const std::size_t iter_nelems,
194
+ const std::size_t acc_nelems,
195
+ const inputT *input,
196
+ outputT *output,
197
+ const std::size_t s0,
198
+ const std::size_t s1,
199
+ const IterIndexerT &iter_indexer,
200
+ const InpIndexerT &inp_indexer,
201
+ const OutIndexerT &out_indexer,
202
+ TransformerT transformer,
203
+ const ScanOpT &scan_op,
204
+ outputT identity,
205
+ std::size_t &acc_groups,
206
+ const std::vector<sycl::event> &depends = {})
196
207
{
197
208
acc_groups = ceiling_quotient<std::size_t >(acc_nelems, n_wi * wg_size);
198
209
@@ -208,7 +219,7 @@ inclusive_scan_base_step(sycl::queue &exec_q,
208
219
209
220
slmT slm_iscan_tmp (lws, cgh);
210
221
211
- using KernelName = inclusive_scan_iter_local_scan_krn <
222
+ using KernelName = inclusive_scan_iter_local_scan_blocked_krn <
212
223
inputT, outputT, n_wi, IterIndexerT, InpIndexerT, OutIndexerT,
213
224
TransformerT, ScanOpT, include_initial>;
214
225
@@ -218,6 +229,7 @@ inclusive_scan_base_step(sycl::queue &exec_q,
218
229
const std::size_t gid = it.get_global_id (0 );
219
230
const std::size_t lid = it.get_local_id (0 );
220
231
232
+ const std::uint32_t wg_size = it.get_local_range (0 );
221
233
const std::size_t reduce_chunks = acc_groups * wg_size;
222
234
const std::size_t iter_gid = gid / reduce_chunks;
223
235
const std::size_t chunk_gid = gid - (iter_gid * reduce_chunks);
@@ -296,6 +308,248 @@ inclusive_scan_base_step(sycl::queue &exec_q,
296
308
return inc_scan_phase1_ev;
297
309
}
298
310
311
+ template <typename inputT,
312
+ typename outputT,
313
+ nwiT n_wi,
314
+ typename IterIndexerT,
315
+ typename InpIndexerT,
316
+ typename OutIndexerT,
317
+ typename TransformerT,
318
+ typename ScanOpT,
319
+ bool include_initial = false >
320
+ sycl::event
321
+ inclusive_scan_base_step_striped (sycl::queue &exec_q,
322
+ const std::uint32_t wg_size,
323
+ const std::size_t iter_nelems,
324
+ const std::size_t acc_nelems,
325
+ const inputT *input,
326
+ outputT *output,
327
+ const std::size_t s0,
328
+ const std::size_t s1,
329
+ const IterIndexerT &iter_indexer,
330
+ const InpIndexerT &inp_indexer,
331
+ const OutIndexerT &out_indexer,
332
+ TransformerT transformer,
333
+ const ScanOpT &scan_op,
334
+ outputT identity,
335
+ std::size_t &acc_groups,
336
+ const std::vector<sycl::event> &depends = {})
337
+ {
338
+ const std::uint32_t reduce_nelems_per_wg = n_wi * wg_size;
339
+ acc_groups =
340
+ ceiling_quotient<std::size_t >(acc_nelems, reduce_nelems_per_wg);
341
+
342
+ sycl::event inc_scan_phase1_ev = exec_q.submit ([&](sycl::handler &cgh) {
343
+ cgh.depends_on (depends);
344
+
345
+ using slmT = sycl::local_accessor<outputT, 1 >;
346
+
347
+ const auto &gRange = sycl::range<1 >{iter_nelems * acc_groups * wg_size};
348
+ const auto &lRange = sycl::range<1 >{wg_size};
349
+
350
+ const auto &ndRange = sycl::nd_range<1 >{gRange , lRange};
351
+
352
+ slmT slm_iscan_tmp (reduce_nelems_per_wg, cgh);
353
+
354
+ using KernelName = inclusive_scan_iter_local_scan_striped_krn<
355
+ inputT, outputT, n_wi, IterIndexerT, InpIndexerT, OutIndexerT,
356
+ TransformerT, ScanOpT, include_initial>;
357
+
358
+ cgh.parallel_for <KernelName>(ndRange, [=, slm_iscan_tmp =
359
+ std::move (slm_iscan_tmp)](
360
+ sycl::nd_item<1 > it) {
361
+ const std::uint32_t lid = it.get_local_linear_id ();
362
+ const std::uint32_t wg_size = it.get_local_range (0 );
363
+
364
+ const auto &sg = it.get_sub_group ();
365
+ const std::uint32_t sgSize = sg.get_max_local_range ()[0 ];
366
+ const std::size_t sgroup_id = sg.get_group_id ()[0 ];
367
+ const std::uint32_t lane_id = sg.get_local_id ()[0 ];
368
+
369
+ const std::size_t flat_group_id = it.get_group (0 );
370
+ const std::size_t iter_gid = flat_group_id / acc_groups;
371
+ const std::size_t acc_group_id =
372
+ flat_group_id - (iter_gid * acc_groups);
373
+
374
+ const auto &iter_offsets = iter_indexer (iter_gid);
375
+ const auto &inp_iter_offset = iter_offsets.get_first_offset ();
376
+ const auto &out_iter_offset = iter_offsets.get_second_offset ();
377
+
378
+ std::array<outputT, n_wi> local_iscan{};
379
+
380
+ const std::size_t inp_id0 = acc_group_id * n_wi * wg_size +
381
+ sgroup_id * n_wi * sgSize + lane_id;
382
+
383
+ #pragma unroll
384
+ for (nwiT m_wi = 0 ; m_wi < n_wi; ++m_wi) {
385
+ const std::size_t inp_id = inp_id0 + m_wi * sgSize;
386
+ if constexpr (!include_initial) {
387
+ local_iscan[m_wi] =
388
+ (inp_id < acc_nelems)
389
+ ? transformer (input[inp_iter_offset +
390
+ inp_indexer (s0 + s1 * inp_id)])
391
+ : identity;
392
+ }
393
+ else {
394
+ // shift input to the left by a single element relative to
395
+ // output
396
+ local_iscan[m_wi] =
397
+ (inp_id < acc_nelems && inp_id > 0 )
398
+ ? transformer (
399
+ input[inp_iter_offset +
400
+ inp_indexer ((s0 + s1 * inp_id) - 1 )])
401
+ : identity;
402
+ }
403
+ }
404
+
405
+ // change layout from striped to blocked
406
+ {
407
+ {
408
+ const std::uint32_t local_offset0 = lid * n_wi;
409
+ #pragma unroll
410
+ for (std::uint32_t i = 0 ; i < n_wi; ++i) {
411
+ slm_iscan_tmp[local_offset0 + i] = local_iscan[i];
412
+ }
413
+
414
+ it.barrier (sycl::access::fence_space::local_space);
415
+ }
416
+
417
+ {
418
+ const std::uint32_t block_offset =
419
+ sgroup_id * sgSize * n_wi;
420
+ const std::uint32_t disp0 = lane_id * n_wi;
421
+ #pragma unroll
422
+ for (nwiT i = 0 ; i < n_wi; ++i) {
423
+ const std::uint32_t disp = disp0 + i;
424
+
425
+ // disp == lane_id1 + i1 * sgSize;
426
+ const std::uint32_t i1 = disp / sgSize;
427
+ const std::uint32_t lane_id1 = disp - i1 * sgSize;
428
+
429
+ const std::uint32_t disp_exchanged =
430
+ (lane_id1 * n_wi + i1);
431
+
432
+ local_iscan[i] =
433
+ slm_iscan_tmp[block_offset + disp_exchanged];
434
+ }
435
+
436
+ it.barrier (sycl::access::fence_space::local_space);
437
+ }
438
+ }
439
+
440
+ #pragma unroll
441
+ for (nwiT m_wi = 1 ; m_wi < n_wi; ++m_wi) {
442
+ local_iscan[m_wi] =
443
+ scan_op (local_iscan[m_wi], local_iscan[m_wi - 1 ]);
444
+ }
445
+ // local_iscan is now result of
446
+ // inclusive scan of locally stored inputs
447
+
448
+ outputT wg_iscan_val;
449
+ if constexpr (can_use_inclusive_scan_over_group<ScanOpT,
450
+ outputT>::value)
451
+ {
452
+ wg_iscan_val = sycl::inclusive_scan_over_group (
453
+ it.get_group (), local_iscan.back (), scan_op, identity);
454
+ }
455
+ else {
456
+ wg_iscan_val = su_ns::custom_inclusive_scan_over_group (
457
+ it.get_group (), slm_iscan_tmp, local_iscan.back (), scan_op);
458
+ // ensure all finished reading from SLM, to avoid race condition
459
+ // with subsequent writes into SLM
460
+ it.barrier (sycl::access::fence_space::local_space);
461
+ }
462
+
463
+ slm_iscan_tmp[(lid + 1 ) % wg_size] = wg_iscan_val;
464
+ it.barrier (sycl::access::fence_space::local_space);
465
+ const outputT modifier = (lid == 0 ) ? identity : slm_iscan_tmp[lid];
466
+
467
+ #pragma unroll
468
+ for (nwiT m_wi = 0 ; m_wi < n_wi; ++m_wi) {
469
+ local_iscan[m_wi] = scan_op (local_iscan[m_wi], modifier);
470
+ }
471
+
472
+ it.barrier (sycl::access::fence_space::local_space);
473
+
474
+ // convert back to blocked layout
475
+ {
476
+ {
477
+ const std::uint32_t local_offset0 = lid * n_wi;
478
+ #pragma unroll
479
+ for (nwiT m_wi = 0 ; m_wi < n_wi; ++m_wi) {
480
+ slm_iscan_tmp[local_offset0 + m_wi] = local_iscan[m_wi];
481
+ }
482
+
483
+ it.barrier (sycl::access::fence_space::local_space);
484
+ }
485
+ }
486
+
487
+ {
488
+ const std::uint32_t block_offset =
489
+ sgroup_id * sgSize * n_wi + lane_id;
490
+ #pragma unroll
491
+ for (nwiT m_wi = 0 ; m_wi < n_wi; ++m_wi) {
492
+ const std::uint32_t m_wi_scaled = m_wi * sgSize;
493
+ const std::size_t out_id = inp_id0 + m_wi_scaled;
494
+ if (out_id < acc_nelems) {
495
+ output[out_iter_offset + out_indexer (out_id)] =
496
+ slm_iscan_tmp[block_offset + m_wi_scaled];
497
+ }
498
+ }
499
+ }
500
+ });
501
+ });
502
+
503
+ return inc_scan_phase1_ev;
504
+ }
505
+
506
+ template <typename inputT,
507
+ typename outputT,
508
+ nwiT n_wi,
509
+ typename IterIndexerT,
510
+ typename InpIndexerT,
511
+ typename OutIndexerT,
512
+ typename TransformerT,
513
+ typename ScanOpT,
514
+ bool include_initial = false >
515
+ sycl::event
516
+ inclusive_scan_base_step (sycl::queue &exec_q,
517
+ const std::uint32_t wg_size,
518
+ const std::size_t iter_nelems,
519
+ const std::size_t acc_nelems,
520
+ const inputT *input,
521
+ outputT *output,
522
+ const std::size_t s0,
523
+ const std::size_t s1,
524
+ const IterIndexerT &iter_indexer,
525
+ const InpIndexerT &inp_indexer,
526
+ const OutIndexerT &out_indexer,
527
+ TransformerT transformer,
528
+ const ScanOpT &scan_op,
529
+ outputT identity,
530
+ std::size_t &acc_groups,
531
+ const std::vector<sycl::event> &depends = {})
532
+ {
533
+ // For small stride use striped load/store.
534
+ // Threshold value chosen experimentally.
535
+ if (s1 <= 16 ) {
536
+ return inclusive_scan_base_step_striped<
537
+ inputT, outputT, n_wi, IterIndexerT, InpIndexerT, OutIndexerT,
538
+ TransformerT, ScanOpT, include_initial>(
539
+ exec_q, wg_size, iter_nelems, acc_nelems, input, output, s0, s1,
540
+ iter_indexer, inp_indexer, out_indexer, transformer, scan_op,
541
+ identity, acc_groups, depends);
542
+ }
543
+ else {
544
+ return inclusive_scan_base_step_blocked<
545
+ inputT, outputT, n_wi, IterIndexerT, InpIndexerT, OutIndexerT,
546
+ TransformerT, ScanOpT, include_initial>(
547
+ exec_q, wg_size, iter_nelems, acc_nelems, input, output, s0, s1,
548
+ iter_indexer, inp_indexer, out_indexer, transformer, scan_op,
549
+ identity, acc_groups, depends);
550
+ }
551
+ }
552
+
299
553
template <typename inputT,
300
554
typename outputT,
301
555
nwiT n_wi,
0 commit comments