Skip to content

Commit e0d5c0b

Browse files
fixed dmmv
1 parent 1f8bbf1 commit e0d5c0b

File tree

1 file changed

+91
-52
lines changed

1 file changed

+91
-52
lines changed

ggml-cuda.cu

Lines changed: 91 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -3832,13 +3832,16 @@ static __global__ void mul_mat_vec_q(
38323832
}
38333833
}
38343834

3835-
template <int qk, int qr, dequantize_kernel_t dequantize_kernel>
3836-
static __global__ void dequantize_mul_mat_vec(const void * __restrict__ vx, const dfloat * __restrict__ y, float * __restrict__ dst, const int ncols, const int nrows) {
3835+
template <int qk, int qr, dequantize_kernel_t dequantize_kernel, bool need_check>
3836+
static __global__ void dequantize_mul_mat_vec(
3837+
const void * __restrict__ vx, const dfloat * __restrict__ y, float * __restrict__ dst,
3838+
const int ncols, const int nrows_x, const int nrows_y) {
3839+
38373840
// qk = quantized weights per x block
38383841
// qr = number of quantized weights per data value in x block
38393842
const int row = blockIdx.y*blockDim.y + threadIdx.y;
38403843

3841-
if (row >= nrows) {
3844+
if (row >= nrows_x) {
38423845
return;
38433846
}
38443847

@@ -3871,16 +3874,22 @@ static __global__ void dequantize_mul_mat_vec(const void * __restrict__ vx, cons
38713874
dfloat2 v;
38723875
dequantize_kernel(vx, ib, iqs + j/qr, v);
38733876

3877+
const int iy0 = iybs + iqs + j/qr + 0;
3878+
const int iy1 = iybs + iqs + j/qr + y_offset;
3879+
38743880
// matrix multiplication
38753881
// for qr = 2 the y index needs to increase by 1 per j iter because of y_offset = qk/2
38763882
#ifdef GGML_CUDA_F16
3877-
tmp += __hmul2(v, {
3878-
y[iybs + iqs + j/qr + 0],
3879-
y[iybs + iqs + j/qr + y_offset]
3880-
});
3883+
const half yi0 = need_check && iy0 >= nrows_y ? __float2half(0.0f) : y[iy0];
3884+
const half yi1 = need_check && iy1 >= nrows_y ? __float2half(0.0f) : y[iy1];
3885+
3886+
tmp += __hmul2(v, {yi0, yi1});
38813887
#else
3882-
tmp += v.x * y[iybs + iqs + j/qr + 0];
3883-
tmp += v.y * y[iybs + iqs + j/qr + y_offset];
3888+
const float yi0 = need_check && iy0 >= nrows_y ? 0.0f : y[iy0];
3889+
const float yi1 = need_check && iy1 >= nrows_y ? 0.0f : y[iy1];
3890+
3891+
tmp += v.x * yi0;
3892+
tmp += v.y * yi1;
38843893
#endif // GGML_CUDA_F16
38853894
}
38863895
}
@@ -4380,91 +4389,120 @@ static void dequantize_row_q6_K_cuda(const void * vx, float * y, const int k, cu
43804389
#endif
43814390
}
43824391

4383-
static void dequantize_mul_mat_vec_q4_0_cuda(const void * vx, const dfloat * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
4392+
static void dequantize_mul_mat_vec_q4_0_cuda(
4393+
const void * vx, const dfloat * y, float * dst, const int ncols, const int nrows_x, const int nrows_y, cudaStream_t stream) {
4394+
43844395
GGML_ASSERT(ncols % GGML_CUDA_DMMV_X == 0);
4385-
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
4396+
const int block_num_y = (nrows_x + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
43864397
const dim3 block_nums(1, block_num_y, 1);
43874398
const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
4388-
dequantize_mul_mat_vec<QK4_0, QR4_0, dequantize_q4_0>
4389-
<<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols, nrows);
4399+
dequantize_mul_mat_vec<QK4_0, QR4_0, dequantize_q4_0, false>
4400+
<<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols, nrows_x, nrows_y);
43904401
}
43914402

4392-
static void dequantize_mul_mat_vec_q4_1_cuda(const void * vx, const dfloat * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
4403+
static void dequantize_mul_mat_vec_q4_1_cuda(
4404+
const void * vx, const dfloat * y, float * dst, const int ncols, const int nrows_x, const int nrows_y, cudaStream_t stream) {
4405+
43934406
GGML_ASSERT(ncols % GGML_CUDA_DMMV_X == 0);
4394-
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
4407+
const int block_num_y = (nrows_x + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
43954408
const dim3 block_nums(1, block_num_y, 1);
43964409
const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
4397-
dequantize_mul_mat_vec<QK4_1, QR4_1, dequantize_q4_1>
4398-
<<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols, nrows);
4410+
dequantize_mul_mat_vec<QK4_1, QR4_1, dequantize_q4_1, false>
4411+
<<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols, nrows_x, nrows_y);
43994412
}
44004413

4401-
static void dequantize_mul_mat_vec_q5_0_cuda(const void * vx, const dfloat * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
4414+
static void dequantize_mul_mat_vec_q5_0_cuda(
4415+
const void * vx, const dfloat * y, float * dst, const int ncols, const int nrows_x, const int nrows_y, cudaStream_t stream) {
4416+
44024417
GGML_ASSERT(ncols % GGML_CUDA_DMMV_X == 0);
4403-
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
4418+
const int block_num_y = (nrows_x + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
44044419
const dim3 block_nums(1, block_num_y, 1);
44054420
const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
4406-
dequantize_mul_mat_vec<QK5_0, QR5_0, dequantize_q5_0>
4407-
<<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols, nrows);
4421+
dequantize_mul_mat_vec<QK5_0, QR5_0, dequantize_q5_0, false>
4422+
<<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols, nrows_x, nrows_y);
44084423
}
44094424

4410-
static void dequantize_mul_mat_vec_q5_1_cuda(const void * vx, const dfloat * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
4425+
static void dequantize_mul_mat_vec_q5_1_cuda(
4426+
const void * vx, const dfloat * y, float * dst, const int ncols, const int nrows_x, const int nrows_y, cudaStream_t stream) {
4427+
44114428
GGML_ASSERT(ncols % GGML_CUDA_DMMV_X == 0);
4412-
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
4429+
const int block_num_y = (nrows_x + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
44134430
const dim3 block_nums(1, block_num_y, 1);
44144431
const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
4415-
dequantize_mul_mat_vec<QK5_1, QR5_1, dequantize_q5_1>
4416-
<<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols, nrows);
4432+
dequantize_mul_mat_vec<QK5_1, QR5_1, dequantize_q5_1, false>
4433+
<<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols, nrows_x, nrows_y);
44174434
}
44184435

4419-
static void dequantize_mul_mat_vec_q8_0_cuda(const void * vx, const dfloat * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
4420-
GGML_ASSERT(ncols % GGML_CUDA_DMMV_X == 0);
4421-
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
4436+
static void dequantize_mul_mat_vec_q8_0_cuda(
4437+
const void * vx, const dfloat * y, float * dst, const int ncols, const int nrows_x, const int nrows_y, cudaStream_t stream) {
4438+
4439+
const int block_num_y = (nrows_x + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
44224440
const dim3 block_nums(1, block_num_y, 1);
44234441
const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
4424-
dequantize_mul_mat_vec<QK8_0, QR8_0, dequantize_q8_0>
4425-
<<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols, nrows);
4442+
if (ncols == nrows_y && ncols % GGML_CUDA_DMMV_X == 0) {
4443+
dequantize_mul_mat_vec<QK8_0, QR8_0, dequantize_q8_0, false>
4444+
<<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols, nrows_x, nrows_y);
4445+
} else {
4446+
dequantize_mul_mat_vec<QK8_0, QR8_0, dequantize_q8_0, true>
4447+
<<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols, nrows_x, nrows_y);
4448+
}
44264449
}
44274450

4428-
static void dequantize_mul_mat_vec_q2_K_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
4451+
static void dequantize_mul_mat_vec_q2_K_cuda(
4452+
const void * vx, const float * y, float * dst, const int ncols, const int nrows, const int nrows_y, cudaStream_t stream) {
4453+
44294454
GGML_ASSERT(ncols % QK_K == 0);
44304455
const int ny = 2; // very slightly faster than 1 even when K_QUANTS_PER_ITERATION = 2
44314456
const int block_num_y = (nrows + ny - 1) / ny;
44324457
const dim3 block_nums(1, block_num_y, 1);
44334458
const dim3 block_dims(32, ny, 1);
44344459
dequantize_mul_mat_vec_q2_k<<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols, nrows);
4460+
(void) nrows_y;
44354461
}
44364462

