Skip to content

Commit 1eaadb6

Browse files
committed
Make generic k-threaded kernels handle arbitrary m_groups
Also increases hyper-parameters for k-threaded kernels to improve performance
1 parent 03c36eb commit 1eaadb6

File tree

1 file changed

+52
-38
lines changed
  • dpctl/tensor/libtensor/include/kernels/linalg_functions

1 file changed

+52
-38
lines changed

dpctl/tensor/libtensor/include/kernels/linalg_functions/gemm.hpp

Lines changed: 52 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -979,13 +979,16 @@ class GemmFunctorThreadK
979979
for (size_t q = 0; q < n_wi * delta_k; q += delta_k) {
980980
size_t sq = s + q;
981981
size_t sqmj = sq * m + j;
982-
local_B_block[local_s + q] = sycl::vec<resT, m_groups>(
983-
(sq < k && j < m)
984-
? static_cast<resT>(rhs[rhs_indexer(sqmj)])
985-
: identity_,
986-
(sq < k && j + 1 < m)
987-
? static_cast<resT>(rhs[rhs_indexer(sqmj + 1)])
988-
: identity_);
982+
sycl::vec<resT, m_groups> local_B_vec;
983+
#pragma unroll
984+
for (size_t vec_idx = 0; vec_idx < m_groups; ++vec_idx) {
985+
local_B_vec[vec_idx] =
986+
(sq < k && j + vec_idx < m)
987+
? static_cast<resT>(
988+
rhs[rhs_indexer(sqmj + vec_idx)])
989+
: identity_;
990+
}
991+
local_B_block[local_s + q] = local_B_vec;
989992
}
990993
}
991994

