Skip to content

Commit ce43706

Browse files
Introduce SYCL utilities sub_group_load/sub_group_store
This would resolve compiler warnings about deprecated sub_group::load, sub_group::store methods. (Warnings in build with nightly SYCLOS DPC++ bundle should be fixed now). Additionally, replaced unsigned int type for template parameters with std::uint8_t
1 parent 70a0a3f commit ce43706

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

76 files changed

+761
-702
lines changed

dpctl/tensor/libtensor/include/kernels/copy_and_cast.hpp

Lines changed: 34 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
#include "dpctl_tensor_types.hpp"
3232
#include "kernels/alignment.hpp"
3333
#include "utils/offset_utils.hpp"
34+
#include "utils/sycl_utils.hpp"
3435
#include "utils/type_utils.hpp"
3536

3637
namespace dpctl
@@ -49,13 +50,16 @@ using dpctl::tensor::kernels::alignment_utils::
4950
using dpctl::tensor::kernels::alignment_utils::is_aligned;
5051
using dpctl::tensor::kernels::alignment_utils::required_alignment;
5152

53+
using dpctl::tensor::sycl_utils::sub_group_load;
54+
using dpctl::tensor::sycl_utils::sub_group_store;
55+
5256
template <typename srcT, typename dstT, typename IndexerT>
5357
class copy_cast_generic_kernel;
5458

5559
template <typename srcT,
5660
typename dstT,
57-
unsigned int vec_sz,
58-
unsigned int n_vecs>
61+
std::uint8_t vec_sz,
62+
std::uint8_t n_vecs>
5963
class copy_cast_contig_kernel;
6064