4437-
static void dequantize_mul_mat_vec_q3_K_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
4463+
static void dequantize_mul_mat_vec_q3_K_cuda(
4464+
const void * vx, const float * y, float * dst, const int ncols, const int nrows, const int nrows_y, cudaStream_t stream) {
4465+
44384466
GGML_ASSERT(ncols % QK_K == 0);
44394467
const int ny = 2 / K_QUANTS_PER_ITERATION;
44404468
const int block_num_y = (nrows + ny - 1) / ny;
44414469
const dim3 block_nums(1, block_num_y, 1);
44424470
const dim3 block_dims(32, ny, 1);
44434471
dequantize_mul_mat_vec_q3_k<<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols, nrows);
4472+
(void) nrows_y;
44444473
}
44454474

4446-
static void dequantize_mul_mat_vec_q4_K_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
4475+
static void dequantize_mul_mat_vec_q4_K_cuda(
4476+
const void * vx, const float * y, float * dst, const int ncols, const int nrows, const int nrows_y, cudaStream_t stream) {
4477+
44474478
GGML_ASSERT(ncols % QK_K == 0);
44484479
const int ny = 2 / K_QUANTS_PER_ITERATION;
44494480
const int block_num_y = (nrows + ny - 1) / ny;
44504481
const dim3 block_nums(1, block_num_y, 1);
44514482
const dim3 block_dims(32, ny, 1);
44524483
dequantize_mul_mat_vec_q4_k<<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols, nrows);
4484+
(void) nrows_y;
44534485
}
44544486

