@@ -491,6 +491,77 @@ static inline __m128i packNibbles( __m128i bytes1, __m128i bytes2 )
491
491
}
492
492
#endif
493
493
494
+ #if __ARM_NEON
495
+
496
+ #if !defined(__aarch64__ )
497
+
498
+ inline static uint16_t vaddvq_u8 (uint8x16_t v ) {
499
+ return
500
+ (uint16_t )vgetq_lane_u8 (v , 0 ) + (uint16_t )vgetq_lane_u8 (v , 1 ) +
501
+ (uint16_t )vgetq_lane_u8 (v , 2 ) + (uint16_t )vgetq_lane_u8 (v , 3 ) +
502
+ (uint16_t )vgetq_lane_u8 (v , 4 ) + (uint16_t )vgetq_lane_u8 (v , 5 ) +
503
+ (uint16_t )vgetq_lane_u8 (v , 6 ) + (uint16_t )vgetq_lane_u8 (v , 7 ) +
504
+ (uint16_t )vgetq_lane_u8 (v , 8 ) + (uint16_t )vgetq_lane_u8 (v , 9 ) +
505
+ (uint16_t )vgetq_lane_u8 (v , 10 ) + (uint16_t )vgetq_lane_u8 (v , 11 ) +
506
+ (uint16_t )vgetq_lane_u8 (v , 12 ) + (uint16_t )vgetq_lane_u8 (v , 13 ) +
507
+ (uint16_t )vgetq_lane_u8 (v , 14 ) + (uint16_t )vgetq_lane_u8 (v , 15 );
508
+ }
509
+
510
+ inline static int32_t vaddvq_s16 (int16x8_t v ) {
511
+ return
512
+ (int32_t )vgetq_lane_s16 (v , 0 ) + (int32_t )vgetq_lane_s16 (v , 1 ) +
513
+ (int32_t )vgetq_lane_s16 (v , 2 ) + (int32_t )vgetq_lane_s16 (v , 3 ) +
514
+ (int32_t )vgetq_lane_s16 (v , 4 ) + (int32_t )vgetq_lane_s16 (v , 5 ) +
515
+ (int32_t )vgetq_lane_s16 (v , 6 ) + (int32_t )vgetq_lane_s16 (v , 7 );
516
+ }
517
+
518
+ inline static uint32_t vaddvq_u16 (uint16x8_t v ) {
519
+ return
520
+ (uint32_t )vgetq_lane_u16 (v , 0 ) + (uint32_t )vgetq_lane_u16 (v , 1 ) +
521
+ (uint32_t )vgetq_lane_u16 (v , 2 ) + (uint32_t )vgetq_lane_u16 (v , 3 ) +
522
+ (uint32_t )vgetq_lane_u16 (v , 4 ) + (uint32_t )vgetq_lane_u16 (v , 5 ) +
523
+ (uint32_t )vgetq_lane_u16 (v , 6 ) + (uint32_t )vgetq_lane_u16 (v , 7 );
524
+ }
525
+
526
+ inline static int32_t vaddvq_s32 (int32x4_t v ) {
527
+ return vgetq_lane_s32 (v , 0 ) + vgetq_lane_s32 (v , 1 ) + vgetq_lane_s32 (v , 2 ) + vgetq_lane_s32 (v , 3 );
528
+ }
529
+
530
+ inline static float vaddvq_f32 (float32x4_t v ) {
531
+ return vgetq_lane_f32 (v , 0 ) + vgetq_lane_f32 (v , 1 ) + vgetq_lane_f32 (v , 2 ) + vgetq_lane_f32 (v , 3 );
532
+ }
533
+
534
+ inline float vminvq_f32 (float32x4_t v ) {
535
+ return
536
+ MIN (MIN (vgetq_lane_f32 (v , 0 ), vgetq_lane_f32 (v , 1 )),
537
+ MIN (vgetq_lane_f32 (v , 2 ), vgetq_lane_f32 (v , 3 )));
538
+ }
539
+
540
+ inline float vmaxvq_f32 (float32x4_t v ) {
541
+ return
542
+ MAX (MAX (vgetq_lane_f32 (v , 0 ), vgetq_lane_f32 (v , 1 )),
543
+ MAX (vgetq_lane_f32 (v , 2 ), vgetq_lane_f32 (v , 3 )));
544
+ }
545
+
546
+ inline int8x8_t vzip1_s8 (int8x8_t a , int8x8_t b ) {
547
+ return vget_low_s8 (vcombine_s8 (a , b ));
548
+ }
549
+
550
+ inline int8x8_t vzip2_s8 (int8x8_t a , int8x8_t b ) {
551
+ return vget_high_s8 (vcombine_s8 (a , b ));
552
+ }
553
+
554
+ inline uint8x8_t vzip1_u8 (uint8x8_t a , uint8x8_t b ) {
555
+ return vget_low_u8 (vcombine_u8 (a , b ));
556
+ }
557
+
558
+ inline uint8x8_t vzip2_u8 (uint8x8_t a , uint8x8_t b ) {
559
+ return vget_high_u8 (vcombine_u8 (a , b ));
560
+ }
561
+
562
+ #endif
563
+ #endif
564
+
494
565
// method 5
495
566
// blocks of QK elements
496
567
// represented with a single float (delta) and QK/2 8-bit ints (i.e QK 4-bit signed integer factors)
@@ -1218,15 +1289,7 @@ static void dequantize_row_q4_1(const void * restrict vx, float * restrict y, in
1218
1289
#define GGML_F32x4_FMA (a , b , c ) vfmaq_f32(a, b, c)
1219
1290
#define GGML_F32x4_ADD vaddq_f32
1220
1291
#define GGML_F32x4_MUL vmulq_f32
1221
- #if defined(__ARM_FEATURE_QRDMX )
1222
- #define GGML_F32x4_REDUCE_ONE (x ) vaddvq_f32(x)
1223
- #else
1224
- #define GGML_F32x4_REDUCE_ONE (x ) \
1225
- (vgetq_lane_f32(x, 0) + \
1226
- vgetq_lane_f32(x, 1) + \
1227
- vgetq_lane_f32(x, 2) + \
1228
- vgetq_lane_f32(x, 3))
1229
- #endif
1292
+ #define GGML_F32x4_REDUCE_ONE (x ) vaddvq_f32(x)
1230
1293
#define GGML_F32x4_REDUCE (res , x ) \
1231
1294
{ \
1232
1295
for (int i = 0; i < GGML_F32_ARR/2; ++i) { \
@@ -1849,55 +1912,43 @@ static void ggml_vec_dot_q4_0(const int n, float * restrict s, const void * rest
1849
1912
// 4-bit -> 8-bit
1850
1913
const int8x16_t v0_0l = vreinterpretq_s8_u8 (vandq_u8 (v0_0 , m4b ));
1851
1914
const int8x16_t v1_0l = vreinterpretq_s8_u8 (vandq_u8 (v1_0 , m4b ));
1852
-
1853
1915
const int8x16_t v0_0h = vreinterpretq_s8_u8 (vshrq_n_u8 (v0_0 , 4 ));
1854
1916
const int8x16_t v1_0h = vreinterpretq_s8_u8 (vshrq_n_u8 (v1_0 , 4 ));
1855
1917
1856
1918
const int8x16_t v0_1l = vreinterpretq_s8_u8 (vandq_u8 (v0_1 , m4b ));
1857
1919
const int8x16_t v1_1l = vreinterpretq_s8_u8 (vandq_u8 (v1_1 , m4b ));
1858
-
1859
1920
const int8x16_t v0_1h = vreinterpretq_s8_u8 (vshrq_n_u8 (v0_1 , 4 ));
1860
1921
const int8x16_t v1_1h = vreinterpretq_s8_u8 (vshrq_n_u8 (v1_1 , 4 ));
1861
1922
1862
1923
// sub 8
1863
1924
const int8x16_t v0_0ls = vsubq_s8 (v0_0l , s8b );
1864
1925
const int8x16_t v1_0ls = vsubq_s8 (v1_0l , s8b );
1865
-
1866
1926
const int8x16_t v0_0hs = vsubq_s8 (v0_0h , s8b );
1867
1927
const int8x16_t v1_0hs = vsubq_s8 (v1_0h , s8b );
1868
1928
1869
1929
const int8x16_t v0_1ls = vsubq_s8 (v0_1l , s8b );
1870
1930
const int8x16_t v1_1ls = vsubq_s8 (v1_1l , s8b );
1871
-
1872
1931
const int8x16_t v0_1hs = vsubq_s8 (v0_1h , s8b );
1873
1932
const int8x16_t v1_1hs = vsubq_s8 (v1_1h , s8b );
1874
1933
1875
1934
#if defined(__ARM_FEATURE_DOTPROD )
1876
- // dot product into int16x8_t
1935
+ // dot product into int32x4_t
1877
1936
int32x4_t p_0 = vdotq_s32 (vdupq_n_s32 (0 ), v0_0ls , v1_0ls );
1878
1937
int32x4_t p_1 = vdotq_s32 (vdupq_n_s32 (0 ), v0_1ls , v1_1ls );
1879
1938
1880
1939
p_0 = vdotq_s32 (p_0 , v0_0hs , v1_0hs );
1881
1940
p_1 = vdotq_s32 (p_1 , v0_1hs , v1_1hs );
1882
1941
1883
- // scalar
1884
- #if defined(__ARM_FEATURE_QRDMX )
1885
- sum0 += x0 -> d * y0 -> d * vaddvq_s32 (p_0 );
1886
- sum1 += x1 -> d * y1 -> d * vaddvq_s32 (p_1 );
1887
- #else
1888
- sum0 += x0 -> d * y0 -> d * (vgetq_lane_s32 (p_0 , 0 ) + vgetq_lane_s32 (p_0 , 1 ) + vgetq_lane_s32 (p_0 , 2 ) + vgetq_lane_s32 (p_0 , 3 ));
1889
- sum1 += x1 -> d * y1 -> d * (vgetq_lane_s32 (p_1 , 0 ) + vgetq_lane_s32 (p_1 , 1 ) + vgetq_lane_s32 (p_1 , 2 ) + vgetq_lane_s32 (p_1 , 3 ));
1890
- #endif
1942
+ sum0 += x0 -> d * y0 -> d * vaddvq_s32 (p_0 );
1943
+ sum1 += x1 -> d * y1 -> d * vaddvq_s32 (p_1 );
1891
1944
#else
1892
1945
const int16x8_t pl0l = vmull_s8 (vget_low_s8 (v0_0ls ), vget_low_s8 (v1_0ls ));
1893
1946
const int16x8_t pl0h = vmull_s8 (vget_high_s8 (v0_0ls ), vget_high_s8 (v1_0ls ));
1894
-
1895
1947
const int16x8_t ph0l = vmull_s8 (vget_low_s8 (v0_0hs ), vget_low_s8 (v1_0hs ));
1896
1948
const int16x8_t ph0h = vmull_s8 (vget_high_s8 (v0_0hs ), vget_high_s8 (v1_0hs ));
1897
1949
1898
1950
const int16x8_t pl1l = vmull_s8 (vget_low_s8 (v0_1ls ), vget_low_s8 (v1_1ls ));
1899
1951
const int16x8_t pl1h = vmull_s8 (vget_high_s8 (v0_1ls ), vget_high_s8 (v1_1ls ));
1900
-
1901
1952
const int16x8_t ph1l = vmull_s8 (vget_low_s8 (v0_1hs ), vget_low_s8 (v1_1hs ));
1902
1953
const int16x8_t ph1h = vmull_s8 (vget_high_s8 (v0_1hs ), vget_high_s8 (v1_1hs ));
1903
1954
@@ -1910,14 +1961,8 @@ static void ggml_vec_dot_q4_0(const int n, float * restrict s, const void * rest
1910
1961
const int16x8_t p_0 = vaddq_s16 (pl_0 , ph_0 );
1911
1962
const int16x8_t p_1 = vaddq_s16 (pl_1 , ph_1 );
1912
1963
1913
- // scalar
1914
- #if defined(__ARM_FEATURE_QRDMX )
1915
- sum0 += x0 -> d * y0 -> d * vaddvq_s16 (p_0 );
1916
- sum1 += x1 -> d * y1 -> d * vaddvq_s16 (p_1 );
1917
- #else
1918
- sum0 += x0 -> d * y0 -> d * (vgetq_lane_s16 (p_0 , 0 ) + vgetq_lane_s16 (p_0 , 1 ) + vgetq_lane_s16 (p_0 , 2 ) + vgetq_lane_s16 (p_0 , 3 ) + vgetq_lane_s16 (p_0 , 4 ) + vgetq_lane_s16 (p_0 , 5 ) + vgetq_lane_s16 (p_0 , 6 ) + vgetq_lane_s16 (p_0 , 7 ));
1919
- sum1 += x1 -> d * y1 -> d * (vgetq_lane_s16 (p_1 , 0 ) + vgetq_lane_s16 (p_1 , 1 ) + vgetq_lane_s16 (p_1 , 2 ) + vgetq_lane_s16 (p_1 , 3 ) + vgetq_lane_s16 (p_1 , 4 ) + vgetq_lane_s16 (p_1 , 5 ) + vgetq_lane_s16 (p_1 , 6 ) + vgetq_lane_s16 (p_1 , 7 ));
1920
- #endif
1964
+ sum0 += x0 -> d * y0 -> d * vaddvq_s16 (p_0 );
1965
+ sum1 += x1 -> d * y1 -> d * vaddvq_s16 (p_1 );
1921
1966
#endif
1922
1967
}
1923
1968
@@ -2265,36 +2310,71 @@ static void ggml_vec_dot_q4_1(const int n, float * restrict s, const void * rest
2265
2310
float sum10 = 0.0f ;
2266
2311
float sum11 = 0.0f ;
2267
2312
2268
- for (int i = 0 ; i < nb ; ++ i ) {
2313
+ for (int i = 0 ; i < nb ; i += 2 ) {
2269
2314
const block_q4_1 * restrict x0 = & x [i + 0 ];
2270
2315
const block_q4_1 * restrict y0 = & y [i + 0 ];
2316
+ const block_q4_1 * restrict x1 = & x [i + 1 ];
2317
+ const block_q4_1 * restrict y1 = & y [i + 1 ];
2271
2318
2272
2319
const uint8x16_t m4b = vdupq_n_u8 (0xf );
2273
2320
2274
2321
const uint8x16_t v0_0 = vld1q_u8 (x0 -> qs );
2275
2322
const uint8x16_t v1_0 = vld1q_u8 (y0 -> qs );
2323
+ const uint8x16_t v0_1 = vld1q_u8 (x1 -> qs );
2324
+ const uint8x16_t v1_1 = vld1q_u8 (y1 -> qs );
2276
2325
2277
- // and with 0xf
2326
+ // 4-bit -> 8-bit
2278
2327
const uint8x16_t v0_0l = vandq_u8 (v0_0 , m4b );
2279
2328
const uint8x16_t v1_0l = vandq_u8 (v1_0 , m4b );
2280
-
2281
2329
const uint8x16_t v0_0h = vshrq_n_u8 (v0_0 , 4 );
2282
2330
const uint8x16_t v1_0h = vshrq_n_u8 (v1_0 , 4 );
2283
2331
2284
- // dot product into uint16x8_t
2332
+ const uint8x16_t v0_1l = vandq_u8 (v0_1 , m4b );
2333
+ const uint8x16_t v1_1l = vandq_u8 (v1_1 , m4b );
2334
+ const uint8x16_t v0_1h = vshrq_n_u8 (v0_1 , 4 );
2335
+ const uint8x16_t v1_1h = vshrq_n_u8 (v1_1 , 4 );
2336
+
2337
+ sum00 += x0 -> m * y0 -> m ;
2338
+ sum01 += y0 -> m * x0 -> d * (vaddvq_u8 (v0_0l ) + vaddvq_u8 (v0_0h ));
2339
+ sum10 += x0 -> m * y0 -> d * (vaddvq_u8 (v1_0l ) + vaddvq_u8 (v1_0h ));
2340
+
2341
+ sum00 += x1 -> m * y1 -> m ;
2342
+ sum01 += y1 -> m * x1 -> d * (vaddvq_u8 (v0_1l ) + vaddvq_u8 (v0_1h ));
2343
+ sum10 += x1 -> m * y1 -> d * (vaddvq_u8 (v1_1l ) + vaddvq_u8 (v1_1h ));
2344
+
2345
+ #if defined(__ARM_FEATURE_DOTPROD )
2346
+ // dot product into int32x4_t
2347
+ int32x4_t p_0 = vdotq_s32 (vdupq_n_s32 (0 ), v0_0l , v1_0l );
2348
+ int32x4_t p_1 = vdotq_s32 (vdupq_n_s32 (0 ), v0_1l , v1_1l );
2349
+
2350
+ p_0 = vdotq_s32 (p_0 , v0_0h , v1_0h );
2351
+ p_1 = vdotq_s32 (p_1 , v0_1h , v1_1h );
2352
+
2353
+ sum11 += x0 -> d * y0 -> d * vaddvq_s32 (p_0 );
2354
+ sum11 += x1 -> d * y1 -> d * vaddvq_s32 (p_1 );
2355
+ #else
2285
2356
const uint16x8_t pl0l = vmull_u8 (vget_low_u8 (v0_0l ), vget_low_u8 (v1_0l ));
2286
2357
const uint16x8_t pl0h = vmull_u8 (vget_high_u8 (v0_0l ), vget_high_u8 (v1_0l ));
2287
-
2288
2358
const uint16x8_t ph0l = vmull_u8 (vget_low_u8 (v0_0h ), vget_low_u8 (v1_0h ));
2289
2359
const uint16x8_t ph0h = vmull_u8 (vget_high_u8 (v0_0h ), vget_high_u8 (v1_0h ));
2290
2360
2291
- const uint16x8_t pl0 = vaddq_u16 (pl0l , pl0h );
2292
- const uint16x8_t ph0 = vaddq_u16 (ph0l , ph0h );
2361
+ const uint16x8_t pl1l = vmull_u8 (vget_low_u8 (v0_1l ), vget_low_u8 (v1_1l ));
2362
+ const uint16x8_t pl1h = vmull_u8 (vget_high_u8 (v0_1l ), vget_high_u8 (v1_1l ));
2363
+ const uint16x8_t ph1l = vmull_u8 (vget_low_u8 (v0_1h ), vget_low_u8 (v1_1h ));
2364
+ const uint16x8_t ph1h = vmull_u8 (vget_high_u8 (v0_1h ), vget_high_u8 (v1_1h ));
2293
2365
2294
- sum00 += x0 -> m * y0 -> m ;
2295
- sum01 += y0 -> m * x0 -> d * (vaddvq_u8 (v0_0l ) + vaddvq_u8 (v0_0h ));
2296
- sum10 += x0 -> m * y0 -> d * (vaddvq_u8 (v1_0l ) + vaddvq_u8 (v1_0h ));
2297
- sum11 += x0 -> d * y0 -> d * vaddvq_u16 (vaddq_u16 (pl0 , ph0 ));
2366
+ const uint16x8_t pl_0 = vaddq_u16 (pl0l , pl0h );
2367
+ const uint16x8_t ph_0 = vaddq_u16 (ph0l , ph0h );
2368
+
2369
+ const uint16x8_t pl_1 = vaddq_u16 (pl1l , pl1h );
2370
+ const uint16x8_t ph_1 = vaddq_u16 (ph1l , ph1h );
2371
+
2372
+ const uint16x8_t p_0 = vaddq_u16 (pl_0 , ph_0 );
2373
+ const uint16x8_t p_1 = vaddq_u16 (pl_1 , ph_1 );
2374
+
2375
+ sum11 += x0 -> d * y0 -> d * vaddvq_u16 (p_0 );
2376
+ sum11 += x1 -> d * y1 -> d * vaddvq_u16 (p_1 );
2377
+ #endif
2298
2378
}
2299
2379
2300
2380
sumf = QK * sum00 + sum01 + sum10 + sum11 ;
0 commit comments