6165
template <typename srcT, typename dstT, typename IndexerT>
@@ -207,8 +211,8 @@ template <typename fnT, typename D, typename S> struct CopyAndCastGenericFactory
207211
template <typename srcT,
208212
typename dstT,
209213
typename CastFnT,
210-
int vec_sz = 4,
211-
int n_vecs = 2,
214+
std::uint8_t vec_sz = 4u,
215+
std::uint8_t n_vecs = 2u,
212216
bool enable_sg_loadstore = true>
213217
class ContigCopyFunctor
214218
{
@@ -227,58 +231,55 @@ class ContigCopyFunctor
227231
{
228232
CastFnT fn{};
229233

234+
constexpr std::uint8_t elems_per_wi = n_vecs * vec_sz;
235+
230236
using dpctl::tensor::type_utils::is_complex;
231237
if constexpr (!enable_sg_loadstore || is_complex<srcT>::value ||
232238
is_complex<dstT>::value)
233239
{
234-
std::uint8_t sgSize = ndit.get_sub_group().get_local_range()[0];
235-
size_t base = ndit.get_global_linear_id();
236-
237-
base = (base / sgSize) * sgSize * n_vecs * vec_sz + (base % sgSize);
238-
for (size_t offset = base;
239-
offset < std::min(nelems, base + sgSize * (n_vecs * vec_sz));
240-
offset += sgSize)
241-
{
240+
std::uint16_t sgSize = ndit.get_sub_group().get_local_range()[0];
241+
const size_t gid = ndit.get_global_linear_id();
242+
243+
// start = (gid / sgSize) * elems_per_sg + (gid % sgSize)
244+
const std::uint16_t elems_per_sg = sgSize * elems_per_wi;
245+
const size_t start = (gid / sgSize) * (elems_per_sg - sgSize) + gid;
246+
const size_t end = std::min(nelems, start + elems_per_sg);
247+
for (size_t offset = start; offset < end; offset += sgSize) {
242248
dst_p[offset] = fn(src_p[offset]);
243249
}
244250
}
245251
else {
246252
auto sg = ndit.get_sub_group();
247-
std::uint8_t sgSize = sg.get_local_range()[0];
248-
std::uint8_t max_sgSize = sg.get_max_local_range()[0];
249-
size_t base = n_vecs * vec_sz *
250-
(ndit.get_group(0) * ndit.get_local_range(0) +
251-
sg.get_group_id()[0] * max_sgSize);
252-
253-
if (base + n_vecs * vec_sz * sgSize < nelems &&
254-
sgSize == max_sgSize)
255-
{
256-
sycl::vec<srcT, vec_sz> src_vec;
253+
const std::uint16_t sgSize = sg.get_max_local_range()[0];
254+
const size_t base =
255+
elems_per_wi * (ndit.get_group(0) * ndit.get_local_range(0) +
256+
sg.get_group_id()[0] * sgSize);
257+
258+
if (base + elems_per_wi * sgSize < nelems) {
257259
sycl::vec<dstT, vec_sz> dst_vec;
258260

259261
#pragma unroll
260262
for (std::uint8_t it = 0; it < n_vecs * vec_sz; it += vec_sz) {
263+
const size_t offset = base + it * sgSize;
261264
auto src_multi_ptr = sycl::address_space_cast<
262265
sycl::access::address_space::global_space,
263-
sycl::access::decorated::yes>(
264-
&src_p[base + it * sgSize]);
266+
sycl::access::decorated::yes>(&src_p[offset]);
265267
auto dst_multi_ptr = sycl::address_space_cast<
266268
sycl::access::address_space::global_space,
267-
sycl::access::decorated::yes>(
268-
&dst_p[base + it * sgSize]);
269+
sycl::access::decorated::yes>(&dst_p[offset]);
269270

270-
src_vec = sg.load<vec_sz>(src_multi_ptr);
271+
const sycl::vec<srcT, vec_sz> src_vec =
272+
sub_group_load<vec_sz>(sg, src_multi_ptr);
271273
#pragma unroll
272274
for (std::uint8_t k = 0; k < vec_sz; k++) {
273275
dst_vec[k] = fn(src_vec[k]);
274276
}
275-
sg.store<vec_sz>(dst_multi_ptr, dst_vec);
277+
sub_group_store<vec_sz>(sg, dst_vec, dst_multi_ptr);
276278
}
277279
}
278280
else {
279-
for (size_t k = base + sg.get_local_id()[0]; k < nelems;
280-
k += sgSize)
281-
{
281+
const size_t start = base + sg.get_local_id()[0];
282+
for (size_t k = start; k < nelems; k += sgSize) {
282283
dst_p[k] = fn(src_p[k]);
283284
}
284285
}
@@ -332,8 +333,8 @@ sycl::event copy_and_cast_contig_impl(sycl::queue &q,
332333
dstTy *dst_tp = reinterpret_cast<dstTy *>(dst_cp);
333334

334335
size_t lws = 64;
335-
constexpr unsigned int vec_sz = 4;
336-
constexpr unsigned int n_vecs = 2;
336+
constexpr std::uint32_t vec_sz = 4;
337+
constexpr std::uint32_t n_vecs = 2;
337338
const size_t n_groups =
338339
((nelems + lws * n_vecs * vec_sz - 1) / (lws * n_vecs * vec_sz));
339340
const auto gws_range = sycl::range<1>(n_groups * lws);

dpctl/tensor/libtensor/include/kernels/copy_as_contiguous.hpp

Lines changed: 55 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -44,8 +44,8 @@ namespace copy_as_contig
4444

4545
template <typename T,
4646
typename IndexerT,
47-
std::uint32_t vec_sz = 4u,
48-
std::uint32_t n_vecs = 2u,
47+
std::uint8_t vec_sz = 4u,
48+
std::uint8_t n_vecs = 2u,
4949
bool enable_sg_loadstore = true>
5050
class CopyAsCContigFunctor
5151
{
@@ -68,25 +68,23 @@ class CopyAsCContigFunctor
6868
{
6969
static_assert(vec_sz > 0);
7070
static_assert(n_vecs > 0);
71-
static_assert(vec_sz * n_vecs < (std::uint32_t(1) << 8));
7271

73-
constexpr std::uint8_t elems_per_wi =
74-
static_cast<std::uint8_t>(vec_sz * n_vecs);
72+
constexpr std::uint8_t elems_per_wi = vec_sz * n_vecs;
7573

7674
using dpctl::tensor::type_utils::is_complex;
7775
if constexpr (!enable_sg_loadstore || is_complex<T>::value) {
7876
const std::uint16_t sgSize =
7977
ndit.get_sub_group().get_local_range()[0];
8078
const std::size_t gid = ndit.get_global_linear_id();
8179

82-
// base = (gid / sgSize) * sgSize * elems_per_wi + (gid % sgSize)
80+
// start = (gid / sgSize) * sgSize * elems_per_wi + (gid % sgSize)
8381
// gid % sgSize == gid - (gid / sgSize) * sgSize
84-
const std::size_t elems_per_sg = sgSize * (elems_per_wi - 1);
85-
const std::size_t base = (gid / sgSize) * elems_per_sg + gid;
86-
const std::size_t offset_max =
87-
std::min(nelems, base + sgSize * elems_per_wi);
82+
const std::size_t elems_per_sg = sgSize * elems_per_wi;
83+
const std::size_t start =
84+
(gid / sgSize) * (elems_per_sg - sgSize) + gid;
85+
const std::size_t end = std::min(nelems, start + elems_per_sg);
8886

89-
for (size_t offset = base; offset < offset_max; offset += sgSize) {
87+
for (size_t offset = start; offset < end; offset += sgSize) {
9088
auto src_offset = src_indexer(offset);
9189
dst_p[offset] = src_p[src_offset];
9290
}
@@ -132,8 +130,8 @@ class CopyAsCContigFunctor
132130

133131
template <typename T,
134132
typename IndexerT,
135-
std::uint32_t vec_sz,
136-
std::uint32_t n_vecs,
133+
std::uint8_t vec_sz,
134+
std::uint8_t n_vecs,
137135
bool enable_sg_load,
138136
typename KernelName>
139137
sycl::event submit_c_contiguous_copy(sycl::queue &exec_q,
@@ -145,7 +143,6 @@ sycl::event submit_c_contiguous_copy(sycl::queue &exec_q,
145143
{
146144
static_assert(vec_sz > 0);
147145
static_assert(n_vecs > 0);
148-
static_assert(vec_sz * n_vecs < (std::uint32_t(1) << 8));
149146

150147
constexpr std::size_t preferred_lws = 256;
151148

@@ -187,8 +184,8 @@ sycl::event submit_c_contiguous_copy(sycl::queue &exec_q,
187184

188185
template <typename T,
189186
typename IndexT,
190-
std::uint32_t vec_sz,
191-
std::uint32_t n_vecs,
187+
std::uint8_t vec_sz,
188+
std::uint8_t n_vecs,
192189
bool enable_sgload>
193190
class as_contig_krn;
194191

@@ -210,8 +207,8 @@ as_c_contiguous_array_generic_impl(sycl::queue &exec_q,
210207
using IndexerT = dpctl::tensor::offset_utils::StridedIndexer;
211208
const IndexerT src_indexer(nd, ssize_t(0), shape_and_strides);
212209

213-
constexpr std::uint32_t vec_sz = 4u;
214-
constexpr std::uint32_t n_vecs = 2u;
210+
constexpr std::uint8_t vec_sz = 4u;
211+
constexpr std::uint8_t n_vecs = 2u;
215212

216213
using dpctl::tensor::kernels::alignment_utils::
217214
disabled_sg_loadstore_wrapper_krn;
@@ -256,8 +253,8 @@ template <typename fnT, typename T> struct AsCContigFactory
256253

257254
template <typename T,
258255
typename IndexerT,
259-
std::uint32_t tile_size,
260-
std::uint32_t n_lines>
256+
std::uint16_t tile_size,
257+
std::uint16_t n_lines>
261258
class as_contig_batch_of_square_matrices_krn;
262259

263260
namespace detail
@@ -283,14 +280,14 @@ sycl::event as_c_contiguous_batch_of_square_matrices_impl(
283280
const T *src_tp = reinterpret_cast<const T *>(src_p);
284281
T *dst_tp = reinterpret_cast<T *>(dst_p);
285282

286-
constexpr std::uint32_t private_tile_size = 4;
287-
constexpr std::uint32_t n_lines = 2;
288-
constexpr std::uint32_t block_size =
283+
constexpr std::uint16_t private_tile_size = 4;
284+
constexpr std::uint16_t n_lines = 2;
285+
constexpr std::uint16_t block_size =
289286
n_lines * private_tile_size * private_tile_size;
290287

291-
constexpr std::uint32_t lws0 = block_size;
292-
constexpr std::uint32_t lws1 = n_lines;
293-
constexpr std::uint32_t nelems_per_wi = (block_size / lws1);
288+
constexpr std::uint16_t lws0 = block_size;
289+
constexpr std::uint16_t lws1 = n_lines;
290+
constexpr std::uint16_t nelems_per_wi = (block_size / lws1);
294291

295292
static_assert(nelems_per_wi * lws1 == block_size);
296293
static_assert(nelems_per_wi == private_tile_size * private_tile_size);
@@ -377,40 +374,41 @@ sycl::event as_c_contiguous_batch_of_square_matrices_impl(
377374
std::array<T, nelems_per_wi> private_block_01 = {T(0)};
378375
std::array<T, nelems_per_wi> private_block_10 = {T(0)};
379376

380-
// 0 <= lid_lin < lws0 * lws1 == (block_size * block_size /
381-
// nelems_per_wi) == (block_size/private_tile_size)**2
382-
constexpr std::uint32_t n_private_tiles_per_axis =
377+
// 0 <= lid_lin < lws0 * lws1 ==
378+
// (block_size * block_size / nelems_per_wi) ==
379+
// (block_size/private_tile_size)**2
380+
constexpr std::uint16_t n_private_tiles_per_axis =
383381
block_size / private_tile_size;
384-
const std::uint32_t local_tile_id0 =
382+
const std::uint16_t local_tile_id0 =
385383
lid_lin / n_private_tiles_per_axis;
386-
const std::uint32_t local_tile_id1 =
384+
const std::uint16_t local_tile_id1 =
387385
lid_lin - local_tile_id0 * n_private_tiles_per_axis;
388386

389387
if (local_tile_id0 <= local_tile_id1) {
390-
for (std::uint32_t pr_i0 = 0; pr_i0 < private_tile_size;
388+
for (std::uint16_t pr_i0 = 0; pr_i0 < private_tile_size;
391389
++pr_i0)
392390
{
393-
for (std::uint32_t pr_i1 = 0; pr_i1 < private_tile_size;
391+
for (std::uint16_t pr_i1 = 0; pr_i1 < private_tile_size;
394392
++pr_i1)
395393
{
396-
const std::uint32_t t0_offset =
394+
const std::uint16_t t0_offset =
397395
local_tile_id0 * private_tile_size;
398-
const std::uint32_t t1_offset =
396+
const std::uint16_t t1_offset =
399397
local_tile_id1 * private_tile_size;
400398

401-
const std::uint32_t pr_offset =
399+
const std::uint16_t pr_offset =
402400
pr_i1 * private_tile_size + pr_i0;
403-
const std::uint32_t rel_offset =
401+
const std::uint16_t rel_offset =
404402
pr_i0 + pr_i1 * block_size;
405403

406404
// read (local_tile_id0, local_tile_id1)
407-
const std::uint32_t local_01_offset =
405+
const std::uint16_t local_01_offset =
408406
(t0_offset + t1_offset * block_size) + rel_offset;
409407
private_block_01[pr_offset] =
410408
local_block[local_01_offset];
411409

412410
// read (local_tile_id1, local_tile_id0)
413-
const std::uint32_t local_10_offset =
411+
const std::uint16_t local_10_offset =
414412
(t1_offset + t0_offset * block_size) + rel_offset;
415413
private_block_10[pr_offset] =
416414
local_block[local_10_offset];
@@ -422,20 +420,20 @@ sycl::event as_c_contiguous_batch_of_square_matrices_impl(
422420
sycl::memory_scope::work_group);
423421

424422
if (local_tile_id0 <= local_tile_id1) {
425-
for (std::uint32_t pr_i0 = 0; pr_i0 < private_tile_size;
423+
for (std::uint16_t pr_i0 = 0; pr_i0 < private_tile_size;
426424
++pr_i0)
427425
{
428-
for (std::uint32_t pr_i1 = 0; pr_i1 < private_tile_size;
426+
for (std::uint16_t pr_i1 = 0; pr_i1 < private_tile_size;
429427
++pr_i1)
430428
{
431-
const std::uint32_t t0_offset =
429+
const std::uint16_t t0_offset =
432430
local_tile_id0 * private_tile_size;
433-
const std::uint32_t t1_offset =
431+
const std::uint16_t t1_offset =
434432
local_tile_id1 * private_tile_size;
435-
const std::uint32_t pr_offset =
433+
const std::uint16_t pr_offset =
436434
pr_i0 * private_tile_size + pr_i1;
437435

438-
const std::uint32_t rel_offset =
436+
const std::uint16_t rel_offset =
439437
pr_i0 + pr_i1 * block_size;
440438

441439
// write back permuted private blocks
@@ -444,7 +442,7 @@ sycl::event as_c_contiguous_batch_of_square_matrices_impl(
444442
local_block[local_01_offset] =
445443
private_block_10[pr_offset];
446444

447-
const std::uint32_t local_10_offset =
445+
const std::uint16_t local_10_offset =
448446
(t1_offset + t0_offset * block_size) + rel_offset;
449447
local_block[local_10_offset] =
450448
private_block_01[pr_offset];
@@ -461,8 +459,8 @@ sycl::event as_c_contiguous_batch_of_square_matrices_impl(
461459
const std::size_t dst_tile_start1 = src_tile_start1;
462460

463461
if (local_dim0 == block_size && local_dim1 == block_size) {
464-
const std::uint32_t dst_i0 = src_i1;
465-
const std::uint32_t dst_i1 = src_i0;
462+
const std::uint16_t dst_i0 = src_i1;
463+
const std::uint16_t dst_i1 = src_i0;
466464

467465
const std::size_t dst_gid0 = (dst_tile_start0 + dst_i0);
468466
const std::size_t dst_gid1 = (dst_tile_start1 + dst_i1);
@@ -471,11 +469,11 @@ sycl::event as_c_contiguous_batch_of_square_matrices_impl(
471469
dst_batch_offset + dst_gid0 * dst_stride + dst_gid1 * 1;
472470
const std::size_t pr_step_dst = lws1 * dst_stride;
473471

474-
const std::uint32_t _local_offset0 =
472+
const std::uint16_t _local_offset0 =
475473
dst_i0 * block_size + dst_i1;
476-
const std::uint32_t _pr_step_local = lws1 * block_size;
474+
const std::uint16_t _pr_step_local = lws1 * block_size;
477475

478-
for (std::uint32_t pr_id = 0; pr_id < nelems_per_wi; ++pr_id) {
476+
for (std::uint16_t pr_id = 0; pr_id < nelems_per_wi; ++pr_id) {
479477
if ((dst_gid1 < n) && ((dst_gid0 + pr_id * lws1) < n)) {
480478
dst_tp[dst_offset0 + pr_step_dst * pr_id] =
481479
local_block[_local_offset0 +
@@ -485,24 +483,24 @@ sycl::event as_c_contiguous_batch_of_square_matrices_impl(
485483
}
486484
else {
487485
// map local_linear_id into (local_dim0, local_dim1)
488-
for (std::uint32_t el_id = lid_lin;
486+
for (std::uint16_t el_id = lid_lin;
489487
el_id < local_dim0 * local_dim1; el_id += lws0 * lws1)
490488
{
491489

492490
// 0 <= local_i0 < local_dim0
493-
const std::uint32_t loc_i0 = el_id / local_dim1;
491+
const std::uint16_t loc_i0 = el_id / local_dim1;
494492
// 0 <= local_i1 < local_dim1
495-
const std::uint32_t loc_i1 = el_id - loc_i0 * local_dim1;
493+
const std::uint16_t loc_i1 = el_id - loc_i0 * local_dim1;
496494

497-
const std::uint32_t dst_i0 = loc_i0;
498-
const std::uint32_t dst_i1 = loc_i1;
495+
const std::uint16_t dst_i0 = loc_i0;
496+
const std::uint16_t dst_i1 = loc_i1;
499497

500498
const std::size_t dst_gid0 = (dst_tile_start0 + dst_i0);
501499
const std::size_t dst_gid1 = (dst_tile_start1 + dst_i1);
502500

503501
const std::size_t dst_offset =
504502
dst_batch_offset + dst_gid0 * dst_stride + dst_gid1 * 1;
505-
const std::uint32_t local_offset =
503+
const std::uint16_t local_offset =
506504
loc_i0 * block_size + loc_i1;
507505

508506
if ((dst_gid1 < n) && (dst_gid0 < n)) {

0 commit comments

Comments
 (0)