@@ -2560,12 +2560,16 @@ typedef struct {
2560
2560
uint8_t qs[QK4_NL/2 ];
2561
2561
} block_iq4_nl;
2562
2562
2563
+ #if QK_K == 64
2564
+ #define block_iq4_xs block_iq4_nl
2565
+ #else
2563
2566
typedef struct {
2564
2567
half d;
2565
2568
uint16_t scales_h;
2566
2569
uint8_t scales_l[QK_K/64 ];
2567
2570
uint8_t qs[QK_K/2 ];
2568
2571
} block_iq4_xs;
2572
+ #endif
2569
2573
2570
2574
// ====================================== dot products =========================
2571
2575
@@ -4346,7 +4350,6 @@ void kernel_mul_mv_iq2_xxs_f32_impl(
4346
4350
threadgroup_barrier (mem_flags::mem_threadgroup);
4347
4351
}
4348
4352
4349
- #if QK_K == 256
4350
4353
const int ix = tiisg;
4351
4354
4352
4355
device const float * y4 = y + 32 * ix;
@@ -4387,12 +4390,6 @@ void kernel_mul_mv_iq2_xxs_f32_impl(
4387
4390
4388
4391
y4 += 32 * 32 ;
4389
4392
}
4390
- #else
4391
- (void ) x;
4392
- (void ) y;
4393
- (void ) yl;
4394
- (void ) nb32;
4395
- #endif
4396
4393
4397
4394
for (int row = 0 ; row < N_DST; ++row) {
4398
4395
all_sum = simd_sum (sumf[row]);
@@ -4482,7 +4479,6 @@ void kernel_mul_mv_iq2_xs_f32_impl(
4482
4479
threadgroup_barrier (mem_flags::mem_threadgroup);
4483
4480
}
4484
4481
4485
- #if QK_K == 256
4486
4482
const int ix = tiisg;
4487
4483
4488
4484
device const float * y4 = y + 32 * ix;
@@ -4533,12 +4529,6 @@ void kernel_mul_mv_iq2_xs_f32_impl(
4533
4529
4534
4530
y4 += 32 * 32 ;
4535
4531
}
4536
- #else
4537
- (void ) x;
4538
- (void ) y;
4539
- (void ) yl;
4540
- (void ) nb32;
4541
- #endif
4542
4532
4543
4533
for (int row = 0 ; row < N_DST; ++row) {
4544
4534
all_sum = simd_sum (sumf[row]);
@@ -4628,7 +4618,6 @@ void kernel_mul_mv_iq3_xxs_f32_impl(
4628
4618
threadgroup_barrier (mem_flags::mem_threadgroup);
4629
4619
}
4630
4620
4631
- #if QK_K == 256
4632
4621
const int ix = tiisg;
4633
4622
4634
4623
device const float * y4 = y + 32 * ix;
@@ -4672,12 +4661,6 @@ void kernel_mul_mv_iq3_xxs_f32_impl(
4672
4661
4673
4662
y4 += 32 * 32 ;
4674
4663
}
4675
- #else
4676
- (void ) x;
4677
- (void ) y;
4678
- (void ) yl;
4679
- (void ) nb32;
4680
- #endif
4681
4664
4682
4665
for (int row = 0 ; row < N_DST; ++row) {
4683
4666
all_sum = simd_sum (sumf[row]);
@@ -5016,7 +4999,6 @@ void kernel_mul_mv_iq1_s_f32_impl(
5016
4999
5017
5000
const int nb32 = nb * (QK_K / 32 );
5018
5001
5019
- #if QK_K == 256
5020
5002
const int ix = tiisg/2 ;
5021
5003
const int il = tiisg%2 ;
5022
5004
@@ -5055,12 +5037,6 @@ void kernel_mul_mv_iq1_s_f32_impl(
5055
5037
5056
5038
y4 += 16 * 32 ;
5057
5039
}
5058
- #else
5059
- (void ) x;
5060
- (void ) y;
5061
- (void ) yl;
5062
- (void ) nb32;
5063
- #endif
5064
5040
5065
5041
for (int row = 0 ; row < N_DST; ++row) {
5066
5042
all_sum = simd_sum (sumf[row]);
@@ -5167,6 +5143,7 @@ void kernel_mul_mv_iq4_nl_f32_impl(
5167
5143
}
5168
5144
}
5169
5145
5146
+ #if QK_K != 64
5170
5147
void kernel_mul_mv_iq4_xs_f32_impl (
5171
5148
device const void * src0,
5172
5149
device const float * src1,
@@ -5260,6 +5237,7 @@ void kernel_mul_mv_iq4_xs_f32_impl(
5260
5237
}
5261
5238
}
5262
5239
}
5240
+ #endif
5263
5241
5264
5242
[[host_name(" kernel_mul_mv_iq1_s_f32" )]]
5265
5243
kernel void kernel_mul_mv_iq1_s_f32 (
@@ -5344,7 +5322,11 @@ kernel void kernel_mul_mv_iq4_xs_f32(
5344
5322
uint tiisg[[thread_index_in_simdgroup]],
5345
5323
uint sgitg[[simdgroup_index_in_threadgroup]]) {
5346
5324
5325
+ #if QK_K == 64
5326
+ kernel_mul_mv_iq4_nl_f32_impl (src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, shared_values, tgpig, tiisg, sgitg);
5327
+ #else
5347
5328
kernel_mul_mv_iq4_xs_f32_impl (src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, shared_values, tgpig, tiisg, sgitg);
5329
+ #endif
5348
5330
}
5349
5331
5350
5332
// ============================= templates and their specializations =============================
@@ -5770,6 +5752,9 @@ void dequantize_iq4_nl(device const block_iq4_nl * xb, short il, thread type4x4
5770
5752
5771
5753
template <typename type4x4>
5772
5754
void dequantize_iq4_xs (device const block_iq4_xs * xb, short il, thread type4x4 & reg) {
5755
+ #if QK_K == 64
5756
+ dequantize_iq4_nl (xb, il, reg);
5757
+ #else
5773
5758
// il is 0...15 for QK_K = 256 => index of block of 32 is il/2
5774
5759
const int ib32 = il/2 ;
5775
5760
il = il%2 ;
@@ -5786,6 +5771,7 @@ void dequantize_iq4_xs(device const block_iq4_xs * xb, short il, thread type4x4
5786
5771
reg[i][2 ] = d * kvalues_iq4nl_f[q8[2 ]];
5787
5772
reg[i][3 ] = d * kvalues_iq4nl_f[q8[3 ]];
5788
5773
}
5774
+ #endif
5789
5775
}
5790
5776
5791
5777
template <typename block_q, short nl, void (*dequantize_func)(device const block_q *, short , thread float4x4 &)>
@@ -6334,7 +6320,11 @@ template [[host_name("kernel_get_rows_iq3_s")]] kernel get_rows_t kernel_get_r
6334
6320
template [[host_name(" kernel_get_rows_iq2_s" )]] kernel get_rows_t kernel_get_rows<block_iq2_s, QK_NL, dequantize_iq2_s>;
6335
6321
template [[host_name(" kernel_get_rows_iq1_s" )]] kernel get_rows_t kernel_get_rows<block_iq1_s, QK_NL, dequantize_iq1_s>;
6336
6322
template [[host_name(" kernel_get_rows_iq4_nl" )]] kernel get_rows_t kernel_get_rows<block_iq4_nl, 2 , dequantize_iq4_nl>;
6323
+ #if QK_K == 64
6324
+ template [[host_name(" kernel_get_rows_iq4_xs" )]] kernel get_rows_t kernel_get_rows<block_iq4_xs, 2 , dequantize_iq4_xs>;
6325
+ #else
6337
6326
template [[host_name(" kernel_get_rows_iq4_xs" )]] kernel get_rows_t kernel_get_rows<block_iq4_xs, QK_NL, dequantize_iq4_xs>;
6327
+ #endif
6338
6328
6339
6329
//
6340
6330
// matrix-matrix multiplication
@@ -6378,7 +6368,11 @@ template [[host_name("kernel_mul_mm_iq3_s_f32")]] kernel mat_mm_t kernel_mul_m
6378
6368
template [[host_name(" kernel_mul_mm_iq2_s_f32" )]] kernel mat_mm_t kernel_mul_mm<block_iq2_s, QK_NL, dequantize_iq2_s>;
6379
6369
template [[host_name(" kernel_mul_mm_iq1_s_f32" )]] kernel mat_mm_t kernel_mul_mm<block_iq1_s, QK_NL, dequantize_iq1_s>;
6380
6370
template [[host_name(" kernel_mul_mm_iq4_nl_f32" )]] kernel mat_mm_t kernel_mul_mm<block_iq4_nl, 2 , dequantize_iq4_nl>;
6371
+ #if QK_K == 64
6372
+ template [[host_name(" kernel_mul_mm_iq4_xs_f32" )]] kernel mat_mm_t kernel_mul_mm<block_iq4_nl, 2 , dequantize_iq4_xs>;
6373
+ #else
6381
6374
template [[host_name(" kernel_mul_mm_iq4_xs_f32" )]] kernel mat_mm_t kernel_mul_mm<block_iq4_xs, QK_NL, dequantize_iq4_xs>;
6375
+ #endif
6382
6376
6383
6377
//
6384
6378
// indirect matrix-matrix multiplication
@@ -6434,7 +6428,11 @@ template [[host_name("kernel_mul_mm_id_iq3_s_f32")]] kernel mat_mm_id_t kernel
6434
6428
template [[host_name(" kernel_mul_mm_id_iq2_s_f32" )]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq2_s, QK_NL, dequantize_iq2_s>;
6435
6429
template [[host_name(" kernel_mul_mm_id_iq1_s_f32" )]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq1_s, QK_NL, dequantize_iq1_s>;
6436
6430
template [[host_name(" kernel_mul_mm_id_iq4_nl_f32" )]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq4_nl, 2 , dequantize_iq4_nl>;
6431
+ #if QK_K == 64
6432
+ template [[host_name(" kernel_mul_mm_id_iq4_xs_f32" )]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq4_xs, 2 , dequantize_iq4_xs>;
6433
+ #else
6437
6434
template [[host_name(" kernel_mul_mm_id_iq4_xs_f32" )]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq4_xs, QK_NL, dequantize_iq4_xs>;
6435
+ #endif
6438
6436
6439
6437
//
6440
6438
// matrix-vector multiplication
@@ -7707,7 +7705,11 @@ kernel void kernel_mul_mv_id_iq4_xs_f32(
7707
7705
7708
7706
const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
7709
7707
7708
+ #if QK_K == 64
7709
+ kernel_mul_mv_iq4_nl_f32_impl (
7710
+ #else
7710
7711
kernel_mul_mv_iq4_xs_f32_impl (
7712
+ #endif
7711
7713
src0[id],
7712
7714
(device const float *) (src1 + bid*nb11),
7713
7715
dst + bid*ne0,
0 commit comments