Skip to content

Commit 7c4263d

Browse files
ikawrakowKawrakow
andauthored
ggml : make i-quants work with super-blocks of 64 (CPU,Metal) (#5760)
* WIP: make i-quants work for QK_K = 64 * iq2_xs: attempt to fix AVX dot product for QK_K = 64 Tests pass, but I get gibberish. * QK_K = 64 tests pass on ARM_NEON and Metal Sadly, that does not mean it actually works. * Make CUDA compile with QK_K = 64 Tests don't pass, plus we get misaligned access * Q2_K: fixed bug in imatrix quantization for QK_K = 64 * iq1_s: turn off SIMD implementation for QK_K = 64 (it does not work) --------- Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>
1 parent cb49e0f commit 7c4263d

File tree

5 files changed

+194
-59
lines changed

5 files changed

+194
-59
lines changed

ggml-cuda.cu

Lines changed: 20 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -544,14 +544,19 @@ static_assert(sizeof(block_iq3_xxs) == sizeof(ggml_fp16_t) + 3*(QK_K/8), "wrong
544544

545545
#define QR3_XS 8
546546
#define QI3_XS (QK_K / (4*QR3_XS))
547+
#if QK_K == 64
548+
#define IQ3S_N_SCALE 2
549+
#else
550+
#define IQ3S_N_SCALE QK_K/64
551+
#endif
547552
typedef struct {
548553
half d;
549554
uint8_t qs[QK_K/4];
550555
uint8_t qh[QK_K/32];
551556
uint8_t signs[QK_K/8];
552-
uint8_t scales[QK_K/64];
557+
uint8_t scales[IQ3S_N_SCALE];
553558
} block_iq3_s;
554-
static_assert(sizeof(block_iq3_s) == sizeof(ggml_fp16_t) + 27*(QK_K/64), "wrong iq3_s block size/padding");
559+
static_assert(sizeof(block_iq3_s) == sizeof(ggml_fp16_t) + 13*(QK_K/32) + IQ3S_N_SCALE, "wrong iq3_s block size/padding");
555560

556561
#define QR1_S 8
557562
#define QI1_S (QK_K / (4*QR1_S))
@@ -571,6 +576,11 @@ typedef struct {
571576
} block_iq4_nl;
572577
static_assert(sizeof(block_iq4_nl) == sizeof(ggml_fp16_t) + QK4_NL/2, "wrong iq4_nl block size/padding");
573578

579+
#if QK_K == 64
580+
#define block_iq4_xs block_iq4_nl
581+
#define QR4_XS QR4_NL
582+
#define QI4_XS QI4_NL
583+
#else
574584
// QR4_XS = 8 is very slightly faster than QR4_XS = 4
575585
#define QR4_XS 8
576586
#define QI4_XS (QK_K / (4*QR4_XS))
@@ -581,7 +591,7 @@ typedef struct {
581591
uint8_t qs[QK_K/2];
582592
} block_iq4_xs;
583593
static_assert(sizeof(block_iq4_xs) == sizeof(ggml_fp16_t) + sizeof(uint16_t) + QK_K/64 + QK_K/2, "wrong iq4_xs block size/padding");
584-
594+
#endif
585595

586596
#define WARP_SIZE 32
587597
#define MATRIX_ROW_PADDING 512 // last row of quant. matrices is a multiple of this to avoid out-of-bounds memory accesses
@@ -2439,9 +2449,9 @@ static __global__ void dequantize_block_iq4_nl(const void * __restrict__ vx, dst
24392449

24402450
}
24412451

2452+
#if QK_K != 64
24422453
template<typename dst_t>
24432454
static __global__ void dequantize_block_iq4_xs(const void * __restrict__ vx, dst_t * __restrict__ yy) {
2444-
24452455
const int i = blockIdx.x;
24462456
const block_iq4_xs * x = (const block_iq4_xs *)vx;
24472457

@@ -2455,8 +2465,8 @@ static __global__ void dequantize_block_iq4_xs(const void * __restrict__ vx, dst
24552465
y[j+ 0] = d * kvalues_iq4nl[q4[j] & 0xf];
24562466
y[j+16] = d * kvalues_iq4nl[q4[j] >> 4];
24572467
}
2458-
24592468
}
2469+
#endif
24602470

24612471
static __global__ void dequantize_mul_mat_vec_q2_k(const void * __restrict__ vx, const float * __restrict__ yy, float * __restrict__ dst, const int ncols, int nrows) {
24622472

@@ -5382,8 +5392,7 @@ static __device__ __forceinline__ float vec_dot_iq4_xs_q8_1(
53825392
return 0.f;
53835393
#endif
53845394
#else
5385-
assert(false);
5386-
return 0.f;
5395+
return vec_dot_iq4_xs_q8_1(vbq, bq8_1, iqs);
53875396
#endif
53885397
}
53895398

@@ -7444,7 +7453,11 @@ static void dequantize_row_iq4_nl_cuda(const void * vx, dst_t * y, const int k,
74447453
template<typename dst_t>
74457454
static void dequantize_row_iq4_xs_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) {
74467455
const int nb = (k + QK_K - 1) / QK_K;
7456+
#if QK_K == 64
7457+
dequantize_block_iq4_nl<<<nb, 32, 0, stream>>>(vx, y);
7458+
#else
74477459
dequantize_block_iq4_xs<<<nb, 32, 0, stream>>>(vx, y);
7460+
#endif
74487461
}
74497462

74507463
template <typename src_t, typename dst_t>

ggml-metal.metal

Lines changed: 30 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -2560,12 +2560,16 @@ typedef struct {
25602560
uint8_t qs[QK4_NL/2];
25612561
} block_iq4_nl;
25622562

2563+
#if QK_K == 64
2564+
#define block_iq4_xs block_iq4_nl
2565+
#else
25632566
typedef struct {
25642567
half d;
25652568
uint16_t scales_h;
25662569
uint8_t scales_l[QK_K/64];
25672570
uint8_t qs[QK_K/2];
25682571
} block_iq4_xs;
2572+
#endif
25692573

25702574
//====================================== dot products =========================
25712575

@@ -4346,7 +4350,6 @@ void kernel_mul_mv_iq2_xxs_f32_impl(
43464350
threadgroup_barrier(mem_flags::mem_threadgroup);
43474351
}
43484352

4349-
#if QK_K == 256
43504353
const int ix = tiisg;
43514354

43524355
device const float * y4 = y + 32 * ix;
@@ -4387,12 +4390,6 @@ void kernel_mul_mv_iq2_xxs_f32_impl(
43874390

43884391
y4 += 32 * 32;
43894392
}
4390-
#else
4391-
(void) x;
4392-
(void) y;
4393-
(void) yl;
4394-
(void) nb32;
4395-
#endif
43964393

43974394
for (int row = 0; row < N_DST; ++row) {
43984395
all_sum = simd_sum(sumf[row]);
@@ -4482,7 +4479,6 @@ void kernel_mul_mv_iq2_xs_f32_impl(
44824479
threadgroup_barrier(mem_flags::mem_threadgroup);
44834480
}
44844481

4485-
#if QK_K == 256
44864482
const int ix = tiisg;
44874483

44884484
device const float * y4 = y + 32 * ix;
@@ -4533,12 +4529,6 @@ void kernel_mul_mv_iq2_xs_f32_impl(
45334529

45344530
y4 += 32 * 32;
45354531
}
4536-
#else
4537-
(void) x;
4538-
(void) y;
4539-
(void) yl;
4540-
(void) nb32;
4541-
#endif
45424532

45434533
for (int row = 0; row < N_DST; ++row) {
45444534
all_sum = simd_sum(sumf[row]);
@@ -4628,7 +4618,6 @@ void kernel_mul_mv_iq3_xxs_f32_impl(
46284618
threadgroup_barrier(mem_flags::mem_threadgroup);
46294619
}
46304620

4631-
#if QK_K == 256
46324621
const int ix = tiisg;
46334622

46344623
device const float * y4 = y + 32 * ix;
@@ -4672,12 +4661,6 @@ void kernel_mul_mv_iq3_xxs_f32_impl(
46724661

46734662
y4 += 32 * 32;
46744663
}
4675-
#else
4676-
(void) x;
4677-
(void) y;
4678-
(void) yl;
4679-
(void) nb32;
4680-
#endif
46814664

46824665
for (int row = 0; row < N_DST; ++row) {
46834666
all_sum = simd_sum(sumf[row]);
@@ -5016,7 +4999,6 @@ void kernel_mul_mv_iq1_s_f32_impl(
50164999

50175000
const int nb32 = nb * (QK_K / 32);
50185001

5019-
#if QK_K == 256
50205002
const int ix = tiisg/2;
50215003
const int il = tiisg%2;
50225004

@@ -5055,12 +5037,6 @@ void kernel_mul_mv_iq1_s_f32_impl(
50555037

50565038
y4 += 16 * 32;
50575039
}
5058-
#else
5059-
(void) x;
5060-
(void) y;
5061-
(void) yl;
5062-
(void) nb32;
5063-
#endif
50645040

50655041
for (int row = 0; row < N_DST; ++row) {
50665042
all_sum = simd_sum(sumf[row]);
@@ -5167,6 +5143,7 @@ void kernel_mul_mv_iq4_nl_f32_impl(
51675143
}
51685144
}
51695145

5146+
#if QK_K != 64
51705147
void kernel_mul_mv_iq4_xs_f32_impl(
51715148
device const void * src0,
51725149
device const float * src1,
@@ -5260,6 +5237,7 @@ void kernel_mul_mv_iq4_xs_f32_impl(
52605237
}
52615238
}
52625239
}
5240+
#endif
52635241

52645242
[[host_name("kernel_mul_mv_iq1_s_f32")]]
52655243
kernel void kernel_mul_mv_iq1_s_f32(
@@ -5344,7 +5322,11 @@ kernel void kernel_mul_mv_iq4_xs_f32(
53445322
uint tiisg[[thread_index_in_simdgroup]],
53455323
uint sgitg[[simdgroup_index_in_threadgroup]]) {
53465324

5325+
#if QK_K == 64
5326+
kernel_mul_mv_iq4_nl_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, shared_values, tgpig, tiisg, sgitg);
5327+
#else
53475328
kernel_mul_mv_iq4_xs_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, shared_values, tgpig, tiisg, sgitg);
5329+
#endif
53485330
}
53495331

53505332
//============================= templates and their specializations =============================
@@ -5770,6 +5752,9 @@ void dequantize_iq4_nl(device const block_iq4_nl * xb, short il, thread type4x4
57705752

57715753
template <typename type4x4>
57725754
void dequantize_iq4_xs(device const block_iq4_xs * xb, short il, thread type4x4 & reg) {
5755+
#if QK_K == 64
5756+
dequantize_iq4_nl(xb, il, reg);
5757+
#else
57735758
// il is 0...15 for QK_K = 256 => index of block of 32 is il/2
57745759
const int ib32 = il/2;
57755760
il = il%2;
@@ -5786,6 +5771,7 @@ void dequantize_iq4_xs(device const block_iq4_xs * xb, short il, thread type4x4
57865771
reg[i][2] = d * kvalues_iq4nl_f[q8[2]];
57875772
reg[i][3] = d * kvalues_iq4nl_f[q8[3]];
57885773
}
5774+
#endif
57895775
}
57905776

57915777
template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread float4x4 &)>
@@ -6334,7 +6320,11 @@ template [[host_name("kernel_get_rows_iq3_s")]] kernel get_rows_t kernel_get_r
63346320
template [[host_name("kernel_get_rows_iq2_s")]] kernel get_rows_t kernel_get_rows<block_iq2_s, QK_NL, dequantize_iq2_s>;
63356321
template [[host_name("kernel_get_rows_iq1_s")]] kernel get_rows_t kernel_get_rows<block_iq1_s, QK_NL, dequantize_iq1_s>;
63366322
template [[host_name("kernel_get_rows_iq4_nl")]] kernel get_rows_t kernel_get_rows<block_iq4_nl, 2, dequantize_iq4_nl>;
6323+
#if QK_K == 64
6324+
template [[host_name("kernel_get_rows_iq4_xs")]] kernel get_rows_t kernel_get_rows<block_iq4_xs, 2, dequantize_iq4_xs>;
6325+
#else
63376326
template [[host_name("kernel_get_rows_iq4_xs")]] kernel get_rows_t kernel_get_rows<block_iq4_xs, QK_NL, dequantize_iq4_xs>;
6327+
#endif
63386328

63396329
//
63406330
// matrix-matrix multiplication
@@ -6378,7 +6368,11 @@ template [[host_name("kernel_mul_mm_iq3_s_f32")]] kernel mat_mm_t kernel_mul_m
63786368
template [[host_name("kernel_mul_mm_iq2_s_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq2_s, QK_NL, dequantize_iq2_s>;
63796369
template [[host_name("kernel_mul_mm_iq1_s_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq1_s, QK_NL, dequantize_iq1_s>;
63806370
template [[host_name("kernel_mul_mm_iq4_nl_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq4_nl, 2, dequantize_iq4_nl>;
6371+
#if QK_K == 64
6372+
template [[host_name("kernel_mul_mm_iq4_xs_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq4_nl, 2, dequantize_iq4_xs>;
6373+
#else
63816374
template [[host_name("kernel_mul_mm_iq4_xs_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq4_xs, QK_NL, dequantize_iq4_xs>;
6375+
#endif
63826376

63836377
//
63846378
// indirect matrix-matrix multiplication
@@ -6434,7 +6428,11 @@ template [[host_name("kernel_mul_mm_id_iq3_s_f32")]] kernel mat_mm_id_t kernel
64346428
template [[host_name("kernel_mul_mm_id_iq2_s_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq2_s, QK_NL, dequantize_iq2_s>;
64356429
template [[host_name("kernel_mul_mm_id_iq1_s_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq1_s, QK_NL, dequantize_iq1_s>;
64366430
template [[host_name("kernel_mul_mm_id_iq4_nl_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq4_nl, 2, dequantize_iq4_nl>;
6431+
#if QK_K == 64
6432+
template [[host_name("kernel_mul_mm_id_iq4_xs_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq4_xs, 2, dequantize_iq4_xs>;
6433+
#else
64376434
template [[host_name("kernel_mul_mm_id_iq4_xs_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq4_xs, QK_NL, dequantize_iq4_xs>;
6435+
#endif
64386436

64396437
//
64406438
// matrix-vector multiplication
@@ -7707,7 +7705,11 @@ kernel void kernel_mul_mv_id_iq4_xs_f32(
77077705

77087706
const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
77097707

7708+
#if QK_K == 64
7709+
kernel_mul_mv_iq4_nl_f32_impl(
7710+
#else
77107711
kernel_mul_mv_iq4_xs_f32_impl(
7712+
#endif
77117713
src0[id],
77127714
(device const float *) (src1 + bid*nb11),
77137715
dst + bid*ne0,

0 commit comments

Comments
 (0)