@@ -544,7 +544,7 @@ static_assert(sizeof(block_iq3_xxs) == sizeof(ggml_fp16_t) + 3*(QK_K/8), "wrong
544
544
545
545
#define QR3_XS 8
546
546
#define QI3_XS (QK_K / (4 *QR3_XS))
547
- #define IQ3S_BLOCK_SIZE 16
547
+ #define IQ3S_BLOCK_SIZE 32
548
548
typedef struct {
549
549
half d;
550
550
uint8_t qs[QK_K/4 ];
@@ -5237,7 +5237,11 @@ static __device__ __forceinline__ float vec_dot_iq3_s_q8_1(
5237
5237
const int ib32 = iqs;
5238
5238
const uint8_t * qs = bq2->qs + 8 *ib32;
5239
5239
const int8_t * q8 = bq8_1[ib32].qs ;
5240
+ #if IQ3S_BLOCK_SIZE == 32
5240
5241
int sumi = 0 ;
5242
+ #else
5243
+ int sumi[2 ] = {0 , 0 };
5244
+ #endif
5241
5245
for (int l = 0 ; l < 4 ; ++l) {
5242
5246
#ifdef IQ3S_SLOW_MULT
5243
5247
aux32[0 ] = ((qs[2 *l+0 ] | ((bq2->qh [ib32] << (8 - 2 *l)) & 256 )) * IQ3S_MULTIPLIER) & 0x0f0f0f0f ;
@@ -5252,12 +5256,23 @@ static __device__ __forceinline__ float vec_dot_iq3_s_q8_1(
5252
5256
uint32_t signs1 = __vcmpeq4 (((bq2->signs [4 *ib32+l] >> 4 ) * 0x01010101 ) & 0x08040201 , 0x08040201 );
5253
5257
const int grid_l = __vsub4 (aux32[0 ] ^ signs0, signs0);
5254
5258
const int grid_h = __vsub4 (aux32[1 ] ^ signs1, signs1);
5259
+ #if IQ3S_BLOCK_SIZE == 32
5255
5260
sumi = __dp4a (grid_l, *((int *)q8+0 ), sumi);
5256
5261
sumi = __dp4a (grid_h, *((int *)q8+1 ), sumi);
5262
+ #else
5263
+ sumi[l/2 ] = __dp4a (grid_l, *((int *)q8+0 ), sumi[l/2 ]);
5264
+ sumi[l/2 ] = __dp4a (grid_h, *((int *)q8+1 ), sumi[l/2 ]);
5265
+ #endif
5257
5266
q8 += 8 ;
5258
5267
}
5268
+ #if IQ3S_BLOCK_SIZE == 32
5259
5269
const float d = (float )bq2->d * (1 + 2 *((bq2->scales [ib32/2 ] >> 4 *(ib32%2 )) & 0xf )) * __low2float (bq8_1[ib32].ds );
5260
5270
return d * sumi;
5271
+ #else
5272
+ int ls1 = 1 + 2 *(bq2->scales [ib32] & 0xf );
5273
+ int ls2 = 1 + 2 *(bq2->scales [ib32] >> 4 );
5274
+ return (float )bq2->d * __low2float (bq8_1[ib32].ds ) * (ls1 * sumi[0 ] + ls2 * sumi[1 ]);
5275
+ #endif
5261
5276
#else
5262
5277
assert (false );
5263
5278
return 0 .f ;
0 commit comments