Skip to content

Commit 1b6dce3

Browse files
committed
iq3_s_mult: borrow tricks from Peter Reid for the AVX2 implementation
1 parent b48bf8b commit 1b6dce3

File tree

1 file changed

+14
-12
lines changed

1 file changed

+14
-12
lines changed

ggml-quants.c

Lines changed: 14 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -10074,11 +10074,10 @@ void ggml_vec_dot_iq3_s_q8_K (int n, float * GGML_RESTRICT s, size_t bs, const v
1007410074
const __m256i mask1 = _mm256_loadu_si256((const __m256i*)k_mask1);
1007510075
const __m256i mask2 = _mm256_loadu_si256((const __m256i*)k_mask2);
1007610076

10077-
const __m256i idx_mask = _mm256_set1_epi32(256);
10078-
const __m256i idx_shift = _mm256_set_epi32(1, 2, 3, 4, 5, 6, 7, 8);
1007910077
const __m256i idx_mult = _mm256_set1_epi32(IQ3S_MULTIPLIER);
1008010078
const __m256i m1 = _mm256_set1_epi8(1);
1008110079
const __m256i m15 = _mm256_set1_epi32(0x0f0f0f0f);
10080+
const __m256i m100 = _mm256_set1_epi32(0x0100);
1008210081
#ifdef IQ3S_SLOW_MULT
1008310082
const __m256i m7 = _mm256_set1_epi32(0x07070707);
1008410083
const __m256i m0 = _mm256_setzero_si256();
@@ -10096,23 +10095,26 @@ void ggml_vec_dot_iq3_s_q8_K (int n, float * GGML_RESTRICT s, size_t bs, const v
1009610095
for (int ib32 = 0; ib32 < QK_K/32; ib32 += 2) {
1009710096
const __m256i q8_1 = _mm256_loadu_si256((const __m256i *)q8); q8 += 32;
1009810097
const __m256i q8_2 = _mm256_loadu_si256((const __m256i *)q8); q8 += 32;
10099-
const __m128i idx_l_8 = _mm_loadu_si128((const __m128i*)qs); qs += 16;
10100-
const __m256i idx_l_16 = _mm256_cvtepu8_epi16(idx_l_8);
10101-
const __m256i idx_h_l = _mm256_and_si256(_mm256_sllv_epi32(_mm256_set1_epi32(qh[ib32+0]), idx_shift), idx_mask);
10102-
const __m256i idx_h_h = _mm256_and_si256(_mm256_sllv_epi32(_mm256_set1_epi32(qh[ib32+1]), idx_shift), idx_mask);
10103-
const __m256i idx_32_l = _mm256_or_si256(idx_h_l, _mm256_cvtepi16_epi32(_mm256_castsi256_si128(idx_l_16)));
10104-
const __m256i idx_32_h = _mm256_or_si256(idx_h_h, _mm256_cvtepi16_epi32(_mm256_extractf128_si256(idx_l_16, 1)));
10098+
10099+
const __m256i q3_low_bytes_1 = _mm256_cvtepu8_epi32(_mm_loadl_epi64((const __m128i*)qs)); qs += 8;
10100+
const __m256i q3_low_bytes_2 = _mm256_cvtepu8_epi32(_mm_loadl_epi64((const __m128i*)qs)); qs += 8;
10101+
uint64_t high_bits_spread_1 = ((uint64_t)qh[ib32+0] * 0x0101010101010101ULL) & 0x8040201008040201ULL;
10102+
uint64_t high_bits_spread_2 = ((uint64_t)qh[ib32+1] * 0x0101010101010101ULL) & 0x8040201008040201ULL;
10103+
const __m256i high_bits_in_low_1 = _mm256_cmpgt_epi32(
10104+
_mm256_cvtepu8_epi32(_mm_loadl_epi64((const __m128i*)&high_bits_spread_1)),
10105+
_mm256_setzero_si256());
10106+
const __m256i high_bits_in_low_2 = _mm256_cmpgt_epi32(
10107+
_mm256_cvtepu8_epi32(_mm_loadl_epi64((const __m128i*)&high_bits_spread_2)),
10108+
_mm256_setzero_si256());
10109+
const __m256i idx_32_l = _mm256_or_si256(_mm256_and_si256(m100, high_bits_in_low_1), q3_low_bytes_1);
10110+
const __m256i idx_32_h = _mm256_or_si256(_mm256_and_si256(m100, high_bits_in_low_2), q3_low_bytes_2);
1010510111

1010610112
#ifdef IQ3S_SLOW_MULT
1010710113
const __m256i idx_l = _mm256_max_epi8(_mm256_sub_epi8(_mm256_and_si256(_mm256_mullo_epi32(idx_mult, idx_32_l), m15), m1), m0);
1010810114
const __m256i q2_1 = _mm256_or_si256(_mm256_slli_epi32(_mm256_and_si256(_mm256_srli_epi32(idx_l, 1), m7), 1), m1);
1010910115
const __m256i idx_h = _mm256_max_epi8(_mm256_sub_epi8(_mm256_and_si256(_mm256_mullo_epi32(idx_mult, idx_32_h), m15), m1), m0);
1011010116
const __m256i q2_2 = _mm256_or_si256(_mm256_slli_epi32(_mm256_and_si256(_mm256_srli_epi32(idx_h, 1), m7), 1), m1);
1011110117
#else
10112-
//const __m256i idx_l = _mm256_or_si256(_mm256_and_si256(_mm256_mullo_epi32(idx_mult, idx_32_l), m15), m1);
10113-
//const __m256i q2_1 = _mm256_or_si256(_mm256_slli_epi32(_mm256_and_si256(_mm256_srli_epi32(idx_l, 1), m7), 1), m1);
10114-
//const __m256i idx_h = _mm256_or_si256(_mm256_and_si256(_mm256_mullo_epi32(idx_mult, idx_32_h), m15), m1);
10115-
//const __m256i q2_2 = _mm256_or_si256(_mm256_slli_epi32(_mm256_and_si256(_mm256_srli_epi32(idx_h, 1), m7), 1), m1);
1011610118
const __m256i q2_1 = _mm256_or_si256(_mm256_and_si256(_mm256_mullo_epi32(idx_mult, idx_32_l), m15), m1);
1011710119
const __m256i q2_2 = _mm256_or_si256(_mm256_and_si256(_mm256_mullo_epi32(idx_mult, idx_32_h), m15), m1);
1011810120
#endif

0 commit comments

Comments
 (0)