@@ -979,13 +979,16 @@ class GemmFunctorThreadK
979
979
for (size_t q = 0 ; q < n_wi * delta_k; q += delta_k) {
980
980
size_t sq = s + q;
981
981
size_t sqmj = sq * m + j;
982
- local_B_block[local_s + q] = sycl::vec<resT, m_groups>(
983
- (sq < k && j < m)
984
- ? static_cast <resT>(rhs[rhs_indexer (sqmj)])
985
- : identity_,
986
- (sq < k && j + 1 < m)
987
- ? static_cast <resT>(rhs[rhs_indexer (sqmj + 1 )])
988
- : identity_);
982
+ sycl::vec<resT, m_groups> local_B_vec;
983
+ #pragma unroll
984
+ for (size_t vec_idx = 0 ; vec_idx < m_groups; ++vec_idx) {
985
+ local_B_vec[vec_idx] =
986
+ (sq < k && j + vec_idx < m)
987
+ ? static_cast <resT>(
988
+ rhs[rhs_indexer (sqmj + vec_idx)])
989
+ : identity_;
990
+ }
991
+ local_B_block[local_s + q] = local_B_vec;
989
992
}
990
993
}
991
994
@@ -1241,7 +1244,7 @@ sycl::event gemm_impl(sycl::queue &exec_q,
1241
1244
constexpr size_t m_groups = 1 ;
1242
1245
size_t delta_k (4 );
1243
1246
size_t n_wi (64 );
1244
- size_t delta_n (16 );
1247
+ size_t delta_n (32 );
1245
1248
1246
1249
gemm_detail::scale_gemm_k_parameters<resTy, m_groups>(
1247
1250
local_mem_size, reserved_slm_size, delta_k,
@@ -1277,7 +1280,7 @@ sycl::event gemm_impl(sycl::queue &exec_q,
1277
1280
constexpr size_t m_groups = 2 ;
1278
1281
size_t delta_k (4 );
1279
1282
size_t n_wi (64 );
1280
- size_t delta_n (16 );
1283
+ size_t delta_n (32 );
1281
1284
1282
1285
gemm_detail::scale_gemm_k_parameters<resTy, m_groups>(
1283
1286
local_mem_size, reserved_slm_size, delta_k,
@@ -1411,7 +1414,7 @@ sycl::event gemm_contig_impl(sycl::queue &exec_q,
1411
1414
constexpr size_t m_groups = 1 ;
1412
1415
size_t delta_k (4 );
1413
1416
size_t n_wi (64 );
1414
- size_t delta_n (16 );
1417
+ size_t delta_n (32 );
1415
1418
1416
1419
gemm_detail::scale_gemm_k_parameters<resTy, m_groups>(
1417
1420
local_mem_size, reserved_slm_size, delta_k,
@@ -1447,7 +1450,7 @@ sycl::event gemm_contig_impl(sycl::queue &exec_q,
1447
1450
constexpr size_t m_groups = 2 ;
1448
1451
size_t delta_k (4 );
1449
1452
size_t n_wi (64 );
1450
- size_t delta_n (16 );
1453
+ size_t delta_n (32 );
1451
1454
1452
1455
gemm_detail::scale_gemm_k_parameters<resTy, m_groups>(
1453
1456
local_mem_size, reserved_slm_size, delta_k,
@@ -1922,13 +1925,16 @@ class GemmNoAtomicFunctorThreadK
1922
1925
for (size_t q = 0 ; q < n_wi * delta_k; q += delta_k) {
1923
1926
size_t sq = s + q;
1924
1927
size_t sqmj = sq * m + j;
1925
- local_B_block[local_s + q] = sycl::vec<resT, m_groups>(
1926
- (sq < k && j < m)
1927
- ? static_cast <resT>(rhs[rhs_indexer (sqmj)])
1928
- : identity_,
1929
- (sq < k && j + 1 < m)
1930
- ? static_cast <resT>(rhs[rhs_indexer (sqmj + 1 )])
1931
- : identity_);
1928
+ sycl::vec<resT, m_groups> local_B_vec;
1929
+ #pragma unroll
1930
+ for (size_t vec_idx = 0 ; vec_idx < m_groups; ++vec_idx) {
1931
+ local_B_vec[vec_idx] =
1932
+ (sq < k && j + vec_idx < m)
1933
+ ? static_cast <resT>(
1934
+ rhs[rhs_indexer (sqmj + vec_idx)])
1935
+ : identity_;
1936
+ }
1937
+ local_B_block[local_s + q] = local_B_vec;
1932
1938
}
1933
1939
}
1934
1940
@@ -2130,7 +2136,7 @@ sycl::event gemm_tree_k_impl(sycl::queue &exec_q,
2130
2136
{
2131
2137
size_t delta_k (4 );
2132
2138
size_t n_wi (64 );
2133
- size_t delta_n (16 );
2139
+ size_t delta_n (32 );
2134
2140
2135
2141
const sycl::device &dev = exec_q.get_device ();
2136
2142
const size_t local_mem_size =
@@ -2862,7 +2868,7 @@ sycl::event gemm_contig_tree_k_impl(sycl::queue &exec_q,
2862
2868
{
2863
2869
size_t delta_k (4 );
2864
2870
size_t n_wi (64 );
2865
- size_t delta_n (16 );
2871
+ size_t delta_n (32 );
2866
2872
2867
2873
const sycl::device &dev = exec_q.get_device ();
2868
2874
const size_t local_mem_size =
@@ -3986,14 +3992,16 @@ class GemmBatchFunctorThreadK
3986
3992
for (size_t q = 0 ; q < n_wi * delta_k; q += delta_k) {
3987
3993
size_t sq = s + q;
3988
3994
size_t sqmj = sq * m + j;
3989
- local_B_block[local_s + q] = sycl::vec<resT, m_groups>(
3990
- (sq < k && j < m)
3991
- ? static_cast <resT>(rhs[rhs_offset + rhs_indexer (sqmj)])
3992
- : identity_,
3993
- (sq < k && j + 1 < m)
3994
- ? static_cast <resT>(
3995
- rhs[rhs_offset + rhs_indexer (sqmj + 1 )])
3996
- : identity_);
3995
+ sycl::vec<resT, m_groups> local_B_vec;
3996
+ #pragma unroll
3997
+ for (size_t vec_idx = 0 ; vec_idx < m_groups; ++vec_idx) {
3998
+ local_B_vec[vec_idx] =
3999
+ (sq < k && j + vec_idx < m)
4000
+ ? static_cast <resT>(
4001
+ rhs[rhs_offset + rhs_indexer (sqmj + vec_idx)])
4002
+ : identity_;
4003
+ }
4004
+ local_B_block[local_s + q] = local_B_vec;
3997
4005
}
3998
4006
}
3999
4007
@@ -4310,7 +4318,7 @@ sycl::event gemm_batch_impl(sycl::queue &exec_q,
4310
4318
constexpr size_t m_groups = 1 ;
4311
4319
size_t delta_k (4 );
4312
4320
size_t n_wi (64 );
4313
- size_t delta_n (16 );
4321
+ size_t delta_n (32 );
4314
4322
4315
4323
gemm_detail::scale_gemm_k_parameters<resTy, m_groups>(
4316
4324
local_mem_size, reserved_slm_size, delta_k,
@@ -4351,7 +4359,7 @@ sycl::event gemm_batch_impl(sycl::queue &exec_q,
4351
4359
constexpr size_t m_groups = 2 ;
4352
4360
size_t delta_k (4 );
4353
4361
size_t n_wi (64 );
4354
- size_t delta_n (16 );
4362
+ size_t delta_n (32 );
4355
4363
4356
4364
gemm_detail::scale_gemm_k_parameters<resTy, m_groups>(
4357
4365
local_mem_size, reserved_slm_size, delta_k,
@@ -4516,7 +4524,7 @@ sycl::event gemm_batch_contig_impl(sycl::queue &exec_q,
4516
4524
constexpr size_t m_groups = 1 ;
4517
4525
size_t delta_k (4 );
4518
4526
size_t n_wi (64 );
4519
- size_t delta_n (16 );
4527
+ size_t delta_n (32 );
4520
4528
4521
4529
gemm_detail::scale_gemm_k_parameters<resTy, m_groups>(
4522
4530
local_mem_size, reserved_slm_size, delta_k,
@@ -4557,7 +4565,7 @@ sycl::event gemm_batch_contig_impl(sycl::queue &exec_q,
4557
4565
constexpr size_t m_groups = 2 ;
4558
4566
size_t delta_k (4 );
4559
4567
size_t n_wi (64 );
4560
- size_t delta_n (16 );
4568
+ size_t delta_n (32 );
4561
4569
4562
4570
gemm_detail::scale_gemm_k_parameters<resTy, m_groups>(
4563
4571
local_mem_size, reserved_slm_size, delta_k,
@@ -5096,10 +5104,16 @@ class GemmBatchNoAtomicFunctorThreadK
5096
5104
for (size_t q = 0 ; q < n_wi * delta_k; q += delta_k) {
5097
5105
size_t sq = s + q;
5098
5106
size_t sqmj = sq * m + j;
5099
- local_B_block[local_s + q] =
5100
- (sq < k && j < m)
5101
- ? static_cast <resT>(rhs[rhs_offset + rhs_indexer (sqmj)])
5102
- : identity_;
5107
+ sycl::vec<resT, m_groups> local_B_vec;
5108
+ #pragma unroll
5109
+ for (size_t vec_idx = 0 ; vec_idx < m_groups; ++vec_idx) {
5110
+ local_B_vec[vec_idx] =
5111
+ (sq < k && j + vec_idx < m)
5112
+ ? static_cast <resT>(
5113
+ rhs[rhs_offset + rhs_indexer (sqmj + vec_idx)])
5114
+ : identity_;
5115
+ }
5116
+ local_B_block[local_s + q] = local_B_vec;
5103
5117
}
5104
5118
}
5105
5119
@@ -5331,7 +5345,7 @@ gemm_batch_tree_k_impl(sycl::queue &exec_q,
5331
5345
{
5332
5346
size_t delta_k (4 );
5333
5347
size_t n_wi (64 );
5334
- size_t delta_n (16 );
5348
+ size_t delta_n (32 );
5335
5349
5336
5350
const sycl::device &dev = exec_q.get_device ();
5337
5351
const size_t local_mem_size =
@@ -6184,7 +6198,7 @@ gemm_batch_contig_tree_k_impl(sycl::queue &exec_q,
6184
6198
{
6185
6199
size_t delta_k (4 );
6186
6200
size_t n_wi (64 );
6187
- size_t delta_n (16 );
6201
+ size_t delta_n (32 );
6188
6202
6189
6203
const sycl::device &dev = exec_q.get_device ();
6190
6204
const size_t local_mem_size =
0 commit comments