From 4fac0d0420440480fdc2501ca80376822b9aa8bd Mon Sep 17 00:00:00 2001 From: Xiongchuan Tan Date: Fri, 21 Mar 2025 21:51:13 +0800 Subject: [PATCH 1/4] ggml : add 128-bit RVV support --- ggml/CMakeLists.txt | 1 + ggml/src/ggml-cpu/CMakeLists.txt | 6 +- ggml/src/ggml-cpu/ggml-cpu-quants.c | 1389 ++++++++++++++++++--------- ggml/src/ggml-impl.h | 29 + 4 files changed, 945 insertions(+), 480 deletions(-) diff --git a/ggml/CMakeLists.txt b/ggml/CMakeLists.txt index 740f9f69cf2ed..433628c4c313a 100644 --- a/ggml/CMakeLists.txt +++ b/ggml/CMakeLists.txt @@ -123,6 +123,7 @@ endif() option(GGML_LASX "ggml: enable lasx" ON) option(GGML_LSX "ggml: enable lsx" ON) option(GGML_RVV "ggml: enable rvv" ON) +option(GGML_RV_ZFH "ggml: enable riscv zfh" OFF) option(GGML_VXE "ggml: enable vxe" ON) option(GGML_CPU_ALL_VARIANTS "ggml: build all variants of the CPU backend (requires GGML_BACKEND_DL)" OFF) diff --git a/ggml/src/ggml-cpu/CMakeLists.txt b/ggml/src/ggml-cpu/CMakeLists.txt index 6aa078a93ea8e..3a21a7fdebd22 100644 --- a/ggml/src/ggml-cpu/CMakeLists.txt +++ b/ggml/src/ggml-cpu/CMakeLists.txt @@ -320,7 +320,11 @@ function(ggml_add_cpu_backend_variant_impl tag_name) elseif (${CMAKE_SYSTEM_PROCESSOR} MATCHES "riscv64") message(STATUS "RISC-V detected") if (GGML_RVV) - list(APPEND ARCH_FLAGS -march=rv64gcv -mabi=lp64d) + if (GGML_RV_ZFH) + list(APPEND ARCH_FLAGS -march=rv64gcv_zfhmin -DGGML_RV_ZFH -mabi=lp64d) + else() + list(APPEND ARCH_FLAGS -march=rv64gcv -mabi=lp64d) + endif() endif() elseif (${CMAKE_SYSTEM_PROCESSOR} MATCHES "s390x") message(STATUS "s390x detected") diff --git a/ggml/src/ggml-cpu/ggml-cpu-quants.c b/ggml/src/ggml-cpu/ggml-cpu-quants.c index 4e0ae057244c9..138b9bc5a7ae3 100644 --- a/ggml/src/ggml-cpu/ggml-cpu-quants.c +++ b/ggml/src/ggml-cpu/ggml-cpu-quants.c @@ -891,15 +891,15 @@ void quantize_row_q8_0(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, i } #elif defined(__riscv_v_intrinsic) - size_t vl = __riscv_vsetvl_e32m4(QK8_0); + size_t vl = QK8_0; for (int i = 0; i < nb; i++) { // load elements - vfloat32m4_t v_x = __riscv_vle32_v_f32m4(x+i*QK8_0, vl); + vfloat32m8_t v_x = __riscv_vle32_v_f32m8(x+i*QK8_0, vl); - vfloat32m4_t vfabs = __riscv_vfabs_v_f32m4(v_x, vl); + vfloat32m8_t vfabs = __riscv_vfabs_v_f32m8(v_x, vl); vfloat32m1_t tmp = __riscv_vfmv_v_f_f32m1(0.0f, vl); - vfloat32m1_t vmax = __riscv_vfredmax_vs_f32m4_f32m1(vfabs, tmp, vl); + vfloat32m1_t vmax = __riscv_vfredmax_vs_f32m8_f32m1(vfabs, tmp, vl); float amax = __riscv_vfmv_f_s_f32m1_f32(vmax); const float d = amax / ((1 << 7) - 1); @@ -907,14 +907,14 @@ void quantize_row_q8_0(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, i y[i].d = GGML_FP32_TO_FP16(d); - vfloat32m4_t x0 = __riscv_vfmul_vf_f32m4(v_x, id, vl); + vfloat32m8_t x0 = __riscv_vfmul_vf_f32m8(v_x, id, vl); // convert to integer - vint16m2_t vi = __riscv_vfncvt_x_f_w_i16m2(x0, vl); - vint8m1_t vs = __riscv_vncvt_x_x_w_i8m1(vi, vl); + vint16m4_t vi = __riscv_vfncvt_x_f_w_i16m4(x0, vl); + vint8m2_t vs = __riscv_vncvt_x_x_w_i8m2(vi, vl); // store result - __riscv_vse8_v_i8m1(y[i].qs , vs, vl); + __riscv_vse8_v_i8m2(y[i].qs , vs, vl); } #elif defined(__POWER9_VECTOR__) @@ -1229,15 +1229,15 @@ void quantize_row_q8_1(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, i } #elif defined(__riscv_v_intrinsic) - size_t vl = __riscv_vsetvl_e32m4(QK8_1); + size_t vl = QK8_1; for (int i = 0; i < nb; i++) { // load elements - vfloat32m4_t v_x = __riscv_vle32_v_f32m4(x+i*QK8_1, vl); + vfloat32m8_t v_x = __riscv_vle32_v_f32m8(x+i*QK8_1, vl); - vfloat32m4_t vfabs = __riscv_vfabs_v_f32m4(v_x, vl); + vfloat32m8_t vfabs = __riscv_vfabs_v_f32m8(v_x, vl); vfloat32m1_t tmp = __riscv_vfmv_v_f_f32m1(0.0, vl); - vfloat32m1_t vmax = __riscv_vfredmax_vs_f32m4_f32m1(vfabs, tmp, vl); + vfloat32m1_t vmax = __riscv_vfredmax_vs_f32m8_f32m1(vfabs, tmp, vl); float amax = __riscv_vfmv_f_s_f32m1_f32(vmax); const float d = amax / ((1 << 7) - 1); @@ -1245,18 +1245,18 @@ void quantize_row_q8_1(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, i y[i].d = GGML_FP32_TO_FP16(d); - vfloat32m4_t x0 = __riscv_vfmul_vf_f32m4(v_x, id, vl); + vfloat32m8_t x0 = __riscv_vfmul_vf_f32m8(v_x, id, vl); // convert to integer - vint16m2_t vi = __riscv_vfncvt_x_f_w_i16m2(x0, vl); - vint8m1_t vs = __riscv_vncvt_x_x_w_i8m1(vi, vl); + vint16m4_t vi = __riscv_vfncvt_x_f_w_i16m4(x0, vl); + vint8m2_t vs = __riscv_vncvt_x_x_w_i8m2(vi, vl); // store result - __riscv_vse8_v_i8m1(y[i].qs , vs, vl); + __riscv_vse8_v_i8m2(y[i].qs , vs, vl); // compute sum for y[i].s vint16m1_t tmp2 = __riscv_vmv_v_x_i16m1(0, vl); - vint16m1_t vwrs = __riscv_vwredsum_vs_i8m1_i16m1(vs, tmp2, vl); + vint16m1_t vwrs = __riscv_vwredsum_vs_i8m2_i16m1(vs, tmp2, vl); // set y[i].s int sum = __riscv_vmv_x_s_i16m1_i16(vwrs); @@ -1822,6 +1822,123 @@ void quantize_row_q8_K(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, in yc[i].d = 1.0f / iscale; } +#elif defined(__riscv_v_intrinsic) + + if (__riscv_vlenb() == 16) { + assert(k % QK_K == 0); + const int64_t nb = k / QK_K; + block_q8_K * restrict yc = y; // Cast to proper type + int frm; + __asm__ __volatile( + "fsrmi %[frm], 0b100" // RMM + : [frm] "=&r" (frm) + ); + int t0, t1, t2, t3, t4; + float ft0, ft1, ft2; + + for (int i = 0; i < nb; i++) { + __asm__ __volatile__( + "mv %[t0], %[x_block]\n\t" + "addi %[t1], %[x_block], 128\n\t" + "addi %[t2], %[x_block], 768\n\t" + "vsetvli zero, %[vl32], e32, m8\n\t" + "vle32.v v16, (%[t0])\n\t" + "vle32.v v24, (%[t1])\n\t" + "vfmax.vv v0, v16, v24\n\t" + "vfmin.vv v8, v16, v24\n\t" + "1:\n\t" + "addi %[t0], %[t0], 256\n\t" + "addi %[t1], %[t1], 256\n\t" + "vle32.v v16, (%[t0])\n\t" // last: 192..224 + "vle32.v v24, (%[t1])\n\t" // last: 224..256 + "vfmax.vv v0, v0, v16\n\t" + "vfmax.vv v0, v0, v24\n\t" + "vfmin.vv v8, v8, v16\n\t" + "vfmin.vv v8, v8, v24\n\t" + "bne %[t0], %[t2], 1b\n\t" + "vfredmax.vs v1, v0, v0\n\t" + "vfredmin.vs v8, v8, v8\n\t" + "vsetivli zero, 1, e32, m1\n\t" + "vfneg.v v9, v8\n\t" + "vmfgt.vv v0, v9, v1\n\t" + "vmerge.vvm v2, v1, v8, v0\n\t" + "fmv.w.x %[ft0], zero\n\t" + "vfmv.f.s %[ft1], v2\n\t" // max + "feq.s %[t2], %[ft0], %[ft1]\n\t" + "bne %[t2], zero, 8f\n\t" + "vmv.v.x v0, zero\n\t" + "li %[t2], -127\n\t" + "fcvt.s.w %[ft0], %[t2]\n\t" + "fdiv.s %[ft2], %[ft0], %[ft1]\n\t" // iscale + "addi %[t0], %[x_block], 768\n\t" + "addi %[t1], %[x_block], 896\n\t" + "fdiv.s %[ft1], %[ft1], %[ft0]\n\t" // d + "addi %[t2], %[qs], 192\n\t" + "addi %[t3], %[qs], 224\n\t" + "addi %[t4], %[bsums], 24\n\t" + "fsw %[ft1], 0(%[d])\n\t" + "vsetvli zero, %[vl32], e32, m8\n\t" + "6:\n\t" + "vfmul.vf v16, v16, %[ft2]\n\t" + "vfmul.vf v24, v24, %[ft2]\n\t" + "vsetvli zero, %[vl32], e16, m4\n\t" + "vfncvt.x.f.w v16, v16\n\t" + "vfncvt.x.f.w v24, v24\n\t" + "vsetvli zero, %[vl32], e8, m2\n\t" + "vnclip.wx v16, v16, zero\n\t" + "vnclip.wx v24, v24, zero\n\t" + "vse8.v v16, (%[t2])\n\t" + "vse8.v v24, (%[t3])\n\t" + "vsetivli zero, 16, e8, m1\n\t" + "vwredsum.vs v4, v16, v0\n\t" + "vwredsum.vs v5, v17, v0\n\t" + "vwredsum.vs v6, v24, v0\n\t" + "vwredsum.vs v7, v25, v0\n\t" + "vsetivli zero, 4, e16, m1\n\t" + "vslideup.vi v4, v5, 1\n\t" + "vslideup.vi v6, v7, 1\n\t" + "vslideup.vi v4, v6, 2\n\t" + "vse16.v v4, (%[t4])\n\t" + "beq %[t2], %[qs], 9f\n\t" + "addi %[t0], %[t0], -256\n\t" + "addi %[t1], %[t1], -256\n\t" + "addi %[t2], %[t2], -64\n\t" + "addi %[t3], %[t3], -64\n\t" + "addi %[t4], %[t4], -8\n\t" + "vsetvli zero, %[vl32], e32, m8\n\t" + "vle32.v v16, (%[t0])\n\t" + "vle32.v v24, (%[t1])\n\t" + "j 6b\n\t" + "8:\n\t" + "addi %[t1], %[qs], 128\n\t" + "sw zero, 0(%[d])\n\t" + "vsetvli zero, %[vl128], e8, m8\n\t" + "vmv.v.x v0, zero\n\t" + "vse8.v v0, (%[qs])\n\t" + "vse8.v v0, (%[t1])\n\t" + "9:" + : [t0] "=&r" (t0), [t1] "=&r" (t1), [t2] "=&r" (t2), [t3] "=&r" (t3) + , [ft0] "=&f" (ft0), [ft1] "=&f" (ft1), [ft2] "=&f" (ft2), [t4] "=&r" (t4) + : [vl32] "r" (32), [vl128] "r" (128) + , [x_block] "r" (x + i * QK_K) + , [d] "r" (&yc[i].d), [qs] "r" (yc[i].qs), [bsums] "r" (yc[i].bsums) + : "memory" + , "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7" + , "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15" + , "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23" + , "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31" + ); + } + + __asm__ __volatile__( + "fsrm %[frm]" + : + : [frm] "r" (frm) + ); + } else { + quantize_row_q8_K_ref(x, y, k); + } + #else quantize_row_q8_K_ref(x, y, k); #endif @@ -2391,33 +2508,31 @@ void ggml_vec_dot_q4_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const voi sumf = hsum_float_4x4(acc_0, acc_1, acc_2, acc_3); #elif defined(__riscv_v_intrinsic) - size_t vl = __riscv_vsetvl_e8m1(qk/2); + size_t vl = qk / 2; for (; ib < nb; ++ib) { // load elements - vuint8mf2_t tx = __riscv_vle8_v_u8mf2(x[ib].qs, vl); + vuint8m1_t tx = __riscv_vle8_v_u8m1(x[ib].qs, vl); - vint8mf2_t y0 = __riscv_vle8_v_i8mf2(y[ib].qs, vl); - vint8mf2_t y1 = __riscv_vle8_v_i8mf2(y[ib].qs+16, vl); + vint8m1_t y0 = __riscv_vle8_v_i8m1(y[ib].qs, vl); + vint8m1_t y1 = __riscv_vle8_v_i8m1(y[ib].qs+16, vl); // mask and store lower part of x, and then upper part - vuint8mf2_t x_a = __riscv_vand_vx_u8mf2(tx, 0x0F, vl); - vuint8mf2_t x_l = __riscv_vsrl_vx_u8mf2(tx, 0x04, vl); + vuint8m1_t x_a = __riscv_vand_vx_u8m1(tx, 0x0F, vl); + vuint8m1_t x_l = __riscv_vsrl_vx_u8m1(tx, 0x04, vl); - vint8mf2_t x_ai = __riscv_vreinterpret_v_u8mf2_i8mf2(x_a); - vint8mf2_t x_li = __riscv_vreinterpret_v_u8mf2_i8mf2(x_l); + vint8m1_t x_ai = __riscv_vreinterpret_v_u8m1_i8m1(x_a); + vint8m1_t x_li = __riscv_vreinterpret_v_u8m1_i8m1(x_l); // subtract offset - vint8mf2_t v0 = __riscv_vsub_vx_i8mf2(x_ai, 8, vl); - vint8mf2_t v1 = __riscv_vsub_vx_i8mf2(x_li, 8, vl); + vint8m1_t v0 = __riscv_vsub_vx_i8m1(x_ai, 8, vl); + vint8m1_t v1 = __riscv_vsub_vx_i8m1(x_li, 8, vl); - vint16m1_t vec_mul1 = __riscv_vwmul_vv_i16m1(v0, y0, vl); - vint16m1_t vec_mul2 = __riscv_vwmul_vv_i16m1(v1, y1, vl); + vint16m2_t vec_mul1 = __riscv_vwmul_vv_i16m2(v0, y0, vl); + vint16m2_t vec_mul2 = __riscv_vwmacc_vv_i16m2(vec_mul1, v1, y1, vl); vint32m1_t vec_zero = __riscv_vmv_v_x_i32m1(0, vl); - - vint32m1_t vs1 = __riscv_vwredsum_vs_i16m1_i32m1(vec_mul1, vec_zero, vl); - vint32m1_t vs2 = __riscv_vwredsum_vs_i16m1_i32m1(vec_mul2, vs1, vl); + vint32m1_t vs2 = __riscv_vwredsum_vs_i16m2_i32m1(vec_mul2, vec_zero, vl); int sumi = __riscv_vmv_x_s_i32m1_i32(vs2); @@ -2783,29 +2898,27 @@ void ggml_vec_dot_q4_1_q8_1(int n, float * GGML_RESTRICT s, size_t bs, const voi sumf = hsum_float_8(acc) + summs; #elif defined(__riscv_v_intrinsic) - size_t vl = __riscv_vsetvl_e8m1(qk/2); + size_t vl = qk / 2; for (; ib < nb; ++ib) { // load elements - vuint8mf2_t tx = __riscv_vle8_v_u8mf2(x[ib].qs, vl); + vuint8m1_t tx = __riscv_vle8_v_u8m1(x[ib].qs, vl); - vint8mf2_t y0 = __riscv_vle8_v_i8mf2(y[ib].qs, vl); - vint8mf2_t y1 = __riscv_vle8_v_i8mf2(y[ib].qs+16, vl); + vint8m1_t y0 = __riscv_vle8_v_i8m1(y[ib].qs, vl); + vint8m1_t y1 = __riscv_vle8_v_i8m1(y[ib].qs+16, vl); // mask and store lower part of x, and then upper part - vuint8mf2_t x_a = __riscv_vand_vx_u8mf2(tx, 0x0F, vl); - vuint8mf2_t x_l = __riscv_vsrl_vx_u8mf2(tx, 0x04, vl); + vuint8m1_t x_a = __riscv_vand_vx_u8m1(tx, 0x0F, vl); + vuint8m1_t x_l = __riscv_vsrl_vx_u8m1(tx, 0x04, vl); - vint8mf2_t v0 = __riscv_vreinterpret_v_u8mf2_i8mf2(x_a); - vint8mf2_t v1 = __riscv_vreinterpret_v_u8mf2_i8mf2(x_l); + vint8m1_t v0 = __riscv_vreinterpret_v_u8m1_i8m1(x_a); + vint8m1_t v1 = __riscv_vreinterpret_v_u8m1_i8m1(x_l); - vint16m1_t vec_mul1 = __riscv_vwmul_vv_i16m1(v0, y0, vl); - vint16m1_t vec_mul2 = __riscv_vwmul_vv_i16m1(v1, y1, vl); + vint16m2_t vec_mul1 = __riscv_vwmul_vv_i16m2(v0, y0, vl); + vint16m2_t vec_mul2 = __riscv_vwmacc_vv_i16m2(vec_mul1, v1, y1, vl); vint32m1_t vec_zero = __riscv_vmv_v_x_i32m1(0, vl); - - vint32m1_t vs1 = __riscv_vwredsum_vs_i16m1_i32m1(vec_mul1, vec_zero, vl); - vint32m1_t vs2 = __riscv_vwredsum_vs_i16m1_i32m1(vec_mul2, vs1, vl); + vint32m1_t vs2 = __riscv_vwredsum_vs_i16m2_i32m1(vec_mul2, vec_zero, vl); int sumi = __riscv_vmv_x_s_i32m1_i32(vs2); @@ -3132,65 +3245,33 @@ void ggml_vec_dot_q5_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const voi sumf = hsum_float_8(acc); #elif defined(__riscv_v_intrinsic) - uint32_t qh; - - size_t vl = __riscv_vsetvl_e8m1(qk/2); - - // These temporary registers are for masking and shift operations - vuint32m2_t vt_1 = __riscv_vid_v_u32m2(vl); - vuint32m2_t vt_2 = __riscv_vsll_vv_u32m2(__riscv_vmv_v_x_u32m2(1, vl), vt_1, vl); - - vuint32m2_t vt_3 = __riscv_vsll_vx_u32m2(vt_2, 16, vl); - vuint32m2_t vt_4 = __riscv_vadd_vx_u32m2(vt_1, 12, vl); + size_t vl; + size_t vlenb = __riscv_vlenb(); for (; ib < nb; ++ib) { - memcpy(&qh, x[ib].qh, sizeof(uint32_t)); - - // ((qh & (1u << (j + 0 ))) >> (j + 0 )) << 4; - vuint32m2_t xha_0 = __riscv_vand_vx_u32m2(vt_2, qh, vl); - vuint32m2_t xhr_0 = __riscv_vsrl_vv_u32m2(xha_0, vt_1, vl); - vuint32m2_t xhl_0 = __riscv_vsll_vx_u32m2(xhr_0, 4, vl); - - // ((qh & (1u << (j + 16))) >> (j + 12)); - vuint32m2_t xha_1 = __riscv_vand_vx_u32m2(vt_3, qh, vl); - vuint32m2_t xhl_1 = __riscv_vsrl_vv_u32m2(xha_1, vt_4, vl); - - // narrowing - vuint16m1_t xhc_0 = __riscv_vncvt_x_x_w_u16m1(xhl_0, vl); - vuint8mf2_t xh_0 = __riscv_vncvt_x_x_w_u8mf2(xhc_0, vl); - - vuint16m1_t xhc_1 = __riscv_vncvt_x_x_w_u16m1(xhl_1, vl); - vuint8mf2_t xh_1 = __riscv_vncvt_x_x_w_u8mf2(xhc_1, vl); - - // load - vuint8mf2_t tx = __riscv_vle8_v_u8mf2(x[ib].qs, vl); - - vint8mf2_t y0 = __riscv_vle8_v_i8mf2(y[ib].qs, vl); - vint8mf2_t y1 = __riscv_vle8_v_i8mf2(y[ib].qs+16, vl); - - vuint8mf2_t x_at = __riscv_vand_vx_u8mf2(tx, 0x0F, vl); - vuint8mf2_t x_lt = __riscv_vsrl_vx_u8mf2(tx, 0x04, vl); - - vuint8mf2_t x_a = __riscv_vor_vv_u8mf2(x_at, xh_0, vl); - vuint8mf2_t x_l = __riscv_vor_vv_u8mf2(x_lt, xh_1, vl); - - vint8mf2_t x_ai = __riscv_vreinterpret_v_u8mf2_i8mf2(x_a); - vint8mf2_t x_li = __riscv_vreinterpret_v_u8mf2_i8mf2(x_l); - - vint8mf2_t v0 = __riscv_vsub_vx_i8mf2(x_ai, 16, vl); - vint8mf2_t v1 = __riscv_vsub_vx_i8mf2(x_li, 16, vl); - - vint16m1_t vec_mul1 = __riscv_vwmul_vv_i16m1(v0, y0, vl); - vint16m1_t vec_mul2 = __riscv_vwmul_vv_i16m1(v1, y1, vl); - - vint32m1_t vec_zero = __riscv_vmv_v_x_i32m1(0, vl); - - vint32m1_t vs1 = __riscv_vwredsum_vs_i16m1_i32m1(vec_mul1, vec_zero, vl); - vint32m1_t vs2 = __riscv_vwredsum_vs_i16m1_i32m1(vec_mul2, vs1, vl); - - int sumi = __riscv_vmv_x_s_i32m1_i32(vs2); - - sumf += (GGML_FP16_TO_FP32(x[ib].d)*GGML_FP16_TO_FP32(y[ib].d)) * sumi; + vl = qk / 2; + vuint8m1_t v0 = __riscv_vle8_v_u8m1(x[ib].qs, vl); + vint8m1_t v0l = __riscv_vreinterpret_v_u8m1_i8m1(__riscv_vand_vx_u8m1(v0, 0x0F, vl)); + vint8m1_t v0h = __riscv_vreinterpret_v_u8m1_i8m1(__riscv_vsrl_vx_u8m1(v0, 4, vl)); + vint8m2_t v0c; + if (vlenb == 16) { + v0c = __riscv_vcreate_v_i8m1_i8m2(v0l, v0h); + } else { + v0l = __riscv_vslideup_vx_i8m1(v0l, v0h, 16, 32); + v0c = __riscv_vlmul_ext_v_i8m1_i8m2(v0l); + } + + vl = qk; + vbool4_t qh = __riscv_vlm_v_b4(x[ib].qh, vl); + qh = __riscv_vmnand_mm_b4(qh, qh, vl); + vint8m2_t v0f = __riscv_vsub_vx_i8m2_mu(qh, v0c, v0c, 0x10, vl); + vint8m2_t v1 = __riscv_vle8_v_i8m2(y[ib].qs, vl); + vint16m4_t mul = __riscv_vwmul_vv_i16m4(v0f, v1, vl); + vint32m1_t zero = __riscv_vmv_v_x_i32m1(0, vl); + vint32m1_t sum = __riscv_vwredsum_vs_i16m4_i32m1(mul, zero, vl); + int32_t sumi = __riscv_vmv_x_s_i32m1_i32(sum); + + sumf += (GGML_FP16_TO_FP32(x[ib].d) * GGML_FP16_TO_FP32(y[ib].d)) * sumi; } #elif defined(__POWER9_VECTOR__) @@ -3503,60 +3584,30 @@ void ggml_vec_dot_q5_1_q8_1(int n, float * GGML_RESTRICT s, size_t bs, const voi sumf = hsum_float_8(acc) + summs; #elif defined(__riscv_v_intrinsic) - uint32_t qh; - - size_t vl = __riscv_vsetvl_e8m1(qk/2); - - // temporary registers for shift operations - vuint32m2_t vt_1 = __riscv_vid_v_u32m2(vl); - vuint32m2_t vt_2 = __riscv_vadd_vx_u32m2(vt_1, 12, vl); + size_t vl; + size_t vlenb = __riscv_vlenb(); for (; ib < nb; ++ib) { - memcpy(&qh, x[ib].qh, sizeof(uint32_t)); - - // load qh - vuint32m2_t vqh = __riscv_vmv_v_x_u32m2(qh, vl); - - // ((qh >> (j + 0)) << 4) & 0x10; - vuint32m2_t xhr_0 = __riscv_vsrl_vv_u32m2(vqh, vt_1, vl); - vuint32m2_t xhl_0 = __riscv_vsll_vx_u32m2(xhr_0, 4, vl); - vuint32m2_t xha_0 = __riscv_vand_vx_u32m2(xhl_0, 0x10, vl); - - // ((qh >> (j + 12)) ) & 0x10; - vuint32m2_t xhr_1 = __riscv_vsrl_vv_u32m2(vqh, vt_2, vl); - vuint32m2_t xha_1 = __riscv_vand_vx_u32m2(xhr_1, 0x10, vl); - - // narrowing - vuint16m1_t xhc_0 = __riscv_vncvt_x_x_w_u16m1(xha_0, vl); - vuint8mf2_t xh_0 = __riscv_vncvt_x_x_w_u8mf2(xhc_0, vl); - - vuint16m1_t xhc_1 = __riscv_vncvt_x_x_w_u16m1(xha_1, vl); - vuint8mf2_t xh_1 = __riscv_vncvt_x_x_w_u8mf2(xhc_1, vl); - - // load - vuint8mf2_t tx = __riscv_vle8_v_u8mf2(x[ib].qs, vl); - - vint8mf2_t y0 = __riscv_vle8_v_i8mf2(y[ib].qs, vl); - vint8mf2_t y1 = __riscv_vle8_v_i8mf2(y[ib].qs+16, vl); - - vuint8mf2_t x_at = __riscv_vand_vx_u8mf2(tx, 0x0F, vl); - vuint8mf2_t x_lt = __riscv_vsrl_vx_u8mf2(tx, 0x04, vl); - - vuint8mf2_t x_a = __riscv_vor_vv_u8mf2(x_at, xh_0, vl); - vuint8mf2_t x_l = __riscv_vor_vv_u8mf2(x_lt, xh_1, vl); - - vint8mf2_t v0 = __riscv_vreinterpret_v_u8mf2_i8mf2(x_a); - vint8mf2_t v1 = __riscv_vreinterpret_v_u8mf2_i8mf2(x_l); - - vint16m1_t vec_mul1 = __riscv_vwmul_vv_i16m1(v0, y0, vl); - vint16m1_t vec_mul2 = __riscv_vwmul_vv_i16m1(v1, y1, vl); - - vint32m1_t vec_zero = __riscv_vmv_v_x_i32m1(0, vl); - - vint32m1_t vs1 = __riscv_vwredsum_vs_i16m1_i32m1(vec_mul1, vec_zero, vl); - vint32m1_t vs2 = __riscv_vwredsum_vs_i16m1_i32m1(vec_mul2, vs1, vl); - - int sumi = __riscv_vmv_x_s_i32m1_i32(vs2); + vl = qk / 2; + vuint8m1_t v0 = __riscv_vle8_v_u8m1(x[ib].qs, vl); + vint8m1_t v0l = __riscv_vreinterpret_v_u8m1_i8m1(__riscv_vand_vx_u8m1(v0, 0x0F, vl)); + vint8m1_t v0h = __riscv_vreinterpret_v_u8m1_i8m1(__riscv_vsrl_vx_u8m1(v0, 4, vl)); + vint8m2_t v0c; + if (vlenb == 16) { + v0c = __riscv_vcreate_v_i8m1_i8m2(v0l, v0h); + } else { + v0l = __riscv_vslideup_vx_i8m1(v0l, v0h, 16, 32); + v0c = __riscv_vlmul_ext_v_i8m1_i8m2(v0l); + } + + vl = qk; + vbool4_t qh = __riscv_vlm_v_b4(x[ib].qh, vl); + vint8m2_t v0f = __riscv_vor_vx_i8m2_mu(qh, v0c, v0c, 0x10, vl); + vint8m2_t v1 = __riscv_vle8_v_i8m2(y[ib].qs, vl); + vint16m4_t mul = __riscv_vwmul_vv_i16m4(v0f, v1, vl); + vint32m1_t zero = __riscv_vmv_v_x_i32m1(0, vl); + vint32m1_t sum = __riscv_vwredsum_vs_i16m4_i32m1(mul, zero, vl); + int32_t sumi = __riscv_vmv_x_s_i32m1_i32(sum); sumf += (GGML_FP16_TO_FP32(x[ib].d)*GGML_FP16_TO_FP32(y[ib].d))*sumi + GGML_FP16_TO_FP32(x[ib].m)*GGML_FP16_TO_FP32(y[ib].s); } @@ -3970,17 +4021,17 @@ void ggml_vec_dot_q8_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const voi sumf = hsum_float_8(accum); #elif defined(__riscv_v_intrinsic) - size_t vl = __riscv_vsetvl_e8m1(qk); + size_t vl = qk; for (; ib < nb; ++ib) { // load elements - vint8m1_t bx_0 = __riscv_vle8_v_i8m1(x[ib].qs, vl); - vint8m1_t by_0 = __riscv_vle8_v_i8m1(y[ib].qs, vl); + vint8m2_t bx_0 = __riscv_vle8_v_i8m2(x[ib].qs, vl); + vint8m2_t by_0 = __riscv_vle8_v_i8m2(y[ib].qs, vl); - vint16m2_t vw_mul = __riscv_vwmul_vv_i16m2(bx_0, by_0, vl); + vint16m4_t vw_mul = __riscv_vwmul_vv_i16m4(bx_0, by_0, vl); vint32m1_t v_zero = __riscv_vmv_v_x_i32m1(0, vl); - vint32m1_t v_sum = __riscv_vwredsum_vs_i16m2_i32m1(vw_mul, v_zero, vl); + vint32m1_t v_sum = __riscv_vwredsum_vs_i16m4_i32m1(vw_mul, v_zero, vl); int sumi = __riscv_vmv_x_s_i32m1_i32(v_sum); @@ -5175,83 +5226,173 @@ void ggml_vec_dot_q2_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const voi #elif defined __riscv_v_intrinsic float sumf = 0; - uint8_t temp_01[32] = {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1}; - for (int i = 0; i < nb; ++i) { - - const uint8_t * q2 = x[i].qs; - const int8_t * q8 = y[i].qs; - const uint8_t * sc = x[i].scales; - - const float dall = y[i].d * GGML_FP16_TO_FP32(x[i].d); - const float dmin = -y[i].d * GGML_FP16_TO_FP32(x[i].dmin); - - size_t vl = 16; - - vuint8m1_t scales = __riscv_vle8_v_u8m1(sc, vl); - vuint8m1_t aux = __riscv_vand_vx_u8m1(scales, 0x0F, vl); - - vint16m1_t q8sums = __riscv_vle16_v_i16m1(y[i].bsums, vl); - - vuint8mf2_t scales_2 = __riscv_vle8_v_u8mf2(sc, vl); - vuint8mf2_t mins8 = __riscv_vsrl_vx_u8mf2(scales_2, 0x4, vl); - vint16m1_t mins = __riscv_vreinterpret_v_u16m1_i16m1(__riscv_vzext_vf2_u16m1(mins8, vl)); - vint32m2_t prod = __riscv_vwmul_vv_i32m2(q8sums, mins, vl); - vint32m1_t vsums = __riscv_vredsum_vs_i32m2_i32m1(prod, __riscv_vmv_v_x_i32m1(0, 1), vl); - - sumf += dmin * __riscv_vmv_x_s_i32m1_i32(vsums); - - vl = 32; - - vint32m1_t vzero = __riscv_vmv_v_x_i32m1(0, 1); - vuint8m1_t v_b = __riscv_vle8_v_u8m1(temp_01, vl); - - uint8_t is=0; - int isum=0; - - for (int j = 0; j < QK_K/128; ++j) { - // load Q2 - vuint8m1_t q2_x = __riscv_vle8_v_u8m1(q2, vl); - - vuint8m1_t q2_0 = __riscv_vand_vx_u8m1(q2_x, 0x03, vl); - vuint8m1_t q2_1 = __riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(q2_x, 0x2, vl), 0x03 , vl); - vuint8m1_t q2_2 = __riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(q2_x, 0x4, vl), 0x03 , vl); - vuint8m1_t q2_3 = __riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(q2_x, 0x6, vl), 0x03 , vl); - - // duplicate scale elements for product - vuint8m1_t sc0 = __riscv_vrgather_vv_u8m1(aux, __riscv_vadd_vx_u8m1(v_b, 0+is, vl), vl); - vuint8m1_t sc1 = __riscv_vrgather_vv_u8m1(aux, __riscv_vadd_vx_u8m1(v_b, 2+is, vl), vl); - vuint8m1_t sc2 = __riscv_vrgather_vv_u8m1(aux, __riscv_vadd_vx_u8m1(v_b, 4+is, vl), vl); - vuint8m1_t sc3 = __riscv_vrgather_vv_u8m1(aux, __riscv_vadd_vx_u8m1(v_b, 6+is, vl), vl); - - vint16m2_t p0 = __riscv_vreinterpret_v_u16m2_i16m2(__riscv_vwmulu_vv_u16m2(q2_0, sc0, vl)); - vint16m2_t p1 = __riscv_vreinterpret_v_u16m2_i16m2(__riscv_vwmulu_vv_u16m2(q2_1, sc1, vl)); - vint16m2_t p2 = __riscv_vreinterpret_v_u16m2_i16m2(__riscv_vwmulu_vv_u16m2(q2_2, sc2, vl)); - vint16m2_t p3 = __riscv_vreinterpret_v_u16m2_i16m2(__riscv_vwmulu_vv_u16m2(q2_3, sc3, vl)); - - // load Q8 - vint8m1_t q8_0 = __riscv_vle8_v_i8m1(q8, vl); - vint8m1_t q8_1 = __riscv_vle8_v_i8m1(q8+32, vl); - vint8m1_t q8_2 = __riscv_vle8_v_i8m1(q8+64, vl); - vint8m1_t q8_3 = __riscv_vle8_v_i8m1(q8+96, vl); - - vint32m4_t s0 = __riscv_vwmul_vv_i32m4(p0, __riscv_vwcvt_x_x_v_i16m2(q8_0, vl), vl); - vint32m4_t s1 = __riscv_vwmul_vv_i32m4(p1, __riscv_vwcvt_x_x_v_i16m2(q8_1, vl), vl); - vint32m4_t s2 = __riscv_vwmul_vv_i32m4(p2, __riscv_vwcvt_x_x_v_i16m2(q8_2, vl), vl); - vint32m4_t s3 = __riscv_vwmul_vv_i32m4(p3, __riscv_vwcvt_x_x_v_i16m2(q8_3, vl), vl); - - vint32m1_t isum0 = __riscv_vredsum_vs_i32m4_i32m1(__riscv_vadd_vv_i32m4(s0, s1, vl), vzero, vl); - vint32m1_t isum1 = __riscv_vredsum_vs_i32m4_i32m1(__riscv_vadd_vv_i32m4(s2, s3, vl), isum0, vl); - - isum += __riscv_vmv_x_s_i32m1_i32(isum1); + if (__riscv_vlenb() >= 32) { + uint8_t temp_01[32] = {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1}; + + for (int i = 0; i < nb; ++i) { + + const uint8_t * q2 = x[i].qs; + const int8_t * q8 = y[i].qs; + const uint8_t * sc = x[i].scales; + + const float dall = y[i].d * GGML_FP16_TO_FP32(x[i].d); + const float dmin = -y[i].d * GGML_FP16_TO_FP32(x[i].dmin); + + size_t vl = 16; + + vuint8m2_t scales = __riscv_vle8_v_u8m2(sc, vl); + vuint8m2_t aux = __riscv_vand_vx_u8m2(scales, 0x0F, vl); + + vint16m2_t q8sums = __riscv_vle16_v_i16m2(y[i].bsums, vl); + + vuint8m1_t scales_2 = __riscv_vle8_v_u8m1(sc, vl); + vuint8m1_t mins8 = __riscv_vsrl_vx_u8m1(scales_2, 0x4, vl); + vint16m2_t mins = __riscv_vreinterpret_v_u16m2_i16m2(__riscv_vzext_vf2_u16m2(mins8, vl)); + vint32m4_t prod = __riscv_vwmul_vv_i32m4(q8sums, mins, vl); + vint32m1_t vsums = __riscv_vredsum_vs_i32m4_i32m1(prod, __riscv_vmv_v_x_i32m1(0, 1), vl); + + sumf += dmin * __riscv_vmv_x_s_i32m1_i32(vsums); + + vl = 32; + + vint32m1_t vzero = __riscv_vmv_v_x_i32m1(0, 1); + vuint8m2_t v_b = __riscv_vle8_v_u8m2(temp_01, vl); + + uint8_t is=0; + int isum=0; + + for (int j = 0; j < QK_K/128; ++j) { + // load Q2 + vuint8m2_t q2_x = __riscv_vle8_v_u8m2(q2, vl); + + vuint8m2_t q2_0 = __riscv_vand_vx_u8m2(q2_x, 0x03, vl); + vuint8m2_t q2_1 = __riscv_vand_vx_u8m2(__riscv_vsrl_vx_u8m2(q2_x, 0x2, vl), 0x03 , vl); + vuint8m2_t q2_2 = __riscv_vand_vx_u8m2(__riscv_vsrl_vx_u8m2(q2_x, 0x4, vl), 0x03 , vl); + vuint8m2_t q2_3 = __riscv_vand_vx_u8m2(__riscv_vsrl_vx_u8m2(q2_x, 0x6, vl), 0x03 , vl); + + // duplicate scale elements for product + vuint8m2_t sc0 = __riscv_vrgather_vv_u8m2(aux, __riscv_vadd_vx_u8m2(v_b, 0+is, vl), vl); + vuint8m2_t sc1 = __riscv_vrgather_vv_u8m2(aux, __riscv_vadd_vx_u8m2(v_b, 2+is, vl), vl); + vuint8m2_t sc2 = __riscv_vrgather_vv_u8m2(aux, __riscv_vadd_vx_u8m2(v_b, 4+is, vl), vl); + vuint8m2_t sc3 = __riscv_vrgather_vv_u8m2(aux, __riscv_vadd_vx_u8m2(v_b, 6+is, vl), vl); + + vint16m4_t p0 = __riscv_vreinterpret_v_u16m4_i16m4(__riscv_vwmulu_vv_u16m4(q2_0, sc0, vl)); + vint16m4_t p1 = __riscv_vreinterpret_v_u16m4_i16m4(__riscv_vwmulu_vv_u16m4(q2_1, sc1, vl)); + vint16m4_t p2 = __riscv_vreinterpret_v_u16m4_i16m4(__riscv_vwmulu_vv_u16m4(q2_2, sc2, vl)); + vint16m4_t p3 = __riscv_vreinterpret_v_u16m4_i16m4(__riscv_vwmulu_vv_u16m4(q2_3, sc3, vl)); + + // load Q8 + vint8m2_t q8_0 = __riscv_vle8_v_i8m2(q8, vl); + vint8m2_t q8_1 = __riscv_vle8_v_i8m2(q8+32, vl); + vint8m2_t q8_2 = __riscv_vle8_v_i8m2(q8+64, vl); + vint8m2_t q8_3 = __riscv_vle8_v_i8m2(q8+96, vl); + + vint32m8_t s0 = __riscv_vwmul_vv_i32m8(p0, __riscv_vwcvt_x_x_v_i16m4(q8_0, vl), vl); + vint32m8_t s1 = __riscv_vwmacc_vv_i32m8(s0, p1, __riscv_vwcvt_x_x_v_i16m4(q8_1, vl), vl); + vint32m8_t s2 = __riscv_vwmacc_vv_i32m8(s1, p2, __riscv_vwcvt_x_x_v_i16m4(q8_2, vl), vl); + vint32m8_t s3 = __riscv_vwmacc_vv_i32m8(s2, p3, __riscv_vwcvt_x_x_v_i16m4(q8_3, vl), vl); + + vint32m1_t isum0 = __riscv_vredsum_vs_i32m8_i32m1(s3, vzero, vl); + + isum += __riscv_vmv_x_s_i32m1_i32(isum0); + + q2+=32; q8+=128; is=8; + + } - q2+=32; q8+=128; is=8; + sumf += dall * isum; + } + } else if (__riscv_vlenb() == 16) { + uint8_t atmp[16]; + for (int i = 0; i < nb; ++i) { + const uint8_t * q2 = x[i].qs; + const int8_t * q8 = y[i].qs; + const uint8_t * sc = x[i].scales; + const float dall = y[i].d * GGML_FP16_TO_FP32(x[i].d); + const float dmin = -y[i].d * GGML_FP16_TO_FP32(x[i].dmin); + uint8_t *patmp = atmp; + int vsums; + int tmp; + __asm__ __volatile__( + "vsetivli zero, 16, e8, m1\n\t" + "vmv.v.x v8, zero\n\t" + "vle8.v v1, (%[sc])\n\t" + "vand.vi v0, v1, 0xF\n\t" + "vsrl.vi v1, v1, 4\n\t" + "vse8.v v0, (%[scale])\n\t" + "vsetivli zero, 16, e16, m2\n\t" + "vle16.v v2, (%[bsums])\n\t" + "vzext.vf2 v0, v1\n\t" + "vwmul.vv v4, v0, v2\n\t" + "vsetivli zero, 16, e32, m4\n\t" + "vredsum.vs v8, v4, v8\n\t" + "vmv.x.s %[vsums], v8" + : [tmp] "=&r" (tmp), [vsums] "=&r" (vsums) + : [sc] "r" (sc), [scale] "r" (atmp), [bsums] "r" (y[i].bsums) + : "memory" + , "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7" + , "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15" + , "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23" + , "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31" + ); + sumf += dmin * vsums; + int isum = 0; + + for (int j = 0; j < QK_K/128; ++j) { + __asm__ __volatile__( + "vsetvli zero, %[vl32], e8, m2\n\t" + "vle8.v v0, (%[q2])\n\t" + "vsrl.vi v2, v0, 2\n\t" + "vsrl.vi v4, v0, 4\n\t" + "vsrl.vi v6, v0, 6\n\t" + "vand.vi v0, v0, 0x3\n\t" + "vand.vi v2, v2, 0x3\n\t" + "vand.vi v4, v4, 0x3\n\t" + "vsetvli zero, %[vl128], e8, m8\n\t" + "vle8.v v8, (%[q8])\n\t" + "vsetvli zero, %[vl64], e8, m4\n\t" + "vwmul.vv v16, v0, v8\n\t" + "vwmul.vv v24, v4, v12\n\t" + "vsetivli zero, 16, e16, m2\n\t" + "vmv.v.x v0, zero\n\t" + "vwredsum.vs v10, v16, v0\n\t" + "vwredsum.vs v9, v18, v0\n\t" + "vwredsum.vs v8, v20, v0\n\t" + "vwredsum.vs v7, v22, v0\n\t" + "vwredsum.vs v11, v24, v0\n\t" + "vwredsum.vs v12, v26, v0\n\t" + "vwredsum.vs v13, v28, v0\n\t" + "vwredsum.vs v14, v30, v0\n\t" + "vsetivli zero, 4, e32, m1\n\t" + "vslideup.vi v10, v9, 1\n\t" + "vslideup.vi v8, v7, 1\n\t" + "vslideup.vi v11, v12, 1\n\t" + "vslideup.vi v13, v14, 1\n\t" + "vslideup.vi v10, v8, 2\n\t" + "vslideup.vi v11, v13, 2\n\t" + "vsetivli zero, 8, e32, m2\n\t" + "vle8.v v15, (%[scale])\n\t" + "vzext.vf4 v12, v15\n\t" + "vmul.vv v10, v10, v12\n\t" + "vredsum.vs v0, v10, v0\n\t" + "vmv.x.s %[tmp], v0\n\t" + "add %[isum], %[isum], %[tmp]" + : [tmp] "=&r" (tmp), [isum] "+&r" (isum) + : [q2] "r" (q2), [scale] "r" (patmp), [q8] "r" (q8) + , [vl32] "r" (32), [vl64] "r" (64), [vl128] "r" (128) + : "memory" + , "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7" + , "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15" + , "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23" + , "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31" + ); + q2 += 32; q8 += 128; patmp += 8; + } + sumf += dall * isum; } - - sumf += dall * isum; - } *s = sumf; @@ -6117,96 +6258,209 @@ void ggml_vec_dot_q3_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const voi uint32_t utmp[4]; float sumf = 0; - for (int i = 0; i < nb; ++i) { - - const uint8_t * GGML_RESTRICT q3 = x[i].qs; - const uint8_t * GGML_RESTRICT qh = x[i].hmask; - const int8_t * GGML_RESTRICT q8 = y[i].qs; - - memcpy(aux, x[i].scales, 12); - utmp[3] = ((aux[1] >> 4) & kmask2) | (((aux[2] >> 6) & kmask1) << 4); - utmp[2] = ((aux[0] >> 4) & kmask2) | (((aux[2] >> 4) & kmask1) << 4); - utmp[1] = (aux[1] & kmask2) | (((aux[2] >> 2) & kmask1) << 4); - utmp[0] = (aux[0] & kmask2) | (((aux[2] >> 0) & kmask1) << 4); - - int8_t * scale = (int8_t *)utmp; - for (int j = 0; j < 16; ++j) scale[j] -= 32; - - - size_t vl = 32; - uint8_t m = 1; - - vint32m1_t vzero = __riscv_vmv_v_x_i32m1(0, 1); - vuint8m1_t vqh = __riscv_vle8_v_u8m1(qh, vl); - - int sum_t = 0; - - for (int j = 0; j < QK_K; j += 128) { - - vl = 32; - - // load Q3 - vuint8m1_t q3_x = __riscv_vle8_v_u8m1(q3, vl); - - vint8m1_t q3_0 = __riscv_vreinterpret_v_u8m1_i8m1(__riscv_vand_vx_u8m1(q3_x, 0x03, vl)); - vint8m1_t q3_1 = __riscv_vreinterpret_v_u8m1_i8m1(__riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(q3_x, 0x2, vl), 0x03 , vl)); - vint8m1_t q3_2 = __riscv_vreinterpret_v_u8m1_i8m1(__riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(q3_x, 0x4, vl), 0x03 , vl)); - vint8m1_t q3_3 = __riscv_vreinterpret_v_u8m1_i8m1(__riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(q3_x, 0x6, vl), 0x03 , vl)); - - // compute mask for subtraction - vuint8m1_t qh_m0 = __riscv_vand_vx_u8m1(vqh, m, vl); - vbool8_t vmask_0 = __riscv_vmseq_vx_u8m1_b8(qh_m0, 0, vl); - vint8m1_t q3_m0 = __riscv_vsub_vx_i8m1_mu(vmask_0, q3_0, q3_0, 0x4, vl); - m <<= 1; - - vuint8m1_t qh_m1 = __riscv_vand_vx_u8m1(vqh, m, vl); - vbool8_t vmask_1 = __riscv_vmseq_vx_u8m1_b8(qh_m1, 0, vl); - vint8m1_t q3_m1 = __riscv_vsub_vx_i8m1_mu(vmask_1, q3_1, q3_1, 0x4, vl); - m <<= 1; - - vuint8m1_t qh_m2 = __riscv_vand_vx_u8m1(vqh, m, vl); - vbool8_t vmask_2 = __riscv_vmseq_vx_u8m1_b8(qh_m2, 0, vl); - vint8m1_t q3_m2 = __riscv_vsub_vx_i8m1_mu(vmask_2, q3_2, q3_2, 0x4, vl); - m <<= 1; - - vuint8m1_t qh_m3 = __riscv_vand_vx_u8m1(vqh, m, vl); - vbool8_t vmask_3 = __riscv_vmseq_vx_u8m1_b8(qh_m3, 0, vl); - vint8m1_t q3_m3 = __riscv_vsub_vx_i8m1_mu(vmask_3, q3_3, q3_3, 0x4, vl); - m <<= 1; - - // load Q8 and take product with Q3 - vint16m2_t a0 = __riscv_vwmul_vv_i16m2(q3_m0, __riscv_vle8_v_i8m1(q8, vl), vl); - vint16m2_t a1 = __riscv_vwmul_vv_i16m2(q3_m1, __riscv_vle8_v_i8m1(q8+32, vl), vl); - vint16m2_t a2 = __riscv_vwmul_vv_i16m2(q3_m2, __riscv_vle8_v_i8m1(q8+64, vl), vl); - vint16m2_t a3 = __riscv_vwmul_vv_i16m2(q3_m3, __riscv_vle8_v_i8m1(q8+96, vl), vl); - - vl = 16; - - // retrieve lane to multiply with scale - vint32m2_t aux0_0 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(a0, 0), (scale[0]), vl); - vint32m2_t aux0_1 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(a0, 1), (scale[1]), vl); - vint32m2_t aux1_0 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(a1, 0), (scale[2]), vl); - vint32m2_t aux1_1 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(a1, 1), (scale[3]), vl); - vint32m2_t aux2_0 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(a2, 0), (scale[4]), vl); - vint32m2_t aux2_1 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(a2, 1), (scale[5]), vl); - vint32m2_t aux3_0 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(a3, 0), (scale[6]), vl); - vint32m2_t aux3_1 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(a3, 1), (scale[7]), vl); + if (__riscv_vlenb() >= 32) { + for (int i = 0; i < nb; ++i) { + + const uint8_t * GGML_RESTRICT q3 = x[i].qs; + const uint8_t * GGML_RESTRICT qh = x[i].hmask; + const int8_t * GGML_RESTRICT q8 = y[i].qs; + + memcpy(aux, x[i].scales, 12); + utmp[3] = ((aux[1] >> 4) & kmask2) | (((aux[2] >> 6) & kmask1) << 4); + utmp[2] = ((aux[0] >> 4) & kmask2) | (((aux[2] >> 4) & kmask1) << 4); + utmp[1] = (aux[1] & kmask2) | (((aux[2] >> 2) & kmask1) << 4); + utmp[0] = (aux[0] & kmask2) | (((aux[2] >> 0) & kmask1) << 4); + + int8_t * scale = (int8_t *)utmp; + for (int j = 0; j < 16; ++j) scale[j] -= 32; + + + size_t vl = 32; + uint8_t m = 1; + + vint32m1_t vzero = __riscv_vmv_v_x_i32m1(0, 1); + vuint8m2_t vqh = __riscv_vle8_v_u8m2(qh, vl); + + int sum_t = 0; + + for (int j = 0; j < QK_K; j += 128) { + + vl = 32; + + // load Q3 + vuint8m2_t q3_x = __riscv_vle8_v_u8m2(q3, vl); + + vint8m2_t q3_0 = __riscv_vreinterpret_v_u8m2_i8m2(__riscv_vand_vx_u8m2(q3_x, 0x03, vl)); + vint8m2_t q3_1 = __riscv_vreinterpret_v_u8m2_i8m2(__riscv_vand_vx_u8m2(__riscv_vsrl_vx_u8m2(q3_x, 0x2, vl), 0x03 , vl)); + vint8m2_t q3_2 = __riscv_vreinterpret_v_u8m2_i8m2(__riscv_vand_vx_u8m2(__riscv_vsrl_vx_u8m2(q3_x, 0x4, vl), 0x03 , vl)); + vint8m2_t q3_3 = __riscv_vreinterpret_v_u8m2_i8m2(__riscv_vand_vx_u8m2(__riscv_vsrl_vx_u8m2(q3_x, 0x6, vl), 0x03 , vl)); + + // compute mask for subtraction + vuint8m2_t qh_m0 = __riscv_vand_vx_u8m2(vqh, m, vl); + vbool4_t vmask_0 = __riscv_vmseq_vx_u8m2_b4(qh_m0, 0, vl); + vint8m2_t q3_m0 = __riscv_vsub_vx_i8m2_mu(vmask_0, q3_0, q3_0, 0x4, vl); + m <<= 1; + + vuint8m2_t qh_m1 = __riscv_vand_vx_u8m2(vqh, m, vl); + vbool4_t vmask_1 = __riscv_vmseq_vx_u8m2_b4(qh_m1, 0, vl); + vint8m2_t q3_m1 = __riscv_vsub_vx_i8m2_mu(vmask_1, q3_1, q3_1, 0x4, vl); + m <<= 1; + + vuint8m2_t qh_m2 = __riscv_vand_vx_u8m2(vqh, m, vl); + vbool4_t vmask_2 = __riscv_vmseq_vx_u8m2_b4(qh_m2, 0, vl); + vint8m2_t q3_m2 = __riscv_vsub_vx_i8m2_mu(vmask_2, q3_2, q3_2, 0x4, vl); + m <<= 1; + + vuint8m2_t qh_m3 = __riscv_vand_vx_u8m2(vqh, m, vl); + vbool4_t vmask_3 = __riscv_vmseq_vx_u8m2_b4(qh_m3, 0, vl); + vint8m2_t q3_m3 = __riscv_vsub_vx_i8m2_mu(vmask_3, q3_3, q3_3, 0x4, vl); + m <<= 1; - vint32m1_t isum0 = __riscv_vredsum_vs_i32m2_i32m1(__riscv_vadd_vv_i32m2(aux0_0, aux0_1, vl), vzero, vl); - vint32m1_t isum1 = __riscv_vredsum_vs_i32m2_i32m1(__riscv_vadd_vv_i32m2(aux1_0, aux1_1, vl), isum0, vl); - vint32m1_t isum2 = __riscv_vredsum_vs_i32m2_i32m1(__riscv_vadd_vv_i32m2(aux2_0, aux2_1, vl), isum1, vl); - vint32m1_t isum3 = __riscv_vredsum_vs_i32m2_i32m1(__riscv_vadd_vv_i32m2(aux3_0, aux3_1, vl), isum2, vl); + // load Q8 and take product with Q3 + vint16m4_t a0 = __riscv_vwmul_vv_i16m4(q3_m0, __riscv_vle8_v_i8m2(q8, vl), vl); + vint16m4_t a1 = __riscv_vwmul_vv_i16m4(q3_m1, __riscv_vle8_v_i8m2(q8+32, vl), vl); + vint16m4_t a2 = __riscv_vwmul_vv_i16m4(q3_m2, __riscv_vle8_v_i8m2(q8+64, vl), vl); + vint16m4_t a3 = __riscv_vwmul_vv_i16m4(q3_m3, __riscv_vle8_v_i8m2(q8+96, vl), vl); + + vl = 16; + + // retrieve lane to multiply with scale + vint32m4_t aux0_0 = __riscv_vwmul_vx_i32m4(__riscv_vget_v_i16m4_i16m2(a0, 0), (scale[0]), vl); + vint32m4_t aux0_1 = __riscv_vwmacc_vx_i32m4(aux0_0, (scale[1]), __riscv_vget_v_i16m4_i16m2(a0, 1), vl); + vint32m4_t aux1_0 = __riscv_vwmacc_vx_i32m4(aux0_1, (scale[2]), __riscv_vget_v_i16m4_i16m2(a1, 0), vl); + vint32m4_t aux1_1 = __riscv_vwmacc_vx_i32m4(aux1_0, (scale[3]), __riscv_vget_v_i16m4_i16m2(a1, 1), vl); + vint32m4_t aux2_0 = __riscv_vwmacc_vx_i32m4(aux1_1, (scale[4]), __riscv_vget_v_i16m4_i16m2(a2, 0), vl); + vint32m4_t aux2_1 = __riscv_vwmacc_vx_i32m4(aux2_0, (scale[5]), __riscv_vget_v_i16m4_i16m2(a2, 1), vl); + vint32m4_t aux3_0 = __riscv_vwmacc_vx_i32m4(aux2_1, (scale[6]), __riscv_vget_v_i16m4_i16m2(a3, 0), vl); + vint32m4_t aux3_1 = __riscv_vwmacc_vx_i32m4(aux3_0, (scale[7]), __riscv_vget_v_i16m4_i16m2(a3, 1), vl); + + vint32m1_t isum0 = __riscv_vredsum_vs_i32m4_i32m1(aux3_1, vzero, vl); + + sum_t += __riscv_vmv_x_s_i32m1_i32(isum0); + + q3 += 32; q8 += 128; scale += 8; + + } - sum_t += __riscv_vmv_x_s_i32m1_i32(isum3); + const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d; + + sumf += d * sum_t; + + } + } else if (__riscv_vlenb() == 16) { + for (int i = 0; i < nb; ++i) { + const uint8_t * restrict q3 = x[i].qs; + const uint8_t * restrict qh = x[i].hmask; + const int8_t * restrict q8 = y[i].qs; + + int8_t * scale = (int8_t *)utmp; + int tmp; + __asm__ __volatile__( + "vsetivli zero, 12, e8, m1\n\t" + "vle8.v v0, (%[s6b])\n\t" + "vmv1r.v v2, v0\n\t" + "vsetivli zero, 2, e64, m1\n\t" + "vmv.v.x v9, %[sh]\n\t"\ + "vslidedown.vi v1, v0, 1\n\t" + "vslide1up.vx v8, v9, zero\n\t" // {0, 0, 4, 4} + "vslideup.vi v0, v2, 1\n\t" // {aux[0], aux[1], aux[0], aux[1]} + "vsetivli zero, 4, e32, m1\n\t" + "vid.v v9\n\t" + "vmv.x.s %[tmp], v1\n\t" + "vsll.vi v9, v9, 1\n\t" // {0, 2, 4, 6} + "vmv.v.x v1, %[tmp]\n\t" // {aux[2], aux[2], aux[2], aux[2]} + "vsrl.vv v4, v1, v9\n\t" + "vsrl.vv v2, v0, v8\n\t" + "vand.vx v5, v4, %[kmask1]\n\t" + "vand.vx v3, v2, %[kmask2]\n\t" + "vsll.vi v6, v5, 4\n\t" + "vor.vv v7, v6, v3\n\t" + "vsetivli zero, 16, e8, m1\n\t" + "vsub.vx v0, v7, %[c]\n\t" + "vse8.v v0, (%[scale])" + : [tmp] "=&r" (tmp) + : [sh] "r" (0x0000000400000004), [s6b] "r" (x[i].scales), [c] "r" (32) + , [scale] "r" (scale), [kmask1] "r" (kmask1), [kmask2] "r" (kmask2) + : "memory" + , "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7" + , "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15" + , "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23" + , "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31" + ); - q3 += 32; q8 += 128; scale += 8; + uint8_t m = 1; + int isum = 0; + for (int j = 0; j < QK_K; j += 128) { + __asm__ __volatile__( + "vsetvli zero, %[vl32], e8, m2, ta, mu\n\t" + "vle8.v v8, (%[q3])\n\t" + "vsrl.vi v10, v8, 2\n\t" + "vsrl.vi v12, v8, 4\n\t" + "vsrl.vi v14, v8, 6\n\t" + "vand.vi v8, v8, 3\n\t" + "vand.vi v10, v10, 3\n\t" + "vand.vi v12, v12, 3\n\t" + "vle8.v v2, (%[qh])\n\t" + "vand.vx v4, v2, %[m]\n\t" + "slli %[m], %[m], 1\n\t" + "vmseq.vx v0, v4, zero\n\t" + "vadd.vi v8, v8, -4, v0.t\n\t" + "vand.vx v4, v2, %[m]\n\t" + "slli %[m], %[m], 1\n\t" + "vmseq.vx v0, v4, zero\n\t" + "vadd.vi v10, v10, -4, v0.t\n\t" + "vand.vx v4, v2, %[m]\n\t" + "slli %[m], %[m], 1\n\t" + "vmseq.vx v0, v4, zero\n\t" + "vadd.vi v12, v12, -4, v0.t\n\t" + "vand.vx v4, v2, %[m]\n\t" + "slli %[m], %[m], 1\n\t" + "vmseq.vx v0, v4, zero\n\t" + "vadd.vi v14, v14, -4, v0.t\n\t" + "vsetvli zero, %[vl128], e8, m8\n\t" + "vle8.v v0, (%[q8])\n\t" + "vsetvli zero, %[vl64], e8, m4\n\t" + "vwmul.vv v16, v0, v8\n\t" + "vwmul.vv v24, v4, v12\n\t" + "vsetivli zero, 16, e16, m2\n\t" + "vmv.v.x v0, zero\n\t" + "vwredsum.vs v10, v16, v0\n\t" + "vwredsum.vs v9, v18, v0\n\t" + "vwredsum.vs v8, v20, v0\n\t" + "vwredsum.vs v7, v22, v0\n\t" + "vwredsum.vs v11, v24, v0\n\t" + "vwredsum.vs v12, v26, v0\n\t" + "vwredsum.vs v13, v28, v0\n\t" + "vwredsum.vs v14, v30, v0\n\t" + "vsetivli zero, 4, e32, m1\n\t" + "vslideup.vi v10, v9, 1\n\t" + "vslideup.vi v8, v7, 1\n\t" + "vslideup.vi v11, v12, 1\n\t" + "vslideup.vi v13, v14, 1\n\t" + "vslideup.vi v10, v8, 2\n\t" + "vslideup.vi v11, v13, 2\n\t" + "vsetivli zero, 8, e32, m2\n\t"\ + "vle8.v v15, (%[scale])\n\t" + "vsext.vf4 v12, v15\n\t" + "vmul.vv v10, v10, v12\n\t" + "vredsum.vs v0, v10, v0\n\t" + "vmv.x.s %[tmp], v0\n\t" + "add %[isum], %[isum], %[tmp]" + : [tmp] "=&r" (tmp), [m] "+&r" (m), [isum] "+&r" (isum) + : [vl128] "r" (128), [vl64] "r" (64), [vl32] "r" (32) + , [q3] "r" (q3), [qh] "r" (qh), [scale] "r" (scale), [q8] "r" (q8) + : "memory" + , "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7" + , "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15" + , "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23" + , "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31" + ); + q3 += 32; q8 += 128; scale += 8; + } + const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d; + sumf += d * isum; } - - const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d; - - sumf += d*sum_t; - } *s = sumf; @@ -6926,67 +7180,172 @@ void ggml_vec_dot_q4_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const voi float sumf = 0; - for (int i = 0; i < nb; ++i) { - - size_t vl = 8; - - const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d); - const float dmin = y[i].d * GGML_FP16_TO_FP32(x[i].dmin); - - vint16mf2_t q8sums_0 = __riscv_vlse16_v_i16mf2(y[i].bsums, 4, vl); - vint16mf2_t q8sums_1 = __riscv_vlse16_v_i16mf2(y[i].bsums+1, 4, vl); - vint16mf2_t q8sums = __riscv_vadd_vv_i16mf2(q8sums_0, q8sums_1, vl); - - memcpy(utmp, x[i].scales, 12); - utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4); - const uint32_t uaux = utmp[1] & kmask1; - utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4); - utmp[2] = uaux; - utmp[0] &= kmask1; - - vuint8mf4_t mins8 = __riscv_vle8_v_u8mf4(mins, vl); - vint16mf2_t v_mins = __riscv_vreinterpret_v_u16mf2_i16mf2(__riscv_vzext_vf2_u16mf2(mins8, vl)); - vint32m1_t prod = __riscv_vwmul_vv_i32m1(q8sums, v_mins, vl); - - vint32m1_t sumi = __riscv_vredsum_vs_i32m1_i32m1(prod, __riscv_vmv_v_x_i32m1(0, 1), vl); - sumf -= dmin * __riscv_vmv_x_s_i32m1_i32(sumi); - - const uint8_t * GGML_RESTRICT q4 = x[i].qs; - const int8_t * GGML_RESTRICT q8 = y[i].qs; - - vl = 32; - - int32_t sum_1 = 0; - int32_t sum_2 = 0; - - vint16m1_t vzero = __riscv_vmv_v_x_i16m1(0, 1); - - for (int j = 0; j < QK_K/64; ++j) { - // load Q4 - vuint8m1_t q4_x = __riscv_vle8_v_u8m1(q4, vl); - - // load Q8 and multiply it with lower Q4 nibble - vint8m1_t q8_0 = __riscv_vle8_v_i8m1(q8, vl); - vint8m1_t q4_0 = __riscv_vreinterpret_v_u8m1_i8m1(__riscv_vand_vx_u8m1(q4_x, 0x0F, vl)); - vint16m2_t qv_0 = __riscv_vwmul_vv_i16m2(q4_0, q8_0, vl); - vint16m1_t vs_0 = __riscv_vredsum_vs_i16m2_i16m1(qv_0, vzero, vl); - - sum_1 += __riscv_vmv_x_s_i16m1_i16(vs_0) * scales[2*j+0]; - - // load Q8 and multiply it with upper Q4 nibble - vint8m1_t q8_1 = __riscv_vle8_v_i8m1(q8+32, vl); - vint8m1_t q4_1 = __riscv_vreinterpret_v_u8m1_i8m1(__riscv_vsrl_vx_u8m1(q4_x, 0x04, vl)); - vint16m2_t qv_1 = __riscv_vwmul_vv_i16m2(q4_1, q8_1, vl); - vint16m1_t vs_1 = __riscv_vredsum_vs_i16m2_i16m1(qv_1, vzero, vl); - - sum_2 += __riscv_vmv_x_s_i16m1_i16(vs_1) * scales[2*j+1]; + if (__riscv_vlenb() >= 32) { + for (int i = 0; i < nb; ++i) { + + size_t vl = 8; + + const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d); + const float dmin = y[i].d * GGML_FP16_TO_FP32(x[i].dmin); + + vint16m1_t q8sums_0 = __riscv_vlse16_v_i16m1(y[i].bsums, 4, vl); + vint16m1_t q8sums_1 = __riscv_vlse16_v_i16m1(y[i].bsums+1, 4, vl); + vint16m1_t q8sums = __riscv_vadd_vv_i16m1(q8sums_0, q8sums_1, vl); + + memcpy(utmp, x[i].scales, 12); + utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4); + const uint32_t uaux = utmp[1] & kmask1; + utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4); + utmp[2] = uaux; + utmp[0] &= kmask1; + + vuint8mf2_t mins8 = __riscv_vle8_v_u8mf2(mins, vl); + vint16m1_t v_mins = __riscv_vreinterpret_v_u16m1_i16m1(__riscv_vzext_vf2_u16m1(mins8, vl)); + vint32m2_t prod = __riscv_vwmul_vv_i32m2(q8sums, v_mins, vl); + + vint32m1_t sumi = __riscv_vredsum_vs_i32m2_i32m1(prod, __riscv_vmv_v_x_i32m1(0, vl), vl); + sumf -= dmin * __riscv_vmv_x_s_i32m1_i32(sumi); + + const uint8_t * GGML_RESTRICT q4 = x[i].qs; + const int8_t * GGML_RESTRICT q8 = y[i].qs; + + vl = 32; + + int32_t sum_1 = 0; + int32_t sum_2 = 0; + + vint16m1_t vzero = __riscv_vmv_v_x_i16m1(0, 1); + + for (int j = 0; j < QK_K/64; ++j) { + // load Q4 + vuint8m2_t q4_x = __riscv_vle8_v_u8m2(q4, vl); + + // load Q8 and multiply it with lower Q4 nibble + vint8m2_t q8_0 = __riscv_vle8_v_i8m2(q8, vl); + vint8m2_t q4_0 = __riscv_vreinterpret_v_u8m2_i8m2(__riscv_vand_vx_u8m2(q4_x, 0x0F, vl)); + vint16m4_t qv_0 = __riscv_vwmul_vv_i16m4(q4_0, q8_0, vl); + vint16m1_t vs_0 = __riscv_vredsum_vs_i16m4_i16m1(qv_0, vzero, vl); + + sum_1 += __riscv_vmv_x_s_i16m1_i16(vs_0) * scales[2*j+0]; + + // load Q8 and multiply it with upper Q4 nibble + vint8m2_t q8_1 = __riscv_vle8_v_i8m2(q8+32, vl); + vint8m2_t q4_1 = __riscv_vreinterpret_v_u8m2_i8m2(__riscv_vsrl_vx_u8m2(q4_x, 0x04, vl)); + vint16m4_t qv_1 = __riscv_vwmul_vv_i16m4(q4_1, q8_1, vl); + vint16m1_t vs_1 = __riscv_vredsum_vs_i16m4_i16m1(qv_1, vzero, vl); + + sum_2 += __riscv_vmv_x_s_i16m1_i16(vs_1) * scales[2*j+1]; + + q4 += 32; q8 += 64; + + } - q4 += 32; q8 += 64; + sumf += d*(sum_1 + sum_2); + + } + } else if (__riscv_vlenb() == 16) { + for (int i = 0; i < nb; ++i) { + const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d); + const float dmin = y[i].d * GGML_FP16_TO_FP32(x[i].dmin); + + int tmp, tmp2, sumi; + __asm__ __volatile__( + "vsetivli zero, 12, e8, m1\n\t" + "vle8.v v1, (%[s6b])\n\t" // {aux[0], aux[1], aux[2]} + "vsetivli zero, 4, e32, m1\n\t" + "vslidedown.vi v2, v1, 2\n\t" + "vmv1r.v v3, v2\n\t" + "vslideup.vi v2, v3, 1\n\t" // {aux[2], aux[2]} + "vsetivli zero, 2, e32, m1\n\t" + "vmv.v.i v4, 4\n\t" + "vand.vx v8, v1, %[kmask1]\n\t" + "vslide1up.vx v5, v4, zero\n\t" // {0, 4} + "vsrl.vi v6, v1, 6\n\t" + "vsrl.vv v7, v2, v5\n\t" + "vand.vx v0, v6, %[kmask3]\n\t" + "vand.vx v2, v7, %[kmask2]\n\t" + "vsll.vi v6, v0, 4\n\t" + "li %[t2], 8\n\t" + "addi %[t1], %[utmp], 4\n\t" + "vor.vv v1, v6, v2\n\t" + "vsse32.v v8, (%[utmp]), %[t2]\n\t" + "vsse32.v v1, (%[t1]), %[t2]\n\t" + "vsetivli zero, 8, e16, m1\n\t" + "vle32.v v2, (%[bsums])\n\t" + "vnsrl.wi v0, v2, 0\n\t" + "vnsrl.wi v1, v2, 16\n\t" + "vadd.vv v2, v0, v1\n\t" + "vle8.v v3, (%[mins])\n\t" + "vzext.vf2 v4, v3\n\t" + "vwmul.vv v6, v4, v2\n\t" + "vmv.v.x v0, zero\n\t" + "vsetivli zero, 8, e32, m2\n\t" + "vredsum.vs v0, v6, v0\n\t" + "vmv.x.s %[sumi], v0" + : [t1] "=&r" (tmp), [t2] "=&r" (tmp2), [sumi] "=&r" (sumi) + : [bsums] "r" (y[i].bsums), [mins] "r" (mins), [utmp] "r" (utmp) + , [s6b] "r" (x[i].scales), [kmask1] "r" (kmask1) + , [kmask2] "r" (kmask2), [kmask3] "r" (kmask3) + : "memory" + , "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7" + , "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15" + , "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23" + , "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31" + ); + sumf -= dmin * sumi; + + const uint8_t * restrict q4 = x[i].qs; + const int8_t * restrict q8 = y[i].qs; + + sumi = 0; + const uint8_t * scale = scales; + + for (int j = 0; j < QK_K/128; ++j) { + int vl128 = 128, vl64 = 64, vl32 = 32; + __asm__ __volatile__( + "vsetvli zero, %[vl128], e8, m8\n\t" + "vle8.v v8, (%[q8])\n\t" + "vsetvli zero, %[vl64], e8, m4\n\t" + "vle8.v v0, (%[q4])\n\t" + "vsrl.vi v4, v0, 4\n\t" + "vand.vi v0, v0, 0xF\n\t" + "vsetvli zero, %[vl32], e8, m2\n\t" + "vwmul.vv v28, v6, v14\n\t" + "vwmul.vv v20, v4, v10\n\t" + "vwmul.vv v24, v2, v12\n\t" + "vwmul.vv v16, v0, v8\n\t" + "vsetivli zero, 4, e32, m1\n\t" + "vle8.v v2, (%[scale])\n\t" + "vmv.v.x v0, zero\n\t" + "vzext.vf4 v1, v2\n\t" + "vsetvli zero, %[vl32], e16, m4\n\t" + "vwredsum.vs v6, v24, v0\n\t" + "vwredsum.vs v7, v28, v0\n\t" + "vwredsum.vs v4, v16, v0\n\t" + "vwredsum.vs v5, v20, v0\n\t" + "vsetivli zero, 4, e32, m1\n\t" + "vslideup.vi v6, v7, 1\n\t" + "vslideup.vi v4, v5, 1\n\t" + "vslideup.vi v4, v6, 2\n\t" + "vmul.vv v8, v4, v1\n\t" + "vredsum.vs v0, v8, v0\n\t" + "vmv.x.s %[tmp], v0\n\t" + "add %[sumi], %[sumi], %[tmp]" + : [tmp] "=&r" (tmp), [sumi] "+&r" (sumi) + : [vl128] "r" (vl128), [vl64] "r" (vl64), [vl32] "r" (vl32) + , [q4] "r" (q4), [q8] "r" (q8), [scale] "r" (scale) + : "memory" + , "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7" + , "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15" + , "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23" + , "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31" + ); + q4 += 64; q8 += 128; scale += 4; + } + + sumf += d * sumi; } - - sumf += d*(sum_1 + sum_2); - } *s = sumf; @@ -7722,9 +8081,9 @@ void ggml_vec_dot_q5_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const voi const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d; const float dmin = GGML_FP16_TO_FP32(x[i].dmin) * y[i].d; - vint16mf2_t q8sums_0 = __riscv_vlse16_v_i16mf2(y[i].bsums, 4, vl); - vint16mf2_t q8sums_1 = __riscv_vlse16_v_i16mf2(y[i].bsums+1, 4, vl); - vint16mf2_t q8sums = __riscv_vadd_vv_i16mf2(q8sums_0, q8sums_1, vl); + vint16m1_t q8sums_0 = __riscv_vlse16_v_i16m1(y[i].bsums, 4, vl); + vint16m1_t q8sums_1 = __riscv_vlse16_v_i16m1(y[i].bsums+1, 4, vl); + vint16m1_t q8sums = __riscv_vadd_vv_i16m1(q8sums_0, q8sums_1, vl); memcpy(utmp, x[i].scales, 12); utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4); @@ -7733,11 +8092,11 @@ void ggml_vec_dot_q5_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const voi utmp[2] = uaux; utmp[0] &= kmask1; - vuint8mf4_t mins8 = __riscv_vle8_v_u8mf4(mins, vl); - vint16mf2_t v_mins = __riscv_vreinterpret_v_u16mf2_i16mf2(__riscv_vzext_vf2_u16mf2(mins8, vl)); - vint32m1_t prod = __riscv_vwmul_vv_i32m1(q8sums, v_mins, vl); + vuint8mf2_t mins8 = __riscv_vle8_v_u8mf2(mins, vl); + vint16m1_t v_mins = __riscv_vreinterpret_v_u16m1_i16m1(__riscv_vzext_vf2_u16m1(mins8, vl)); + vint32m2_t prod = __riscv_vwmul_vv_i32m2(q8sums, v_mins, vl); - vint32m1_t sumi = __riscv_vredsum_vs_i32m1_i32m1(prod, __riscv_vmv_v_x_i32m1(0, 1), vl); + vint32m1_t sumi = __riscv_vredsum_vs_i32m2_i32m1(prod, __riscv_vmv_v_x_i32m1(0, 1), vl); sumf -= dmin * __riscv_vmv_x_s_i32m1_i32(sumi); vl = 32; @@ -7746,43 +8105,42 @@ void ggml_vec_dot_q5_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const voi uint8_t m = 1; vint32m1_t vzero = __riscv_vmv_v_x_i32m1(0, 1); - vuint8m1_t vqh = __riscv_vle8_v_u8m1(hm, vl); + vuint8m2_t vqh = __riscv_vle8_v_u8m2(hm, vl); for (int j = 0; j < QK_K/64; ++j) { // load Q5 and Q8 - vuint8m1_t q5_x = __riscv_vle8_v_u8m1(q5, vl); - vint8m1_t q8_y1 = __riscv_vle8_v_i8m1(q8, vl); - vint8m1_t q8_y2 = __riscv_vle8_v_i8m1(q8+32, vl); + vuint8m2_t q5_x = __riscv_vle8_v_u8m2(q5, vl); + vint8m2_t q8_y1 = __riscv_vle8_v_i8m2(q8, vl); + vint8m2_t q8_y2 = __riscv_vle8_v_i8m2(q8+32, vl); // compute mask for addition - vint8m1_t q5_a = __riscv_vreinterpret_v_u8m1_i8m1(__riscv_vand_vx_u8m1(q5_x, 0x0F, vl)); - vuint8m1_t qh_m1 = __riscv_vand_vx_u8m1(vqh, m, vl); - vbool8_t vmask_1 = __riscv_vmsne_vx_u8m1_b8(qh_m1, 0, vl); - vint8m1_t q5_m1 = __riscv_vadd_vx_i8m1_mu(vmask_1, q5_a, q5_a, 16, vl); + vint8m2_t q5_a = __riscv_vreinterpret_v_u8m2_i8m2(__riscv_vand_vx_u8m2(q5_x, 0x0F, vl)); + vuint8m2_t qh_m1 = __riscv_vand_vx_u8m2(vqh, m, vl); + vbool4_t vmask_1 = __riscv_vmsne_vx_u8m2_b4(qh_m1, 0, vl); + vint8m2_t q5_m1 = __riscv_vadd_vx_i8m2_mu(vmask_1, q5_a, q5_a, 16, vl); m <<= 1; - vint8m1_t q5_l = __riscv_vreinterpret_v_u8m1_i8m1(__riscv_vsrl_vx_u8m1(q5_x, 0x04, vl)); - vuint8m1_t qh_m2 = __riscv_vand_vx_u8m1(vqh, m, vl); - vbool8_t vmask_2 = __riscv_vmsne_vx_u8m1_b8(qh_m2, 0, vl); - vint8m1_t q5_m2 = __riscv_vadd_vx_i8m1_mu(vmask_2, q5_l, q5_l, 16, vl); + vint8m2_t q5_l = __riscv_vreinterpret_v_u8m2_i8m2(__riscv_vsrl_vx_u8m2(q5_x, 0x04, vl)); + vuint8m2_t qh_m2 = __riscv_vand_vx_u8m2(vqh, m, vl); + vbool4_t vmask_2 = __riscv_vmsne_vx_u8m2_b4(qh_m2, 0, vl); + vint8m2_t q5_m2 = __riscv_vadd_vx_i8m2_mu(vmask_2, q5_l, q5_l, 16, vl); m <<= 1; - vint16m2_t v0 = __riscv_vwmul_vv_i16m2(q5_m1, q8_y1, vl); - vint16m2_t v1 = __riscv_vwmul_vv_i16m2(q5_m2, q8_y2, vl); + vint16m4_t v0 = __riscv_vwmul_vv_i16m4(q5_m1, q8_y1, vl); + vint16m4_t v1 = __riscv_vwmul_vv_i16m4(q5_m2, q8_y2, vl); - vint32m4_t vs1 = __riscv_vwmul_vx_i32m4(v0, scales[is++], vl); - vint32m4_t vs2 = __riscv_vwmul_vx_i32m4(v1, scales[is++], vl); + vint32m8_t vs1 = __riscv_vwmul_vx_i32m8(v0, scales[is++], vl); + vint32m8_t vs2 = __riscv_vwmul_vx_i32m8(v1, scales[is++], vl); - vint32m1_t vacc1 = __riscv_vredsum_vs_i32m4_i32m1(vs1, vzero, vl); - vint32m1_t vacc2 = __riscv_vredsum_vs_i32m4_i32m1(vs2, vzero, vl); + vint32m1_t vacc1 = __riscv_vredsum_vs_i32m8_i32m1(vs1, vzero, vl); + vint32m1_t vacc2 = __riscv_vredsum_vs_i32m8_i32m1(vs2, vacc1, vl); - aux32 += __riscv_vmv_x_s_i32m1_i32(vacc1) + __riscv_vmv_x_s_i32m1_i32(vacc2); + aux32 += __riscv_vmv_x_s_i32m1_i32(vacc2); q5 += 32; q8 += 64; } - vfloat32m1_t vaux = __riscv_vfmul_vf_f32m1(__riscv_vfmv_v_f_f32m1(aux32, 1), d, 1); - sums += __riscv_vfmv_f_s_f32m1_f32(vaux); + sums += aux32 * d; } @@ -8668,84 +9026,157 @@ void ggml_vec_dot_q6_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const voi #elif defined __riscv_v_intrinsic float sumf = 0; - for (int i = 0; i < nb; ++i) { - - const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d; - - const uint8_t * GGML_RESTRICT q6 = x[i].ql; - const uint8_t * GGML_RESTRICT qh = x[i].qh; - const int8_t * GGML_RESTRICT q8 = y[i].qs; - - const int8_t * GGML_RESTRICT scale = x[i].scales; - - size_t vl; - - vint32m1_t vzero = __riscv_vmv_v_x_i32m1(0, 1); - - int sum_t = 0; - int is = 0; - - for (int j = 0; j < QK_K/128; ++j) { - - vl = 32; - - // load qh - vuint8m1_t qh_x = __riscv_vle8_v_u8m1(qh, vl); - - // load Q6 - vuint8m1_t q6_0 = __riscv_vle8_v_u8m1(q6, vl); - vuint8m1_t q6_1 = __riscv_vle8_v_u8m1(q6+32, vl); - - vuint8m1_t q6a_0 = __riscv_vand_vx_u8m1(q6_0, 0x0F, vl); - vuint8m1_t q6a_1 = __riscv_vand_vx_u8m1(q6_1, 0x0F, vl); - vuint8m1_t q6s_0 = __riscv_vsrl_vx_u8m1(q6_0, 0x04, vl); - vuint8m1_t q6s_1 = __riscv_vsrl_vx_u8m1(q6_1, 0x04, vl); - - vuint8m1_t qh_0 = __riscv_vand_vx_u8m1(qh_x, 0x03, vl); - vuint8m1_t qh_1 = __riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(qh_x, 0x2, vl), 0x03 , vl); - vuint8m1_t qh_2 = __riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(qh_x, 0x4, vl), 0x03 , vl); - vuint8m1_t qh_3 = __riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(qh_x, 0x6, vl), 0x03 , vl); - - vuint8m1_t qhi_0 = __riscv_vor_vv_u8m1(q6a_0, __riscv_vsll_vx_u8m1(qh_0, 0x04, vl), vl); - vuint8m1_t qhi_1 = __riscv_vor_vv_u8m1(q6a_1, __riscv_vsll_vx_u8m1(qh_1, 0x04, vl), vl); - vuint8m1_t qhi_2 = __riscv_vor_vv_u8m1(q6s_0, __riscv_vsll_vx_u8m1(qh_2, 0x04, vl), vl); - vuint8m1_t qhi_3 = __riscv_vor_vv_u8m1(q6s_1, __riscv_vsll_vx_u8m1(qh_3, 0x04, vl), vl); - - vint8m1_t a_0 = __riscv_vsub_vx_i8m1(__riscv_vreinterpret_v_u8m1_i8m1(qhi_0), 32, vl); - vint8m1_t a_1 = __riscv_vsub_vx_i8m1(__riscv_vreinterpret_v_u8m1_i8m1(qhi_1), 32, vl); - vint8m1_t a_2 = __riscv_vsub_vx_i8m1(__riscv_vreinterpret_v_u8m1_i8m1(qhi_2), 32, vl); - vint8m1_t a_3 = __riscv_vsub_vx_i8m1(__riscv_vreinterpret_v_u8m1_i8m1(qhi_3), 32, vl); - - // load Q8 and take product - vint16m2_t va_q_0 = __riscv_vwmul_vv_i16m2(a_0, __riscv_vle8_v_i8m1(q8, vl), vl); - vint16m2_t va_q_1 = __riscv_vwmul_vv_i16m2(a_1, __riscv_vle8_v_i8m1(q8+32, vl), vl); - vint16m2_t va_q_2 = __riscv_vwmul_vv_i16m2(a_2, __riscv_vle8_v_i8m1(q8+64, vl), vl); - vint16m2_t va_q_3 = __riscv_vwmul_vv_i16m2(a_3, __riscv_vle8_v_i8m1(q8+96, vl), vl); - - vl = 16; - - vint32m2_t vaux_0 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(va_q_0, 0), scale[is+0], vl); - vint32m2_t vaux_1 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(va_q_0, 1), scale[is+1], vl); - vint32m2_t vaux_2 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(va_q_1, 0), scale[is+2], vl); - vint32m2_t vaux_3 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(va_q_1, 1), scale[is+3], vl); - vint32m2_t vaux_4 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(va_q_2, 0), scale[is+4], vl); - vint32m2_t vaux_5 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(va_q_2, 1), scale[is+5], vl); - vint32m2_t vaux_6 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(va_q_3, 0), scale[is+6], vl); - vint32m2_t vaux_7 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(va_q_3, 1), scale[is+7], vl); - - vint32m1_t isum0 = __riscv_vredsum_vs_i32m2_i32m1(__riscv_vadd_vv_i32m2(vaux_0, vaux_1, vl), vzero, vl); - vint32m1_t isum1 = __riscv_vredsum_vs_i32m2_i32m1(__riscv_vadd_vv_i32m2(vaux_2, vaux_3, vl), isum0, vl); - vint32m1_t isum2 = __riscv_vredsum_vs_i32m2_i32m1(__riscv_vadd_vv_i32m2(vaux_4, vaux_5, vl), isum1, vl); - vint32m1_t isum3 = __riscv_vredsum_vs_i32m2_i32m1(__riscv_vadd_vv_i32m2(vaux_6, vaux_7, vl), isum2, vl); - - sum_t += __riscv_vmv_x_s_i32m1_i32(isum3); - - q6 += 64; qh += 32; q8 += 128; is=8; + if (__riscv_vlenb() >= 32) { + for (int i = 0; i < nb; ++i) { + + const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d; + + const uint8_t * GGML_RESTRICT q6 = x[i].ql; + const uint8_t * GGML_RESTRICT qh = x[i].qh; + const int8_t * GGML_RESTRICT q8 = y[i].qs; + + const int8_t * GGML_RESTRICT scale = x[i].scales; + + size_t vl; + + vint32m1_t vzero = __riscv_vmv_v_x_i32m1(0, 1); + + int sum_t = 0; + int is = 0; + + for (int j = 0; j < QK_K/128; ++j) { + + vl = 32; + + // load qh + vuint8m2_t qh_x = __riscv_vle8_v_u8m2(qh, vl); + + // load Q6 + vuint8m2_t q6_0 = __riscv_vle8_v_u8m2(q6, vl); + vuint8m2_t q6_1 = __riscv_vle8_v_u8m2(q6+32, vl); + + vuint8m2_t q6a_0 = __riscv_vand_vx_u8m2(q6_0, 0x0F, vl); + vuint8m2_t q6a_1 = __riscv_vand_vx_u8m2(q6_1, 0x0F, vl); + vuint8m2_t q6s_0 = __riscv_vsrl_vx_u8m2(q6_0, 0x04, vl); + vuint8m2_t q6s_1 = __riscv_vsrl_vx_u8m2(q6_1, 0x04, vl); + + vuint8m2_t qh_0 = __riscv_vand_vx_u8m2(qh_x, 0x03, vl); + vuint8m2_t qh_1 = __riscv_vand_vx_u8m2(__riscv_vsrl_vx_u8m2(qh_x, 0x2, vl), 0x03 , vl); + vuint8m2_t qh_2 = __riscv_vand_vx_u8m2(__riscv_vsrl_vx_u8m2(qh_x, 0x4, vl), 0x03 , vl); + vuint8m2_t qh_3 = __riscv_vand_vx_u8m2(__riscv_vsrl_vx_u8m2(qh_x, 0x6, vl), 0x03 , vl); + + vuint8m2_t qhi_0 = __riscv_vor_vv_u8m2(q6a_0, __riscv_vsll_vx_u8m2(qh_0, 0x04, vl), vl); + vuint8m2_t qhi_1 = __riscv_vor_vv_u8m2(q6a_1, __riscv_vsll_vx_u8m2(qh_1, 0x04, vl), vl); + vuint8m2_t qhi_2 = __riscv_vor_vv_u8m2(q6s_0, __riscv_vsll_vx_u8m2(qh_2, 0x04, vl), vl); + vuint8m2_t qhi_3 = __riscv_vor_vv_u8m2(q6s_1, __riscv_vsll_vx_u8m2(qh_3, 0x04, vl), vl); + + vint8m2_t a_0 = __riscv_vsub_vx_i8m2(__riscv_vreinterpret_v_u8m2_i8m2(qhi_0), 32, vl); + vint8m2_t a_1 = __riscv_vsub_vx_i8m2(__riscv_vreinterpret_v_u8m2_i8m2(qhi_1), 32, vl); + vint8m2_t a_2 = __riscv_vsub_vx_i8m2(__riscv_vreinterpret_v_u8m2_i8m2(qhi_2), 32, vl); + vint8m2_t a_3 = __riscv_vsub_vx_i8m2(__riscv_vreinterpret_v_u8m2_i8m2(qhi_3), 32, vl); + + // load Q8 and take product + vint16m4_t va_q_0 = __riscv_vwmul_vv_i16m4(a_0, __riscv_vle8_v_i8m2(q8, vl), vl); + vint16m4_t va_q_1 = __riscv_vwmul_vv_i16m4(a_1, __riscv_vle8_v_i8m2(q8+32, vl), vl); + vint16m4_t va_q_2 = __riscv_vwmul_vv_i16m4(a_2, __riscv_vle8_v_i8m2(q8+64, vl), vl); + vint16m4_t va_q_3 = __riscv_vwmul_vv_i16m4(a_3, __riscv_vle8_v_i8m2(q8+96, vl), vl); + + vl = 16; + + vint32m4_t vaux_0 = __riscv_vwmul_vx_i32m4(__riscv_vget_v_i16m4_i16m2(va_q_0, 0), scale[is+0], vl); + vint32m4_t vaux_1 = __riscv_vwmacc_vx_i32m4(vaux_0, scale[is+1], __riscv_vget_v_i16m4_i16m2(va_q_0, 1), vl); + vint32m4_t vaux_2 = __riscv_vwmacc_vx_i32m4(vaux_1, scale[is+2], __riscv_vget_v_i16m4_i16m2(va_q_1, 0), vl); + vint32m4_t vaux_3 = __riscv_vwmacc_vx_i32m4(vaux_2, scale[is+3], __riscv_vget_v_i16m4_i16m2(va_q_1, 1), vl); + vint32m4_t vaux_4 = __riscv_vwmacc_vx_i32m4(vaux_3, scale[is+4], __riscv_vget_v_i16m4_i16m2(va_q_2, 0), vl); + vint32m4_t vaux_5 = __riscv_vwmacc_vx_i32m4(vaux_4, scale[is+5], __riscv_vget_v_i16m4_i16m2(va_q_2, 1), vl); + vint32m4_t vaux_6 = __riscv_vwmacc_vx_i32m4(vaux_5, scale[is+6], __riscv_vget_v_i16m4_i16m2(va_q_3, 0), vl); + vint32m4_t vaux_7 = __riscv_vwmacc_vx_i32m4(vaux_6, scale[is+7], __riscv_vget_v_i16m4_i16m2(va_q_3, 1), vl); + + vint32m1_t isum0 = __riscv_vredsum_vs_i32m4_i32m1(vaux_7, vzero, vl); + + sum_t += __riscv_vmv_x_s_i32m1_i32(isum0); + + q6 += 64; qh += 32; q8 += 128; is += 8; + + } + + sumf += d * sum_t; + + } + } else if (__riscv_vlenb() == 16) { + for (int i = 0; i < nb; ++i) { + + const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d; + + const uint8_t * restrict q6 = x[i].ql; + const uint8_t * restrict qh = x[i].qh; + const int8_t * restrict q8 = y[i].qs; + + const int8_t * restrict scale = x[i].scales; + + int sum_t = 0; + int t0; + + for (int j = 0; j < QK_K/128; ++j) { + __asm__ __volatile__( + "vsetvli zero, %[vl32], e8, m2\n\t" + "vle8.v v4, (%[qh])\n\t" + "vsll.vi v0, v4, 4\n\t" + "vsll.vi v2, v4, 2\n\t" + "vsrl.vi v6, v4, 2\n\t" + "vsetvli zero, %[vl64], e8, m4\n\t" + "vle8.v v8, (%[q6])\n\t" + "vsrl.vi v12, v8, 4\n\t" + "vand.vi v8, v8, 0xF\n\t" + "vsetvli zero, %[vl128], e8, m8\n\t" + "vand.vx v0, v0, %[mask]\n\t" + "vor.vv v8, v8, v0\n\t" + "vle8.v v0, (%[q8])\n\t" + "vsub.vx v8, v8, %[vl32]\n\t" + "vsetvli zero, %[vl64], e8, m4\n\t" + "vwmul.vv v16, v0, v8\n\t" + "vwmul.vv v24, v4, v12\n\t" + "vsetivli zero, 16, e16, m2\n\t" + "vmv.v.x v0, zero\n\t" + "vwredsum.vs v10, v16, v0\n\t" + "vwredsum.vs v9, v18, v0\n\t" + "vwredsum.vs v8, v20, v0\n\t" + "vwredsum.vs v7, v22, v0\n\t" + "vwredsum.vs v11, v24, v0\n\t" + "vwredsum.vs v12, v26, v0\n\t" + "vwredsum.vs v13, v28, v0\n\t" + "vwredsum.vs v14, v30, v0\n\t" + "vsetivli zero, 4, e32, m1\n\t" + "vslideup.vi v10, v9, 1\n\t" + "vslideup.vi v8, v7, 1\n\t" + "vslideup.vi v11, v12, 1\n\t" + "vslideup.vi v13, v14, 1\n\t" + "vslideup.vi v10, v8, 2\n\t" + "vslideup.vi v11, v13, 2\n\t" + "vsetivli zero, 8, e32, m2\n\t" + "vle8.v v2, (%[scale])\n\t" + "vsext.vf4 v4, v2\n\t" + "vmul.vv v2, v4, v10\n\t" + "vredsum.vs v0, v2, v0\n\t" + "vmv.x.s %[t0], v0\n\t" + "add %[sumi], %[sumi], %[t0]" + : [sumi] "+&r" (sum_t), [t0] "=&r" (t0) + : [qh] "r" (qh), [q6] "r" (q6), [q8] "r" (q8), [scale] "r" (scale) + , [vl32] "r" (32), [vl64] "r" (64), [vl128] "r" (128) + , [mask] "r" (0x30) + : "memory" + , "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7" + , "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15" + , "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23" + , "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31" + ); + q6 += 64; qh += 32; q8 += 128; scale += 8; + } + + sumf += d * sum_t; + } - - sumf += d * sum_t; - } *s = sumf; diff --git a/ggml/src/ggml-impl.h b/ggml/src/ggml-impl.h index 1fbcbd0456e99..be2e3fc915551 100644 --- a/ggml/src/ggml-impl.h +++ b/ggml/src/ggml-impl.h @@ -381,6 +381,35 @@ GGML_API void ggml_aligned_free(void * ptr, size_t size); return r; } +#elif defined(__riscv) && defined(GGML_RV_ZFH) + + static inline float ggml_compute_fp16_to_fp32(ggml_fp16_t h) { + float f; + __asm__( + "fmv.h.x %[f], %[h]\n\t" + "fcvt.s.h %[f], %[f]" + : [f] "=&f" (f) + : [h] "r" (h) + ); + return f; + } + + static inline ggml_fp16_t ggml_compute_fp32_to_fp16(float f) { + ggml_fp16_t res; + __asm__( + "fcvt.h.s %[f], %[f]\n\t" + "fmv.x.h %[h], %[f]" + : [h] "=&r" (res) + : [f] "f" (f) + ); + return res; + } + + #define GGML_COMPUTE_FP16_TO_FP32(x) ggml_compute_fp16_to_fp32(x) + #define GGML_COMPUTE_FP32_TO_FP16(x) ggml_compute_fp32_to_fp16(x) + #define GGML_FP16_TO_FP32(x) GGML_COMPUTE_FP16_TO_FP32(x) + #define GGML_FP32_TO_FP16(x) GGML_COMPUTE_FP32_TO_FP16(x) + #else // FP16 <-> FP32 From de8387e8fc9ac40a9bb46f45c38cd2276081e213 Mon Sep 17 00:00:00 2001 From: Xiongchuan Tan Date: Fri, 21 Mar 2025 22:05:26 +0800 Subject: [PATCH 2/4] ggml : revert to old RVV 256+ q2_K, q3_K, q4_K, q6_K impl --- ggml/src/ggml-cpu/ggml-cpu-quants.c | 426 +++++++++++----------------- 1 file changed, 158 insertions(+), 268 deletions(-) diff --git a/ggml/src/ggml-cpu/ggml-cpu-quants.c b/ggml/src/ggml-cpu/ggml-cpu-quants.c index 138b9bc5a7ae3..99a1e940f0ff6 100644 --- a/ggml/src/ggml-cpu/ggml-cpu-quants.c +++ b/ggml/src/ggml-cpu/ggml-cpu-quants.c @@ -1822,123 +1822,6 @@ void quantize_row_q8_K(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, in yc[i].d = 1.0f / iscale; } -#elif defined(__riscv_v_intrinsic) - - if (__riscv_vlenb() == 16) { - assert(k % QK_K == 0); - const int64_t nb = k / QK_K; - block_q8_K * restrict yc = y; // Cast to proper type - int frm; - __asm__ __volatile( - "fsrmi %[frm], 0b100" // RMM - : [frm] "=&r" (frm) - ); - int t0, t1, t2, t3, t4; - float ft0, ft1, ft2; - - for (int i = 0; i < nb; i++) { - __asm__ __volatile__( - "mv %[t0], %[x_block]\n\t" - "addi %[t1], %[x_block], 128\n\t" - "addi %[t2], %[x_block], 768\n\t" - "vsetvli zero, %[vl32], e32, m8\n\t" - "vle32.v v16, (%[t0])\n\t" - "vle32.v v24, (%[t1])\n\t" - "vfmax.vv v0, v16, v24\n\t" - "vfmin.vv v8, v16, v24\n\t" - "1:\n\t" - "addi %[t0], %[t0], 256\n\t" - "addi %[t1], %[t1], 256\n\t" - "vle32.v v16, (%[t0])\n\t" // last: 192..224 - "vle32.v v24, (%[t1])\n\t" // last: 224..256 - "vfmax.vv v0, v0, v16\n\t" - "vfmax.vv v0, v0, v24\n\t" - "vfmin.vv v8, v8, v16\n\t" - "vfmin.vv v8, v8, v24\n\t" - "bne %[t0], %[t2], 1b\n\t" - "vfredmax.vs v1, v0, v0\n\t" - "vfredmin.vs v8, v8, v8\n\t" - "vsetivli zero, 1, e32, m1\n\t" - "vfneg.v v9, v8\n\t" - "vmfgt.vv v0, v9, v1\n\t" - "vmerge.vvm v2, v1, v8, v0\n\t" - "fmv.w.x %[ft0], zero\n\t" - "vfmv.f.s %[ft1], v2\n\t" // max - "feq.s %[t2], %[ft0], %[ft1]\n\t" - "bne %[t2], zero, 8f\n\t" - "vmv.v.x v0, zero\n\t" - "li %[t2], -127\n\t" - "fcvt.s.w %[ft0], %[t2]\n\t" - "fdiv.s %[ft2], %[ft0], %[ft1]\n\t" // iscale - "addi %[t0], %[x_block], 768\n\t" - "addi %[t1], %[x_block], 896\n\t" - "fdiv.s %[ft1], %[ft1], %[ft0]\n\t" // d - "addi %[t2], %[qs], 192\n\t" - "addi %[t3], %[qs], 224\n\t" - "addi %[t4], %[bsums], 24\n\t" - "fsw %[ft1], 0(%[d])\n\t" - "vsetvli zero, %[vl32], e32, m8\n\t" - "6:\n\t" - "vfmul.vf v16, v16, %[ft2]\n\t" - "vfmul.vf v24, v24, %[ft2]\n\t" - "vsetvli zero, %[vl32], e16, m4\n\t" - "vfncvt.x.f.w v16, v16\n\t" - "vfncvt.x.f.w v24, v24\n\t" - "vsetvli zero, %[vl32], e8, m2\n\t" - "vnclip.wx v16, v16, zero\n\t" - "vnclip.wx v24, v24, zero\n\t" - "vse8.v v16, (%[t2])\n\t" - "vse8.v v24, (%[t3])\n\t" - "vsetivli zero, 16, e8, m1\n\t" - "vwredsum.vs v4, v16, v0\n\t" - "vwredsum.vs v5, v17, v0\n\t" - "vwredsum.vs v6, v24, v0\n\t" - "vwredsum.vs v7, v25, v0\n\t" - "vsetivli zero, 4, e16, m1\n\t" - "vslideup.vi v4, v5, 1\n\t" - "vslideup.vi v6, v7, 1\n\t" - "vslideup.vi v4, v6, 2\n\t" - "vse16.v v4, (%[t4])\n\t" - "beq %[t2], %[qs], 9f\n\t" - "addi %[t0], %[t0], -256\n\t" - "addi %[t1], %[t1], -256\n\t" - "addi %[t2], %[t2], -64\n\t" - "addi %[t3], %[t3], -64\n\t" - "addi %[t4], %[t4], -8\n\t" - "vsetvli zero, %[vl32], e32, m8\n\t" - "vle32.v v16, (%[t0])\n\t" - "vle32.v v24, (%[t1])\n\t" - "j 6b\n\t" - "8:\n\t" - "addi %[t1], %[qs], 128\n\t" - "sw zero, 0(%[d])\n\t" - "vsetvli zero, %[vl128], e8, m8\n\t" - "vmv.v.x v0, zero\n\t" - "vse8.v v0, (%[qs])\n\t" - "vse8.v v0, (%[t1])\n\t" - "9:" - : [t0] "=&r" (t0), [t1] "=&r" (t1), [t2] "=&r" (t2), [t3] "=&r" (t3) - , [ft0] "=&f" (ft0), [ft1] "=&f" (ft1), [ft2] "=&f" (ft2), [t4] "=&r" (t4) - : [vl32] "r" (32), [vl128] "r" (128) - , [x_block] "r" (x + i * QK_K) - , [d] "r" (&yc[i].d), [qs] "r" (yc[i].qs), [bsums] "r" (yc[i].bsums) - : "memory" - , "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7" - , "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15" - , "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23" - , "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31" - ); - } - - __asm__ __volatile__( - "fsrm %[frm]" - : - : [frm] "r" (frm) - ); - } else { - quantize_row_q8_K_ref(x, y, k); - } - #else quantize_row_q8_K_ref(x, y, k); #endif @@ -5228,78 +5111,79 @@ void ggml_vec_dot_q2_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const voi float sumf = 0; if (__riscv_vlenb() >= 32) { - uint8_t temp_01[32] = {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1}; + uint8_t temp_01[32] = { 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1 }; for (int i = 0; i < nb; ++i) { - const uint8_t * q2 = x[i].qs; - const int8_t * q8 = y[i].qs; + const int8_t * q8 = y[i].qs; const uint8_t * sc = x[i].scales; - + const float dall = y[i].d * GGML_FP16_TO_FP32(x[i].d); const float dmin = -y[i].d * GGML_FP16_TO_FP32(x[i].dmin); - + size_t vl = 16; - - vuint8m2_t scales = __riscv_vle8_v_u8m2(sc, vl); - vuint8m2_t aux = __riscv_vand_vx_u8m2(scales, 0x0F, vl); - - vint16m2_t q8sums = __riscv_vle16_v_i16m2(y[i].bsums, vl); - - vuint8m1_t scales_2 = __riscv_vle8_v_u8m1(sc, vl); - vuint8m1_t mins8 = __riscv_vsrl_vx_u8m1(scales_2, 0x4, vl); - vint16m2_t mins = __riscv_vreinterpret_v_u16m2_i16m2(__riscv_vzext_vf2_u16m2(mins8, vl)); - vint32m4_t prod = __riscv_vwmul_vv_i32m4(q8sums, mins, vl); - vint32m1_t vsums = __riscv_vredsum_vs_i32m4_i32m1(prod, __riscv_vmv_v_x_i32m1(0, 1), vl); - - sumf += dmin * __riscv_vmv_x_s_i32m1_i32(vsums); - + + vuint8m1_t scales = __riscv_vle8_v_u8m1(sc, vl); + vuint8m1_t aux = __riscv_vand_vx_u8m1(scales, 0x0F, vl); + + vint16m1_t q8sums = __riscv_vle16_v_i16m1(y[i].bsums, vl); + + vuint8mf2_t scales_2 = __riscv_vle8_v_u8mf2(sc, vl); + vuint8mf2_t mins8 = __riscv_vsrl_vx_u8mf2(scales_2, 0x4, vl); + vint16m1_t mins = __riscv_vreinterpret_v_u16m1_i16m1(__riscv_vzext_vf2_u16m1(mins8, vl)); + vint32m2_t prod = __riscv_vwmul_vv_i32m2(q8sums, mins, vl); + vint32m1_t vsums = __riscv_vredsum_vs_i32m2_i32m1(prod, __riscv_vmv_v_x_i32m1(0, 1), vl); + + sumf += dmin * __riscv_vmv_x_s_i32m1_i32(vsums); + vl = 32; - + vint32m1_t vzero = __riscv_vmv_v_x_i32m1(0, 1); - vuint8m2_t v_b = __riscv_vle8_v_u8m2(temp_01, vl); - - uint8_t is=0; - int isum=0; - - for (int j = 0; j < QK_K/128; ++j) { + vuint8m1_t v_b = __riscv_vle8_v_u8m1(temp_01, vl); + + uint8_t is = 0; + int isum = 0; + + for (int j = 0; j < QK_K / 128; ++j) { // load Q2 - vuint8m2_t q2_x = __riscv_vle8_v_u8m2(q2, vl); - - vuint8m2_t q2_0 = __riscv_vand_vx_u8m2(q2_x, 0x03, vl); - vuint8m2_t q2_1 = __riscv_vand_vx_u8m2(__riscv_vsrl_vx_u8m2(q2_x, 0x2, vl), 0x03 , vl); - vuint8m2_t q2_2 = __riscv_vand_vx_u8m2(__riscv_vsrl_vx_u8m2(q2_x, 0x4, vl), 0x03 , vl); - vuint8m2_t q2_3 = __riscv_vand_vx_u8m2(__riscv_vsrl_vx_u8m2(q2_x, 0x6, vl), 0x03 , vl); - + vuint8m1_t q2_x = __riscv_vle8_v_u8m1(q2, vl); + + vuint8m1_t q2_0 = __riscv_vand_vx_u8m1(q2_x, 0x03, vl); + vuint8m1_t q2_1 = __riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(q2_x, 0x2, vl), 0x03, vl); + vuint8m1_t q2_2 = __riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(q2_x, 0x4, vl), 0x03, vl); + vuint8m1_t q2_3 = __riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(q2_x, 0x6, vl), 0x03, vl); + // duplicate scale elements for product - vuint8m2_t sc0 = __riscv_vrgather_vv_u8m2(aux, __riscv_vadd_vx_u8m2(v_b, 0+is, vl), vl); - vuint8m2_t sc1 = __riscv_vrgather_vv_u8m2(aux, __riscv_vadd_vx_u8m2(v_b, 2+is, vl), vl); - vuint8m2_t sc2 = __riscv_vrgather_vv_u8m2(aux, __riscv_vadd_vx_u8m2(v_b, 4+is, vl), vl); - vuint8m2_t sc3 = __riscv_vrgather_vv_u8m2(aux, __riscv_vadd_vx_u8m2(v_b, 6+is, vl), vl); - - vint16m4_t p0 = __riscv_vreinterpret_v_u16m4_i16m4(__riscv_vwmulu_vv_u16m4(q2_0, sc0, vl)); - vint16m4_t p1 = __riscv_vreinterpret_v_u16m4_i16m4(__riscv_vwmulu_vv_u16m4(q2_1, sc1, vl)); - vint16m4_t p2 = __riscv_vreinterpret_v_u16m4_i16m4(__riscv_vwmulu_vv_u16m4(q2_2, sc2, vl)); - vint16m4_t p3 = __riscv_vreinterpret_v_u16m4_i16m4(__riscv_vwmulu_vv_u16m4(q2_3, sc3, vl)); - + vuint8m1_t sc0 = __riscv_vrgather_vv_u8m1(aux, __riscv_vadd_vx_u8m1(v_b, 0 + is, vl), vl); + vuint8m1_t sc1 = __riscv_vrgather_vv_u8m1(aux, __riscv_vadd_vx_u8m1(v_b, 2 + is, vl), vl); + vuint8m1_t sc2 = __riscv_vrgather_vv_u8m1(aux, __riscv_vadd_vx_u8m1(v_b, 4 + is, vl), vl); + vuint8m1_t sc3 = __riscv_vrgather_vv_u8m1(aux, __riscv_vadd_vx_u8m1(v_b, 6 + is, vl), vl); + + vint16m2_t p0 = __riscv_vreinterpret_v_u16m2_i16m2(__riscv_vwmulu_vv_u16m2(q2_0, sc0, vl)); + vint16m2_t p1 = __riscv_vreinterpret_v_u16m2_i16m2(__riscv_vwmulu_vv_u16m2(q2_1, sc1, vl)); + vint16m2_t p2 = __riscv_vreinterpret_v_u16m2_i16m2(__riscv_vwmulu_vv_u16m2(q2_2, sc2, vl)); + vint16m2_t p3 = __riscv_vreinterpret_v_u16m2_i16m2(__riscv_vwmulu_vv_u16m2(q2_3, sc3, vl)); + // load Q8 - vint8m2_t q8_0 = __riscv_vle8_v_i8m2(q8, vl); - vint8m2_t q8_1 = __riscv_vle8_v_i8m2(q8+32, vl); - vint8m2_t q8_2 = __riscv_vle8_v_i8m2(q8+64, vl); - vint8m2_t q8_3 = __riscv_vle8_v_i8m2(q8+96, vl); - - vint32m8_t s0 = __riscv_vwmul_vv_i32m8(p0, __riscv_vwcvt_x_x_v_i16m4(q8_0, vl), vl); - vint32m8_t s1 = __riscv_vwmacc_vv_i32m8(s0, p1, __riscv_vwcvt_x_x_v_i16m4(q8_1, vl), vl); - vint32m8_t s2 = __riscv_vwmacc_vv_i32m8(s1, p2, __riscv_vwcvt_x_x_v_i16m4(q8_2, vl), vl); - vint32m8_t s3 = __riscv_vwmacc_vv_i32m8(s2, p3, __riscv_vwcvt_x_x_v_i16m4(q8_3, vl), vl); - - vint32m1_t isum0 = __riscv_vredsum_vs_i32m8_i32m1(s3, vzero, vl); - - isum += __riscv_vmv_x_s_i32m1_i32(isum0); - - q2+=32; q8+=128; is=8; - + vint8m1_t q8_0 = __riscv_vle8_v_i8m1(q8, vl); + vint8m1_t q8_1 = __riscv_vle8_v_i8m1(q8 + 32, vl); + vint8m1_t q8_2 = __riscv_vle8_v_i8m1(q8 + 64, vl); + vint8m1_t q8_3 = __riscv_vle8_v_i8m1(q8 + 96, vl); + + vint32m4_t s0 = __riscv_vwmul_vv_i32m4(p0, __riscv_vwcvt_x_x_v_i16m2(q8_0, vl), vl); + vint32m4_t s1 = __riscv_vwmul_vv_i32m4(p1, __riscv_vwcvt_x_x_v_i16m2(q8_1, vl), vl); + vint32m4_t s2 = __riscv_vwmul_vv_i32m4(p2, __riscv_vwcvt_x_x_v_i16m2(q8_2, vl), vl); + vint32m4_t s3 = __riscv_vwmul_vv_i32m4(p3, __riscv_vwcvt_x_x_v_i16m2(q8_3, vl), vl); + + vint32m1_t isum0 = __riscv_vredsum_vs_i32m4_i32m1(__riscv_vadd_vv_i32m4(s0, s1, vl), vzero, vl); + vint32m1_t isum1 = __riscv_vredsum_vs_i32m4_i32m1(__riscv_vadd_vv_i32m4(s2, s3, vl), isum0, vl); + + isum += __riscv_vmv_x_s_i32m1_i32(isum1); + + q2 += 32; + q8 += 128; + is = 8; } sumf += dall * isum; @@ -6279,7 +6163,7 @@ void ggml_vec_dot_q3_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const voi uint8_t m = 1; vint32m1_t vzero = __riscv_vmv_v_x_i32m1(0, 1); - vuint8m2_t vqh = __riscv_vle8_v_u8m2(qh, vl); + vuint8m1_t vqh = __riscv_vle8_v_u8m1(qh, vl); int sum_t = 0; @@ -6288,63 +6172,66 @@ void ggml_vec_dot_q3_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const voi vl = 32; // load Q3 - vuint8m2_t q3_x = __riscv_vle8_v_u8m2(q3, vl); + vuint8m1_t q3_x = __riscv_vle8_v_u8m1(q3, vl); - vint8m2_t q3_0 = __riscv_vreinterpret_v_u8m2_i8m2(__riscv_vand_vx_u8m2(q3_x, 0x03, vl)); - vint8m2_t q3_1 = __riscv_vreinterpret_v_u8m2_i8m2(__riscv_vand_vx_u8m2(__riscv_vsrl_vx_u8m2(q3_x, 0x2, vl), 0x03 , vl)); - vint8m2_t q3_2 = __riscv_vreinterpret_v_u8m2_i8m2(__riscv_vand_vx_u8m2(__riscv_vsrl_vx_u8m2(q3_x, 0x4, vl), 0x03 , vl)); - vint8m2_t q3_3 = __riscv_vreinterpret_v_u8m2_i8m2(__riscv_vand_vx_u8m2(__riscv_vsrl_vx_u8m2(q3_x, 0x6, vl), 0x03 , vl)); + vint8m1_t q3_0 = __riscv_vreinterpret_v_u8m1_i8m1(__riscv_vand_vx_u8m1(q3_x, 0x03, vl)); + vint8m1_t q3_1 = __riscv_vreinterpret_v_u8m1_i8m1(__riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(q3_x, 0x2, vl), 0x03 , vl)); + vint8m1_t q3_2 = __riscv_vreinterpret_v_u8m1_i8m1(__riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(q3_x, 0x4, vl), 0x03 , vl)); + vint8m1_t q3_3 = __riscv_vreinterpret_v_u8m1_i8m1(__riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(q3_x, 0x6, vl), 0x03 , vl)); // compute mask for subtraction - vuint8m2_t qh_m0 = __riscv_vand_vx_u8m2(vqh, m, vl); - vbool4_t vmask_0 = __riscv_vmseq_vx_u8m2_b4(qh_m0, 0, vl); - vint8m2_t q3_m0 = __riscv_vsub_vx_i8m2_mu(vmask_0, q3_0, q3_0, 0x4, vl); + vuint8m1_t qh_m0 = __riscv_vand_vx_u8m1(vqh, m, vl); + vbool8_t vmask_0 = __riscv_vmseq_vx_u8m1_b8(qh_m0, 0, vl); + vint8m1_t q3_m0 = __riscv_vsub_vx_i8m1_mu(vmask_0, q3_0, q3_0, 0x4, vl); m <<= 1; - vuint8m2_t qh_m1 = __riscv_vand_vx_u8m2(vqh, m, vl); - vbool4_t vmask_1 = __riscv_vmseq_vx_u8m2_b4(qh_m1, 0, vl); - vint8m2_t q3_m1 = __riscv_vsub_vx_i8m2_mu(vmask_1, q3_1, q3_1, 0x4, vl); + vuint8m1_t qh_m1 = __riscv_vand_vx_u8m1(vqh, m, vl); + vbool8_t vmask_1 = __riscv_vmseq_vx_u8m1_b8(qh_m1, 0, vl); + vint8m1_t q3_m1 = __riscv_vsub_vx_i8m1_mu(vmask_1, q3_1, q3_1, 0x4, vl); m <<= 1; - vuint8m2_t qh_m2 = __riscv_vand_vx_u8m2(vqh, m, vl); - vbool4_t vmask_2 = __riscv_vmseq_vx_u8m2_b4(qh_m2, 0, vl); - vint8m2_t q3_m2 = __riscv_vsub_vx_i8m2_mu(vmask_2, q3_2, q3_2, 0x4, vl); + vuint8m1_t qh_m2 = __riscv_vand_vx_u8m1(vqh, m, vl); + vbool8_t vmask_2 = __riscv_vmseq_vx_u8m1_b8(qh_m2, 0, vl); + vint8m1_t q3_m2 = __riscv_vsub_vx_i8m1_mu(vmask_2, q3_2, q3_2, 0x4, vl); m <<= 1; - vuint8m2_t qh_m3 = __riscv_vand_vx_u8m2(vqh, m, vl); - vbool4_t vmask_3 = __riscv_vmseq_vx_u8m2_b4(qh_m3, 0, vl); - vint8m2_t q3_m3 = __riscv_vsub_vx_i8m2_mu(vmask_3, q3_3, q3_3, 0x4, vl); + vuint8m1_t qh_m3 = __riscv_vand_vx_u8m1(vqh, m, vl); + vbool8_t vmask_3 = __riscv_vmseq_vx_u8m1_b8(qh_m3, 0, vl); + vint8m1_t q3_m3 = __riscv_vsub_vx_i8m1_mu(vmask_3, q3_3, q3_3, 0x4, vl); m <<= 1; - + // load Q8 and take product with Q3 - vint16m4_t a0 = __riscv_vwmul_vv_i16m4(q3_m0, __riscv_vle8_v_i8m2(q8, vl), vl); - vint16m4_t a1 = __riscv_vwmul_vv_i16m4(q3_m1, __riscv_vle8_v_i8m2(q8+32, vl), vl); - vint16m4_t a2 = __riscv_vwmul_vv_i16m4(q3_m2, __riscv_vle8_v_i8m2(q8+64, vl), vl); - vint16m4_t a3 = __riscv_vwmul_vv_i16m4(q3_m3, __riscv_vle8_v_i8m2(q8+96, vl), vl); - + vint16m2_t a0 = __riscv_vwmul_vv_i16m2(q3_m0, __riscv_vle8_v_i8m1(q8, vl), vl); + vint16m2_t a1 = __riscv_vwmul_vv_i16m2(q3_m1, __riscv_vle8_v_i8m1(q8+32, vl), vl); + vint16m2_t a2 = __riscv_vwmul_vv_i16m2(q3_m2, __riscv_vle8_v_i8m1(q8+64, vl), vl); + vint16m2_t a3 = __riscv_vwmul_vv_i16m2(q3_m3, __riscv_vle8_v_i8m1(q8+96, vl), vl); + vl = 16; - + // retrieve lane to multiply with scale - vint32m4_t aux0_0 = __riscv_vwmul_vx_i32m4(__riscv_vget_v_i16m4_i16m2(a0, 0), (scale[0]), vl); - vint32m4_t aux0_1 = __riscv_vwmacc_vx_i32m4(aux0_0, (scale[1]), __riscv_vget_v_i16m4_i16m2(a0, 1), vl); - vint32m4_t aux1_0 = __riscv_vwmacc_vx_i32m4(aux0_1, (scale[2]), __riscv_vget_v_i16m4_i16m2(a1, 0), vl); - vint32m4_t aux1_1 = __riscv_vwmacc_vx_i32m4(aux1_0, (scale[3]), __riscv_vget_v_i16m4_i16m2(a1, 1), vl); - vint32m4_t aux2_0 = __riscv_vwmacc_vx_i32m4(aux1_1, (scale[4]), __riscv_vget_v_i16m4_i16m2(a2, 0), vl); - vint32m4_t aux2_1 = __riscv_vwmacc_vx_i32m4(aux2_0, (scale[5]), __riscv_vget_v_i16m4_i16m2(a2, 1), vl); - vint32m4_t aux3_0 = __riscv_vwmacc_vx_i32m4(aux2_1, (scale[6]), __riscv_vget_v_i16m4_i16m2(a3, 0), vl); - vint32m4_t aux3_1 = __riscv_vwmacc_vx_i32m4(aux3_0, (scale[7]), __riscv_vget_v_i16m4_i16m2(a3, 1), vl); + vint32m2_t aux0_0 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(a0, 0), (scale[0]), vl); + vint32m2_t aux0_1 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(a0, 1), (scale[1]), vl); + vint32m2_t aux1_0 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(a1, 0), (scale[2]), vl); + vint32m2_t aux1_1 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(a1, 1), (scale[3]), vl); + vint32m2_t aux2_0 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(a2, 0), (scale[4]), vl); + vint32m2_t aux2_1 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(a2, 1), (scale[5]), vl); + vint32m2_t aux3_0 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(a3, 0), (scale[6]), vl); + vint32m2_t aux3_1 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(a3, 1), (scale[7]), vl); - vint32m1_t isum0 = __riscv_vredsum_vs_i32m4_i32m1(aux3_1, vzero, vl); + vint32m1_t isum0 = __riscv_vredsum_vs_i32m2_i32m1(__riscv_vadd_vv_i32m2(aux0_0, aux0_1, vl), vzero, vl); + vint32m1_t isum1 = __riscv_vredsum_vs_i32m2_i32m1(__riscv_vadd_vv_i32m2(aux1_0, aux1_1, vl), isum0, vl); + vint32m1_t isum2 = __riscv_vredsum_vs_i32m2_i32m1(__riscv_vadd_vv_i32m2(aux2_0, aux2_1, vl), isum1, vl); + vint32m1_t isum3 = __riscv_vredsum_vs_i32m2_i32m1(__riscv_vadd_vv_i32m2(aux3_0, aux3_1, vl), isum2, vl); - sum_t += __riscv_vmv_x_s_i32m1_i32(isum0); + sum_t += __riscv_vmv_x_s_i32m1_i32(isum3); q3 += 32; q8 += 128; scale += 8; } - + const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d; - - sumf += d * sum_t; + + sumf += d*sum_t; } } else if (__riscv_vlenb() == 16) { @@ -7182,15 +7069,15 @@ void ggml_vec_dot_q4_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const voi if (__riscv_vlenb() >= 32) { for (int i = 0; i < nb; ++i) { - + size_t vl = 8; const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d); const float dmin = y[i].d * GGML_FP16_TO_FP32(x[i].dmin); - vint16m1_t q8sums_0 = __riscv_vlse16_v_i16m1(y[i].bsums, 4, vl); - vint16m1_t q8sums_1 = __riscv_vlse16_v_i16m1(y[i].bsums+1, 4, vl); - vint16m1_t q8sums = __riscv_vadd_vv_i16m1(q8sums_0, q8sums_1, vl); + vint16mf2_t q8sums_0 = __riscv_vlse16_v_i16mf2(y[i].bsums, 4, vl); + vint16mf2_t q8sums_1 = __riscv_vlse16_v_i16mf2(y[i].bsums+1, 4, vl); + vint16mf2_t q8sums = __riscv_vadd_vv_i16mf2(q8sums_0, q8sums_1, vl); memcpy(utmp, x[i].scales, 12); utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4); @@ -7199,11 +7086,11 @@ void ggml_vec_dot_q4_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const voi utmp[2] = uaux; utmp[0] &= kmask1; - vuint8mf2_t mins8 = __riscv_vle8_v_u8mf2(mins, vl); - vint16m1_t v_mins = __riscv_vreinterpret_v_u16m1_i16m1(__riscv_vzext_vf2_u16m1(mins8, vl)); - vint32m2_t prod = __riscv_vwmul_vv_i32m2(q8sums, v_mins, vl); + vuint8mf4_t mins8 = __riscv_vle8_v_u8mf4(mins, vl); + vint16mf2_t v_mins = __riscv_vreinterpret_v_u16mf2_i16mf2(__riscv_vzext_vf2_u16mf2(mins8, vl)); + vint32m1_t prod = __riscv_vwmul_vv_i32m1(q8sums, v_mins, vl); - vint32m1_t sumi = __riscv_vredsum_vs_i32m2_i32m1(prod, __riscv_vmv_v_x_i32m1(0, vl), vl); + vint32m1_t sumi = __riscv_vredsum_vs_i32m1_i32m1(prod, __riscv_vmv_v_x_i32m1(0, 1), vl); sumf -= dmin * __riscv_vmv_x_s_i32m1_i32(sumi); const uint8_t * GGML_RESTRICT q4 = x[i].qs; @@ -7218,28 +7105,28 @@ void ggml_vec_dot_q4_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const voi for (int j = 0; j < QK_K/64; ++j) { // load Q4 - vuint8m2_t q4_x = __riscv_vle8_v_u8m2(q4, vl); + vuint8m1_t q4_x = __riscv_vle8_v_u8m1(q4, vl); // load Q8 and multiply it with lower Q4 nibble - vint8m2_t q8_0 = __riscv_vle8_v_i8m2(q8, vl); - vint8m2_t q4_0 = __riscv_vreinterpret_v_u8m2_i8m2(__riscv_vand_vx_u8m2(q4_x, 0x0F, vl)); - vint16m4_t qv_0 = __riscv_vwmul_vv_i16m4(q4_0, q8_0, vl); - vint16m1_t vs_0 = __riscv_vredsum_vs_i16m4_i16m1(qv_0, vzero, vl); + vint8m1_t q8_0 = __riscv_vle8_v_i8m1(q8, vl); + vint8m1_t q4_0 = __riscv_vreinterpret_v_u8m1_i8m1(__riscv_vand_vx_u8m1(q4_x, 0x0F, vl)); + vint16m2_t qv_0 = __riscv_vwmul_vv_i16m2(q4_0, q8_0, vl); + vint16m1_t vs_0 = __riscv_vredsum_vs_i16m2_i16m1(qv_0, vzero, vl); sum_1 += __riscv_vmv_x_s_i16m1_i16(vs_0) * scales[2*j+0]; // load Q8 and multiply it with upper Q4 nibble - vint8m2_t q8_1 = __riscv_vle8_v_i8m2(q8+32, vl); - vint8m2_t q4_1 = __riscv_vreinterpret_v_u8m2_i8m2(__riscv_vsrl_vx_u8m2(q4_x, 0x04, vl)); - vint16m4_t qv_1 = __riscv_vwmul_vv_i16m4(q4_1, q8_1, vl); - vint16m1_t vs_1 = __riscv_vredsum_vs_i16m4_i16m1(qv_1, vzero, vl); + vint8m1_t q8_1 = __riscv_vle8_v_i8m1(q8+32, vl); + vint8m1_t q4_1 = __riscv_vreinterpret_v_u8m1_i8m1(__riscv_vsrl_vx_u8m1(q4_x, 0x04, vl)); + vint16m2_t qv_1 = __riscv_vwmul_vv_i16m2(q4_1, q8_1, vl); + vint16m1_t vs_1 = __riscv_vredsum_vs_i16m2_i16m1(qv_1, vzero, vl); sum_2 += __riscv_vmv_x_s_i16m1_i16(vs_1) * scales[2*j+1]; q4 += 32; q8 += 64; } - + sumf += d*(sum_1 + sum_2); } @@ -9029,7 +8916,7 @@ void ggml_vec_dot_q6_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const voi if (__riscv_vlenb() >= 32) { for (int i = 0; i < nb; ++i) { - + const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d; const uint8_t * GGML_RESTRICT q6 = x[i].ql; @@ -9050,54 +8937,57 @@ void ggml_vec_dot_q6_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const voi vl = 32; // load qh - vuint8m2_t qh_x = __riscv_vle8_v_u8m2(qh, vl); + vuint8m1_t qh_x = __riscv_vle8_v_u8m1(qh, vl); // load Q6 - vuint8m2_t q6_0 = __riscv_vle8_v_u8m2(q6, vl); - vuint8m2_t q6_1 = __riscv_vle8_v_u8m2(q6+32, vl); + vuint8m1_t q6_0 = __riscv_vle8_v_u8m1(q6, vl); + vuint8m1_t q6_1 = __riscv_vle8_v_u8m1(q6+32, vl); - vuint8m2_t q6a_0 = __riscv_vand_vx_u8m2(q6_0, 0x0F, vl); - vuint8m2_t q6a_1 = __riscv_vand_vx_u8m2(q6_1, 0x0F, vl); - vuint8m2_t q6s_0 = __riscv_vsrl_vx_u8m2(q6_0, 0x04, vl); - vuint8m2_t q6s_1 = __riscv_vsrl_vx_u8m2(q6_1, 0x04, vl); + vuint8m1_t q6a_0 = __riscv_vand_vx_u8m1(q6_0, 0x0F, vl); + vuint8m1_t q6a_1 = __riscv_vand_vx_u8m1(q6_1, 0x0F, vl); + vuint8m1_t q6s_0 = __riscv_vsrl_vx_u8m1(q6_0, 0x04, vl); + vuint8m1_t q6s_1 = __riscv_vsrl_vx_u8m1(q6_1, 0x04, vl); - vuint8m2_t qh_0 = __riscv_vand_vx_u8m2(qh_x, 0x03, vl); - vuint8m2_t qh_1 = __riscv_vand_vx_u8m2(__riscv_vsrl_vx_u8m2(qh_x, 0x2, vl), 0x03 , vl); - vuint8m2_t qh_2 = __riscv_vand_vx_u8m2(__riscv_vsrl_vx_u8m2(qh_x, 0x4, vl), 0x03 , vl); - vuint8m2_t qh_3 = __riscv_vand_vx_u8m2(__riscv_vsrl_vx_u8m2(qh_x, 0x6, vl), 0x03 , vl); + vuint8m1_t qh_0 = __riscv_vand_vx_u8m1(qh_x, 0x03, vl); + vuint8m1_t qh_1 = __riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(qh_x, 0x2, vl), 0x03 , vl); + vuint8m1_t qh_2 = __riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(qh_x, 0x4, vl), 0x03 , vl); + vuint8m1_t qh_3 = __riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(qh_x, 0x6, vl), 0x03 , vl); - vuint8m2_t qhi_0 = __riscv_vor_vv_u8m2(q6a_0, __riscv_vsll_vx_u8m2(qh_0, 0x04, vl), vl); - vuint8m2_t qhi_1 = __riscv_vor_vv_u8m2(q6a_1, __riscv_vsll_vx_u8m2(qh_1, 0x04, vl), vl); - vuint8m2_t qhi_2 = __riscv_vor_vv_u8m2(q6s_0, __riscv_vsll_vx_u8m2(qh_2, 0x04, vl), vl); - vuint8m2_t qhi_3 = __riscv_vor_vv_u8m2(q6s_1, __riscv_vsll_vx_u8m2(qh_3, 0x04, vl), vl); + vuint8m1_t qhi_0 = __riscv_vor_vv_u8m1(q6a_0, __riscv_vsll_vx_u8m1(qh_0, 0x04, vl), vl); + vuint8m1_t qhi_1 = __riscv_vor_vv_u8m1(q6a_1, __riscv_vsll_vx_u8m1(qh_1, 0x04, vl), vl); + vuint8m1_t qhi_2 = __riscv_vor_vv_u8m1(q6s_0, __riscv_vsll_vx_u8m1(qh_2, 0x04, vl), vl); + vuint8m1_t qhi_3 = __riscv_vor_vv_u8m1(q6s_1, __riscv_vsll_vx_u8m1(qh_3, 0x04, vl), vl); - vint8m2_t a_0 = __riscv_vsub_vx_i8m2(__riscv_vreinterpret_v_u8m2_i8m2(qhi_0), 32, vl); - vint8m2_t a_1 = __riscv_vsub_vx_i8m2(__riscv_vreinterpret_v_u8m2_i8m2(qhi_1), 32, vl); - vint8m2_t a_2 = __riscv_vsub_vx_i8m2(__riscv_vreinterpret_v_u8m2_i8m2(qhi_2), 32, vl); - vint8m2_t a_3 = __riscv_vsub_vx_i8m2(__riscv_vreinterpret_v_u8m2_i8m2(qhi_3), 32, vl); + vint8m1_t a_0 = __riscv_vsub_vx_i8m1(__riscv_vreinterpret_v_u8m1_i8m1(qhi_0), 32, vl); + vint8m1_t a_1 = __riscv_vsub_vx_i8m1(__riscv_vreinterpret_v_u8m1_i8m1(qhi_1), 32, vl); + vint8m1_t a_2 = __riscv_vsub_vx_i8m1(__riscv_vreinterpret_v_u8m1_i8m1(qhi_2), 32, vl); + vint8m1_t a_3 = __riscv_vsub_vx_i8m1(__riscv_vreinterpret_v_u8m1_i8m1(qhi_3), 32, vl); // load Q8 and take product - vint16m4_t va_q_0 = __riscv_vwmul_vv_i16m4(a_0, __riscv_vle8_v_i8m2(q8, vl), vl); - vint16m4_t va_q_1 = __riscv_vwmul_vv_i16m4(a_1, __riscv_vle8_v_i8m2(q8+32, vl), vl); - vint16m4_t va_q_2 = __riscv_vwmul_vv_i16m4(a_2, __riscv_vle8_v_i8m2(q8+64, vl), vl); - vint16m4_t va_q_3 = __riscv_vwmul_vv_i16m4(a_3, __riscv_vle8_v_i8m2(q8+96, vl), vl); + vint16m2_t va_q_0 = __riscv_vwmul_vv_i16m2(a_0, __riscv_vle8_v_i8m1(q8, vl), vl); + vint16m2_t va_q_1 = __riscv_vwmul_vv_i16m2(a_1, __riscv_vle8_v_i8m1(q8+32, vl), vl); + vint16m2_t va_q_2 = __riscv_vwmul_vv_i16m2(a_2, __riscv_vle8_v_i8m1(q8+64, vl), vl); + vint16m2_t va_q_3 = __riscv_vwmul_vv_i16m2(a_3, __riscv_vle8_v_i8m1(q8+96, vl), vl); vl = 16; - vint32m4_t vaux_0 = __riscv_vwmul_vx_i32m4(__riscv_vget_v_i16m4_i16m2(va_q_0, 0), scale[is+0], vl); - vint32m4_t vaux_1 = __riscv_vwmacc_vx_i32m4(vaux_0, scale[is+1], __riscv_vget_v_i16m4_i16m2(va_q_0, 1), vl); - vint32m4_t vaux_2 = __riscv_vwmacc_vx_i32m4(vaux_1, scale[is+2], __riscv_vget_v_i16m4_i16m2(va_q_1, 0), vl); - vint32m4_t vaux_3 = __riscv_vwmacc_vx_i32m4(vaux_2, scale[is+3], __riscv_vget_v_i16m4_i16m2(va_q_1, 1), vl); - vint32m4_t vaux_4 = __riscv_vwmacc_vx_i32m4(vaux_3, scale[is+4], __riscv_vget_v_i16m4_i16m2(va_q_2, 0), vl); - vint32m4_t vaux_5 = __riscv_vwmacc_vx_i32m4(vaux_4, scale[is+5], __riscv_vget_v_i16m4_i16m2(va_q_2, 1), vl); - vint32m4_t vaux_6 = __riscv_vwmacc_vx_i32m4(vaux_5, scale[is+6], __riscv_vget_v_i16m4_i16m2(va_q_3, 0), vl); - vint32m4_t vaux_7 = __riscv_vwmacc_vx_i32m4(vaux_6, scale[is+7], __riscv_vget_v_i16m4_i16m2(va_q_3, 1), vl); + vint32m2_t vaux_0 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(va_q_0, 0), scale[is+0], vl); + vint32m2_t vaux_1 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(va_q_0, 1), scale[is+1], vl); + vint32m2_t vaux_2 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(va_q_1, 0), scale[is+2], vl); + vint32m2_t vaux_3 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(va_q_1, 1), scale[is+3], vl); + vint32m2_t vaux_4 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(va_q_2, 0), scale[is+4], vl); + vint32m2_t vaux_5 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(va_q_2, 1), scale[is+5], vl); + vint32m2_t vaux_6 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(va_q_3, 0), scale[is+6], vl); + vint32m2_t vaux_7 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(va_q_3, 1), scale[is+7], vl); - vint32m1_t isum0 = __riscv_vredsum_vs_i32m4_i32m1(vaux_7, vzero, vl); + vint32m1_t isum0 = __riscv_vredsum_vs_i32m2_i32m1(__riscv_vadd_vv_i32m2(vaux_0, vaux_1, vl), vzero, vl); + vint32m1_t isum1 = __riscv_vredsum_vs_i32m2_i32m1(__riscv_vadd_vv_i32m2(vaux_2, vaux_3, vl), isum0, vl); + vint32m1_t isum2 = __riscv_vredsum_vs_i32m2_i32m1(__riscv_vadd_vv_i32m2(vaux_4, vaux_5, vl), isum1, vl); + vint32m1_t isum3 = __riscv_vredsum_vs_i32m2_i32m1(__riscv_vadd_vv_i32m2(vaux_6, vaux_7, vl), isum2, vl); - sum_t += __riscv_vmv_x_s_i32m1_i32(isum0); + sum_t += __riscv_vmv_x_s_i32m1_i32(isum3); - q6 += 64; qh += 32; q8 += 128; is += 8; + q6 += 64; qh += 32; q8 += 128; is=8; } From 0b43956499315af933a5cd85d8dd1719b9dab5f5 Mon Sep 17 00:00:00 2001 From: Xiongchuan Tan Date: Sun, 23 Mar 2025 23:46:13 +0800 Subject: [PATCH 3/4] remove trailing whitespaces --- ggml/src/ggml-cpu/ggml-cpu-quants.c | 152 ++++++++++++++-------------- 1 file changed, 76 insertions(+), 76 deletions(-) diff --git a/ggml/src/ggml-cpu/ggml-cpu-quants.c b/ggml/src/ggml-cpu/ggml-cpu-quants.c index 99a1e940f0ff6..3f535d513d286 100644 --- a/ggml/src/ggml-cpu/ggml-cpu-quants.c +++ b/ggml/src/ggml-cpu/ggml-cpu-quants.c @@ -6144,70 +6144,70 @@ void ggml_vec_dot_q3_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const voi float sumf = 0; if (__riscv_vlenb() >= 32) { for (int i = 0; i < nb; ++i) { - + const uint8_t * GGML_RESTRICT q3 = x[i].qs; const uint8_t * GGML_RESTRICT qh = x[i].hmask; const int8_t * GGML_RESTRICT q8 = y[i].qs; - + memcpy(aux, x[i].scales, 12); utmp[3] = ((aux[1] >> 4) & kmask2) | (((aux[2] >> 6) & kmask1) << 4); utmp[2] = ((aux[0] >> 4) & kmask2) | (((aux[2] >> 4) & kmask1) << 4); utmp[1] = (aux[1] & kmask2) | (((aux[2] >> 2) & kmask1) << 4); utmp[0] = (aux[0] & kmask2) | (((aux[2] >> 0) & kmask1) << 4); - + int8_t * scale = (int8_t *)utmp; for (int j = 0; j < 16; ++j) scale[j] -= 32; - - + + size_t vl = 32; uint8_t m = 1; - + vint32m1_t vzero = __riscv_vmv_v_x_i32m1(0, 1); vuint8m1_t vqh = __riscv_vle8_v_u8m1(qh, vl); - + int sum_t = 0; - + for (int j = 0; j < QK_K; j += 128) { - + vl = 32; - + // load Q3 vuint8m1_t q3_x = __riscv_vle8_v_u8m1(q3, vl); - + vint8m1_t q3_0 = __riscv_vreinterpret_v_u8m1_i8m1(__riscv_vand_vx_u8m1(q3_x, 0x03, vl)); vint8m1_t q3_1 = __riscv_vreinterpret_v_u8m1_i8m1(__riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(q3_x, 0x2, vl), 0x03 , vl)); vint8m1_t q3_2 = __riscv_vreinterpret_v_u8m1_i8m1(__riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(q3_x, 0x4, vl), 0x03 , vl)); vint8m1_t q3_3 = __riscv_vreinterpret_v_u8m1_i8m1(__riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(q3_x, 0x6, vl), 0x03 , vl)); - + // compute mask for subtraction vuint8m1_t qh_m0 = __riscv_vand_vx_u8m1(vqh, m, vl); vbool8_t vmask_0 = __riscv_vmseq_vx_u8m1_b8(qh_m0, 0, vl); vint8m1_t q3_m0 = __riscv_vsub_vx_i8m1_mu(vmask_0, q3_0, q3_0, 0x4, vl); m <<= 1; - + vuint8m1_t qh_m1 = __riscv_vand_vx_u8m1(vqh, m, vl); vbool8_t vmask_1 = __riscv_vmseq_vx_u8m1_b8(qh_m1, 0, vl); vint8m1_t q3_m1 = __riscv_vsub_vx_i8m1_mu(vmask_1, q3_1, q3_1, 0x4, vl); m <<= 1; - + vuint8m1_t qh_m2 = __riscv_vand_vx_u8m1(vqh, m, vl); vbool8_t vmask_2 = __riscv_vmseq_vx_u8m1_b8(qh_m2, 0, vl); vint8m1_t q3_m2 = __riscv_vsub_vx_i8m1_mu(vmask_2, q3_2, q3_2, 0x4, vl); m <<= 1; - + vuint8m1_t qh_m3 = __riscv_vand_vx_u8m1(vqh, m, vl); vbool8_t vmask_3 = __riscv_vmseq_vx_u8m1_b8(qh_m3, 0, vl); vint8m1_t q3_m3 = __riscv_vsub_vx_i8m1_mu(vmask_3, q3_3, q3_3, 0x4, vl); m <<= 1; - + // load Q8 and take product with Q3 vint16m2_t a0 = __riscv_vwmul_vv_i16m2(q3_m0, __riscv_vle8_v_i8m1(q8, vl), vl); vint16m2_t a1 = __riscv_vwmul_vv_i16m2(q3_m1, __riscv_vle8_v_i8m1(q8+32, vl), vl); vint16m2_t a2 = __riscv_vwmul_vv_i16m2(q3_m2, __riscv_vle8_v_i8m1(q8+64, vl), vl); vint16m2_t a3 = __riscv_vwmul_vv_i16m2(q3_m3, __riscv_vle8_v_i8m1(q8+96, vl), vl); - + vl = 16; - + // retrieve lane to multiply with scale vint32m2_t aux0_0 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(a0, 0), (scale[0]), vl); vint32m2_t aux0_1 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(a0, 1), (scale[1]), vl); @@ -6217,22 +6217,22 @@ void ggml_vec_dot_q3_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const voi vint32m2_t aux2_1 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(a2, 1), (scale[5]), vl); vint32m2_t aux3_0 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(a3, 0), (scale[6]), vl); vint32m2_t aux3_1 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(a3, 1), (scale[7]), vl); - + vint32m1_t isum0 = __riscv_vredsum_vs_i32m2_i32m1(__riscv_vadd_vv_i32m2(aux0_0, aux0_1, vl), vzero, vl); vint32m1_t isum1 = __riscv_vredsum_vs_i32m2_i32m1(__riscv_vadd_vv_i32m2(aux1_0, aux1_1, vl), isum0, vl); vint32m1_t isum2 = __riscv_vredsum_vs_i32m2_i32m1(__riscv_vadd_vv_i32m2(aux2_0, aux2_1, vl), isum1, vl); vint32m1_t isum3 = __riscv_vredsum_vs_i32m2_i32m1(__riscv_vadd_vv_i32m2(aux3_0, aux3_1, vl), isum2, vl); - + sum_t += __riscv_vmv_x_s_i32m1_i32(isum3); - + q3 += 32; q8 += 128; scale += 8; - + } - + const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d; - + sumf += d*sum_t; - + } } else if (__riscv_vlenb() == 16) { for (int i = 0; i < nb; ++i) { @@ -7071,64 +7071,64 @@ void ggml_vec_dot_q4_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const voi for (int i = 0; i < nb; ++i) { size_t vl = 8; - + const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d); const float dmin = y[i].d * GGML_FP16_TO_FP32(x[i].dmin); - + vint16mf2_t q8sums_0 = __riscv_vlse16_v_i16mf2(y[i].bsums, 4, vl); vint16mf2_t q8sums_1 = __riscv_vlse16_v_i16mf2(y[i].bsums+1, 4, vl); vint16mf2_t q8sums = __riscv_vadd_vv_i16mf2(q8sums_0, q8sums_1, vl); - + memcpy(utmp, x[i].scales, 12); utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4); const uint32_t uaux = utmp[1] & kmask1; utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4); utmp[2] = uaux; utmp[0] &= kmask1; - + vuint8mf4_t mins8 = __riscv_vle8_v_u8mf4(mins, vl); vint16mf2_t v_mins = __riscv_vreinterpret_v_u16mf2_i16mf2(__riscv_vzext_vf2_u16mf2(mins8, vl)); vint32m1_t prod = __riscv_vwmul_vv_i32m1(q8sums, v_mins, vl); - + vint32m1_t sumi = __riscv_vredsum_vs_i32m1_i32m1(prod, __riscv_vmv_v_x_i32m1(0, 1), vl); sumf -= dmin * __riscv_vmv_x_s_i32m1_i32(sumi); - + const uint8_t * GGML_RESTRICT q4 = x[i].qs; const int8_t * GGML_RESTRICT q8 = y[i].qs; - + vl = 32; - + int32_t sum_1 = 0; int32_t sum_2 = 0; - + vint16m1_t vzero = __riscv_vmv_v_x_i16m1(0, 1); - + for (int j = 0; j < QK_K/64; ++j) { // load Q4 vuint8m1_t q4_x = __riscv_vle8_v_u8m1(q4, vl); - + // load Q8 and multiply it with lower Q4 nibble vint8m1_t q8_0 = __riscv_vle8_v_i8m1(q8, vl); vint8m1_t q4_0 = __riscv_vreinterpret_v_u8m1_i8m1(__riscv_vand_vx_u8m1(q4_x, 0x0F, vl)); vint16m2_t qv_0 = __riscv_vwmul_vv_i16m2(q4_0, q8_0, vl); vint16m1_t vs_0 = __riscv_vredsum_vs_i16m2_i16m1(qv_0, vzero, vl); - + sum_1 += __riscv_vmv_x_s_i16m1_i16(vs_0) * scales[2*j+0]; - + // load Q8 and multiply it with upper Q4 nibble vint8m1_t q8_1 = __riscv_vle8_v_i8m1(q8+32, vl); vint8m1_t q4_1 = __riscv_vreinterpret_v_u8m1_i8m1(__riscv_vsrl_vx_u8m1(q4_x, 0x04, vl)); vint16m2_t qv_1 = __riscv_vwmul_vv_i16m2(q4_1, q8_1, vl); vint16m1_t vs_1 = __riscv_vredsum_vs_i16m2_i16m1(qv_1, vzero, vl); - + sum_2 += __riscv_vmv_x_s_i16m1_i16(vs_1) * scales[2*j+1]; - + q4 += 32; q8 += 64; - + } - + sumf += d*(sum_1 + sum_2); - + } } else if (__riscv_vlenb() == 16) { for (int i = 0; i < nb; ++i) { @@ -7180,13 +7180,13 @@ void ggml_vec_dot_q4_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const voi , "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31" ); sumf -= dmin * sumi; - + const uint8_t * restrict q4 = x[i].qs; const int8_t * restrict q8 = y[i].qs; - + sumi = 0; const uint8_t * scale = scales; - + for (int j = 0; j < QK_K/128; ++j) { int vl128 = 128, vl64 = 64, vl32 = 32; __asm__ __volatile__( @@ -7230,7 +7230,7 @@ void ggml_vec_dot_q4_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const voi q4 += 64; q8 += 128; scale += 4; } - + sumf += d * sumi; } } @@ -8918,59 +8918,59 @@ void ggml_vec_dot_q6_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const voi for (int i = 0; i < nb; ++i) { const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d; - + const uint8_t * GGML_RESTRICT q6 = x[i].ql; const uint8_t * GGML_RESTRICT qh = x[i].qh; const int8_t * GGML_RESTRICT q8 = y[i].qs; - + const int8_t * GGML_RESTRICT scale = x[i].scales; - + size_t vl; - + vint32m1_t vzero = __riscv_vmv_v_x_i32m1(0, 1); - + int sum_t = 0; int is = 0; - + for (int j = 0; j < QK_K/128; ++j) { - + vl = 32; - + // load qh vuint8m1_t qh_x = __riscv_vle8_v_u8m1(qh, vl); - + // load Q6 vuint8m1_t q6_0 = __riscv_vle8_v_u8m1(q6, vl); vuint8m1_t q6_1 = __riscv_vle8_v_u8m1(q6+32, vl); - + vuint8m1_t q6a_0 = __riscv_vand_vx_u8m1(q6_0, 0x0F, vl); vuint8m1_t q6a_1 = __riscv_vand_vx_u8m1(q6_1, 0x0F, vl); vuint8m1_t q6s_0 = __riscv_vsrl_vx_u8m1(q6_0, 0x04, vl); vuint8m1_t q6s_1 = __riscv_vsrl_vx_u8m1(q6_1, 0x04, vl); - + vuint8m1_t qh_0 = __riscv_vand_vx_u8m1(qh_x, 0x03, vl); vuint8m1_t qh_1 = __riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(qh_x, 0x2, vl), 0x03 , vl); vuint8m1_t qh_2 = __riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(qh_x, 0x4, vl), 0x03 , vl); vuint8m1_t qh_3 = __riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(qh_x, 0x6, vl), 0x03 , vl); - + vuint8m1_t qhi_0 = __riscv_vor_vv_u8m1(q6a_0, __riscv_vsll_vx_u8m1(qh_0, 0x04, vl), vl); vuint8m1_t qhi_1 = __riscv_vor_vv_u8m1(q6a_1, __riscv_vsll_vx_u8m1(qh_1, 0x04, vl), vl); vuint8m1_t qhi_2 = __riscv_vor_vv_u8m1(q6s_0, __riscv_vsll_vx_u8m1(qh_2, 0x04, vl), vl); vuint8m1_t qhi_3 = __riscv_vor_vv_u8m1(q6s_1, __riscv_vsll_vx_u8m1(qh_3, 0x04, vl), vl); - + vint8m1_t a_0 = __riscv_vsub_vx_i8m1(__riscv_vreinterpret_v_u8m1_i8m1(qhi_0), 32, vl); vint8m1_t a_1 = __riscv_vsub_vx_i8m1(__riscv_vreinterpret_v_u8m1_i8m1(qhi_1), 32, vl); vint8m1_t a_2 = __riscv_vsub_vx_i8m1(__riscv_vreinterpret_v_u8m1_i8m1(qhi_2), 32, vl); vint8m1_t a_3 = __riscv_vsub_vx_i8m1(__riscv_vreinterpret_v_u8m1_i8m1(qhi_3), 32, vl); - + // load Q8 and take product vint16m2_t va_q_0 = __riscv_vwmul_vv_i16m2(a_0, __riscv_vle8_v_i8m1(q8, vl), vl); vint16m2_t va_q_1 = __riscv_vwmul_vv_i16m2(a_1, __riscv_vle8_v_i8m1(q8+32, vl), vl); vint16m2_t va_q_2 = __riscv_vwmul_vv_i16m2(a_2, __riscv_vle8_v_i8m1(q8+64, vl), vl); vint16m2_t va_q_3 = __riscv_vwmul_vv_i16m2(a_3, __riscv_vle8_v_i8m1(q8+96, vl), vl); - + vl = 16; - + vint32m2_t vaux_0 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(va_q_0, 0), scale[is+0], vl); vint32m2_t vaux_1 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(va_q_0, 1), scale[is+1], vl); vint32m2_t vaux_2 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(va_q_1, 0), scale[is+2], vl); @@ -8979,35 +8979,35 @@ void ggml_vec_dot_q6_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const voi vint32m2_t vaux_5 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(va_q_2, 1), scale[is+5], vl); vint32m2_t vaux_6 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(va_q_3, 0), scale[is+6], vl); vint32m2_t vaux_7 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(va_q_3, 1), scale[is+7], vl); - + vint32m1_t isum0 = __riscv_vredsum_vs_i32m2_i32m1(__riscv_vadd_vv_i32m2(vaux_0, vaux_1, vl), vzero, vl); vint32m1_t isum1 = __riscv_vredsum_vs_i32m2_i32m1(__riscv_vadd_vv_i32m2(vaux_2, vaux_3, vl), isum0, vl); vint32m1_t isum2 = __riscv_vredsum_vs_i32m2_i32m1(__riscv_vadd_vv_i32m2(vaux_4, vaux_5, vl), isum1, vl); vint32m1_t isum3 = __riscv_vredsum_vs_i32m2_i32m1(__riscv_vadd_vv_i32m2(vaux_6, vaux_7, vl), isum2, vl); - + sum_t += __riscv_vmv_x_s_i32m1_i32(isum3); - + q6 += 64; qh += 32; q8 += 128; is=8; - + } - + sumf += d * sum_t; - + } } else if (__riscv_vlenb() == 16) { for (int i = 0; i < nb; ++i) { - + const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d; - + const uint8_t * restrict q6 = x[i].ql; const uint8_t * restrict qh = x[i].qh; const int8_t * restrict q8 = y[i].qs; - + const int8_t * restrict scale = x[i].scales; - + int sum_t = 0; int t0; - + for (int j = 0; j < QK_K/128; ++j) { __asm__ __volatile__( "vsetvli zero, %[vl32], e8, m2\n\t" @@ -9063,9 +9063,9 @@ void ggml_vec_dot_q6_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const voi ); q6 += 64; qh += 32; q8 += 128; scale += 8; } - + sumf += d * sum_t; - + } } From d1cac3da9ca6be1a02885830db0278d213f93b6d Mon Sep 17 00:00:00 2001 From: Xiongchuan Tan Date: Tue, 25 Mar 2025 18:17:59 +0800 Subject: [PATCH 4/4] restructure vector length selection code --- ggml/src/ggml-cpu/ggml-cpu-quants.c | 51 ++++++++++++++++++++++------- 1 file changed, 40 insertions(+), 11 deletions(-) diff --git a/ggml/src/ggml-cpu/ggml-cpu-quants.c b/ggml/src/ggml-cpu/ggml-cpu-quants.c index 3f535d513d286..91a81bdc3ccd0 100644 --- a/ggml/src/ggml-cpu/ggml-cpu-quants.c +++ b/ggml/src/ggml-cpu/ggml-cpu-quants.c @@ -5108,12 +5108,15 @@ void ggml_vec_dot_q2_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const voi #elif defined __riscv_v_intrinsic + const int vector_length = __riscv_vlenb() * 8; float sumf = 0; - if (__riscv_vlenb() >= 32) { - uint8_t temp_01[32] = { 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1 }; + uint8_t temp_01[32] = { 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1 }; + uint8_t atmp[16]; + switch (vector_length) { + case 256: for (int i = 0; i < nb; ++i) { const uint8_t * q2 = x[i].qs; const int8_t * q8 = y[i].qs; @@ -5188,8 +5191,8 @@ void ggml_vec_dot_q2_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const voi sumf += dall * isum; } - } else if (__riscv_vlenb() == 16) { - uint8_t atmp[16]; + break; + case 128: for (int i = 0; i < nb; ++i) { const uint8_t * q2 = x[i].qs; const int8_t * q8 = y[i].qs; @@ -5277,6 +5280,10 @@ void ggml_vec_dot_q2_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const voi sumf += dall * isum; } + break; + default: + assert(false && "Unsupported vector length"); + break; } *s = sumf; @@ -6141,8 +6148,11 @@ void ggml_vec_dot_q3_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const voi uint32_t aux[3]; uint32_t utmp[4]; + const int vector_length = __riscv_vlenb() * 8; float sumf = 0; - if (__riscv_vlenb() >= 32) { + + switch (vector_length) { + case 256: for (int i = 0; i < nb; ++i) { const uint8_t * GGML_RESTRICT q3 = x[i].qs; @@ -6234,7 +6244,8 @@ void ggml_vec_dot_q3_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const voi sumf += d*sum_t; } - } else if (__riscv_vlenb() == 16) { + break; + case 128: for (int i = 0; i < nb; ++i) { const uint8_t * restrict q3 = x[i].qs; const uint8_t * restrict qh = x[i].hmask; @@ -6348,6 +6359,10 @@ void ggml_vec_dot_q3_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const voi const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d; sumf += d * isum; } + break; + default: + assert(false && "Unsupported vector length"); + break; } *s = sumf; @@ -7065,9 +7080,11 @@ void ggml_vec_dot_q4_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const voi const uint8_t * scales = (const uint8_t*)&utmp[0]; const uint8_t * mins = (const uint8_t*)&utmp[2]; + const int vector_length = __riscv_vlenb() * 8; float sumf = 0; - if (__riscv_vlenb() >= 32) { + switch (vector_length) { + case 256: for (int i = 0; i < nb; ++i) { size_t vl = 8; @@ -7130,7 +7147,8 @@ void ggml_vec_dot_q4_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const voi sumf += d*(sum_1 + sum_2); } - } else if (__riscv_vlenb() == 16) { + break; + case 128: for (int i = 0; i < nb; ++i) { const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d); const float dmin = y[i].d * GGML_FP16_TO_FP32(x[i].dmin); @@ -7233,6 +7251,10 @@ void ggml_vec_dot_q4_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const voi sumf += d * sumi; } + break; + default: + assert(false && "Unsupported vector length"); + break; } *s = sumf; @@ -8912,9 +8934,11 @@ void ggml_vec_dot_q6_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const voi #elif defined __riscv_v_intrinsic + const int vector_length = __riscv_vlenb() * 8; float sumf = 0; - if (__riscv_vlenb() >= 32) { + switch (vector_length) { + case 256: for (int i = 0; i < nb; ++i) { const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d; @@ -8994,7 +9018,8 @@ void ggml_vec_dot_q6_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const voi sumf += d * sum_t; } - } else if (__riscv_vlenb() == 16) { + break; + case 128: for (int i = 0; i < nb; ++i) { const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d; @@ -9067,6 +9092,10 @@ void ggml_vec_dot_q6_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const voi sumf += d * sum_t; } + break; + default: + assert(false && "Unsupported vector length"); + break; } *s = sumf;