@@ -10074,11 +10074,10 @@ void ggml_vec_dot_iq3_s_q8_K (int n, float * GGML_RESTRICT s, size_t bs, const v
10074
10074
const __m256i mask1 = _mm256_loadu_si256((const __m256i*)k_mask1);
10075
10075
const __m256i mask2 = _mm256_loadu_si256((const __m256i*)k_mask2);
10076
10076
10077
- const __m256i idx_mask = _mm256_set1_epi32(256);
10078
- const __m256i idx_shift = _mm256_set_epi32(1, 2, 3, 4, 5, 6, 7, 8);
10079
10077
const __m256i idx_mult = _mm256_set1_epi32(IQ3S_MULTIPLIER);
10080
10078
const __m256i m1 = _mm256_set1_epi8(1);
10081
10079
const __m256i m15 = _mm256_set1_epi32(0x0f0f0f0f);
10080
+ const __m256i m100 = _mm256_set1_epi32(0x0100);
10082
10081
#ifdef IQ3S_SLOW_MULT
10083
10082
const __m256i m7 = _mm256_set1_epi32(0x07070707);
10084
10083
const __m256i m0 = _mm256_setzero_si256();
@@ -10096,23 +10095,26 @@ void ggml_vec_dot_iq3_s_q8_K (int n, float * GGML_RESTRICT s, size_t bs, const v
10096
10095
for (int ib32 = 0; ib32 < QK_K/32; ib32 += 2) {
10097
10096
const __m256i q8_1 = _mm256_loadu_si256((const __m256i *)q8); q8 += 32;
10098
10097
const __m256i q8_2 = _mm256_loadu_si256((const __m256i *)q8); q8 += 32;
10099
- const __m128i idx_l_8 = _mm_loadu_si128((const __m128i*)qs); qs += 16;
10100
- const __m256i idx_l_16 = _mm256_cvtepu8_epi16(idx_l_8);
10101
- const __m256i idx_h_l = _mm256_and_si256(_mm256_sllv_epi32(_mm256_set1_epi32(qh[ib32+0]), idx_shift), idx_mask);
10102
- const __m256i idx_h_h = _mm256_and_si256(_mm256_sllv_epi32(_mm256_set1_epi32(qh[ib32+1]), idx_shift), idx_mask);
10103
- const __m256i idx_32_l = _mm256_or_si256(idx_h_l, _mm256_cvtepi16_epi32(_mm256_castsi256_si128(idx_l_16)));
10104
- const __m256i idx_32_h = _mm256_or_si256(idx_h_h, _mm256_cvtepi16_epi32(_mm256_extractf128_si256(idx_l_16, 1)));
10098
+
10099
+ const __m256i q3_low_bytes_1 = _mm256_cvtepu8_epi32(_mm_loadl_epi64((const __m128i*)qs)); qs += 8;
10100
+ const __m256i q3_low_bytes_2 = _mm256_cvtepu8_epi32(_mm_loadl_epi64((const __m128i*)qs)); qs += 8;
10101
+ uint64_t high_bits_spread_1 = ((uint64_t)qh[ib32+0] * 0x0101010101010101ULL) & 0x8040201008040201ULL;
10102
+ uint64_t high_bits_spread_2 = ((uint64_t)qh[ib32+1] * 0x0101010101010101ULL) & 0x8040201008040201ULL;
10103
+ const __m256i high_bits_in_low_1 = _mm256_cmpgt_epi32(
10104
+ _mm256_cvtepu8_epi32(_mm_loadl_epi64((const __m128i*)&high_bits_spread_1)),
10105
+ _mm256_setzero_si256());
10106
+ const __m256i high_bits_in_low_2 = _mm256_cmpgt_epi32(
10107
+ _mm256_cvtepu8_epi32(_mm_loadl_epi64((const __m128i*)&high_bits_spread_2)),
10108
+ _mm256_setzero_si256());
10109
+ const __m256i idx_32_l = _mm256_or_si256(_mm256_and_si256(m100, high_bits_in_low_1), q3_low_bytes_1);
10110
+ const __m256i idx_32_h = _mm256_or_si256(_mm256_and_si256(m100, high_bits_in_low_2), q3_low_bytes_2);
10105
10111
10106
10112
#ifdef IQ3S_SLOW_MULT
10107
10113
const __m256i idx_l = _mm256_max_epi8(_mm256_sub_epi8(_mm256_and_si256(_mm256_mullo_epi32(idx_mult, idx_32_l), m15), m1), m0);
10108
10114
const __m256i q2_1 = _mm256_or_si256(_mm256_slli_epi32(_mm256_and_si256(_mm256_srli_epi32(idx_l, 1), m7), 1), m1);
10109
10115
const __m256i idx_h = _mm256_max_epi8(_mm256_sub_epi8(_mm256_and_si256(_mm256_mullo_epi32(idx_mult, idx_32_h), m15), m1), m0);
10110
10116
const __m256i q2_2 = _mm256_or_si256(_mm256_slli_epi32(_mm256_and_si256(_mm256_srli_epi32(idx_h, 1), m7), 1), m1);
10111
10117
#else
10112
- //const __m256i idx_l = _mm256_or_si256(_mm256_and_si256(_mm256_mullo_epi32(idx_mult, idx_32_l), m15), m1);
10113
- //const __m256i q2_1 = _mm256_or_si256(_mm256_slli_epi32(_mm256_and_si256(_mm256_srli_epi32(idx_l, 1), m7), 1), m1);
10114
- //const __m256i idx_h = _mm256_or_si256(_mm256_and_si256(_mm256_mullo_epi32(idx_mult, idx_32_h), m15), m1);
10115
- //const __m256i q2_2 = _mm256_or_si256(_mm256_slli_epi32(_mm256_and_si256(_mm256_srli_epi32(idx_h, 1), m7), 1), m1);
10116
10118
const __m256i q2_1 = _mm256_or_si256(_mm256_and_si256(_mm256_mullo_epi32(idx_mult, idx_32_l), m15), m1);
10117
10119
const __m256i q2_2 = _mm256_or_si256(_mm256_and_si256(_mm256_mullo_epi32(idx_mult, idx_32_h), m15), m1);
10118
10120
#endif
0 commit comments