Skip to content

Commit 3d8d51c

Browse files
committed
metal: fix performance degradation from gqa
Integers are slow on the GPU, and 64-bit divides are extremely slow. In the context of GQA, we introduce a 64-bit divide that cannot be optimized out by the compiler, which results in a decrease of ~8% in inference performance. This commit fixes that issue by calculating a part of the offset with a 32-bit divide. Naturally, this limits the size of a single matrix to ~4GB. However, this limitation should suffice for the near future.
1 parent 5f6de2a commit 3d8d51c

File tree

2 files changed

+35
-17
lines changed

2 files changed

+35
-17
lines changed

ggml-metal.m

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -712,6 +712,7 @@ void ggml_metal_graph_compute(
712712

713713
GGML_ASSERT(ne00 == ne10);
714714
// GGML_ASSERT(ne02 == ne12); // Should be checked on individual data types until broadcast is implemented everywhere
715+
uint gqa = ne12/ne02;
715716
GGML_ASSERT(ne03 == ne13);
716717

717718
// for now the matrix-matrix multiplication kernel only works on A14+/M1+ SoCs
@@ -743,6 +744,7 @@ void ggml_metal_graph_compute(
743744
[encoder setBytes:&ne12 length:sizeof(ne12) atIndex:7];
744745
[encoder setBytes:&ne0 length:sizeof(ne0) atIndex:8];
745746
[encoder setBytes:&ne1 length:sizeof(ne1) atIndex:9];
747+
[encoder setBytes:&gqa length:sizeof(gqa) atIndex:10];
746748
[encoder setThreadgroupMemoryLength:8192 atIndex:0];
747749
[encoder dispatchThreadgroups:MTLSizeMake( (ne11+31)/32, (ne01+63) / 64, ne12) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)];
748750
}
@@ -845,6 +847,7 @@ void ggml_metal_graph_compute(
845847
[encoder setBytes:&nb12 length:sizeof(nb12) atIndex:14];
846848
[encoder setBytes:&ne0 length:sizeof(ne0) atIndex:15];
847849
[encoder setBytes:&ne1 length:sizeof(ne1) atIndex:16];
850+
[encoder setBytes:&gqa length:sizeof(gqa) atIndex:17];
848851

849852
if (src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_Q4_1 ||
850853
src0t == GGML_TYPE_Q2_K || src0t == GGML_TYPE_Q4_K) {

ggml-metal.metal

Lines changed: 32 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -343,14 +343,15 @@ inline float block_q_n_dot_y(device const block_q4_1 * qb_curr, float sumy, thre
343343
// N_DST, so this is another explicit assumption of the implementation.
344344
template<typename block_q_type, int nr, int nsg, int nw>
345345
void mul_vec_q_n_f32(device const void * src0, device const float * src1, device float * dst,
346-
int64_t ne00, int64_t ne01, int64_t ne02, int64_t ne10, int64_t ne12, int64_t ne0,
346+
int64_t ne00, int64_t ne01, int64_t ne02, int64_t ne10, int64_t ne12, int64_t ne0, uint gqa,
347347
uint3 tgpig, uint tiisg, uint sgitg) {
348348
const int nb = ne00/QK4_0;
349349
const int r0 = tgpig.x;
350350
const int r1 = tgpig.y;
351351
const int im = tgpig.z;
352352
const int first_row = (r0 * nsg + sgitg) * nr;
353-
device const block_q_type * x = (device const block_q_type *) src0 + first_row * nb + im/(ne12/ne02)*(ne02/QK4_0);
353+
const uint offset0 = first_row * nb + im/gqa*(ne02/QK4_0);
354+
device const block_q_type * x = (device const block_q_type *) src0 + offset0;
354355
device const float * y = (device const float *) src1 + r1*ne10 + im*ne12;
355356
float yl[16]; // src1 vector cache
356357
float sumf[nr]={0.f};
@@ -397,10 +398,11 @@ kernel void kernel_mul_mat_q4_0_f32(
397398
constant int64_t & ne10[[buffer(9)]],
398399
constant int64_t & ne12[[buffer(11)]],
399400
constant int64_t & ne0[[buffer(15)]],
401+
constant uint & gqa[[buffer(17)]],
400402
uint3 tgpig[[threadgroup_position_in_grid]],
401403
uint tiisg[[thread_index_in_simdgroup]],
402404
uint sgitg[[simdgroup_index_in_threadgroup]]) {
403-
mul_vec_q_n_f32<block_q4_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,tgpig,tiisg,sgitg);
405+
mul_vec_q_n_f32<block_q4_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,gqa,tgpig,tiisg,sgitg);
404406
}
405407

406408
kernel void kernel_mul_mat_q4_1_f32(
@@ -413,10 +415,11 @@ kernel void kernel_mul_mat_q4_1_f32(
413415
constant int64_t & ne10[[buffer(9)]],
414416
constant int64_t & ne12[[buffer(11)]],
415417
constant int64_t & ne0[[buffer(15)]],
418+
constant uint & gqa[[buffer(17)]],
416419
uint3 tgpig[[threadgroup_position_in_grid]],
417420
uint tiisg[[thread_index_in_simdgroup]],
418421
uint sgitg[[simdgroup_index_in_threadgroup]]) {
419-
mul_vec_q_n_f32<block_q4_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,tgpig,tiisg,sgitg);
422+
mul_vec_q_n_f32<block_q4_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,gqa,tgpig,tiisg,sgitg);
420423
}
421424

422425
kernel void kernel_mul_mat_f16_f32(
@@ -797,6 +800,7 @@ kernel void kernel_mul_mat_q2_K_f32(
797800
constant int64_t & ne10[[buffer(9)]],
798801
constant int64_t & ne12[[buffer(11)]],
799802
constant int64_t & ne0[[buffer(15)]],
803+
constant uint & gqa[[buffer(17)]],
800804
uint3 tgpig[[threadgroup_position_in_grid]],
801805
uint tiisg[[thread_index_in_simdgroup]],
802806
uint sgitg[[simdgroup_index_in_threadgroup]]) {
@@ -808,7 +812,8 @@ kernel void kernel_mul_mat_q2_K_f32(
808812

809813
const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST;
810814
const int ib_row = first_row * nb;
811-
device const block_q2_K * x = (device const block_q2_K *) src0 + ib_row + r2/(ne12/ne02)*(ne02/QK_K);
815+
const uint offset0 = r2/gqa*(ne02/QK_K);
816+
device const block_q2_K * x = (device const block_q2_K *) src0 + ib_row + offset0;
812817
device const float * y = (device const float *) src1 + r1*ne10 + r2*ne12;
813818
float yl[32];
814819
float sumf[N_DST]={0.f}, all_sum;
@@ -938,6 +943,7 @@ kernel void kernel_mul_mat_q3_K_f32(
938943
constant int64_t & ne10[[buffer(9)]],
939944
constant int64_t & ne12[[buffer(11)]],
940945
constant int64_t & ne0[[buffer(15)]],
946+
constant uint & gqa[[buffer(17)]],
941947
uint3 tgpig[[threadgroup_position_in_grid]],
942948
uint tiisg[[thread_index_in_simdgroup]],
943949
uint sgitg[[simdgroup_index_in_threadgroup]]) {
@@ -949,8 +955,8 @@ kernel void kernel_mul_mat_q3_K_f32(
949955
const int64_t r2 = tgpig.x;
950956

951957
const int first_row = (r0 * N_SIMDGROUP + sgitg) * 2;
952-
953-
device const block_q3_K * x = (device const block_q3_K *) src0 + first_row*nb + r2/(ne12/ne02)*(ne02/QK_K);
958+
const uint offset0 = r2/gqa*(ne02/QK_K);
959+
device const block_q3_K * x = (device const block_q3_K *) src0 + first_row*nb + offset0;
954960
device const float * yy = (device const float *) src1 + r1*ne10 + r2*ne02;
955961

956962
float yl[16];
@@ -1054,6 +1060,7 @@ kernel void kernel_mul_mat_q3_K_f32(
10541060
constant int64_t & ne10[[buffer(9)]],
10551061
constant int64_t & ne12[[buffer(11)]],
10561062
constant int64_t & ne0[[buffer(15)]],
1063+
constant uint & gqa[[buffer(17)]],
10571064
uint3 tgpig[[threadgroup_position_in_grid]],
10581065
uint tiisg[[thread_index_in_simdgroup]],
10591066
uint sgitg[[simdgroup_index_in_threadgroup]]) {
@@ -1065,8 +1072,8 @@ kernel void kernel_mul_mat_q3_K_f32(
10651072
const int64_t r2 = tgpig.x;
10661073

10671074
const int row = 2 * r0 + sgitg;
1068-
1069-
device const block_q3_K * x = (device const block_q3_K *) src0 + row*nb + r2/(ne12/ne02)*(ne02/QK_K);
1075+
const uint offset0 = r2/gqa*(ne02/QK_K);
1076+
device const block_q3_K * x = (device const block_q3_K *) src0 + row*nb + offest0;
10701077
device const float * yy = (device const float *) src1 + r1*ne10 + r2*ne02;
10711078
const int ix = tiisg/4;
10721079
const int il = 4 * (tiisg%4);// 0, 4, 8, 12
@@ -1123,6 +1130,7 @@ kernel void kernel_mul_mat_q4_K_f32(
11231130
constant int64_t & ne10[[buffer(9)]],
11241131
constant int64_t & ne12[[buffer(11)]],
11251132
constant int64_t & ne0[[buffer(15)]],
1133+
constant uint & gqa[[buffer(17)]],
11261134
uint3 tgpig[[threadgroup_position_in_grid]],
11271135
uint tiisg[[thread_index_in_simdgroup]],
11281136
uint sgitg[[simdgroup_index_in_threadgroup]]) {
@@ -1142,7 +1150,8 @@ kernel void kernel_mul_mat_q4_K_f32(
11421150
const int r2 = tgpig.z;
11431151
const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST;
11441152
const int ib_row = first_row * nb;
1145-
device const block_q4_K * x = (device const block_q4_K *) src0 + ib_row + r2/(ne12/ne02)*(ne02/QK_K);
1153+
const uint offset0 = r2/gqa*(ne02/QK_K);
1154+
device const block_q4_K * x = (device const block_q4_K *) src0 + ib_row + offset0;
11461155
device const float * y = (device const float *) src1 + r1*ne10 + r2*ne12;
11471156
float yl[16];
11481157
float yh[16];
@@ -1225,6 +1234,7 @@ kernel void kernel_mul_mat_q4_K_f32(
12251234
constant int64_t & ne10[[buffer(9)]],
12261235
constant int64_t & ne12[[buffer(11)]],
12271236
constant int64_t & ne0[[buffer(15)]],
1237+
constant uint & gqa[[buffer(17)]],
12281238
uint3 tgpig[[threadgroup_position_in_grid]],
12291239
uint tiisg[[thread_index_in_simdgroup]],
12301240
uint sgitg[[simdgroup_index_in_threadgroup]]) {
@@ -1238,7 +1248,8 @@ kernel void kernel_mul_mat_q4_K_f32(
12381248
const int r2 = tgpig.z;
12391249
const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST;
12401250
const int ib_row = first_row * nb;
1241-
device const block_q4_K * x = (device const block_q4_K *) src0 + ib_row + r2/(ne12/ne02)*(ne02/QK_K);
1251+
const uint offset0 = r2/gqa*(ne02/QK_K);
1252+
device const block_q4_K * x = (device const block_q4_K *) src0 + ib_row + offset0;
12421253
device const float * y = (device const float *) src1 + r1*ne10 + r2*ne12;
12431254
float yl[8];
12441255
float yh[8];
@@ -1311,6 +1322,7 @@ kernel void kernel_mul_mat_q5_K_f32(
13111322
constant int64_t & ne10[[buffer(9)]],
13121323
constant int64_t & ne12[[buffer(11)]],
13131324
constant int64_t & ne0[[buffer(15)]],
1325+
constant uint & gqa[[buffer(17)]],
13141326
uint3 tgpig[[threadgroup_position_in_grid]],
13151327
uint tiisg[[thread_index_in_simdgroup]],
13161328
uint sgitg[[simdgroup_index_in_threadgroup]]) {
@@ -1322,8 +1334,8 @@ kernel void kernel_mul_mat_q5_K_f32(
13221334
const int r2 = tgpig.z;
13231335

13241336
const int first_row = (r0 * N_SIMDGROUP + sgitg) * 2;
1325-
1326-
device const block_q5_K * x = (device const block_q5_K *) src0 + first_row*nb + r2/(ne12/ne02)*(ne02/QK_K);
1337+
const uint offset0 = r2/gqa*(ne02/QK_K);
1338+
device const block_q5_K * x = (device const block_q5_K *) src0 + first_row*nb + offset0;
13271339
device const float * yy = (device const float *) src1 + r1*ne10 + r2*ne12;
13281340

13291341
float sumf[2]={0.f};
@@ -1474,6 +1486,7 @@ kernel void kernel_mul_mat_q6_K_f32(
14741486
constant int64_t & ne10[[buffer(9)]],
14751487
constant int64_t & ne12[[buffer(11)]],
14761488
constant int64_t & ne0[[buffer(15)]],
1489+
constant uint & gqa[[buffer(17)]],
14771490
uint3 tgpig[[threadgroup_position_in_grid]],
14781491
uint tiisg[[thread_index_in_simdgroup]],
14791492
uint sgitg[[simdgroup_index_in_threadgroup]]) {
@@ -1490,8 +1503,8 @@ kernel void kernel_mul_mat_q6_K_f32(
14901503
const int r2 = tgpig.z;
14911504

14921505
const int row = 2 * r0 + sgitg;
1493-
1494-
device const block_q6_K * x = (device const block_q6_K *) src0 + row * nb + r2/(ne12/ne02)*(ne02/QK_K);
1506+
const uint offset0 = r2/gqa*(ne02/QK_K);
1507+
device const block_q6_K * x = (device const block_q6_K *) src0 + row * nb + offset0;
14951508
device const float * yy = (device const float *) src1 + r1*ne10 + r2*ne12;
14961509

14971510
float sumf = 0;
@@ -1792,6 +1805,7 @@ kernel void kernel_mul_mm(device const uchar * src0,
17921805
constant int64_t & ne12,
17931806
constant int64_t & ne0,
17941807
constant int64_t & ne1,
1808+
constant uint & gqa,
17951809
threadgroup uchar * shared_memory [[threadgroup(0)]],
17961810
uint3 tgpig[[threadgroup_position_in_grid]],
17971811
uint tiitg[[thread_index_in_threadgroup]],
@@ -1818,7 +1832,8 @@ kernel void kernel_mul_mm(device const uchar * src0,
18181832
}
18191833

18201834
short il = (tiitg % THREAD_PER_ROW);
1821-
device const block_q * x = (device const block_q *)(src0 + (r0 * BLOCK_SIZE_M + thread_row) * nb01 + im/(ne12/ne02)*nb02) + il/nl;
1835+
uint offset0 = im/gqa*nb02; ushort offset1 = il/nl;
1836+
device const block_q * x = (device const block_q *)(src0 + (r0 * BLOCK_SIZE_M + thread_row) * nb01 + offset0) + offset1;
18221837
device const float * y = src1 + (r1 * BLOCK_SIZE_N + thread_col) * ne00 \
18231838
+ BLOCK_SIZE_K / THREAD_PER_COL * (tiitg % THREAD_PER_COL) + im * ne12;
18241839

@@ -1909,7 +1924,7 @@ template [[host_name("kernel_get_rows_q6_K")]] kernel get_rows_t kernel_get_rows
19091924

19101925
typedef void (mat_mm_t)(device const uchar *, device const float *, device float *, constant int64_t &,\
19111926
constant int64_t &, constant int64_t &, constant int64_t &, constant int64_t &, \
1912-
constant int64_t &, constant int64_t &, threadgroup uchar *, uint3, uint, uint);
1927+
constant int64_t &, constant int64_t &, constant uint &, threadgroup uchar *, uint3, uint, uint);
19131928

19141929
template [[host_name("kernel_mul_mm_f16_f32")]] kernel mat_mm_t kernel_mul_mm<half4x4, 1, dequantize_f16>;
19151930
template [[host_name("kernel_mul_mm_q4_0_f32")]] kernel mat_mm_t kernel_mul_mm<block_q4_0, 2, dequantize_q4_0>;

0 commit comments

Comments
 (0)