@@ -3832,13 +3832,16 @@ static __global__ void mul_mat_vec_q(
3832
3832
}
3833
3833
}
3834
3834
3835
- template <int qk, int qr, dequantize_kernel_t dequantize_kernel>
3836
- static __global__ void dequantize_mul_mat_vec (const void * __restrict__ vx, const dfloat * __restrict__ y, float * __restrict__ dst, const int ncols, const int nrows) {
3835
+ template <int qk, int qr, dequantize_kernel_t dequantize_kernel, bool need_check>
3836
+ static __global__ void dequantize_mul_mat_vec (
3837
+ const void * __restrict__ vx, const dfloat * __restrict__ y, float * __restrict__ dst,
3838
+ const int ncols, const int nrows_x, const int nrows_y) {
3839
+
3837
3840
// qk = quantized weights per x block
3838
3841
// qr = number of quantized weights per data value in x block
3839
3842
const int row = blockIdx .y *blockDim .y + threadIdx .y ;
3840
3843
3841
- if (row >= nrows ) {
3844
+ if (row >= nrows_x ) {
3842
3845
return ;
3843
3846
}
3844
3847
@@ -3871,16 +3874,22 @@ static __global__ void dequantize_mul_mat_vec(const void * __restrict__ vx, cons
3871
3874
dfloat2 v;
3872
3875
dequantize_kernel (vx, ib, iqs + j/qr, v);
3873
3876
3877
+ const int iy0 = iybs + iqs + j/qr + 0 ;
3878
+ const int iy1 = iybs + iqs + j/qr + y_offset;
3879
+
3874
3880
// matrix multiplication
3875
3881
// for qr = 2 the y index needs to increase by 1 per j iter because of y_offset = qk/2
3876
3882
#ifdef GGML_CUDA_F16
3877
- tmp += __hmul2 (v, {
3878
- y[iybs + iqs + j/qr + 0 ],
3879
- y[iybs + iqs + j/qr + y_offset]
3880
- });
3883
+ const half yi0 = need_check && iy0 >= nrows_y ? __float2half ( 0 . 0f ) : y[iy0];
3884
+ const half yi1 = need_check && iy1 >= nrows_y ? __float2half ( 0 . 0f ) : y[iy1];
3885
+
3886
+ tmp += __hmul2 (v, {yi0, yi1 });
3881
3887
#else
3882
- tmp += v.x * y[iybs + iqs + j/qr + 0 ];
3883
- tmp += v.y * y[iybs + iqs + j/qr + y_offset];
3888
+ const float yi0 = need_check && iy0 >= nrows_y ? 0 .0f : y[iy0];
3889
+ const float yi1 = need_check && iy1 >= nrows_y ? 0 .0f : y[iy1];
3890
+
3891
+ tmp += v.x * yi0;
3892
+ tmp += v.y * yi1;
3884
3893
#endif // GGML_CUDA_F16
3885
3894
}
3886
3895
}
@@ -4380,91 +4389,120 @@ static void dequantize_row_q6_K_cuda(const void * vx, float * y, const int k, cu
4380
4389
#endif
4381
4390
}
4382
4391
4383
- static void dequantize_mul_mat_vec_q4_0_cuda (const void * vx, const dfloat * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
4392
+ static void dequantize_mul_mat_vec_q4_0_cuda (
4393
+ const void * vx, const dfloat * y, float * dst, const int ncols, const int nrows_x, const int nrows_y, cudaStream_t stream) {
4394
+
4384
4395
GGML_ASSERT (ncols % GGML_CUDA_DMMV_X == 0 );
4385
- const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1 ) / GGML_CUDA_MMV_Y;
4396
+ const int block_num_y = (nrows_x + GGML_CUDA_MMV_Y - 1 ) / GGML_CUDA_MMV_Y;
4386
4397
const dim3 block_nums (1 , block_num_y, 1 );
4387
4398
const dim3 block_dims (WARP_SIZE, GGML_CUDA_MMV_Y, 1 );
4388
- dequantize_mul_mat_vec<QK4_0, QR4_0, dequantize_q4_0>
4389
- <<<block_nums, block_dims, 0 , stream>>> (vx, y, dst, ncols, nrows );
4399
+ dequantize_mul_mat_vec<QK4_0, QR4_0, dequantize_q4_0, false >
4400
+ <<<block_nums, block_dims, 0 , stream>>> (vx, y, dst, ncols, nrows_x, nrows_y );
4390
4401
}
4391
4402
4392
- static void dequantize_mul_mat_vec_q4_1_cuda (const void * vx, const dfloat * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
4403
+ static void dequantize_mul_mat_vec_q4_1_cuda (
4404
+ const void * vx, const dfloat * y, float * dst, const int ncols, const int nrows_x, const int nrows_y, cudaStream_t stream) {
4405
+
4393
4406
GGML_ASSERT (ncols % GGML_CUDA_DMMV_X == 0 );
4394
- const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1 ) / GGML_CUDA_MMV_Y;
4407
+ const int block_num_y = (nrows_x + GGML_CUDA_MMV_Y - 1 ) / GGML_CUDA_MMV_Y;
4395
4408
const dim3 block_nums (1 , block_num_y, 1 );
4396
4409
const dim3 block_dims (WARP_SIZE, GGML_CUDA_MMV_Y, 1 );
4397
- dequantize_mul_mat_vec<QK4_1, QR4_1, dequantize_q4_1>
4398
- <<<block_nums, block_dims, 0 , stream>>> (vx, y, dst, ncols, nrows );
4410
+ dequantize_mul_mat_vec<QK4_1, QR4_1, dequantize_q4_1, false >
4411
+ <<<block_nums, block_dims, 0 , stream>>> (vx, y, dst, ncols, nrows_x, nrows_y );
4399
4412
}
4400
4413
4401
- static void dequantize_mul_mat_vec_q5_0_cuda (const void * vx, const dfloat * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
4414
+ static void dequantize_mul_mat_vec_q5_0_cuda (
4415
+ const void * vx, const dfloat * y, float * dst, const int ncols, const int nrows_x, const int nrows_y, cudaStream_t stream) {
4416
+
4402
4417
GGML_ASSERT (ncols % GGML_CUDA_DMMV_X == 0 );
4403
- const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1 ) / GGML_CUDA_MMV_Y;
4418
+ const int block_num_y = (nrows_x + GGML_CUDA_MMV_Y - 1 ) / GGML_CUDA_MMV_Y;
4404
4419
const dim3 block_nums (1 , block_num_y, 1 );
4405
4420
const dim3 block_dims (WARP_SIZE, GGML_CUDA_MMV_Y, 1 );
4406
- dequantize_mul_mat_vec<QK5_0, QR5_0, dequantize_q5_0>
4407
- <<<block_nums, block_dims, 0 , stream>>> (vx, y, dst, ncols, nrows );
4421
+ dequantize_mul_mat_vec<QK5_0, QR5_0, dequantize_q5_0, false >
4422
+ <<<block_nums, block_dims, 0 , stream>>> (vx, y, dst, ncols, nrows_x, nrows_y );
4408
4423
}
4409
4424
4410
- static void dequantize_mul_mat_vec_q5_1_cuda (const void * vx, const dfloat * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
4425
+ static void dequantize_mul_mat_vec_q5_1_cuda (
4426
+ const void * vx, const dfloat * y, float * dst, const int ncols, const int nrows_x, const int nrows_y, cudaStream_t stream) {
4427
+
4411
4428
GGML_ASSERT (ncols % GGML_CUDA_DMMV_X == 0 );
4412
- const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1 ) / GGML_CUDA_MMV_Y;
4429
+ const int block_num_y = (nrows_x + GGML_CUDA_MMV_Y - 1 ) / GGML_CUDA_MMV_Y;
4413
4430
const dim3 block_nums (1 , block_num_y, 1 );
4414
4431
const dim3 block_dims (WARP_SIZE, GGML_CUDA_MMV_Y, 1 );
4415
- dequantize_mul_mat_vec<QK5_1, QR5_1, dequantize_q5_1>
4416
- <<<block_nums, block_dims, 0 , stream>>> (vx, y, dst, ncols, nrows );
4432
+ dequantize_mul_mat_vec<QK5_1, QR5_1, dequantize_q5_1, false >
4433
+ <<<block_nums, block_dims, 0 , stream>>> (vx, y, dst, ncols, nrows_x, nrows_y );
4417
4434
}
4418
4435
4419
- static void dequantize_mul_mat_vec_q8_0_cuda (const void * vx, const dfloat * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
4420
- GGML_ASSERT (ncols % GGML_CUDA_DMMV_X == 0 );
4421
- const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1 ) / GGML_CUDA_MMV_Y;
4436
+ static void dequantize_mul_mat_vec_q8_0_cuda (
4437
+ const void * vx, const dfloat * y, float * dst, const int ncols, const int nrows_x, const int nrows_y, cudaStream_t stream) {
4438
+
4439
+ const int block_num_y = (nrows_x + GGML_CUDA_MMV_Y - 1 ) / GGML_CUDA_MMV_Y;
4422
4440
const dim3 block_nums (1 , block_num_y, 1 );
4423
4441
const dim3 block_dims (WARP_SIZE, GGML_CUDA_MMV_Y, 1 );
4424
- dequantize_mul_mat_vec<QK8_0, QR8_0, dequantize_q8_0>
4425
- <<<block_nums, block_dims, 0 , stream>>> (vx, y, dst, ncols, nrows);
4442
+ if (ncols == nrows_y && ncols % GGML_CUDA_DMMV_X == 0 ) {
4443
+ dequantize_mul_mat_vec<QK8_0, QR8_0, dequantize_q8_0, false >
4444
+ <<<block_nums, block_dims, 0 , stream>>> (vx, y, dst, ncols, nrows_x, nrows_y);
4445
+ } else {
4446
+ dequantize_mul_mat_vec<QK8_0, QR8_0, dequantize_q8_0, true >
4447
+ <<<block_nums, block_dims, 0 , stream>>> (vx, y, dst, ncols, nrows_x, nrows_y);
4448
+ }
4426
4449
}
4427
4450
4428
- static void dequantize_mul_mat_vec_q2_K_cuda (const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
4451
+ static void dequantize_mul_mat_vec_q2_K_cuda (
4452
+ const void * vx, const float * y, float * dst, const int ncols, const int nrows, const int nrows_y, cudaStream_t stream) {
4453
+
4429
4454
GGML_ASSERT (ncols % QK_K == 0 );
4430
4455
const int ny = 2 ; // very slightly faster than 1 even when K_QUANTS_PER_ITERATION = 2
4431
4456
const int block_num_y = (nrows + ny - 1 ) / ny;
4432
4457
const dim3 block_nums (1 , block_num_y, 1 );
4433
4458
const dim3 block_dims (32 , ny, 1 );
4434
4459
dequantize_mul_mat_vec_q2_k<<<block_nums, block_dims, 0 , stream>>> (vx, y, dst, ncols, nrows);
4460
+ (void ) nrows_y;
4435
4461
}
4436
4462
4437
- static void dequantize_mul_mat_vec_q3_K_cuda (const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
4463
+ static void dequantize_mul_mat_vec_q3_K_cuda (
4464
+ const void * vx, const float * y, float * dst, const int ncols, const int nrows, const int nrows_y, cudaStream_t stream) {
4465
+
4438
4466
GGML_ASSERT (ncols % QK_K == 0 );
4439
4467
const int ny = 2 / K_QUANTS_PER_ITERATION;
4440
4468
const int block_num_y = (nrows + ny - 1 ) / ny;
4441
4469
const dim3 block_nums (1 , block_num_y, 1 );
4442
4470
const dim3 block_dims (32 , ny, 1 );
4443
4471
dequantize_mul_mat_vec_q3_k<<<block_nums, block_dims, 0 , stream>>> (vx, y, dst, ncols, nrows);
4472
+ (void ) nrows_y;
4444
4473
}
4445
4474
4446
- static void dequantize_mul_mat_vec_q4_K_cuda (const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
4475
+ static void dequantize_mul_mat_vec_q4_K_cuda (
4476
+ const void * vx, const float * y, float * dst, const int ncols, const int nrows, const int nrows_y, cudaStream_t stream) {
4477
+
4447
4478
GGML_ASSERT (ncols % QK_K == 0 );
4448
4479
const int ny = 2 / K_QUANTS_PER_ITERATION;
4449
4480
const int block_num_y = (nrows + ny - 1 ) / ny;
4450
4481
const dim3 block_nums (1 , block_num_y, 1 );
4451
4482
const dim3 block_dims (32 , ny, 1 );
4452
4483
dequantize_mul_mat_vec_q4_k<<<block_nums, block_dims, 0 , stream>>> (vx, y, dst, ncols, nrows);
4484
+ (void ) nrows_y;
4453
4485
}
4454
4486
4455
- static void dequantize_mul_mat_vec_q5_K_cuda (const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
4487
+ static void dequantize_mul_mat_vec_q5_K_cuda (
4488
+ const void * vx, const float * y, float * dst, const int ncols, const int nrows, const int nrows_y, cudaStream_t stream) {
4489
+
4456
4490
GGML_ASSERT (ncols % QK_K == 0 );
4457
4491
const dim3 block_dims (32 , 1 , 1 );
4458
4492
dequantize_mul_mat_vec_q5_k<<<nrows, block_dims, 0 , stream>>> (vx, y, dst, ncols);
4493
+ (void ) nrows_y;
4459
4494
}
4460
4495
4461
- static void dequantize_mul_mat_vec_q6_K_cuda (const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
4496
+ static void dequantize_mul_mat_vec_q6_K_cuda (
4497
+ const void * vx, const float * y, float * dst, const int ncols, const int nrows, const int nrows_y, cudaStream_t stream) {
4498
+
4462
4499
GGML_ASSERT (ncols % QK_K == 0 );
4463
4500
const int ny = 2 / K_QUANTS_PER_ITERATION;
4464
4501
const int block_num_y = (nrows + ny - 1 ) / ny;
4465
4502
const dim3 block_nums (1 , block_num_y, 1 );
4466
4503
const dim3 block_dims (32 , ny, 1 );
4467
4504
dequantize_mul_mat_vec_q6_k<<<block_nums, block_dims, 0 , stream>>> (vx, y, dst, ncols, nrows);
4505
+ (void ) nrows_y;
4468
4506
}
4469
4507
4470
4508
static void mul_mat_vec_q4_0_q8_1_cuda (
@@ -4592,13 +4630,15 @@ static void convert_fp16_to_fp32_cuda(const void * vx, float * y, const int k, c
4592
4630
dequantize_block<1 , 1 , convert_f16><<<num_blocks, CUDA_DEQUANTIZE_BLOCK_SIZE, 0 , stream>>> (vx, y, k);
4593
4631
}
4594
4632
4595
- static void convert_mul_mat_vec_f16_cuda (const void * vx, const dfloat * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
4633
+ static void convert_mul_mat_vec_f16_cuda (
4634
+ const void * vx, const dfloat * y, float * dst, const int ncols, const int nrows_x, const int nrows_y, cudaStream_t stream) {
4635
+
4596
4636
GGML_ASSERT (ncols % GGML_CUDA_DMMV_X == 0 );
4597
- const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1 ) / GGML_CUDA_MMV_Y;
4637
+ const int block_num_y = (nrows_x + GGML_CUDA_MMV_Y - 1 ) / GGML_CUDA_MMV_Y;
4598
4638
const dim3 block_nums (1 , block_num_y, 1 );
4599
4639
const dim3 block_dims (WARP_SIZE, GGML_CUDA_MMV_Y, 1 );
4600
- dequantize_mul_mat_vec<1 , 1 , convert_f16>
4601
- <<<block_nums, block_dims, 0 , stream>>> (vx, y, dst, ncols, nrows );
4640
+ dequantize_mul_mat_vec<1 , 1 , convert_f16, false >
4641
+ <<<block_nums, block_dims, 0 , stream>>> (vx, y, dst, ncols, nrows_x, nrows_y );
4602
4642
}
4603
4643
4604
4644
static to_fp32_cuda_t ggml_get_to_fp32_cuda (ggml_type type) {
@@ -5804,37 +5844,37 @@ inline void ggml_cuda_op_mul_mat_vec(
5804
5844
5805
5845
switch (src0->type ) {
5806
5846
case GGML_TYPE_Q4_0:
5807
- dequantize_mul_mat_vec_q4_0_cuda (src0_ddq_i, src1_dfloat, dst_ddf_i, ne00, nrows, cudaStream_main);
5847
+ dequantize_mul_mat_vec_q4_0_cuda (src0_ddq_i, src1_dfloat, dst_ddf_i, ne00, nrows, ne10, cudaStream_main);
5808
5848
break ;
5809
5849
case GGML_TYPE_Q4_1:
5810
- dequantize_mul_mat_vec_q4_1_cuda (src0_ddq_i, src1_dfloat, dst_ddf_i, ne00, nrows, cudaStream_main);
5850
+ dequantize_mul_mat_vec_q4_1_cuda (src0_ddq_i, src1_dfloat, dst_ddf_i, ne00, nrows, ne10, cudaStream_main);
5811
5851
break ;
5812
5852
case GGML_TYPE_Q5_0:
5813
- dequantize_mul_mat_vec_q5_0_cuda (src0_ddq_i, src1_dfloat, dst_ddf_i, ne00, nrows, cudaStream_main);
5853
+ dequantize_mul_mat_vec_q5_0_cuda (src0_ddq_i, src1_dfloat, dst_ddf_i, ne00, nrows, ne10, cudaStream_main);
5814
5854
break ;
5815
5855
case GGML_TYPE_Q5_1:
5816
- dequantize_mul_mat_vec_q5_1_cuda (src0_ddq_i, src1_dfloat, dst_ddf_i, ne00, nrows, cudaStream_main);
5856
+ dequantize_mul_mat_vec_q5_1_cuda (src0_ddq_i, src1_dfloat, dst_ddf_i, ne00, nrows, ne10, cudaStream_main);
5817
5857
break ;
5818
5858
case GGML_TYPE_Q8_0:
5819
- dequantize_mul_mat_vec_q8_0_cuda (src0_ddq_i, src1_dfloat, dst_ddf_i, ne00, nrows, cudaStream_main);
5859
+ dequantize_mul_mat_vec_q8_0_cuda (src0_ddq_i, src1_dfloat, dst_ddf_i, ne00, nrows, ne10, cudaStream_main);
5820
5860
break ;
5821
5861
case GGML_TYPE_Q2_K:
5822
- dequantize_mul_mat_vec_q2_K_cuda (src0_ddq_i, src1_ddf_i, dst_ddf_i, ne00, nrows, cudaStream_main);
5862
+ dequantize_mul_mat_vec_q2_K_cuda (src0_ddq_i, src1_ddf_i, dst_ddf_i, ne00, nrows, ne10, cudaStream_main);
5823
5863
break ;
5824
5864
case GGML_TYPE_Q3_K:
5825
- dequantize_mul_mat_vec_q3_K_cuda (src0_ddq_i, src1_ddf_i, dst_ddf_i, ne00, nrows, cudaStream_main);
5865
+ dequantize_mul_mat_vec_q3_K_cuda (src0_ddq_i, src1_ddf_i, dst_ddf_i, ne00, nrows, ne10, cudaStream_main);
5826
5866
break ;
5827
5867
case GGML_TYPE_Q4_K:
5828
- dequantize_mul_mat_vec_q4_K_cuda (src0_ddq_i, src1_ddf_i, dst_ddf_i, ne00, nrows, cudaStream_main);
5868
+ dequantize_mul_mat_vec_q4_K_cuda (src0_ddq_i, src1_ddf_i, dst_ddf_i, ne00, nrows, ne10, cudaStream_main);
5829
5869
break ;
5830
5870
case GGML_TYPE_Q5_K:
5831
- dequantize_mul_mat_vec_q5_K_cuda (src0_ddq_i, src1_ddf_i, dst_ddf_i, ne00, nrows, cudaStream_main);
5871
+ dequantize_mul_mat_vec_q5_K_cuda (src0_ddq_i, src1_ddf_i, dst_ddf_i, ne00, nrows, ne10, cudaStream_main);
5832
5872
break ;
5833
5873
case GGML_TYPE_Q6_K:
5834
- dequantize_mul_mat_vec_q6_K_cuda (src0_ddq_i, src1_ddf_i, dst_ddf_i, ne00, nrows, cudaStream_main);
5874
+ dequantize_mul_mat_vec_q6_K_cuda (src0_ddq_i, src1_ddf_i, dst_ddf_i, ne00, nrows, ne10, cudaStream_main);
5835
5875
break ;
5836
5876
case GGML_TYPE_F16:
5837
- convert_mul_mat_vec_f16_cuda (src0_ddq_i, src1_dfloat, dst_ddf_i, ne00, nrows, cudaStream_main);
5877
+ convert_mul_mat_vec_f16_cuda (src0_ddq_i, src1_dfloat, dst_ddf_i, ne00, nrows, ne10, cudaStream_main);
5838
5878
break ;
5839
5879
default :
5840
5880
GGML_ASSERT (false );
@@ -6550,8 +6590,7 @@ void ggml_cuda_mul_mat(const ggml_tensor * src0, const ggml_tensor * src1, ggml_
6550
6590
}
6551
6591
6552
6592
// no quantized non-contiguous support for lower CC kernels implemented
6553
- // const bool nc_okay = src0->type == GGML_TYPE_F16 || g_compute_capabilities[g_main_device] >= MIN_CC_DP4A;
6554
- const bool nc_okay = false ;
6593
+ const bool nc_okay = src0->type == GGML_TYPE_F16 || g_compute_capabilities[g_main_device] >= MIN_CC_DP4A;
6555
6594
6556
6595
if (all_on_device && nc_okay && ggml_is_permuted (src0) && ggml_is_permuted (src1) && src1->ne [1 ] == 1 ) {
6557
6596
ggml_cuda_mul_mat_vec_p021 (src0, src1, dst);
0 commit comments