4455-
static void dequantize_mul_mat_vec_q5_K_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
4487+
static void dequantize_mul_mat_vec_q5_K_cuda(
4488+
const void * vx, const float * y, float * dst, const int ncols, const int nrows, const int nrows_y, cudaStream_t stream) {
4489+
44564490
GGML_ASSERT(ncols % QK_K == 0);
44574491
const dim3 block_dims(32, 1, 1);
44584492
dequantize_mul_mat_vec_q5_k<<<nrows, block_dims, 0, stream>>>(vx, y, dst, ncols);
4493+
(void) nrows_y;
44594494
}
44604495

4461-
static void dequantize_mul_mat_vec_q6_K_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
4496+
static void dequantize_mul_mat_vec_q6_K_cuda(
4497+
const void * vx, const float * y, float * dst, const int ncols, const int nrows, const int nrows_y, cudaStream_t stream) {
4498+
44624499
GGML_ASSERT(ncols % QK_K == 0);
44634500
const int ny = 2 / K_QUANTS_PER_ITERATION;
44644501
const int block_num_y = (nrows + ny - 1) / ny;
44654502
const dim3 block_nums(1, block_num_y, 1);
44664503
const dim3 block_dims(32, ny, 1);
44674504
dequantize_mul_mat_vec_q6_k<<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols, nrows);
4505+
(void) nrows_y;
44684506
}
44694507

