Skip to content

Commit 18c5ed7

Browse files
Add implementation of base scan using striped load/store pattern
1 parent 25a961f commit 18c5ed7

File tree

1 file changed

+273
-19
lines changed

1 file changed

+273
-19
lines changed

dpctl/tensor/libtensor/include/kernels/accumulators.hpp

Lines changed: 273 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -145,7 +145,7 @@ template <typename T> class stack_strided_t
145145

146146
namespace su_ns = dpctl::tensor::sycl_utils;
147147

148-
using nwiT = std::uint16_t;
148+
using nwiT = std::uint32_t;
149149

150150
template <typename inputT,
151151
typename outputT,
@@ -156,7 +156,18 @@ template <typename inputT,
156156
typename TransformerT,
157157
typename ScanOpT,
158158
bool include_initial>
159-
class inclusive_scan_iter_local_scan_krn;
159+
class inclusive_scan_iter_local_scan_blocked_krn;
160+
161+
template <typename inputT,
162+
typename outputT,
163+
nwiT n_wi,
164+
typename IterIndexerT,
165+
typename InpIndexerT,
166+
typename OutIndexerT,
167+
typename TransformerT,
168+
typename ScanOpT,
169+
bool include_initial>
170+
class inclusive_scan_iter_local_scan_striped_krn;
160171

