@@ -343,14 +343,15 @@ inline float block_q_n_dot_y(device const block_q4_1 * qb_curr, float sumy, thre
343
343
// N_DST, so this is another explicit assumption of the implementation.
344
344
template <typename block_q_type, int nr, int nsg, int nw>
345
345
void mul_vec_q_n_f32 (device const void * src0, device const float * src1, device float * dst,
346
- int64_t ne00, int64_t ne01, int64_t ne02, int64_t ne10, int64_t ne12, int64_t ne0,
346
+ int64_t ne00, int64_t ne01, int64_t ne02, int64_t ne10, int64_t ne12, int64_t ne0, uint gqa,
347
347
uint3 tgpig, uint tiisg, uint sgitg) {
348
348
const int nb = ne00/QK4_0;
349
349
const int r0 = tgpig.x ;
350
350
const int r1 = tgpig.y ;
351
351
const int im = tgpig.z ;
352
352
const int first_row = (r0 * nsg + sgitg) * nr;
353
- device const block_q_type * x = (device const block_q_type *) src0 + first_row * nb + im/(ne12/ne02)*(ne02/QK4_0);
353
+ const uint offset0 = first_row * nb + im/gqa*(ne02/QK4_0);
354
+ device const block_q_type * x = (device const block_q_type *) src0 + offset0;
354
355
device const float * y = (device const float *) src1 + r1*ne10 + im*ne12;
355
356
float yl[16 ]; // src1 vector cache
356
357
float sumf[nr]={0 .f };
@@ -397,10 +398,11 @@ kernel void kernel_mul_mat_q4_0_f32(
397
398
constant int64_t & ne10[[buffer(9 )]],
398
399
constant int64_t & ne12[[buffer(11 )]],
399
400
constant int64_t & ne0[[buffer(15 )]],
401
+ constant uint & gqa[[buffer(17 )]],
400
402
uint3 tgpig[[threadgroup_position_in_grid]],
401
403
uint tiisg[[thread_index_in_simdgroup]],
402
404
uint sgitg[[simdgroup_index_in_threadgroup]]) {
403
- mul_vec_q_n_f32<block_q4_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,tgpig,tiisg,sgitg);
405
+ mul_vec_q_n_f32<block_q4_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,gqa, tgpig,tiisg,sgitg);
404
406
}
405
407
406
408
kernel void kernel_mul_mat_q4_1_f32 (
@@ -413,10 +415,11 @@ kernel void kernel_mul_mat_q4_1_f32(
413
415
constant int64_t & ne10[[buffer(9 )]],
414
416
constant int64_t & ne12[[buffer(11 )]],
415
417
constant int64_t & ne0[[buffer(15 )]],
418
+ constant uint & gqa[[buffer(17 )]],
416
419
uint3 tgpig[[threadgroup_position_in_grid]],
417
420
uint tiisg[[thread_index_in_simdgroup]],
418
421
uint sgitg[[simdgroup_index_in_threadgroup]]) {
419
- mul_vec_q_n_f32<block_q4_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,tgpig,tiisg,sgitg);
422
+ mul_vec_q_n_f32<block_q4_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,gqa, tgpig,tiisg,sgitg);
420
423
}
421
424
422
425
kernel void kernel_mul_mat_f16_f32 (
@@ -797,6 +800,7 @@ kernel void kernel_mul_mat_q2_K_f32(
797
800
constant int64_t & ne10[[buffer(9 )]],
798
801
constant int64_t & ne12[[buffer(11 )]],
799
802
constant int64_t & ne0[[buffer(15 )]],
803
+ constant uint & gqa[[buffer(17 )]],
800
804
uint3 tgpig[[threadgroup_position_in_grid]],
801
805
uint tiisg[[thread_index_in_simdgroup]],
802
806
uint sgitg[[simdgroup_index_in_threadgroup]]) {
@@ -808,7 +812,8 @@ kernel void kernel_mul_mat_q2_K_f32(
808
812
809
813
const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST;
810
814
const int ib_row = first_row * nb;
811
- device const block_q2_K * x = (device const block_q2_K *) src0 + ib_row + r2/(ne12/ne02)*(ne02/QK_K);
815
+ const uint offset0 = r2/gqa*(ne02/QK_K);
816
+ device const block_q2_K * x = (device const block_q2_K *) src0 + ib_row + offset0;
812
817
device const float * y = (device const float *) src1 + r1*ne10 + r2*ne12;
813
818
float yl[32 ];
814
819
float sumf[N_DST]={0 .f }, all_sum;
@@ -938,6 +943,7 @@ kernel void kernel_mul_mat_q3_K_f32(
938
943
constant int64_t & ne10[[buffer(9 )]],
939
944
constant int64_t & ne12[[buffer(11 )]],
940
945
constant int64_t & ne0[[buffer(15 )]],
946
+ constant uint & gqa[[buffer(17 )]],
941
947
uint3 tgpig[[threadgroup_position_in_grid]],
942
948
uint tiisg[[thread_index_in_simdgroup]],
943
949
uint sgitg[[simdgroup_index_in_threadgroup]]) {
@@ -949,8 +955,8 @@ kernel void kernel_mul_mat_q3_K_f32(
949
955
const int64_t r2 = tgpig.x ;
950
956
951
957
const int first_row = (r0 * N_SIMDGROUP + sgitg) * 2 ;
952
-
953
- device const block_q3_K * x = (device const block_q3_K *) src0 + first_row*nb + r2/(ne12/ne02)*(ne02/QK_K) ;
958
+ const uint offset0 = r2/gqa*(ne02/QK_K);
959
+ device const block_q3_K * x = (device const block_q3_K *) src0 + first_row*nb + offset0 ;
954
960
device const float * yy = (device const float *) src1 + r1*ne10 + r2*ne02;
955
961
956
962
float yl[16 ];
@@ -1054,6 +1060,7 @@ kernel void kernel_mul_mat_q3_K_f32(
1054
1060
constant int64_t & ne10[[buffer(9 )]],
1055
1061
constant int64_t & ne12[[buffer(11 )]],
1056
1062
constant int64_t & ne0[[buffer(15 )]],
1063
+ constant uint & gqa[[buffer(17 )]],
1057
1064
uint3 tgpig[[threadgroup_position_in_grid]],
1058
1065
uint tiisg[[thread_index_in_simdgroup]],
1059
1066
uint sgitg[[simdgroup_index_in_threadgroup]]) {
@@ -1065,8 +1072,8 @@ kernel void kernel_mul_mat_q3_K_f32(
1065
1072
const int64_t r2 = tgpig.x ;
1066
1073
1067
1074
const int row = 2 * r0 + sgitg;
1068
-
1069
- device const block_q3_K * x = (device const block_q3_K *) src0 + row*nb + r2/(ne12/ne02)*(ne02/QK_K) ;
1075
+ const uint offset0 = r2/gqa*(ne02/QK_K);
1076
+ device const block_q3_K * x = (device const block_q3_K *) src0 + row*nb + offest0 ;
1070
1077
device const float * yy = (device const float *) src1 + r1*ne10 + r2*ne02;
1071
1078
const int ix = tiisg/4 ;
1072
1079
const int il = 4 * (tiisg%4 );// 0, 4, 8, 12
@@ -1123,6 +1130,7 @@ kernel void kernel_mul_mat_q4_K_f32(
1123
1130
constant int64_t & ne10[[buffer(9 )]],
1124
1131
constant int64_t & ne12[[buffer(11 )]],
1125
1132
constant int64_t & ne0[[buffer(15 )]],
1133
+ constant uint & gqa[[buffer(17 )]],
1126
1134
uint3 tgpig[[threadgroup_position_in_grid]],
1127
1135
uint tiisg[[thread_index_in_simdgroup]],
1128
1136
uint sgitg[[simdgroup_index_in_threadgroup]]) {
@@ -1142,7 +1150,8 @@ kernel void kernel_mul_mat_q4_K_f32(
1142
1150
const int r2 = tgpig.z ;
1143
1151
const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST;
1144
1152
const int ib_row = first_row * nb;
1145
- device const block_q4_K * x = (device const block_q4_K *) src0 + ib_row + r2/(ne12/ne02)*(ne02/QK_K);
1153
+ const uint offset0 = r2/gqa*(ne02/QK_K);
1154
+ device const block_q4_K * x = (device const block_q4_K *) src0 + ib_row + offset0;
1146
1155
device const float * y = (device const float *) src1 + r1*ne10 + r2*ne12;
1147
1156
float yl[16 ];
1148
1157
float yh[16 ];
@@ -1225,6 +1234,7 @@ kernel void kernel_mul_mat_q4_K_f32(
1225
1234
constant int64_t & ne10[[buffer(9 )]],
1226
1235
constant int64_t & ne12[[buffer(11 )]],
1227
1236
constant int64_t & ne0[[buffer(15 )]],
1237
+ constant uint & gqa[[buffer(17 )]],
1228
1238
uint3 tgpig[[threadgroup_position_in_grid]],
1229
1239
uint tiisg[[thread_index_in_simdgroup]],
1230
1240
uint sgitg[[simdgroup_index_in_threadgroup]]) {
@@ -1238,7 +1248,8 @@ kernel void kernel_mul_mat_q4_K_f32(
1238
1248
const int r2 = tgpig.z ;
1239
1249
const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST;
1240
1250
const int ib_row = first_row * nb;
1241
- device const block_q4_K * x = (device const block_q4_K *) src0 + ib_row + r2/(ne12/ne02)*(ne02/QK_K);
1251
+ const uint offset0 = r2/gqa*(ne02/QK_K);
1252
+ device const block_q4_K * x = (device const block_q4_K *) src0 + ib_row + offset0;
1242
1253
device const float * y = (device const float *) src1 + r1*ne10 + r2*ne12;
1243
1254
float yl[8 ];
1244
1255
float yh[8 ];
@@ -1311,6 +1322,7 @@ kernel void kernel_mul_mat_q5_K_f32(
1311
1322
constant int64_t & ne10[[buffer(9 )]],
1312
1323
constant int64_t & ne12[[buffer(11 )]],
1313
1324
constant int64_t & ne0[[buffer(15 )]],
1325
+ constant uint & gqa[[buffer(17 )]],
1314
1326
uint3 tgpig[[threadgroup_position_in_grid]],
1315
1327
uint tiisg[[thread_index_in_simdgroup]],
1316
1328
uint sgitg[[simdgroup_index_in_threadgroup]]) {
@@ -1322,8 +1334,8 @@ kernel void kernel_mul_mat_q5_K_f32(
1322
1334
const int r2 = tgpig.z ;
1323
1335
1324
1336
const int first_row = (r0 * N_SIMDGROUP + sgitg) * 2 ;
1325
-
1326
- device const block_q5_K * x = (device const block_q5_K *) src0 + first_row*nb + r2/(ne12/ne02)*(ne02/QK_K) ;
1337
+ const uint offset0 = r2/gqa*(ne02/QK_K);
1338
+ device const block_q5_K * x = (device const block_q5_K *) src0 + first_row*nb + offset0 ;
1327
1339
device const float * yy = (device const float *) src1 + r1*ne10 + r2*ne12;
1328
1340
1329
1341
float sumf[2 ]={0 .f };
@@ -1474,6 +1486,7 @@ kernel void kernel_mul_mat_q6_K_f32(
1474
1486
constant int64_t & ne10[[buffer(9 )]],
1475
1487
constant int64_t & ne12[[buffer(11 )]],
1476
1488
constant int64_t & ne0[[buffer(15 )]],
1489
+ constant uint & gqa[[buffer(17 )]],
1477
1490
uint3 tgpig[[threadgroup_position_in_grid]],
1478
1491
uint tiisg[[thread_index_in_simdgroup]],
1479
1492
uint sgitg[[simdgroup_index_in_threadgroup]]) {
@@ -1490,8 +1503,8 @@ kernel void kernel_mul_mat_q6_K_f32(
1490
1503
const int r2 = tgpig.z ;
1491
1504
1492
1505
const int row = 2 * r0 + sgitg;
1493
-
1494
- device const block_q6_K * x = (device const block_q6_K *) src0 + row * nb + r2/(ne12/ne02)*(ne02/QK_K) ;
1506
+ const uint offset0 = r2/gqa*(ne02/QK_K);
1507
+ device const block_q6_K * x = (device const block_q6_K *) src0 + row * nb + offset0 ;
1495
1508
device const float * yy = (device const float *) src1 + r1*ne10 + r2*ne12;
1496
1509
1497
1510
float sumf = 0 ;
@@ -1792,6 +1805,7 @@ kernel void kernel_mul_mm(device const uchar * src0,
1792
1805
constant int64_t & ne12,
1793
1806
constant int64_t & ne0,
1794
1807
constant int64_t & ne1,
1808
+ constant uint & gqa,
1795
1809
threadgroup uchar * shared_memory [[threadgroup(0 )]],
1796
1810
uint3 tgpig[[threadgroup_position_in_grid]],
1797
1811
uint tiitg[[thread_index_in_threadgroup]],
@@ -1818,7 +1832,8 @@ kernel void kernel_mul_mm(device const uchar * src0,
1818
1832
}
1819
1833
1820
1834
short il = (tiitg % THREAD_PER_ROW);
1821
- device const block_q * x = (device const block_q *)(src0 + (r0 * BLOCK_SIZE_M + thread_row) * nb01 + im/(ne12/ne02)*nb02) + il/nl;
1835
+ uint offset0 = im/gqa*nb02; ushort offset1 = il/nl;
1836
+ device const block_q * x = (device const block_q *)(src0 + (r0 * BLOCK_SIZE_M + thread_row) * nb01 + offset0) + offset1;
1822
1837
device const float * y = src1 + (r1 * BLOCK_SIZE_N + thread_col) * ne00 \
1823
1838
+ BLOCK_SIZE_K / THREAD_PER_COL * (tiitg % THREAD_PER_COL) + im * ne12;
1824
1839
@@ -1909,7 +1924,7 @@ template [[host_name("kernel_get_rows_q6_K")]] kernel get_rows_t kernel_get_rows
1909
1924
1910
1925
typedef void (mat_mm_t )(device const uchar *, device const float *, device float *, constant int64_t &,\
1911
1926
constant int64_t &, constant int64_t &, constant int64_t &, constant int64_t &, \
1912
- constant int64_t &, constant int64_t &, threadgroup uchar *, uint3, uint , uint );
1927
+ constant int64_t &, constant int64_t &, constant uint &, threadgroup uchar *, uint3, uint , uint );
1913
1928
1914
1929
template [[host_name(" kernel_mul_mm_f16_f32" )]] kernel mat_mm_t kernel_mul_mm<half4x4, 1 , dequantize_f16>;
1915
1930
template [[host_name(" kernel_mul_mm_q4_0_f32" )]] kernel mat_mm_t kernel_mul_mm<block_q4_0, 2 , dequantize_q4_0>;
0 commit comments