44704508
static void mul_mat_vec_q4_0_q8_1_cuda(
@@ -4592,13 +4630,15 @@ static void convert_fp16_to_fp32_cuda(const void * vx, float * y, const int k, c
45924630
dequantize_block<1, 1, convert_f16><<<num_blocks, CUDA_DEQUANTIZE_BLOCK_SIZE, 0, stream>>>(vx, y, k);
45934631
}
45944632

4595-
static void convert_mul_mat_vec_f16_cuda(const void * vx, const dfloat * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
4633+
static void convert_mul_mat_vec_f16_cuda(
4634+
const void * vx, const dfloat * y, float * dst, const int ncols, const int nrows_x, const int nrows_y, cudaStream_t stream) {
4635+
45964636
GGML_ASSERT(ncols % GGML_CUDA_DMMV_X == 0);
4597-
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
4637+
const int block_num_y = (nrows_x + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
45984638
const dim3 block_nums(1, block_num_y, 1);
45994639
const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
4600-
dequantize_mul_mat_vec<1, 1, convert_f16>
4601-
<<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols, nrows);
4640+
dequantize_mul_mat_vec<1, 1, convert_f16, false>
4641+
<<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols, nrows_x, nrows_y);
46024642
}
46034643

46044644
static to_fp32_cuda_t ggml_get_to_fp32_cuda(ggml_type type) {
@@ -5804,37 +5844,37 @@ inline void ggml_cuda_op_mul_mat_vec(
58045844

58055845
switch (src0->type) {
58065846
case GGML_TYPE_Q4_0:
5807-
dequantize_mul_mat_vec_q4_0_cuda(src0_ddq_i, src1_dfloat, dst_ddf_i, ne00, nrows, cudaStream_main);
5847+
dequantize_mul_mat_vec_q4_0_cuda(src0_ddq_i, src1_dfloat, dst_ddf_i, ne00, nrows, ne10, cudaStream_main);
58085848
break;
58095849
case GGML_TYPE_Q4_1:
5810-
dequantize_mul_mat_vec_q4_1_cuda(src0_ddq_i, src1_dfloat, dst_ddf_i, ne00, nrows, cudaStream_main);
5850+
dequantize_mul_mat_vec_q4_1_cuda(src0_ddq_i, src1_dfloat, dst_ddf_i, ne00, nrows, ne10, cudaStream_main);
58115851
break;
58125852
case GGML_TYPE_Q5_0:
5813-
dequantize_mul_mat_vec_q5_0_cuda(src0_ddq_i, src1_dfloat, dst_ddf_i, ne00, nrows, cudaStream_main);
5853+
dequantize_mul_mat_vec_q5_0_cuda(src0_ddq_i, src1_dfloat, dst_ddf_i, ne00, nrows, ne10, cudaStream_main);
58145854
break;
58155855
case GGML_TYPE_Q5_1:
5816-
dequantize_mul_mat_vec_q5_1_cuda(src0_ddq_i, src1_dfloat, dst_ddf_i, ne00, nrows, cudaStream_main);
5856+
dequantize_mul_mat_vec_q5_1_cuda(src0_ddq_i, src1_dfloat, dst_ddf_i, ne00, nrows, ne10, cudaStream_main);
58175857
break;
58185858
case GGML_TYPE_Q8_0:
5819-
dequantize_mul_mat_vec_q8_0_cuda(src0_ddq_i, src1_dfloat, dst_ddf_i, ne00, nrows, cudaStream_main);
5859+
dequantize_mul_mat_vec_q8_0_cuda(src0_ddq_i, src1_dfloat, dst_ddf_i, ne00, nrows, ne10, cudaStream_main);
58205860
break;
58215861
case GGML_TYPE_Q2_K:
5822-
dequantize_mul_mat_vec_q2_K_cuda(src0_ddq_i, src1_ddf_i, dst_ddf_i, ne00, nrows, cudaStream_main);
5862+
dequantize_mul_mat_vec_q2_K_cuda(src0_ddq_i, src1_ddf_i, dst_ddf_i, ne00, nrows, ne10, cudaStream_main);
58235863
break;
58245864
case GGML_TYPE_Q3_K:
5825-
dequantize_mul_mat_vec_q3_K_cuda(src0_ddq_i, src1_ddf_i, dst_ddf_i, ne00, nrows, cudaStream_main);
5865+
dequantize_mul_mat_vec_q3_K_cuda(src0_ddq_i, src1_ddf_i, dst_ddf_i, ne00, nrows, ne10, cudaStream_main);
58265866
break;
58275867
case GGML_TYPE_Q4_K:
5828-
dequantize_mul_mat_vec_q4_K_cuda(src0_ddq_i, src1_ddf_i, dst_ddf_i, ne00, nrows, cudaStream_main);
5868+
dequantize_mul_mat_vec_q4_K_cuda(src0_ddq_i, src1_ddf_i, dst_ddf_i, ne00, nrows, ne10, cudaStream_main);
58295869
break;
58305870
case GGML_TYPE_Q5_K:
5831-
dequantize_mul_mat_vec_q5_K_cuda(src0_ddq_i, src1_ddf_i, dst_ddf_i, ne00, nrows, cudaStream_main);
5871+
dequantize_mul_mat_vec_q5_K_cuda(src0_ddq_i, src1_ddf_i, dst_ddf_i, ne00, nrows, ne10, cudaStream_main);
58325872
break;
58335873
case GGML_TYPE_Q6_K:
5834-
dequantize_mul_mat_vec_q6_K_cuda(src0_ddq_i, src1_ddf_i, dst_ddf_i, ne00, nrows, cudaStream_main);
5874+
dequantize_mul_mat_vec_q6_K_cuda(src0_ddq_i, src1_ddf_i, dst_ddf_i, ne00, nrows, ne10, cudaStream_main);
58355875
break;
58365876
case GGML_TYPE_F16:
5837-
convert_mul_mat_vec_f16_cuda(src0_ddq_i, src1_dfloat, dst_ddf_i, ne00, nrows, cudaStream_main);
5877+
convert_mul_mat_vec_f16_cuda(src0_ddq_i, src1_dfloat, dst_ddf_i, ne00, nrows, ne10, cudaStream_main);
58385878
break;
58395879
default:
58405880
GGML_ASSERT(false);
@@ -6550,8 +6590,7 @@ void ggml_cuda_mul_mat(const ggml_tensor * src0, const ggml_tensor * src1, ggml_
65506590
}
65516591

65526592
// no quantized non-contiguous support for lower CC kernels implemented
6553-
// const bool nc_okay = src0->type == GGML_TYPE_F16 || g_compute_capabilities[g_main_device] >= MIN_CC_DP4A;
6554-
const bool nc_okay = false;
6593+
const bool nc_okay = src0->type == GGML_TYPE_F16 || g_compute_capabilities[g_main_device] >= MIN_CC_DP4A;
65556594

65566595
if (all_on_device && nc_okay && ggml_is_permuted(src0) && ggml_is_permuted(src1) && src1->ne[1] == 1) {
65576596
ggml_cuda_mul_mat_vec_p021(src0, src1, dst);

0 commit comments

Comments
 (0)