Skip to content

Commit f2c2bd6

Browse files
committed
iq3_s_mult: also CUDA
1 parent e5e7256 commit f2c2bd6

File tree

2 files changed

+18
-1
lines changed

2 files changed

+18
-1
lines changed

ggml-cuda.cu

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -544,7 +544,7 @@ static_assert(sizeof(block_iq3_xxs) == sizeof(ggml_fp16_t) + 3*(QK_K/8), "wrong
544544

545545
#define QR3_XS 8
546546
#define QI3_XS (QK_K / (4*QR3_XS))
547-
#define IQ3S_BLOCK_SIZE 16
547+
#define IQ3S_BLOCK_SIZE 32
548548
typedef struct {
549549
half d;
550550
uint8_t qs[QK_K/4];
@@ -5237,7 +5237,11 @@ static __device__ __forceinline__ float vec_dot_iq3_s_q8_1(
52375237
const int ib32 = iqs;
52385238
const uint8_t * qs = bq2->qs + 8*ib32;
52395239
const int8_t * q8 = bq8_1[ib32].qs;
5240+
#if IQ3S_BLOCK_SIZE == 32
52405241
int sumi = 0;
5242+
#else
5243+
int sumi[2] = {0, 0};
5244+
#endif
52415245
for (int l = 0; l < 4; ++l) {
52425246
#ifdef IQ3S_SLOW_MULT
52435247
aux32[0] = ((qs[2*l+0] | ((bq2->qh[ib32] << (8 - 2*l)) & 256)) * IQ3S_MULTIPLIER) & 0x0f0f0f0f;
@@ -5252,12 +5256,23 @@ static __device__ __forceinline__ float vec_dot_iq3_s_q8_1(
52525256
uint32_t signs1 = __vcmpeq4(((bq2->signs[4*ib32+l] >> 4) * 0x01010101) & 0x08040201, 0x08040201);
52535257
const int grid_l = __vsub4(aux32[0] ^ signs0, signs0);
52545258
const int grid_h = __vsub4(aux32[1] ^ signs1, signs1);
5259+
#if IQ3S_BLOCK_SIZE == 32
52555260
sumi = __dp4a(grid_l, *((int *)q8+0), sumi);
52565261
sumi = __dp4a(grid_h, *((int *)q8+1), sumi);
5262+
#else
5263+
sumi[l/2] = __dp4a(grid_l, *((int *)q8+0), sumi[l/2]);
5264+
sumi[l/2] = __dp4a(grid_h, *((int *)q8+1), sumi[l/2]);
5265+
#endif
52575266
q8 += 8;
52585267
}
5268+
#if IQ3S_BLOCK_SIZE == 32
52595269
const float d = (float)bq2->d * (1 + 2*((bq2->scales[ib32/2] >> 4*(ib32%2)) & 0xf)) * __low2float(bq8_1[ib32].ds);
52605270
return d * sumi;
5271+
#else
5272+
int ls1 = 1 + 2*(bq2->scales[ib32] & 0xf);
5273+
int ls2 = 1 + 2*(bq2->scales[ib32] >> 4);
5274+
return (float)bq2->d * __low2float(bq8_1[ib32].ds) * (ls1 * sumi[0] + ls2 * sumi[1]);
5275+
#endif
52615276
#else
52625277
assert(false);
52635278
return 0.f;

ggml-quants.c

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10037,6 +10037,8 @@ void ggml_vec_dot_iq3_s_q8_K (int n, float * GGML_RESTRICT s, size_t bs, const v
1003710037
UNUSED(by);
1003810038
UNUSED(bs);
1003910039

10040+
GGML_ASSERT(IQ3S_BLOCK_SIZE == 32 && "IQ3S_BLOCK_SIZE != 32 is not implemented");
10041+
1004010042
const block_iq3_s * restrict x = vx;
1004110043
const block_q8_K * restrict y = vy;
1004210044

0 commit comments

Comments
 (0)