@@ -1068,7 +1068,7 @@ sycl::event gemm_impl(sycl::queue &exec_q,
1068
1068
constexpr size_t m_groups = 1 ;
1069
1069
size_t delta_k (4 );
1070
1070
size_t n_wi (64 );
1071
- size_t delta_n (32 );
1071
+ size_t delta_n (16 );
1072
1072
1073
1073
gemm_detail::scale_gemm_k_parameters<resTy, m_groups>(
1074
1074
local_mem_size, reserved_slm_size,
@@ -1105,7 +1105,7 @@ sycl::event gemm_impl(sycl::queue &exec_q,
1105
1105
constexpr size_t m_groups = 2 ;
1106
1106
size_t delta_k (4 );
1107
1107
size_t n_wi (64 );
1108
- size_t delta_n (32 );
1108
+ size_t delta_n (16 );
1109
1109
1110
1110
gemm_detail::scale_gemm_k_parameters<resTy, m_groups>(
1111
1111
local_mem_size, reserved_slm_size,
@@ -1236,7 +1236,7 @@ sycl::event gemm_contig_impl(sycl::queue &exec_q,
1236
1236
constexpr size_t m_groups = 1 ;
1237
1237
size_t delta_k (4 );
1238
1238
size_t n_wi (64 );
1239
- size_t delta_n (32 );
1239
+ size_t delta_n (16 );
1240
1240
1241
1241
gemm_detail::scale_gemm_k_parameters<resTy, m_groups>(
1242
1242
local_mem_size, reserved_slm_size,
@@ -1273,7 +1273,7 @@ sycl::event gemm_contig_impl(sycl::queue &exec_q,
1273
1273
constexpr size_t m_groups = 2 ;
1274
1274
size_t delta_k (4 );
1275
1275
size_t n_wi (64 );
1276
- size_t delta_n (32 );
1276
+ size_t delta_n (16 );
1277
1277
1278
1278
gemm_detail::scale_gemm_k_parameters<resTy, m_groups>(
1279
1279
local_mem_size, reserved_slm_size,
@@ -1968,7 +1968,7 @@ sycl::event gemm_tree_impl(sycl::queue &exec_q,
1968
1968
// temp memory if only one group is needed
1969
1969
size_t delta_k (4 );
1970
1970
size_t n_wi (64 );
1971
- size_t delta_n (32 );
1971
+ size_t delta_n (16 );
1972
1972
1973
1973
using dpctl::tensor::type_utils::is_complex;
1974
1974
if constexpr (!is_complex<resTy>::value) {
@@ -3402,7 +3402,7 @@ sycl::event gemm_contig_tree_impl(sycl::queue &exec_q,
3402
3402
// temp memory if only one group is needed
3403
3403
size_t delta_k (4 );
3404
3404
size_t n_wi (64 );
3405
- size_t delta_n (32 );
3405
+ size_t delta_n (16 );
3406
3406
3407
3407
using dpctl::tensor::type_utils::is_complex;
3408
3408
if constexpr (!is_complex<resTy>::value) {
@@ -5472,8 +5472,8 @@ sycl::event gemm_batch_impl(sycl::queue &exec_q,
5472
5472
if (m == 1 ) {
5473
5473
constexpr int m_groups = 1 ;
5474
5474
size_t delta_k (4 );
5475
- size_t n_wi (64 );
5476
- size_t delta_n (32 );
5475
+ size_t n_wi (32 );
5476
+ size_t delta_n (16 );
5477
5477
5478
5478
gemm_detail::scale_gemm_k_parameters<resTy, m_groups>(
5479
5479
local_mem_size, reserved_slm_size,
@@ -5514,8 +5514,8 @@ sycl::event gemm_batch_impl(sycl::queue &exec_q,
5514
5514
else if (k > n && k > m) {
5515
5515
constexpr size_t m_groups = 2 ;
5516
5516
size_t delta_k (4 );
5517
- size_t n_wi (64 );
5518
- size_t delta_n (32 );
5517
+ size_t n_wi (32 );
5518
+ size_t delta_n (16 );
5519
5519
5520
5520
gemm_detail::scale_gemm_k_parameters<resTy, m_groups>(
5521
5521
local_mem_size, reserved_slm_size,
@@ -5677,8 +5677,8 @@ sycl::event gemm_batch_contig_impl(sycl::queue &exec_q,
5677
5677
if (m == 1 ) {
5678
5678
constexpr int m_groups = 1 ;
5679
5679
size_t delta_k (4 );
5680
- size_t n_wi (64 );
5681
- size_t delta_n (32 );
5680
+ size_t n_wi (32 );
5681
+ size_t delta_n (16 );
5682
5682
5683
5683
gemm_detail::scale_gemm_k_parameters<resTy, m_groups>(
5684
5684
local_mem_size, reserved_slm_size,
@@ -5719,8 +5719,8 @@ sycl::event gemm_batch_contig_impl(sycl::queue &exec_q,
5719
5719
else if (k > n && k > m) {
5720
5720
constexpr size_t m_groups = 2 ;
5721
5721
size_t delta_k (4 );
5722
- size_t n_wi (64 );
5723
- size_t delta_n (32 );
5722
+ size_t n_wi (32 );
5723
+ size_t delta_n (16 );
5724
5724
5725
5725
gemm_detail::scale_gemm_k_parameters<resTy, m_groups>(
5726
5726
local_mem_size, reserved_slm_size,
@@ -6499,7 +6499,7 @@ gemm_batch_tree_impl(sycl::queue &exec_q,
6499
6499
6500
6500
if ((k > n && k > m) || m == 1 ) {
6501
6501
size_t delta_k (4 );
6502
- size_t n_wi (64 );
6502
+ size_t n_wi (32 );
6503
6503
size_t delta_n (4 );
6504
6504
6505
6505
using dpctl::tensor::type_utils::is_complex;
@@ -8205,7 +8205,7 @@ gemm_batch_contig_tree_impl(sycl::queue &exec_q,
8205
8205
8206
8206
if ((k > n && k > m) || m == 1 ) {
8207
8207
size_t delta_k (4 );
8208
- size_t n_wi (64 );
8208
+ size_t n_wi (32 );
8209
8209
size_t delta_n (4 );
8210
8210
8211
8211
using dpctl::tensor::type_utils::is_complex;
0 commit comments