@@ -44,8 +44,8 @@ namespace copy_as_contig
44
44
45
45
template <typename T,
46
46
typename IndexerT,
47
- std::uint32_t vec_sz = 4u ,
48
- std::uint32_t n_vecs = 2u ,
47
+ std::uint8_t vec_sz = 4u ,
48
+ std::uint8_t n_vecs = 2u ,
49
49
bool enable_sg_loadstore = true >
50
50
class CopyAsCContigFunctor
51
51
{
@@ -68,25 +68,23 @@ class CopyAsCContigFunctor
68
68
{
69
69
static_assert (vec_sz > 0 );
70
70
static_assert (n_vecs > 0 );
71
- static_assert (vec_sz * n_vecs < (std::uint32_t (1 ) << 8 ));
72
71
73
- constexpr std::uint8_t elems_per_wi =
74
- static_cast <std::uint8_t >(vec_sz * n_vecs);
72
+ constexpr std::uint8_t elems_per_wi = vec_sz * n_vecs;
75
73
76
74
using dpctl::tensor::type_utils::is_complex;
77
75
if constexpr (!enable_sg_loadstore || is_complex<T>::value) {
78
76
const std::uint16_t sgSize =
79
77
ndit.get_sub_group ().get_local_range ()[0 ];
80
78
const std::size_t gid = ndit.get_global_linear_id ();
81
79
82
- // base = (gid / sgSize) * sgSize * elems_per_wi + (gid % sgSize)
80
+ // start = (gid / sgSize) * sgSize * elems_per_wi + (gid % sgSize)
83
81
// gid % sgSize == gid - (gid / sgSize) * sgSize
84
- const std::size_t elems_per_sg = sgSize * ( elems_per_wi - 1 ) ;
85
- const std::size_t base = (gid / sgSize) * elems_per_sg + gid;
86
- const std:: size_t offset_max =
87
- std::min (nelems, base + sgSize * elems_per_wi );
82
+ const std::size_t elems_per_sg = sgSize * elems_per_wi;
83
+ const std::size_t start =
84
+ (gid / sgSize) * (elems_per_sg - sgSize) + gid;
85
+ const std:: size_t end = std::min (nelems, start + elems_per_sg );
88
86
89
- for (size_t offset = base ; offset < offset_max ; offset += sgSize) {
87
+ for (size_t offset = start ; offset < end ; offset += sgSize) {
90
88
auto src_offset = src_indexer (offset);
91
89
dst_p[offset] = src_p[src_offset];
92
90
}
@@ -132,8 +130,8 @@ class CopyAsCContigFunctor
132
130
133
131
template <typename T,
134
132
typename IndexerT,
135
- std::uint32_t vec_sz,
136
- std::uint32_t n_vecs,
133
+ std::uint8_t vec_sz,
134
+ std::uint8_t n_vecs,
137
135
bool enable_sg_load,
138
136
typename KernelName>
139
137
sycl::event submit_c_contiguous_copy (sycl::queue &exec_q,
@@ -145,7 +143,6 @@ sycl::event submit_c_contiguous_copy(sycl::queue &exec_q,
145
143
{
146
144
static_assert (vec_sz > 0 );
147
145
static_assert (n_vecs > 0 );
148
- static_assert (vec_sz * n_vecs < (std::uint32_t (1 ) << 8 ));
149
146
150
147
constexpr std::size_t preferred_lws = 256 ;
151
148
@@ -187,8 +184,8 @@ sycl::event submit_c_contiguous_copy(sycl::queue &exec_q,
187
184
188
185
template <typename T,
189
186
typename IndexT,
190
- std::uint32_t vec_sz,
191
- std::uint32_t n_vecs,
187
+ std::uint8_t vec_sz,
188
+ std::uint8_t n_vecs,
192
189
bool enable_sgload>
193
190
class as_contig_krn ;
194
191
@@ -210,8 +207,8 @@ as_c_contiguous_array_generic_impl(sycl::queue &exec_q,
210
207
using IndexerT = dpctl::tensor::offset_utils::StridedIndexer;
211
208
const IndexerT src_indexer (nd, ssize_t (0 ), shape_and_strides);
212
209
213
- constexpr std::uint32_t vec_sz = 4u ;
214
- constexpr std::uint32_t n_vecs = 2u ;
210
+ constexpr std::uint8_t vec_sz = 4u ;
211
+ constexpr std::uint8_t n_vecs = 2u ;
215
212
216
213
using dpctl::tensor::kernels::alignment_utils::
217
214
disabled_sg_loadstore_wrapper_krn;
@@ -256,8 +253,8 @@ template <typename fnT, typename T> struct AsCContigFactory
256
253
257
254
template <typename T,
258
255
typename IndexerT,
259
- std::uint32_t tile_size,
260
- std::uint32_t n_lines>
256
+ std::uint16_t tile_size,
257
+ std::uint16_t n_lines>
261
258
class as_contig_batch_of_square_matrices_krn ;
262
259
263
260
namespace detail
@@ -283,14 +280,14 @@ sycl::event as_c_contiguous_batch_of_square_matrices_impl(
283
280
const T *src_tp = reinterpret_cast <const T *>(src_p);
284
281
T *dst_tp = reinterpret_cast <T *>(dst_p);
285
282
286
- constexpr std::uint32_t private_tile_size = 4 ;
287
- constexpr std::uint32_t n_lines = 2 ;
288
- constexpr std::uint32_t block_size =
283
+ constexpr std::uint16_t private_tile_size = 4 ;
284
+ constexpr std::uint16_t n_lines = 2 ;
285
+ constexpr std::uint16_t block_size =
289
286
n_lines * private_tile_size * private_tile_size;
290
287
291
- constexpr std::uint32_t lws0 = block_size;
292
- constexpr std::uint32_t lws1 = n_lines;
293
- constexpr std::uint32_t nelems_per_wi = (block_size / lws1);
288
+ constexpr std::uint16_t lws0 = block_size;
289
+ constexpr std::uint16_t lws1 = n_lines;
290
+ constexpr std::uint16_t nelems_per_wi = (block_size / lws1);
294
291
295
292
static_assert (nelems_per_wi * lws1 == block_size);
296
293
static_assert (nelems_per_wi == private_tile_size * private_tile_size);
@@ -377,40 +374,41 @@ sycl::event as_c_contiguous_batch_of_square_matrices_impl(
377
374
std::array<T, nelems_per_wi> private_block_01 = {T (0 )};
378
375
std::array<T, nelems_per_wi> private_block_10 = {T (0 )};
379
376
380
- // 0 <= lid_lin < lws0 * lws1 == (block_size * block_size /
381
- // nelems_per_wi) == (block_size/private_tile_size)**2
382
- constexpr std::uint32_t n_private_tiles_per_axis =
377
+ // 0 <= lid_lin < lws0 * lws1 ==
378
+ // (block_size * block_size / nelems_per_wi) ==
379
+ // (block_size/private_tile_size)**2
380
+ constexpr std::uint16_t n_private_tiles_per_axis =
383
381
block_size / private_tile_size;
384
- const std::uint32_t local_tile_id0 =
382
+ const std::uint16_t local_tile_id0 =
385
383
lid_lin / n_private_tiles_per_axis;
386
- const std::uint32_t local_tile_id1 =
384
+ const std::uint16_t local_tile_id1 =
387
385
lid_lin - local_tile_id0 * n_private_tiles_per_axis;
388
386
389
387
if (local_tile_id0 <= local_tile_id1) {
390
- for (std::uint32_t pr_i0 = 0 ; pr_i0 < private_tile_size;
388
+ for (std::uint16_t pr_i0 = 0 ; pr_i0 < private_tile_size;
391
389
++pr_i0)
392
390
{
393
- for (std::uint32_t pr_i1 = 0 ; pr_i1 < private_tile_size;
391
+ for (std::uint16_t pr_i1 = 0 ; pr_i1 < private_tile_size;
394
392
++pr_i1)
395
393
{
396
- const std::uint32_t t0_offset =
394
+ const std::uint16_t t0_offset =
397
395
local_tile_id0 * private_tile_size;
398
- const std::uint32_t t1_offset =
396
+ const std::uint16_t t1_offset =
399
397
local_tile_id1 * private_tile_size;
400
398
401
- const std::uint32_t pr_offset =
399
+ const std::uint16_t pr_offset =
402
400
pr_i1 * private_tile_size + pr_i0;
403
- const std::uint32_t rel_offset =
401
+ const std::uint16_t rel_offset =
404
402
pr_i0 + pr_i1 * block_size;
405
403
406
404
// read (local_tile_id0, local_tile_id1)
407
- const std::uint32_t local_01_offset =
405
+ const std::uint16_t local_01_offset =
408
406
(t0_offset + t1_offset * block_size) + rel_offset;
409
407
private_block_01[pr_offset] =
410
408
local_block[local_01_offset];
411
409
412
410
// read (local_tile_id1, local_tile_id0)
413
- const std::uint32_t local_10_offset =
411
+ const std::uint16_t local_10_offset =
414
412
(t1_offset + t0_offset * block_size) + rel_offset;
415
413
private_block_10[pr_offset] =
416
414
local_block[local_10_offset];
@@ -422,20 +420,20 @@ sycl::event as_c_contiguous_batch_of_square_matrices_impl(
422
420
sycl::memory_scope::work_group);
423
421
424
422
if (local_tile_id0 <= local_tile_id1) {
425
- for (std::uint32_t pr_i0 = 0 ; pr_i0 < private_tile_size;
423
+ for (std::uint16_t pr_i0 = 0 ; pr_i0 < private_tile_size;
426
424
++pr_i0)
427
425
{
428
- for (std::uint32_t pr_i1 = 0 ; pr_i1 < private_tile_size;
426
+ for (std::uint16_t pr_i1 = 0 ; pr_i1 < private_tile_size;
429
427
++pr_i1)
430
428
{
431
- const std::uint32_t t0_offset =
429
+ const std::uint16_t t0_offset =
432
430
local_tile_id0 * private_tile_size;
433
- const std::uint32_t t1_offset =
431
+ const std::uint16_t t1_offset =
434
432
local_tile_id1 * private_tile_size;
435
- const std::uint32_t pr_offset =
433
+ const std::uint16_t pr_offset =
436
434
pr_i0 * private_tile_size + pr_i1;
437
435
438
- const std::uint32_t rel_offset =
436
+ const std::uint16_t rel_offset =
439
437
pr_i0 + pr_i1 * block_size;
440
438
441
439
// write back permuted private blocks
@@ -444,7 +442,7 @@ sycl::event as_c_contiguous_batch_of_square_matrices_impl(
444
442
local_block[local_01_offset] =
445
443
private_block_10[pr_offset];
446
444
447
- const std::uint32_t local_10_offset =
445
+ const std::uint16_t local_10_offset =
448
446
(t1_offset + t0_offset * block_size) + rel_offset;
449
447
local_block[local_10_offset] =
450
448
private_block_01[pr_offset];
@@ -461,8 +459,8 @@ sycl::event as_c_contiguous_batch_of_square_matrices_impl(
461
459
const std::size_t dst_tile_start1 = src_tile_start1;
462
460
463
461
if (local_dim0 == block_size && local_dim1 == block_size) {
464
- const std::uint32_t dst_i0 = src_i1;
465
- const std::uint32_t dst_i1 = src_i0;
462
+ const std::uint16_t dst_i0 = src_i1;
463
+ const std::uint16_t dst_i1 = src_i0;
466
464
467
465
const std::size_t dst_gid0 = (dst_tile_start0 + dst_i0);
468
466
const std::size_t dst_gid1 = (dst_tile_start1 + dst_i1);
@@ -471,11 +469,11 @@ sycl::event as_c_contiguous_batch_of_square_matrices_impl(
471
469
dst_batch_offset + dst_gid0 * dst_stride + dst_gid1 * 1 ;
472
470
const std::size_t pr_step_dst = lws1 * dst_stride;
473
471
474
- const std::uint32_t _local_offset0 =
472
+ const std::uint16_t _local_offset0 =
475
473
dst_i0 * block_size + dst_i1;
476
- const std::uint32_t _pr_step_local = lws1 * block_size;
474
+ const std::uint16_t _pr_step_local = lws1 * block_size;
477
475
478
- for (std::uint32_t pr_id = 0 ; pr_id < nelems_per_wi; ++pr_id) {
476
+ for (std::uint16_t pr_id = 0 ; pr_id < nelems_per_wi; ++pr_id) {
479
477
if ((dst_gid1 < n) && ((dst_gid0 + pr_id * lws1) < n)) {
480
478
dst_tp[dst_offset0 + pr_step_dst * pr_id] =
481
479
local_block[_local_offset0 +
@@ -485,24 +483,24 @@ sycl::event as_c_contiguous_batch_of_square_matrices_impl(
485
483
}
486
484
else {
487
485
// map local_linear_id into (local_dim0, local_dim1)
488
- for (std::uint32_t el_id = lid_lin;
486
+ for (std::uint16_t el_id = lid_lin;
489
487
el_id < local_dim0 * local_dim1; el_id += lws0 * lws1)
490
488
{
491
489
492
490
// 0 <= local_i0 < local_dim0
493
- const std::uint32_t loc_i0 = el_id / local_dim1;
491
+ const std::uint16_t loc_i0 = el_id / local_dim1;
494
492
// 0 <= local_i1 < local_dim1
495
- const std::uint32_t loc_i1 = el_id - loc_i0 * local_dim1;
493
+ const std::uint16_t loc_i1 = el_id - loc_i0 * local_dim1;
496
494
497
- const std::uint32_t dst_i0 = loc_i0;
498
- const std::uint32_t dst_i1 = loc_i1;
495
+ const std::uint16_t dst_i0 = loc_i0;
496
+ const std::uint16_t dst_i1 = loc_i1;
499
497
500
498
const std::size_t dst_gid0 = (dst_tile_start0 + dst_i0);
501
499
const std::size_t dst_gid1 = (dst_tile_start1 + dst_i1);
502
500
503
501
const std::size_t dst_offset =
504
502
dst_batch_offset + dst_gid0 * dst_stride + dst_gid1 * 1 ;
505
- const std::uint32_t local_offset =
503
+ const std::uint16_t local_offset =
506
504
loc_i0 * block_size + loc_i1;
507
505
508
506
if ((dst_gid1 < n) && (dst_gid0 < n)) {
0 commit comments