Skip to content

Commit 5ed38b6

Browse files
authored
ggml : fix MUL_MAT_ID repack with Q8_K (#12544)
* ggml : fix MUL_MAT_ID repack with Q8_K ggml-ci * ggml : improve repack templates ggml-ci
1 parent fd7855f commit 5ed38b6

File tree

1 file changed

+87
-91
lines changed

1 file changed

+87
-91
lines changed

ggml/src/ggml-cpu/ggml-cpu-aarch64.cpp

Lines changed: 87 additions & 91 deletions
Original file line numberDiff line numberDiff line change
@@ -250,7 +250,7 @@ static inline __m256i mul_sum_i8_pairs_int32x8(const __m256i x, const __m256i y)
250250

251251
static const int8_t kvalues_iq4nl[16] = {-127, -104, -83, -65, -49, -35, -22, -10, 1, 13, 25, 38, 53, 69, 89, 113};
252252

253-
static void quantize_q8_0_4x4(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k) {
253+
static void ggml_quantize_mat_q8_0_4x4(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k) {
254254
assert(QK8_0 == 32);
255255
assert(k % QK8_0 == 0);
256256
const int nb = k / QK8_0;
@@ -344,7 +344,7 @@ static void quantize_q8_0_4x4(const float * GGML_RESTRICT x, void * GGML_RESTRIC
344344
#endif
345345
}
346346

347-
static void quantize_q8_0_4x8(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k) {
347+
static void ggml_quantize_mat_q8_0_4x8(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k) {
348348
assert(QK8_0 == 32);
349349
assert(k % QK8_0 == 0);
350350
const int nb = k / QK8_0;
@@ -559,7 +559,7 @@ static void quantize_q8_0_4x8(const float * GGML_RESTRICT x, void * GGML_RESTRIC
559559
#endif
560560
}
561561

562-
static void quantize_q8_K_4x8(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k) {
562+
static void ggml_quantize_mat_q8_K_4x8(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k) {
563563
assert(QK_K == 256);
564564
assert(k % QK_K == 0);
565565
const int nb = k / QK_K;
@@ -811,7 +811,7 @@ static void quantize_q8_K_4x8(const float * GGML_RESTRICT x, void * GGML_RESTRIC
811811
// i.e first four bsums from the first super block, followed by first four bsums from second super block and so on
812812
for (int j = 0; j < QK_K * 4; j++) {
813813
int src_offset = (j / (4 * blck_size_interleave)) * blck_size_interleave;
814-
int src_id = (j % (4 * blck_size_interleave)) / blck_size_interleave;
814+
int src_id = (j % (4 * blck_size_interleave)) / blck_size_interleave;
815815
src_offset += (j % blck_size_interleave);
816816
int index = (((j & 31) >> 3) << 2) + ((j >> 8) << 4) + ((j >> 6) & 3);
817817

@@ -823,26 +823,25 @@ static void quantize_q8_K_4x8(const float * GGML_RESTRICT x, void * GGML_RESTRIC
823823
#endif
824824
}
825825

826-
static void quantize_mat_q8_0(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t nrow, int64_t n_per_row, int64_t blck_size_interleave) {
826+
template <int64_t INTER_SIZE, ggml_type PARAM_TYPE>
827+
void ggml_quantize_mat_t(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t nrow, int64_t n_per_row);
828+
829+
template <> void ggml_quantize_mat_t<4, GGML_TYPE_Q8_0>(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t nrow, int64_t n_per_row) {
827830
assert(nrow == 4);
828831
UNUSED(nrow);
829-
if (blck_size_interleave == 4) {
830-
quantize_q8_0_4x4(x, vy, n_per_row);
831-
} else if (blck_size_interleave == 8) {
832-
quantize_q8_0_4x8(x, vy, n_per_row);
833-
} else {
834-
assert(false);
835-
}
832+
ggml_quantize_mat_q8_0_4x4(x, vy, n_per_row);
836833
}
837834

838-
static void quantize_mat_q8_K(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t nrow, int64_t n_per_row, int64_t blck_size_interleave) {
835+
template <> void ggml_quantize_mat_t<8, GGML_TYPE_Q8_0>(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t nrow, int64_t n_per_row) {
839836
assert(nrow == 4);
840837
UNUSED(nrow);
841-
if (blck_size_interleave == 8) {
842-
quantize_q8_K_4x8(x, vy, n_per_row);
843-
} else {
844-
assert(false);
845-
}
838+
ggml_quantize_mat_q8_0_4x8(x, vy, n_per_row);
839+
}
840+
841+
template <> void ggml_quantize_mat_t<8, GGML_TYPE_Q8_K>(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t nrow, int64_t n_per_row) {
842+
assert(nrow == 4);
843+
UNUSED(nrow);
844+
ggml_quantize_mat_q8_K_4x8(x, vy, n_per_row);
846845
}
847846

848847
static void ggml_gemv_q4_0_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
@@ -5276,52 +5275,50 @@ template <> int repack<block_iq4_nl, 4, 4>(struct ggml_tensor * t, const void *
52765275
//}
52775276

52785277
// gemv
5279-
template <typename BLOC_TYPE, int64_t INTER_SIZE, int64_t NB_COLS>
5278+
template <typename BLOC_TYPE, int64_t INTER_SIZE, int64_t NB_COLS, ggml_type PARAM_TYPE>
52805279
void gemv(int, float *, size_t, const void *, const void *, int, int);
52815280

5282-
template <> void gemv<block_q4_0, 4, 4>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
5281+
template <> void gemv<block_q4_0, 4, 4, GGML_TYPE_Q8_0>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
52835282
ggml_gemv_q4_0_4x4_q8_0(n, s, bs, vx, vy, nr, nc);
52845283
}
52855284

5286-
template <> void gemv<block_q4_0, 8, 4>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
5285+
template <> void gemv<block_q4_0, 8, 4, GGML_TYPE_Q8_0>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
52875286
ggml_gemv_q4_0_4x8_q8_0(n, s, bs, vx, vy, nr, nc);
52885287
}
52895288

5290-
template <> void gemv<block_q4_0, 8, 8>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
5289+
template <> void gemv<block_q4_0, 8, 8, GGML_TYPE_Q8_0>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
52915290
ggml_gemv_q4_0_8x8_q8_0(n, s, bs, vx, vy, nr, nc);
52925291
}
52935292

5294-
template <> void gemv<block_q4_K, 8, 8>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
5293+
template <> void gemv<block_q4_K, 8, 8, GGML_TYPE_Q8_K>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
52955294
ggml_gemv_q4_K_8x8_q8_K(n, s, bs, vx, vy, nr, nc);
52965295
}
52975296

5298-
template <>
5299-
void gemv<block_iq4_nl, 4, 4>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
5297+
template <> void gemv<block_iq4_nl, 4, 4, GGML_TYPE_Q8_0>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
53005298
ggml_gemv_iq4_nl_4x4_q8_0(n, s, bs, vx, vy, nr, nc);
53015299
}
53025300

53035301
// gemm
5304-
template <typename BLOC_TYPE, int64_t INTER_SIZE, int64_t NB_COLS>
5302+
template <typename BLOC_TYPE, int64_t INTER_SIZE, int64_t NB_COLS, ggml_type PARAM_TYPE>
53055303
void gemm(int, float *, size_t, const void *, const void *, int, int);
53065304

5307-
template <> void gemm<block_q4_0, 4, 4>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
5305+
template <> void gemm<block_q4_0, 4, 4, GGML_TYPE_Q8_0>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
53085306
ggml_gemm_q4_0_4x4_q8_0(n, s, bs, vx, vy, nr, nc);
53095307
}
53105308

5311-
template <> void gemm<block_q4_0, 8, 4>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
5309+
template <> void gemm<block_q4_0, 8, 4, GGML_TYPE_Q8_0>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
53125310
ggml_gemm_q4_0_4x8_q8_0(n, s, bs, vx, vy, nr, nc);
53135311
}
53145312

5315-
template <> void gemm<block_q4_0, 8, 8>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
5313+
template <> void gemm<block_q4_0, 8, 8, GGML_TYPE_Q8_0>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
53165314
ggml_gemm_q4_0_8x8_q8_0(n, s, bs, vx, vy, nr, nc);
53175315
}
53185316

5319-
template <> void gemm<block_q4_K, 8, 8>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
5317+
template <> void gemm<block_q4_K, 8, 8, GGML_TYPE_Q8_K>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
53205318
ggml_gemm_q4_K_8x8_q8_K(n, s, bs, vx, vy, nr, nc);
53215319
}
53225320

5323-
template <>
5324-
void gemm<block_iq4_nl, 4, 4>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
5321+
template <> void gemm<block_iq4_nl, 4, 4, GGML_TYPE_Q8_0>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
53255322
ggml_gemm_iq4_nl_4x4_q8_0(n, s, bs, vx, vy, nr, nc);
53265323
}
53275324

@@ -5335,32 +5332,32 @@ template <typename BLOC_TYPE, int64_t INTER_SIZE, int64_t NB_COLS, ggml_type PAR
53355332
bool work_size(int /* n_threads */, const struct ggml_tensor * op, size_t & size) override {
53365333
// not realy a GGML_TYPE_Q8_0 but same size.
53375334
switch (op->op) {
5338-
case GGML_OP_MUL_MAT:
5339-
size = ggml_row_size(PARAM_TYPE, ggml_nelements(op->src[1]));
5340-
return true;
5341-
case GGML_OP_MUL_MAT_ID:
5342-
size = ggml_row_size(PARAM_TYPE, ggml_nelements(op->src[1]));
5343-
size = GGML_PAD(size, sizeof(int64_t)); // + padding for next bloc.
5344-
size += sizeof(int64_t) * (1+op->src[0]->ne[2]) * op->src[1]->ne[2];
5345-
return true;
5346-
default:
5347-
// GGML_ABORT("fatal error");
5348-
break;
5335+
case GGML_OP_MUL_MAT:
5336+
size = ggml_row_size(PARAM_TYPE, ggml_nelements(op->src[1]));
5337+
return true;
5338+
case GGML_OP_MUL_MAT_ID:
5339+
size = ggml_row_size(PARAM_TYPE, ggml_nelements(op->src[1]));
5340+
size = GGML_PAD(size, sizeof(int64_t)); // + padding for next bloc.
5341+
size += sizeof(int64_t) * (1+op->src[0]->ne[2]) * op->src[1]->ne[2];
5342+
return true;
5343+
default:
5344+
// GGML_ABORT("fatal error");
5345+
break;
53495346
}
53505347
return false;
53515348
}
53525349

53535350
bool compute_forward(struct ggml_compute_params * params, struct ggml_tensor * op) override {
53545351
switch (op->op) {
5355-
case GGML_OP_MUL_MAT:
5356-
forward_mul_mat(params, op);
5357-
return true;
5358-
case GGML_OP_MUL_MAT_ID:
5359-
forward_mul_mat_id(params, op);
5360-
return true;
5361-
default:
5362-
// GGML_ABORT("fatal error");
5363-
break;
5352+
case GGML_OP_MUL_MAT:
5353+
forward_mul_mat(params, op);
5354+
return true;
5355+
case GGML_OP_MUL_MAT_ID:
5356+
forward_mul_mat_id(params, op);
5357+
return true;
5358+
default:
5359+
// GGML_ABORT("fatal error");
5360+
break;
53645361
}
53655362
return false;
53665363
}
@@ -5399,17 +5396,10 @@ template <typename BLOC_TYPE, int64_t INTER_SIZE, int64_t NB_COLS, ggml_type PAR
53995396
const ggml_from_float_t from_float = ggml_get_type_traits_cpu(PARAM_TYPE)->from_float;
54005397

54015398
int64_t i11_processed = 0;
5402-
if(PARAM_TYPE == GGML_TYPE_Q8_K) {
5403-
for (int64_t i11 = ith * 4; i11 < ne11 - ne11 % 4; i11 += nth * 4) {
5404-
quantize_mat_q8_K((float *) ((char *) src1->data + i11 * nb11), (void *) (wdata + i11 * nbw1), 4, ne10,
5405-
INTER_SIZE);
5406-
}
5407-
} else {
5408-
for (int64_t i11 = ith * 4; i11 < ne11 - ne11 % 4; i11 += nth * 4) {
5409-
quantize_mat_q8_0((float *) ((char *) src1->data + i11 * nb11), (void *) (wdata + i11 * nbw1), 4, ne10,
5410-
INTER_SIZE);
5411-
}
5399+
for (int64_t i11 = ith * 4; i11 < ne11 - ne11 % 4; i11 += nth * 4) {
5400+
ggml_quantize_mat_t<INTER_SIZE, PARAM_TYPE>((float *) ((char *) src1->data + i11 * nb11), (void *) (wdata + i11 * nbw1), 4, ne10);
54125401
}
5402+
54135403
i11_processed = ne11 - ne11 % 4;
54145404
for (int64_t i11 = i11_processed + ith; i11 < ne11; i11 += nth) {
54155405
from_float((float *) ((char *) src1->data + i11 * nb11), (void *) (wdata + i11 * nbw1), ne10);
@@ -5422,22 +5412,24 @@ template <typename BLOC_TYPE, int64_t INTER_SIZE, int64_t NB_COLS, ggml_type PAR
54225412
int64_t src0_start = (ith * ne01) / nth;
54235413
int64_t src0_end = ((ith + 1) * ne01) / nth;
54245414
src0_start = (src0_start % NB_COLS) ? src0_start + NB_COLS - (src0_start % NB_COLS) : src0_start;
5425-
src0_end = (src0_end % NB_COLS) ? src0_end + NB_COLS - (src0_end % NB_COLS) : src0_end;
5415+
src0_end = (src0_end % NB_COLS) ? src0_end + NB_COLS - (src0_end % NB_COLS) : src0_end;
54265416
if (src0_start >= src0_end) {
54275417
return;
54285418
}
54295419

54305420
// If there are more than three rows in src1, use gemm; otherwise, use gemv.
54315421
if (ne11 > 3) {
5432-
gemm<BLOC_TYPE, INTER_SIZE, NB_COLS>(ne00, (float *) ((char *) dst->data) + src0_start, ne01,
5433-
(const char *) src0->data + src0_start * nb01,
5434-
(const char *) src1_wdata, ne11 - ne11 % 4, src0_end - src0_start);
5422+
gemm<BLOC_TYPE, INTER_SIZE, NB_COLS, PARAM_TYPE>(ne00,
5423+
(float *) ((char *) dst->data) + src0_start, ne01,
5424+
(const char *) src0->data + src0_start * nb01,
5425+
(const char *) src1_wdata, ne11 - ne11 % 4, src0_end - src0_start);
54355426
}
54365427
for (int iter = ne11 - ne11 % 4; iter < ne11; iter++) {
5437-
gemv<BLOC_TYPE, INTER_SIZE, NB_COLS>(ne00, (float *) ((char *) dst->data + (iter * nb1)) + src0_start, ne01,
5438-
(const char *) src0->data + src0_start * nb01,
5439-
(const char *) src1_wdata + (src1_col_stride * iter), 1,
5440-
src0_end - src0_start);
5428+
gemv<BLOC_TYPE, INTER_SIZE, NB_COLS, PARAM_TYPE>(ne00,
5429+
(float *) ((char *) dst->data + (iter * nb1)) + src0_start, ne01,
5430+
(const char *) src0->data + src0_start * nb01,
5431+
(const char *) src1_wdata + (src1_col_stride * iter), 1,
5432+
src0_end - src0_start);
54415433
}
54425434
}
54435435

@@ -5452,7 +5444,7 @@ template <typename BLOC_TYPE, int64_t INTER_SIZE, int64_t NB_COLS, ggml_type PAR
54525444
const int ith = params->ith;
54535445
const int nth = params->nth;
54545446

5455-
const ggml_from_float_t from_float = ggml_get_type_traits_cpu(GGML_TYPE_Q8_0)->from_float;
5447+
const ggml_from_float_t from_float = ggml_get_type_traits_cpu(PARAM_TYPE)->from_float;
54565448

54575449
// we don't support permuted src0 or src1
54585450
GGML_ASSERT(nb00 == ggml_type_size(src0->type));
@@ -5474,7 +5466,7 @@ template <typename BLOC_TYPE, int64_t INTER_SIZE, int64_t NB_COLS, ggml_type PAR
54745466
const int n_ids = ids->ne[0]; // n_expert_used
54755467
const int n_as = ne02; // n_expert
54765468

5477-
const size_t nbw1 = ggml_row_size(GGML_TYPE_Q8_0, ne10);
5469+
const size_t nbw1 = ggml_row_size(PARAM_TYPE, ne10);
54785470
const size_t nbw2 = nbw1*ne11;
54795471
const size_t nbw3 = nbw2*ne12;
54805472

@@ -5486,12 +5478,13 @@ template <typename BLOC_TYPE, int64_t INTER_SIZE, int64_t NB_COLS, ggml_type PAR
54865478
GGML_ASSERT(params->wsize >= (GGML_PAD(nbw3, sizeof(int64_t)) + n_as * sizeof(int64_t) +
54875479
n_as * ne12 * sizeof(mmid_row_mapping)));
54885480

5489-
auto wdata = (char *) params->wdata;
5490-
auto wdata_src1_end = (char *) wdata + GGML_PAD(nbw3, sizeof(int64_t));
5491-
int64_t * matrix_row_counts = (int64_t *) (wdata_src1_end); // [n_as]
5481+
auto * wdata = (char *) params->wdata;
5482+
auto * wdata_src1_end = (char *) wdata + GGML_PAD(nbw3, sizeof(int64_t));
5483+
auto * matrix_row_counts = (int64_t *) (wdata_src1_end); // [n_as]
5484+
54925485
struct mmid_row_mapping * matrix_rows = (struct mmid_row_mapping *) (matrix_row_counts + n_as); // [n_as][ne12]
54935486

5494-
// src1: float32 => block_q8_0
5487+
// src1: float32 => param type
54955488
for (int64_t i12 = 0; i12 < ne12; ++i12) {
54965489
for (int64_t i11 = ith; i11 < ne11; i11 += nth) {
54975490
from_float((float *)((char *) src1->data + i12 * nb12 + i11 * nb11),
@@ -5530,34 +5523,37 @@ template <typename BLOC_TYPE, int64_t INTER_SIZE, int64_t NB_COLS, ggml_type PAR
55305523
continue;
55315524
}
55325525

5533-
auto src0_cur = (const char *) src0->data + cur_a*nb02;
5526+
const auto * src0_cur = (const char *) src0->data + cur_a*nb02;
55345527

55355528
//const int64_t nr0 = ne01; // src0 rows
55365529
const int64_t nr1 = cne1; // src1 rows
55375530

55385531
int64_t src0_cur_start = (ith * ne01) / nth;
55395532
int64_t src0_cur_end = ((ith + 1) * ne01) / nth;
5540-
src0_cur_start =
5541-
(src0_cur_start % NB_COLS) ? src0_cur_start + NB_COLS - (src0_cur_start % NB_COLS) : src0_cur_start;
5542-
src0_cur_end = (src0_cur_end % NB_COLS) ? src0_cur_end + NB_COLS - (src0_cur_end % NB_COLS) : src0_cur_end;
55435533

5544-
if (src0_cur_start >= src0_cur_end) return;
5534+
src0_cur_start = (src0_cur_start % NB_COLS) ? src0_cur_start + NB_COLS - (src0_cur_start % NB_COLS) : src0_cur_start;
5535+
src0_cur_end = (src0_cur_end % NB_COLS) ? src0_cur_end + NB_COLS - (src0_cur_end % NB_COLS) : src0_cur_end;
5536+
5537+
if (src0_cur_start >= src0_cur_end) {
5538+
return;
5539+
}
55455540

55465541
for (int ir1 = 0; ir1 < nr1; ir1++) {
55475542
struct mmid_row_mapping row_mapping = MMID_MATRIX_ROW(cur_a, ir1);
5548-
const int id = row_mapping.i1; // selected expert index
55495543

5550-
const int64_t i11 = id % ne11;
5551-
const int64_t i12 = row_mapping.i2; // row index in src1
5544+
const int id = row_mapping.i1; // selected expert index
5545+
5546+
const int64_t i11 = id % ne11;
5547+
const int64_t i12 = row_mapping.i2; // row index in src1
55525548

5553-
const int64_t i1 = id; // selected expert index
5554-
const int64_t i2 = i12; // row
5549+
const int64_t i1 = id; // selected expert index
5550+
const int64_t i2 = i12; // row
55555551

5556-
auto src1_col = (const char *) wdata + (i11 * nbw1 + i12 * nbw2);
5552+
const auto * src1_col = (const char *) wdata + (i11 * nbw1 + i12 * nbw2);
55575553

5558-
gemv<BLOC_TYPE, INTER_SIZE, NB_COLS>(
5559-
ne00, (float *)((char *) dst->data + (i1 * nb1 + i2 * nb2)) + src0_cur_start,
5560-
ne01, src0_cur + src0_cur_start * nb01,
5554+
gemv<BLOC_TYPE, INTER_SIZE, NB_COLS, PARAM_TYPE>(ne00,
5555+
(float *)((char *) dst->data + (i1 * nb1 + i2 * nb2)) + src0_cur_start, ne01,
5556+
src0_cur + src0_cur_start * nb01,
55615557
src1_col, 1, src0_cur_end - src0_cur_start);
55625558
}
55635559
}
@@ -5578,7 +5574,7 @@ static const tensor_traits<block_q4_0, 8, 8, GGML_TYPE_Q8_0> q4_0_8x8_q8_0;
55785574
static const tensor_traits<block_q4_K, 8, 8, GGML_TYPE_Q8_K> q4_K_8x8_q8_K;
55795575

55805576
// instance for IQ4
5581-
static const tensor_traits<block_iq4_nl, 4, 4, GGML_TYPE_IQ4_NL> iq4_nl_4x4_q8_0;
5577+
static const tensor_traits<block_iq4_nl, 4, 4, GGML_TYPE_Q8_0> iq4_nl_4x4_q8_0;
55825578

55835579
} // namespace ggml::cpu::aarch64
55845580

0 commit comments

Comments
 (0)