Skip to content

Commit 465263d

Browse files
authored
sgemm : AVX Q4_0 and Q8_0 (#6891)
* basic avx implementation * style * combine denibble with load * reduce 256 to 128 (and back!) conversions * sse load * Update sgemm.cpp * oops oops
1 parent 911b390 commit 465263d

File tree

1 file changed

+56
-21
lines changed

1 file changed

+56
-21
lines changed

sgemm.cpp

Lines changed: 56 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,3 @@
1-
// -*- mode:c++;indent-tabs-mode:nil;c-basic-offset:4;coding:utf-8 -*-
2-
// vi: set et ft=c++ ts=4 sts=4 sw=4 fenc=utf-8 :vi
3-
//
41
// Copyright 2024 Mozilla Foundation
52
//
63
// Permission is hereby granted, free of charge, to any person obtaining
@@ -585,15 +582,15 @@ class tinyBLAS_Q0_ARM {
585582
};
586583
#endif // __ARM_FEATURE_DOTPROD
587584

588-
#if defined(__AVX2__) || defined(__AVX512F__)
585+
#if defined(__AVX2__) || defined(__AVX512F__) || defined(__AVX__)
589586
template <typename TA, typename TB, typename TC>
590-
class tinyBLAS_Q0_AVX2 {
587+
class tinyBLAS_Q0_AVX {
591588
public:
592-
tinyBLAS_Q0_AVX2(int64_t k,
593-
const TA *A, int64_t lda,
594-
const TB *B, int64_t ldb,
595-
TC *C, int64_t ldc,
596-
int ith, int nth)
589+
tinyBLAS_Q0_AVX(int64_t k,
590+
const TA *A, int64_t lda,
591+
const TB *B, int64_t ldb,
592+
TC *C, int64_t ldc,
593+
int ith, int nth)
597594
: A(A), B(B), C(C), k(k), lda(lda), ldb(ldb), ldc(ldc), ith(ith), nth(nth) {
598595
}
599596

@@ -728,14 +725,34 @@ class tinyBLAS_Q0_AVX2 {
728725
__m256 Cv[RN][RM] = {};
729726
for (int64_t l = 0; l < k; ++l)
730727
for (int64_t j = 0; j < RN; ++j)
731-
for (int64_t i = 0; i < RM; ++i)
728+
for (int64_t i = 0; i < RM; ++i) {
729+
#if defined(__AVX2__)
730+
__m256 udTmp = updot(_mm256_sign_epi8(load(A + lda * (ii + i) + l),
731+
load(A + lda * (ii + i) + l)),
732+
_mm256_sign_epi8(load(B + ldb * (jj + j) + l),
733+
load(A + lda * (ii + i) + l)));
734+
#else
735+
__m128i ali0 = load0(A + lda * (ii + i) + l);
736+
__m128i ali1 = load1(A + lda * (ii + i) + l);
737+
__m128i blj0 = load0(B + ldb * (jj + j) + l);
738+
__m128i blj1 = load1(B + ldb * (jj + j) + l);
739+
740+
__m128i sepAA0 = _mm_sign_epi8(ali0, ali0);
741+
__m128i sepAA1 = _mm_sign_epi8(ali1, ali1);
742+
__m128i sepBA0 = _mm_sign_epi8(blj0, ali0);
743+
__m128i sepBA1 = _mm_sign_epi8(blj1, ali1);
744+
745+
// updot
746+
const __m128i oneFill = _mm_set1_epi16(1);
747+
__m128i mad0 = _mm_maddubs_epi16(sepAA0, sepBA0);
748+
__m128i mad1 = _mm_maddubs_epi16(sepAA1, sepBA1);
749+
__m256 udTmp = _mm256_cvtepi32_ps(MM256_SET_M128I(_mm_madd_epi16(oneFill, mad1), _mm_madd_epi16(oneFill, mad0)));
750+
#endif
732751
Cv[j][i] = madd(_mm256_set1_ps(unhalf(A[lda * (ii + i) + l].d) *
733752
unhalf(B[ldb * (jj + j) + l].d)),
734-
updot(_mm256_sign_epi8(load(A + lda * (ii + i) + l),
735-
load(A + lda * (ii + i) + l)),
736-
_mm256_sign_epi8(load(B + ldb * (jj + j) + l),
737-
load(A + lda * (ii + i) + l))),
738-
Cv[j][i]);
753+
udTmp,
754+
Cv[j][i]);
755+
}
739756
for (int64_t j = 0; j < RN; ++j)
740757
for (int64_t i = 0; i < RM; ++i)
741758
C[ldc * (jj + j) + (ii + i)] = hsum(Cv[j][i]);
@@ -746,10 +763,28 @@ class tinyBLAS_Q0_AVX2 {
746763
return _mm256_loadu_si256((const __m256i *)b->qs);
747764
}
748765

766+
inline __m128i load0(const block_q8_0 *b) {
767+
return _mm_loadu_si128((const __m128i *)b->qs);
768+
}
769+
770+
inline __m128i load1(const block_q8_0 *b) {
771+
return _mm_loadu_si128(((const __m128i *)b->qs) + 1);
772+
}
773+
749774
inline __m256i load(const block_q4_0 *b) {
750775
return _mm256_sub_epi8(denibble(b->qs), _mm256_set1_epi8(8));
751776
}
752777

778+
inline __m128i load0(const block_q4_0 *b) {
779+
const __m128i x = _mm_loadu_si128((const __m128i *)(b->qs));
780+
return _mm_sub_epi8(_mm_and_si128(_mm_set1_epi8(15), x), _mm_set1_epi8(8));
781+
}
782+
783+
inline __m128i load1(const block_q4_0 *b) {
784+
const __m128i x = _mm_loadu_si128((const __m128i *)(b->qs));
785+
return _mm_sub_epi8(_mm_and_si128(_mm_set1_epi8(15), _mm_srli_epi16(x, 4)), _mm_set1_epi8(8));
786+
}
787+
753788
inline __m256 updot(__m256i u, __m256i s) {
754789
__m256i res;
755790
#if defined(__AVXVNNI__) || (defined(__AVX512VNNI__) && defined(__AVX512VL__))
@@ -777,7 +812,7 @@ class tinyBLAS_Q0_AVX2 {
777812
const int ith;
778813
const int nth;
779814
};
780-
#endif // __AVX2__
815+
#endif // __AVX__
781816

782817
} // namespace
783818

@@ -928,8 +963,8 @@ bool llamafile_sgemm(int64_t m, int64_t n, int64_t k, const void *A, int64_t lda
928963
case GGML_TYPE_Q8_0: {
929964
if (Btype != GGML_TYPE_Q8_0)
930965
return false;
931-
#if defined(__AVX2__) || defined(__AVX512F__)
932-
tinyBLAS_Q0_AVX2<block_q8_0, block_q8_0, float> tb{
966+
#if defined(__AVX2__) || defined(__AVX512F__) || defined(__AVX__)
967+
tinyBLAS_Q0_AVX<block_q8_0, block_q8_0, float> tb{
933968
k, (const block_q8_0 *)A, lda,
934969
(const block_q8_0 *)B, ldb,
935970
(float *)C, ldc,
@@ -952,8 +987,8 @@ bool llamafile_sgemm(int64_t m, int64_t n, int64_t k, const void *A, int64_t lda
952987
case GGML_TYPE_Q4_0: {
953988
if (Btype != GGML_TYPE_Q8_0)
954989
return false;
955-
#if defined(__AVX2__) || defined(__AVX512F__)
956-
tinyBLAS_Q0_AVX2<block_q4_0, block_q8_0, float> tb{
990+
#if defined(__AVX2__) || defined(__AVX512F__) || defined(__AVX__)
991+
tinyBLAS_Q0_AVX<block_q4_0, block_q8_0, float> tb{
957992
k, (const block_q4_0 *)A, lda,
958993
(const block_q8_0 *)B, ldb,
959994
(float *)C, ldc,

0 commit comments

Comments
 (0)