161172
template <typename inputT,
162173
typename outputT,
@@ -177,22 +188,22 @@ template <typename inputT,
177188
typename ScanOpT,
178189
bool include_initial = false>
179190
sycl::event
180-
inclusive_scan_base_step(sycl::queue &exec_q,
181-
const std::size_t wg_size,
182-
const std::size_t iter_nelems,
183-
const std::size_t acc_nelems,
184-
const inputT *input,
185-
outputT *output,
186-
const std::size_t s0,
187-
const std::size_t s1,
188-
const IterIndexerT &iter_indexer,
189-
const InpIndexerT &inp_indexer,
190-
const OutIndexerT &out_indexer,
191-
TransformerT transformer,
192-
const ScanOpT &scan_op,
193-
outputT identity,
194-
std::size_t &acc_groups,
195-
const std::vector<sycl::event> &depends = {})
191+
inclusive_scan_base_step_blocked(sycl::queue &exec_q,
192+
const std::uint32_t wg_size,
193+
const std::size_t iter_nelems,
194+
const std::size_t acc_nelems,
195+
const inputT *input,
196+
outputT *output,
197+
const std::size_t s0,
198+
const std::size_t s1,
199+
const IterIndexerT &iter_indexer,
200+
const InpIndexerT &inp_indexer,
201+
const OutIndexerT &out_indexer,
202+
TransformerT transformer,
203+
const ScanOpT &scan_op,
204+
outputT identity,
205+
std::size_t &acc_groups,
206+
const std::vector<sycl::event> &depends = {})
196207
{
197208
acc_groups = ceiling_quotient<std::size_t>(acc_nelems, n_wi * wg_size);
198209

@@ -208,7 +219,7 @@ inclusive_scan_base_step(sycl::queue &exec_q,
208219

209220
slmT slm_iscan_tmp(lws, cgh);
210221

211-
using KernelName = inclusive_scan_iter_local_scan_krn<
222+
using KernelName = inclusive_scan_iter_local_scan_blocked_krn<
212223
inputT, outputT, n_wi, IterIndexerT, InpIndexerT, OutIndexerT,
213224
TransformerT, ScanOpT, include_initial>;
214225

@@ -218,6 +229,7 @@ inclusive_scan_base_step(sycl::queue &exec_q,
218229
const std::size_t gid = it.get_global_id(0);
219230
const std::size_t lid = it.get_local_id(0);
220231

232+
const std::uint32_t wg_size = it.get_local_range(0);
221233
const std::size_t reduce_chunks = acc_groups * wg_size;
222234
const std::size_t iter_gid = gid / reduce_chunks;
223235
const std::size_t chunk_gid = gid - (iter_gid * reduce_chunks);
@@ -296,6 +308,248 @@ inclusive_scan_base_step(sycl::queue &exec_q,
296308
return inc_scan_phase1_ev;
297309
}
298310

311+
template <typename inputT,
312+
typename outputT,
313+
nwiT n_wi,
314+
typename IterIndexerT,
315+
typename InpIndexerT,
316+
typename OutIndexerT,
317+
typename TransformerT,
318+
typename ScanOpT,
319+
bool include_initial = false>
320+
sycl::event
321+
inclusive_scan_base_step_striped(sycl::queue &exec_q,
322+
const std::uint32_t wg_size,
323+
const std::size_t iter_nelems,
324+
const std::size_t acc_nelems,
325+
const inputT *input,
326+
outputT *output,
327+
const std::size_t s0,
328+
const std::size_t s1,
329+
const IterIndexerT &iter_indexer,
330+
const InpIndexerT &inp_indexer,
331+
const OutIndexerT &out_indexer,
332+
TransformerT transformer,
333+
const ScanOpT &scan_op,
334+
outputT identity,
335+
std::size_t &acc_groups,
336+
const std::vector<sycl::event> &depends = {})
337+
{
338+
const std::uint32_t reduce_nelems_per_wg = n_wi * wg_size;
339+
acc_groups =
340+
ceiling_quotient<std::size_t>(acc_nelems, reduce_nelems_per_wg);
341+
342+
sycl::event inc_scan_phase1_ev = exec_q.submit([&](sycl::handler &cgh) {
343+
cgh.depends_on(depends);
344+
345+
using slmT = sycl::local_accessor<outputT, 1>;
346+
347+
const auto &gRange = sycl::range<1>{iter_nelems * acc_groups * wg_size};
348+
const auto &lRange = sycl::range<1>{wg_size};
349+
350+
const auto &ndRange = sycl::nd_range<1>{gRange, lRange};
351+
352+
slmT slm_iscan_tmp(reduce_nelems_per_wg, cgh);
353+
354+
using KernelName = inclusive_scan_iter_local_scan_striped_krn<
355+
inputT, outputT, n_wi, IterIndexerT, InpIndexerT, OutIndexerT,
356+
TransformerT, ScanOpT, include_initial>;
357+
358+
cgh.parallel_for<KernelName>(ndRange, [=, slm_iscan_tmp =
359+
std::move(slm_iscan_tmp)](
360+
sycl::nd_item<1> it) {
361+
const std::uint32_t lid = it.get_local_linear_id();
362+
const std::uint32_t wg_size = it.get_local_range(0);
363+
364+
const auto &sg = it.get_sub_group();
365+
const std::uint32_t sgSize = sg.get_max_local_range()[0];
366+
const std::size_t sgroup_id = sg.get_group_id()[0];
367+
const std::uint32_t lane_id = sg.get_local_id()[0];
368+
369+
const std::size_t flat_group_id = it.get_group(0);
370+
const std::size_t iter_gid = flat_group_id / acc_groups;
371+
const std::size_t acc_group_id =
372+
flat_group_id - (iter_gid * acc_groups);
373+
374+
const auto &iter_offsets = iter_indexer(iter_gid);
375+
const auto &inp_iter_offset = iter_offsets.get_first_offset();
376+
const auto &out_iter_offset = iter_offsets.get_second_offset();
377+
378+
std::array<outputT, n_wi> local_iscan{};
379+
380+
const std::size_t inp_id0 = acc_group_id * n_wi * wg_size +
381+
sgroup_id * n_wi * sgSize + lane_id;
382+
383+
#pragma unroll
384+
for (nwiT m_wi = 0; m_wi < n_wi; ++m_wi) {
385+
const std::size_t inp_id = inp_id0 + m_wi * sgSize;
386+
if constexpr (!include_initial) {
387+
local_iscan[m_wi] =
388+
(inp_id < acc_nelems)
389+
? transformer(input[inp_iter_offset +
390+
inp_indexer(s0 + s1 * inp_id)])
391+
: identity;
392+
}
393+
else {
394+
// shift input to the left by a single element relative to
395+
// output
396+
local_iscan[m_wi] =
397+
(inp_id < acc_nelems && inp_id > 0)
398+
? transformer(
399+
input[inp_iter_offset +
400+
inp_indexer((s0 + s1 * inp_id) - 1)])
401+
: identity;
402+
}
403+
}
404+
405+
// change layout from striped to blocked
406+
{
407+
{
408+
const std::uint32_t local_offset0 = lid * n_wi;
409+
#pragma unroll
410+
for (std::uint32_t i = 0; i < n_wi; ++i) {
411+
slm_iscan_tmp[local_offset0 + i] = local_iscan[i];
412+
}
413+
414+
it.barrier(sycl::access::fence_space::local_space);
415+
}
416+
417+
{
418+
const std::uint32_t block_offset =
419+
sgroup_id * sgSize * n_wi;
420+
const std::uint32_t disp0 = lane_id * n_wi;
421+
#pragma unroll
422+
for (nwiT i = 0; i < n_wi; ++i) {
423+
const std::uint32_t disp = disp0 + i;
424+
425+
// disp == lane_id1 + i1 * sgSize;
426+
const std::uint32_t i1 = disp / sgSize;
427+
const std::uint32_t lane_id1 = disp - i1 * sgSize;
428+
429+
const std::uint32_t disp_exchanged =
430+
(lane_id1 * n_wi + i1);
431+
432+
local_iscan[i] =
433+
slm_iscan_tmp[block_offset + disp_exchanged];
434+
}
435+
436+
it.barrier(sycl::access::fence_space::local_space);
437+
}
438+
}
439+
440+
#pragma unroll
441+
for (nwiT m_wi = 1; m_wi < n_wi; ++m_wi) {
442+
local_iscan[m_wi] =
443+
scan_op(local_iscan[m_wi], local_iscan[m_wi - 1]);
444+
}
445+
// local_iscan is now result of
446+
// inclusive scan of locally stored inputs
447+
448+
outputT wg_iscan_val;
449+
if constexpr (can_use_inclusive_scan_over_group<ScanOpT,
450+
outputT>::value)
451+
{
452+
wg_iscan_val = sycl::inclusive_scan_over_group(
453+
it.get_group(), local_iscan.back(), scan_op, identity);
454+
}
455+
else {
456+
wg_iscan_val = su_ns::custom_inclusive_scan_over_group(
457+
it.get_group(), slm_iscan_tmp, local_iscan.back(), scan_op);
458+
// ensure all finished reading from SLM, to avoid race condition
459+
// with subsequent writes into SLM
460+
it.barrier(sycl::access::fence_space::local_space);
461+
}
462+
463+
slm_iscan_tmp[(lid + 1) % wg_size] = wg_iscan_val;
464+
it.barrier(sycl::access::fence_space::local_space);
465+
const outputT modifier = (lid == 0) ? identity : slm_iscan_tmp[lid];
466+
467+
#pragma unroll
468+
for (nwiT m_wi = 0; m_wi < n_wi; ++m_wi) {
469+
local_iscan[m_wi] = scan_op(local_iscan[m_wi], modifier);
470+
}
471+
472+
it.barrier(sycl::access::fence_space::local_space);
473+
474+
// convert back to blocked layout
475+
{
476+
{
477+
const std::uint32_t local_offset0 = lid * n_wi;
478+
#pragma unroll
479+
for (nwiT m_wi = 0; m_wi < n_wi; ++m_wi) {
480+
slm_iscan_tmp[local_offset0 + m_wi] = local_iscan[m_wi];
481+
}
482+
483+
it.barrier(sycl::access::fence_space::local_space);
484+
}
485+
}
486+
487+
{
488+
const std::uint32_t block_offset =
489+
sgroup_id * sgSize * n_wi + lane_id;
490+
#pragma unroll
491+
for (nwiT m_wi = 0; m_wi < n_wi; ++m_wi) {
492+
const std::uint32_t m_wi_scaled = m_wi * sgSize;
493+
const std::size_t out_id = inp_id0 + m_wi_scaled;
494+
if (out_id < acc_nelems) {
495+
output[out_iter_offset + out_indexer(out_id)] =
496+
slm_iscan_tmp[block_offset + m_wi_scaled];
497+
}
498+
}
499+
}
500+
});
501+
});
502+
503+
return inc_scan_phase1_ev;
504+
}
505+
506+
template <typename inputT,
507+
typename outputT,
508+
nwiT n_wi,
509+
typename IterIndexerT,
510+
typename InpIndexerT,
511+
typename OutIndexerT,
512+
typename TransformerT,
513+
typename ScanOpT,
514+
bool include_initial = false>
515+
sycl::event
516+
inclusive_scan_base_step(sycl::queue &exec_q,
517+
const std::uint32_t wg_size,
518+
const std::size_t iter_nelems,
519+
const std::size_t acc_nelems,
520+
const inputT *input,
521+
outputT *output,
522+
const std::size_t s0,
523+
const std::size_t s1,
524+
const IterIndexerT &iter_indexer,
525+
const InpIndexerT &inp_indexer,
526+
const OutIndexerT &out_indexer,
527+
TransformerT transformer,
528+
const ScanOpT &scan_op,
529+
outputT identity,
530+
std::size_t &acc_groups,
531+
const std::vector<sycl::event> &depends = {})
532+
{
533+
// For small stride use striped load/store.
534+
// Threshold value chosen experimentally.
535+
if (s1 <= 16) {
536+
return inclusive_scan_base_step_striped<
537+
inputT, outputT, n_wi, IterIndexerT, InpIndexerT, OutIndexerT,
538+
TransformerT, ScanOpT, include_initial>(
539+
exec_q, wg_size, iter_nelems, acc_nelems, input, output, s0, s1,
540+
iter_indexer, inp_indexer, out_indexer, transformer, scan_op,
541+
identity, acc_groups, depends);
542+
}
543+
else {
544+
return inclusive_scan_base_step_blocked<
545+
inputT, outputT, n_wi, IterIndexerT, InpIndexerT, OutIndexerT,
546+
TransformerT, ScanOpT, include_initial>(
547+
exec_q, wg_size, iter_nelems, acc_nelems, input, output, s0, s1,
548+
iter_indexer, inp_indexer, out_indexer, transformer, scan_op,
549+
identity, acc_groups, depends);
550+
}
551+
}
552+
299553
template <typename inputT,
300554
typename outputT,
301555
nwiT n_wi,

0 commit comments

Comments
 (0)