@@ -1241,7 +1244,7 @@ sycl::event gemm_impl(sycl::queue &exec_q,
12411244
constexpr size_t m_groups = 1;
12421245
size_t delta_k(4);
12431246
size_t n_wi(64);
1244-
size_t delta_n(16);
1247+
size_t delta_n(32);
12451248

12461249
gemm_detail::scale_gemm_k_parameters<resTy, m_groups>(
12471250
local_mem_size, reserved_slm_size, delta_k,
@@ -1277,7 +1280,7 @@ sycl::event gemm_impl(sycl::queue &exec_q,
12771280
constexpr size_t m_groups = 2;
12781281
size_t delta_k(4);
12791282
size_t n_wi(64);
1280-
size_t delta_n(16);
1283+
size_t delta_n(32);
12811284

12821285
gemm_detail::scale_gemm_k_parameters<resTy, m_groups>(
12831286
local_mem_size, reserved_slm_size, delta_k,
@@ -1411,7 +1414,7 @@ sycl::event gemm_contig_impl(sycl::queue &exec_q,
14111414
constexpr size_t m_groups = 1;
14121415
size_t delta_k(4);
14131416
size_t n_wi(64);
1414-
size_t delta_n(16);
1417+
size_t delta_n(32);
14151418

14161419
gemm_detail::scale_gemm_k_parameters<resTy, m_groups>(
14171420
local_mem_size, reserved_slm_size, delta_k,
@@ -1447,7 +1450,7 @@ sycl::event gemm_contig_impl(sycl::queue &exec_q,
14471450
constexpr size_t m_groups = 2;
14481451
size_t delta_k(4);
14491452
size_t n_wi(64);
1450-
size_t delta_n(16);
1453+
size_t delta_n(32);
14511454

14521455
gemm_detail::scale_gemm_k_parameters<resTy, m_groups>(
14531456
local_mem_size, reserved_slm_size, delta_k,
@@ -1922,13 +1925,16 @@ class GemmNoAtomicFunctorThreadK
19221925
for (size_t q = 0; q < n_wi * delta_k; q += delta_k) {
19231926
size_t sq = s + q;
19241927
size_t sqmj = sq * m + j;
1925-
local_B_block[local_s + q] = sycl::vec<resT, m_groups>(
1926-
(sq < k && j < m)
1927-
? static_cast<resT>(rhs[rhs_indexer(sqmj)])
1928-
: identity_,
1929-
(sq < k && j + 1 < m)
1930-
? static_cast<resT>(rhs[rhs_indexer(sqmj + 1)])
1931-
: identity_);
1928+
sycl::vec<resT, m_groups> local_B_vec;
1929+
#pragma unroll
1930+
for (size_t vec_idx = 0; vec_idx < m_groups; ++vec_idx) {
1931+
local_B_vec[vec_idx] =
1932+
(sq < k && j + vec_idx < m)
1933+
? static_cast<resT>(
1934+
rhs[rhs_indexer(sqmj + vec_idx)])
1935+
: identity_;
1936+
}
1937+
local_B_block[local_s + q] = local_B_vec;
19321938
}
19331939
}
19341940

@@ -2130,7 +2136,7 @@ sycl::event gemm_tree_k_impl(sycl::queue &exec_q,
21302136
{
21312137
size_t delta_k(4);
21322138
size_t n_wi(64);
2133-
size_t delta_n(16);
2139+
size_t delta_n(32);
21342140

21352141
const sycl::device &dev = exec_q.get_device();
21362142
const size_t local_mem_size =
@@ -2862,7 +2868,7 @@ sycl::event gemm_contig_tree_k_impl(sycl::queue &exec_q,
28622868
{
28632869
size_t delta_k(4);
28642870
size_t n_wi(64);
2865-
size_t delta_n(16);
2871+
size_t delta_n(32);
28662872

28672873
const sycl::device &dev = exec_q.get_device();
28682874
const size_t local_mem_size =
@@ -3986,14 +3992,16 @@ class GemmBatchFunctorThreadK
39863992
for (size_t q = 0; q < n_wi * delta_k; q += delta_k) {
39873993
size_t sq = s + q;
39883994
size_t sqmj = sq * m + j;
3989-
local_B_block[local_s + q] = sycl::vec<resT, m_groups>(
3990-
(sq < k && j < m)
3991-
? static_cast<resT>(rhs[rhs_offset + rhs_indexer(sqmj)])
3992-
: identity_,
3993-
(sq < k && j + 1 < m)
3994-
? static_cast<resT>(
3995-
rhs[rhs_offset + rhs_indexer(sqmj + 1)])
3996-
: identity_);
3995+
sycl::vec<resT, m_groups> local_B_vec;
3996+
#pragma unroll
3997+
for (size_t vec_idx = 0; vec_idx < m_groups; ++vec_idx) {
3998+
local_B_vec[vec_idx] =
3999+
(sq < k && j + vec_idx < m)
4000+
? static_cast<resT>(
4001+
rhs[rhs_offset + rhs_indexer(sqmj + vec_idx)])
4002+
: identity_;
4003+
}
4004+
local_B_block[local_s + q] = local_B_vec;
39974005
}
39984006
}
39994007

@@ -4310,7 +4318,7 @@ sycl::event gemm_batch_impl(sycl::queue &exec_q,
43104318
constexpr size_t m_groups = 1;
43114319
size_t delta_k(4);
43124320
size_t n_wi(64);
4313-
size_t delta_n(16);
4321+
size_t delta_n(32);
43144322

43154323
gemm_detail::scale_gemm_k_parameters<resTy, m_groups>(
43164324
local_mem_size, reserved_slm_size, delta_k,
@@ -4351,7 +4359,7 @@ sycl::event gemm_batch_impl(sycl::queue &exec_q,
43514359
constexpr size_t m_groups = 2;
43524360
size_t delta_k(4);
43534361
size_t n_wi(64);
4354-
size_t delta_n(16);
4362+
size_t delta_n(32);
43554363

43564364
gemm_detail::scale_gemm_k_parameters<resTy, m_groups>(
43574365
local_mem_size, reserved_slm_size, delta_k,
@@ -4516,7 +4524,7 @@ sycl::event gemm_batch_contig_impl(sycl::queue &exec_q,
45164524
constexpr size_t m_groups = 1;
45174525
size_t delta_k(4);
45184526
size_t n_wi(64);
4519-
size_t delta_n(16);
4527+
size_t delta_n(32);
45204528

45214529
gemm_detail::scale_gemm_k_parameters<resTy, m_groups>(
45224530
local_mem_size, reserved_slm_size, delta_k,
@@ -4557,7 +4565,7 @@ sycl::event gemm_batch_contig_impl(sycl::queue &exec_q,
45574565
constexpr size_t m_groups = 2;
45584566
size_t delta_k(4);
45594567
size_t n_wi(64);
4560-
size_t delta_n(16);
4568+
size_t delta_n(32);
45614569

45624570
gemm_detail::scale_gemm_k_parameters<resTy, m_groups>(
45634571
local_mem_size, reserved_slm_size, delta_k,
@@ -5096,10 +5104,16 @@ class GemmBatchNoAtomicFunctorThreadK
50965104
for (size_t q = 0; q < n_wi * delta_k; q += delta_k) {
50975105
size_t sq = s + q;
50985106
size_t sqmj = sq * m + j;
5099-
local_B_block[local_s + q] =
5100-
(sq < k && j < m)
5101-
? static_cast<resT>(rhs[rhs_offset + rhs_indexer(sqmj)])
5102-
: identity_;
5107+
sycl::vec<resT, m_groups> local_B_vec;
5108+
#pragma unroll
5109+
for (size_t vec_idx = 0; vec_idx < m_groups; ++vec_idx) {
5110+
local_B_vec[vec_idx] =
5111+
(sq < k && j + vec_idx < m)
5112+
? static_cast<resT>(
5113+
rhs[rhs_offset + rhs_indexer(sqmj + vec_idx)])
5114+
: identity_;
5115+
}
5116+
local_B_block[local_s + q] = local_B_vec;
51035117
}
51045118
}
51055119

@@ -5331,7 +5345,7 @@ gemm_batch_tree_k_impl(sycl::queue &exec_q,
53315345
{
53325346
size_t delta_k(4);
53335347
size_t n_wi(64);
5334-
size_t delta_n(16);
5348+
size_t delta_n(32);
53355349

53365350
const sycl::device &dev = exec_q.get_device();
53375351
const size_t local_mem_size =
@@ -6184,7 +6198,7 @@ gemm_batch_contig_tree_k_impl(sycl::queue &exec_q,
61846198
{
61856199
size_t delta_k(4);
61866200
size_t n_wi(64);
6187-
size_t delta_n(16);
6201+
size_t delta_n(32);
61886202

61896203
const sycl::device &dev = exec_q.get_device();
61906204
const size_t local_mem_size =

0 commit comments

Comments
 (0)