Skip to content

Commit 144ac0f

Browse files
committed
Drastically reduced parameters used for gemm kernels which thread over k
Experimental change to see if this stabilizes CI
1 parent ad53472 commit 144ac0f

File tree

1 file changed

+22
-22
lines changed
  • dpctl/tensor/libtensor/include/kernels/linalg_functions

1 file changed

+22
-22
lines changed

dpctl/tensor/libtensor/include/kernels/linalg_functions/gemm.hpp

Lines changed: 22 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1067,8 +1067,8 @@ sycl::event gemm_impl(sycl::queue &exec_q,
10671067
if (m == 1) {
10681068
constexpr size_t m_groups = 1;
10691069
size_t delta_k(4);
1070-
size_t n_wi(64);
1071-
size_t delta_n(16);
1070+
size_t n_wi(4);
1071+
size_t delta_n(4);
10721072

10731073
gemm_detail::scale_gemm_k_parameters<resTy, m_groups>(
10741074
local_mem_size, reserved_slm_size, delta_k,
@@ -1103,8 +1103,8 @@ sycl::event gemm_impl(sycl::queue &exec_q,
11031103
else if (k > n && k > m) {
11041104
constexpr size_t m_groups = 2;
11051105
size_t delta_k(4);
1106-
size_t n_wi(64);
1107-
size_t delta_n(16);
1106+
size_t n_wi(4);
1107+
size_t delta_n(4);
11081108

11091109
gemm_detail::scale_gemm_k_parameters<resTy, m_groups>(
11101110
local_mem_size, reserved_slm_size, delta_k,
@@ -1233,8 +1233,8 @@ sycl::event gemm_contig_impl(sycl::queue &exec_q,
12331233
if (m == 1) {
12341234
constexpr size_t m_groups = 1;
12351235
size_t delta_k(4);
1236-
size_t n_wi(64);
1237-
size_t delta_n(16);
1236+
size_t n_wi(4);
1237+
size_t delta_n(4);
12381238

12391239
gemm_detail::scale_gemm_k_parameters<resTy, m_groups>(
12401240
local_mem_size, reserved_slm_size, delta_k,
@@ -1269,8 +1269,8 @@ sycl::event gemm_contig_impl(sycl::queue &exec_q,
12691269
else if (k > n && k > m) {
12701270
constexpr size_t m_groups = 2;
12711271
size_t delta_k(4);
1272-
size_t n_wi(64);
1273-
size_t delta_n(16);
1272+
size_t n_wi(4);
1273+
size_t delta_n(4);
12741274

12751275
gemm_detail::scale_gemm_k_parameters<resTy, m_groups>(
12761276
local_mem_size, reserved_slm_size, delta_k,
@@ -1963,8 +1963,8 @@ sycl::event gemm_tree_impl(sycl::queue &exec_q,
19631963
// items in a column, so no need for allocating
19641964
// temp memory if only one group is needed
19651965
size_t delta_k(4);
1966-
size_t n_wi(64);
1967-
size_t delta_n(16);
1966+
size_t n_wi(4);
1967+
size_t delta_n(4);
19681968

19691969
using dpctl::tensor::type_utils::is_complex;
19701970
if constexpr (!is_complex<resTy>::value) {
@@ -3394,8 +3394,8 @@ sycl::event gemm_contig_tree_impl(sycl::queue &exec_q,
33943394
// items in a column, so no need for allocating
33953395
// temp memory if only one group is needed
33963396
size_t delta_k(4);
3397-
size_t n_wi(64);
3398-
size_t delta_n(16);
3397+
size_t n_wi(4);
3398+
size_t delta_n(4);
33993399

34003400
using dpctl::tensor::type_utils::is_complex;
34013401
if constexpr (!is_complex<resTy>::value) {
@@ -5462,8 +5462,8 @@ sycl::event gemm_batch_impl(sycl::queue &exec_q,
54625462
if (m == 1) {
54635463
constexpr int m_groups = 1;
54645464
size_t delta_k(4);
5465-
size_t n_wi(32);
5466-
size_t delta_n(16);
5465+
size_t n_wi(4);
5466+
size_t delta_n(4);
54675467

54685468
gemm_detail::scale_gemm_k_parameters<resTy, m_groups>(
54695469
local_mem_size, reserved_slm_size, delta_k,
@@ -5503,8 +5503,8 @@ sycl::event gemm_batch_impl(sycl::queue &exec_q,
55035503
else if (k > n && k > m) {
55045504
constexpr size_t m_groups = 2;
55055505
size_t delta_k(4);
5506-
size_t n_wi(32);
5507-
size_t delta_n(16);
5506+
size_t n_wi(4);
5507+
size_t delta_n(4);
55085508

55095509
gemm_detail::scale_gemm_k_parameters<resTy, m_groups>(
55105510
local_mem_size, reserved_slm_size, delta_k,
@@ -5664,8 +5664,8 @@ sycl::event gemm_batch_contig_impl(sycl::queue &exec_q,
56645664
if (m == 1) {
56655665
constexpr int m_groups = 1;
56665666
size_t delta_k(4);
5667-
size_t n_wi(32);
5668-
size_t delta_n(16);
5667+
size_t n_wi(4);
5668+
size_t delta_n(4);
56695669

56705670
gemm_detail::scale_gemm_k_parameters<resTy, m_groups>(
56715671
local_mem_size, reserved_slm_size, delta_k,
@@ -5705,8 +5705,8 @@ sycl::event gemm_batch_contig_impl(sycl::queue &exec_q,
57055705
else if (k > n && k > m) {
57065706
constexpr size_t m_groups = 2;
57075707
size_t delta_k(4);
5708-
size_t n_wi(32);
5709-
size_t delta_n(16);
5708+
size_t n_wi(4);
5709+
size_t delta_n(4);
57105710

57115711
gemm_detail::scale_gemm_k_parameters<resTy, m_groups>(
57125712
local_mem_size, reserved_slm_size, delta_k,
@@ -6484,7 +6484,7 @@ gemm_batch_tree_impl(sycl::queue &exec_q,
64846484

64856485
if ((k > n && k > m) || m == 1) {
64866486
size_t delta_k(4);
6487-
size_t n_wi(32);
6487+
size_t n_wi(4);
64886488
size_t delta_n(4);
64896489

64906490
using dpctl::tensor::type_utils::is_complex;
@@ -8187,7 +8187,7 @@ gemm_batch_contig_tree_impl(sycl::queue &exec_q,
81878187

81888188
if ((k > n && k > m) || m == 1) {
81898189
size_t delta_k(4);
8190-
size_t n_wi(32);
8190+
size_t n_wi(4);
81918191
size_t delta_n(4);
81928192

81938193
using dpctl::tensor::type_utils::is_complex;

0 commit comments

Comments
 (0)