Skip to content

Commit d990e3f

Browse files
authored
ggml : speed-up ggml_vec_dot_q4_1() ARM_NEON + 32-bit ARM support (#900)
* ggml : speed-up q4_1 ARM_NEON by ~5% * ggml : implement vaddvq when missing * ggml : implement vminvq and vmaxvq when missing * ggml : implement vzip when missing * ggml : fix comment * ggml : try to use correct ifdef
1 parent 9190e8e commit d990e3f

File tree

1 file changed

+123
-43
lines changed

1 file changed

+123
-43
lines changed

ggml.c

Lines changed: 123 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -491,6 +491,77 @@ static inline __m128i packNibbles( __m128i bytes1, __m128i bytes2 )
491491
}
492492
#endif
493493

494+
#if __ARM_NEON
495+
496+
#if !defined(__aarch64__)
497+
498+
inline static uint16_t vaddvq_u8(uint8x16_t v) {
499+
return
500+
(uint16_t)vgetq_lane_u8(v, 0) + (uint16_t)vgetq_lane_u8(v, 1) +
501+
(uint16_t)vgetq_lane_u8(v, 2) + (uint16_t)vgetq_lane_u8(v, 3) +
502+
(uint16_t)vgetq_lane_u8(v, 4) + (uint16_t)vgetq_lane_u8(v, 5) +
503+
(uint16_t)vgetq_lane_u8(v, 6) + (uint16_t)vgetq_lane_u8(v, 7) +
504+
(uint16_t)vgetq_lane_u8(v, 8) + (uint16_t)vgetq_lane_u8(v, 9) +
505+
(uint16_t)vgetq_lane_u8(v, 10) + (uint16_t)vgetq_lane_u8(v, 11) +
506+
(uint16_t)vgetq_lane_u8(v, 12) + (uint16_t)vgetq_lane_u8(v, 13) +
507+
(uint16_t)vgetq_lane_u8(v, 14) + (uint16_t)vgetq_lane_u8(v, 15);
508+
}
509+
510+
inline static int32_t vaddvq_s16(int16x8_t v) {
511+
return
512+
(int32_t)vgetq_lane_s16(v, 0) + (int32_t)vgetq_lane_s16(v, 1) +
513+
(int32_t)vgetq_lane_s16(v, 2) + (int32_t)vgetq_lane_s16(v, 3) +
514+
(int32_t)vgetq_lane_s16(v, 4) + (int32_t)vgetq_lane_s16(v, 5) +
515+
(int32_t)vgetq_lane_s16(v, 6) + (int32_t)vgetq_lane_s16(v, 7);
516+
}
517+
518+
inline static uint32_t vaddvq_u16(uint16x8_t v) {
519+
return
520+
(uint32_t)vgetq_lane_u16(v, 0) + (uint32_t)vgetq_lane_u16(v, 1) +
521+
(uint32_t)vgetq_lane_u16(v, 2) + (uint32_t)vgetq_lane_u16(v, 3) +
522+
(uint32_t)vgetq_lane_u16(v, 4) + (uint32_t)vgetq_lane_u16(v, 5) +
523+
(uint32_t)vgetq_lane_u16(v, 6) + (uint32_t)vgetq_lane_u16(v, 7);
524+
}
525+
526+
inline static int32_t vaddvq_s32(int32x4_t v) {
527+
return vgetq_lane_s32(v, 0) + vgetq_lane_s32(v, 1) + vgetq_lane_s32(v, 2) + vgetq_lane_s32(v, 3);
528+
}
529+
530+
inline static float vaddvq_f32(float32x4_t v) {
531+
return vgetq_lane_f32(v, 0) + vgetq_lane_f32(v, 1) + vgetq_lane_f32(v, 2) + vgetq_lane_f32(v, 3);
532+
}
533+
534+
inline float vminvq_f32(float32x4_t v) {
535+
return
536+
MIN(MIN(vgetq_lane_f32(v, 0), vgetq_lane_f32(v, 1)),
537+
MIN(vgetq_lane_f32(v, 2), vgetq_lane_f32(v, 3)));
538+
}
539+
540+
inline float vmaxvq_f32(float32x4_t v) {
541+
return
542+
MAX(MAX(vgetq_lane_f32(v, 0), vgetq_lane_f32(v, 1)),
543+
MAX(vgetq_lane_f32(v, 2), vgetq_lane_f32(v, 3)));
544+
}
545+
546+
inline int8x8_t vzip1_s8(int8x8_t a, int8x8_t b) {
547+
return vget_low_s8(vcombine_s8(a, b));
548+
}
549+
550+
inline int8x8_t vzip2_s8(int8x8_t a, int8x8_t b) {
551+
return vget_high_s8(vcombine_s8(a, b));
552+
}
553+
554+
inline uint8x8_t vzip1_u8(uint8x8_t a, uint8x8_t b) {
555+
return vget_low_u8(vcombine_u8(a, b));
556+
}
557+
558+
inline uint8x8_t vzip2_u8(uint8x8_t a, uint8x8_t b) {
559+
return vget_high_u8(vcombine_u8(a, b));
560+
}
561+
562+
#endif
563+
#endif
564+
494565
// method 5
495566
// blocks of QK elements
496567
// represented with a single float (delta) and QK/2 8-bit ints (i.e QK 4-bit signed integer factors)
@@ -1218,15 +1289,7 @@ static void dequantize_row_q4_1(const void * restrict vx, float * restrict y, in
12181289
#define GGML_F32x4_FMA(a, b, c) vfmaq_f32(a, b, c)
12191290
#define GGML_F32x4_ADD vaddq_f32
12201291
#define GGML_F32x4_MUL vmulq_f32
1221-
#if defined(__ARM_FEATURE_QRDMX)
1222-
#define GGML_F32x4_REDUCE_ONE(x) vaddvq_f32(x)
1223-
#else
1224-
#define GGML_F32x4_REDUCE_ONE(x) \
1225-
(vgetq_lane_f32(x, 0) + \
1226-
vgetq_lane_f32(x, 1) + \
1227-
vgetq_lane_f32(x, 2) + \
1228-
vgetq_lane_f32(x, 3))
1229-
#endif
1292+
#define GGML_F32x4_REDUCE_ONE(x) vaddvq_f32(x)
12301293
#define GGML_F32x4_REDUCE(res, x) \
12311294
{ \
12321295
for (int i = 0; i < GGML_F32_ARR/2; ++i) { \
@@ -1849,55 +1912,43 @@ static void ggml_vec_dot_q4_0(const int n, float * restrict s, const void * rest
18491912
// 4-bit -> 8-bit
18501913
const int8x16_t v0_0l = vreinterpretq_s8_u8(vandq_u8(v0_0, m4b));
18511914
const int8x16_t v1_0l = vreinterpretq_s8_u8(vandq_u8(v1_0, m4b));
1852-
18531915
const int8x16_t v0_0h = vreinterpretq_s8_u8(vshrq_n_u8(v0_0, 4));
18541916
const int8x16_t v1_0h = vreinterpretq_s8_u8(vshrq_n_u8(v1_0, 4));
18551917

18561918
const int8x16_t v0_1l = vreinterpretq_s8_u8(vandq_u8(v0_1, m4b));
18571919
const int8x16_t v1_1l = vreinterpretq_s8_u8(vandq_u8(v1_1, m4b));
1858-
18591920
const int8x16_t v0_1h = vreinterpretq_s8_u8(vshrq_n_u8(v0_1, 4));
18601921
const int8x16_t v1_1h = vreinterpretq_s8_u8(vshrq_n_u8(v1_1, 4));
18611922

18621923
// sub 8
18631924
const int8x16_t v0_0ls = vsubq_s8(v0_0l, s8b);
18641925
const int8x16_t v1_0ls = vsubq_s8(v1_0l, s8b);
1865-
18661926
const int8x16_t v0_0hs = vsubq_s8(v0_0h, s8b);
18671927
const int8x16_t v1_0hs = vsubq_s8(v1_0h, s8b);
18681928

18691929
const int8x16_t v0_1ls = vsubq_s8(v0_1l, s8b);
18701930
const int8x16_t v1_1ls = vsubq_s8(v1_1l, s8b);
1871-
18721931
const int8x16_t v0_1hs = vsubq_s8(v0_1h, s8b);
18731932
const int8x16_t v1_1hs = vsubq_s8(v1_1h, s8b);
18741933

18751934
#if defined(__ARM_FEATURE_DOTPROD)
1876-
// dot product into int16x8_t
1935+
// dot product into int32x4_t
18771936
int32x4_t p_0 = vdotq_s32(vdupq_n_s32(0), v0_0ls, v1_0ls);
18781937
int32x4_t p_1 = vdotq_s32(vdupq_n_s32(0), v0_1ls, v1_1ls);
18791938

18801939
p_0 = vdotq_s32(p_0, v0_0hs, v1_0hs);
18811940
p_1 = vdotq_s32(p_1, v0_1hs, v1_1hs);
18821941

1883-
// scalar
1884-
#if defined(__ARM_FEATURE_QRDMX)
1885-
sum0 += x0->d * y0->d * vaddvq_s32(p_0);
1886-
sum1 += x1->d * y1->d * vaddvq_s32(p_1);
1887-
#else
1888-
sum0 += x0->d * y0->d * (vgetq_lane_s32(p_0, 0) + vgetq_lane_s32(p_0, 1) + vgetq_lane_s32(p_0, 2) + vgetq_lane_s32(p_0, 3));
1889-
sum1 += x1->d * y1->d * (vgetq_lane_s32(p_1, 0) + vgetq_lane_s32(p_1, 1) + vgetq_lane_s32(p_1, 2) + vgetq_lane_s32(p_1, 3));
1890-
#endif
1942+
sum0 += x0->d*y0->d*vaddvq_s32(p_0);
1943+
sum1 += x1->d*y1->d*vaddvq_s32(p_1);
18911944
#else
18921945
const int16x8_t pl0l = vmull_s8(vget_low_s8 (v0_0ls), vget_low_s8 (v1_0ls));
18931946
const int16x8_t pl0h = vmull_s8(vget_high_s8(v0_0ls), vget_high_s8(v1_0ls));
1894-
18951947
const int16x8_t ph0l = vmull_s8(vget_low_s8 (v0_0hs), vget_low_s8 (v1_0hs));
18961948
const int16x8_t ph0h = vmull_s8(vget_high_s8(v0_0hs), vget_high_s8(v1_0hs));
18971949

18981950
const int16x8_t pl1l = vmull_s8(vget_low_s8 (v0_1ls), vget_low_s8 (v1_1ls));
18991951
const int16x8_t pl1h = vmull_s8(vget_high_s8(v0_1ls), vget_high_s8(v1_1ls));
1900-
19011952
const int16x8_t ph1l = vmull_s8(vget_low_s8 (v0_1hs), vget_low_s8 (v1_1hs));
19021953
const int16x8_t ph1h = vmull_s8(vget_high_s8(v0_1hs), vget_high_s8(v1_1hs));
19031954

@@ -1910,14 +1961,8 @@ static void ggml_vec_dot_q4_0(const int n, float * restrict s, const void * rest
19101961
const int16x8_t p_0 = vaddq_s16(pl_0, ph_0);
19111962
const int16x8_t p_1 = vaddq_s16(pl_1, ph_1);
19121963

1913-
// scalar
1914-
#if defined(__ARM_FEATURE_QRDMX)
1915-
sum0 += x0->d * y0->d * vaddvq_s16(p_0);
1916-
sum1 += x1->d * y1->d * vaddvq_s16(p_1);
1917-
#else
1918-
sum0 += x0->d * y0->d * (vgetq_lane_s16(p_0, 0) + vgetq_lane_s16(p_0, 1) + vgetq_lane_s16(p_0, 2) + vgetq_lane_s16(p_0, 3) + vgetq_lane_s16(p_0, 4) + vgetq_lane_s16(p_0, 5) + vgetq_lane_s16(p_0, 6) + vgetq_lane_s16(p_0, 7));
1919-
sum1 += x1->d * y1->d * (vgetq_lane_s16(p_1, 0) + vgetq_lane_s16(p_1, 1) + vgetq_lane_s16(p_1, 2) + vgetq_lane_s16(p_1, 3) + vgetq_lane_s16(p_1, 4) + vgetq_lane_s16(p_1, 5) + vgetq_lane_s16(p_1, 6) + vgetq_lane_s16(p_1, 7));
1920-
#endif
1964+
sum0 += x0->d*y0->d*vaddvq_s16(p_0);
1965+
sum1 += x1->d*y1->d*vaddvq_s16(p_1);
19211966
#endif
19221967
}
19231968

@@ -2265,36 +2310,71 @@ static void ggml_vec_dot_q4_1(const int n, float * restrict s, const void * rest
22652310
float sum10 = 0.0f;
22662311
float sum11 = 0.0f;
22672312

2268-
for (int i = 0; i < nb; ++i) {
2313+
for (int i = 0; i < nb; i += 2) {
22692314
const block_q4_1 * restrict x0 = &x[i + 0];
22702315
const block_q4_1 * restrict y0 = &y[i + 0];
2316+
const block_q4_1 * restrict x1 = &x[i + 1];
2317+
const block_q4_1 * restrict y1 = &y[i + 1];
22712318

22722319
const uint8x16_t m4b = vdupq_n_u8(0xf);
22732320

22742321
const uint8x16_t v0_0 = vld1q_u8(x0->qs);
22752322
const uint8x16_t v1_0 = vld1q_u8(y0->qs);
2323+
const uint8x16_t v0_1 = vld1q_u8(x1->qs);
2324+
const uint8x16_t v1_1 = vld1q_u8(y1->qs);
22762325

2277-
// and with 0xf
2326+
// 4-bit -> 8-bit
22782327
const uint8x16_t v0_0l = vandq_u8(v0_0, m4b);
22792328
const uint8x16_t v1_0l = vandq_u8(v1_0, m4b);
2280-
22812329
const uint8x16_t v0_0h = vshrq_n_u8(v0_0, 4);
22822330
const uint8x16_t v1_0h = vshrq_n_u8(v1_0, 4);
22832331

2284-
// dot product into uint16x8_t
2332+
const uint8x16_t v0_1l = vandq_u8(v0_1, m4b);
2333+
const uint8x16_t v1_1l = vandq_u8(v1_1, m4b);
2334+
const uint8x16_t v0_1h = vshrq_n_u8(v0_1, 4);
2335+
const uint8x16_t v1_1h = vshrq_n_u8(v1_1, 4);
2336+
2337+
sum00 += x0->m*y0->m;
2338+
sum01 += y0->m*x0->d*(vaddvq_u8(v0_0l) + vaddvq_u8(v0_0h));
2339+
sum10 += x0->m*y0->d*(vaddvq_u8(v1_0l) + vaddvq_u8(v1_0h));
2340+
2341+
sum00 += x1->m*y1->m;
2342+
sum01 += y1->m*x1->d*(vaddvq_u8(v0_1l) + vaddvq_u8(v0_1h));
2343+
sum10 += x1->m*y1->d*(vaddvq_u8(v1_1l) + vaddvq_u8(v1_1h));
2344+
2345+
#if defined(__ARM_FEATURE_DOTPROD)
2346+
// dot product into int32x4_t
2347+
int32x4_t p_0 = vdotq_s32(vdupq_n_s32(0), v0_0l, v1_0l);
2348+
int32x4_t p_1 = vdotq_s32(vdupq_n_s32(0), v0_1l, v1_1l);
2349+
2350+
p_0 = vdotq_s32(p_0, v0_0h, v1_0h);
2351+
p_1 = vdotq_s32(p_1, v0_1h, v1_1h);
2352+
2353+
sum11 += x0->d*y0->d*vaddvq_s32(p_0);
2354+
sum11 += x1->d*y1->d*vaddvq_s32(p_1);
2355+
#else
22852356
const uint16x8_t pl0l = vmull_u8(vget_low_u8 (v0_0l), vget_low_u8 (v1_0l));
22862357
const uint16x8_t pl0h = vmull_u8(vget_high_u8(v0_0l), vget_high_u8(v1_0l));
2287-
22882358
const uint16x8_t ph0l = vmull_u8(vget_low_u8 (v0_0h), vget_low_u8 (v1_0h));
22892359
const uint16x8_t ph0h = vmull_u8(vget_high_u8(v0_0h), vget_high_u8(v1_0h));
22902360

2291-
const uint16x8_t pl0 = vaddq_u16(pl0l, pl0h);
2292-
const uint16x8_t ph0 = vaddq_u16(ph0l, ph0h);
2361+
const uint16x8_t pl1l = vmull_u8(vget_low_u8 (v0_1l), vget_low_u8 (v1_1l));
2362+
const uint16x8_t pl1h = vmull_u8(vget_high_u8(v0_1l), vget_high_u8(v1_1l));
2363+
const uint16x8_t ph1l = vmull_u8(vget_low_u8 (v0_1h), vget_low_u8 (v1_1h));
2364+
const uint16x8_t ph1h = vmull_u8(vget_high_u8(v0_1h), vget_high_u8(v1_1h));
22932365

2294-
sum00 += x0->m*y0->m;
2295-
sum01 += y0->m*x0->d*(vaddvq_u8(v0_0l) + vaddvq_u8(v0_0h));
2296-
sum10 += x0->m*y0->d*(vaddvq_u8(v1_0l) + vaddvq_u8(v1_0h));
2297-
sum11 += x0->d*y0->d*vaddvq_u16(vaddq_u16(pl0, ph0));
2366+
const uint16x8_t pl_0 = vaddq_u16(pl0l, pl0h);
2367+
const uint16x8_t ph_0 = vaddq_u16(ph0l, ph0h);
2368+
2369+
const uint16x8_t pl_1 = vaddq_u16(pl1l, pl1h);
2370+
const uint16x8_t ph_1 = vaddq_u16(ph1l, ph1h);
2371+
2372+
const uint16x8_t p_0 = vaddq_u16(pl_0, ph_0);
2373+
const uint16x8_t p_1 = vaddq_u16(pl_1, ph_1);
2374+
2375+
sum11 += x0->d*y0->d*vaddvq_u16(p_0);
2376+
sum11 += x1->d*y1->d*vaddvq_u16(p_1);
2377+
#endif
22982378
}
22992379

23002380
sumf = QK*sum00 + sum01 + sum10 + sum11;

0 commit comments

Comments
 (0)