Skip to content

Commit e6e61e3

Browse files
committed
iq3_s: partial fix for QK_K = 64
1 parent 1d47de3 commit e6e61e3

File tree

3 files changed

+17
-7
lines changed

3 files changed

+17
-7
lines changed

ggml-metal.metal

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2526,12 +2526,17 @@ typedef struct {
25262526
// 98 bytes / block for QK_K = 256, so 3.0625 bpw
25272527

25282528
// 3.4375 bpw
2529+
#if QK_K == 64
2530+
#define IQ3S_N_SCALE 2
2531+
#else
2532+
#define IQ3S_N_SCALE QK_K/64
2533+
#endif
25292534
typedef struct {
25302535
half d;
25312536
uint8_t qs[QK_K/4];
25322537
uint8_t qh[QK_K/32];
25332538
uint8_t signs[QK_K/8];
2534-
uint8_t scales[QK_K/64];
2539+
uint8_t scales[IQ3S_N_SCALE];
25352540
} block_iq3_s;
25362541

25372542
typedef struct {

ggml-quants.c

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10772,7 +10772,7 @@ static void quantize_row_iq3_xxs_impl(int grid_size, const float * restrict x, v
1077210772

1077310773
const int kMaxQ = 8;
1077410774

10775-
const int nbl = n/256;
10775+
const int nbl = n/QK_K;
1077610776

1077710777
ggml_fp16_t * dh;
1077810778
uint8_t * qs;
@@ -11018,7 +11018,7 @@ static void quantize_row_iq3_s_impl(int block_size, const float * restrict x, vo
1101811018

1101911019
const int kMaxQ = 8;
1102011020

11021-
const int nbl = n/256;
11021+
const int nbl = n/QK_K;
1102211022

1102311023
block_iq3_s * y = vy;
1102411024

@@ -11189,7 +11189,7 @@ size_t quantize_iq3_s(const float * src, void * dst, int nrow, int n_per_row, in
1118911189
uint8_t block_signs[IQ3S_BLOCK_SIZE/8];
1119011190
char * qrow = (char *)dst;
1119111191
for (int row = 0; row < nrow; ++row) {
11192-
quantize_row_iq3_s_impl(32, src, qrow, n_per_row, quant_weights,
11192+
quantize_row_iq3_s_impl(IQ3S_BLOCK_SIZE, src, qrow, n_per_row, quant_weights,
1119311193
scales, weight, xval, L, Laux, waux, is_on_grid, is_on_grid_aux, block_signs);
1119411194
src += n_per_row;
1119511195
qrow += nblock*sizeof(block_iq3_s);

ggml-quants.h

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -191,15 +191,20 @@ typedef struct {
191191
} block_iq3_xxs;
192192
static_assert(sizeof(block_iq3_xxs) == sizeof(ggml_fp16_t) + 3*(QK_K/8), "wrong iq3_xxs block size/padding");
193193

194-
// 3.3125 bpw
194+
// 3.4375 bpw
195+
#if QK_K == 64
196+
#define IQ3S_N_SCALE 2
197+
#else
198+
#define IQ3S_N_SCALE QK_K/64
199+
#endif
195200
typedef struct {
196201
ggml_fp16_t d;
197202
uint8_t qs[QK_K/4];
198203
uint8_t qh[QK_K/32];
199204
uint8_t signs[QK_K/8];
200-
uint8_t scales[QK_K/64];
205+
uint8_t scales[IQ3S_N_SCALE];
201206
} block_iq3_s;
202-
static_assert(sizeof(block_iq3_s) == sizeof(ggml_fp16_t) + 27*(QK_K/64), "wrong iq3_s block size/padding");
207+
static_assert(sizeof(block_iq3_s) == sizeof(ggml_fp16_t) + 13*(QK_K/32) + IQ3S_N_SCALE, "wrong iq3_s block size/padding");
203208

204209
typedef struct {
205210
ggml_fp16_t d;

0 commit comments

Comments
 (0)