|
2 | 2 | #include "fattn-common.cuh"
|
3 | 3 |
|
4 | 4 | template<int D, int ncols, ggml_type type_K, ggml_type type_V, bool use_logit_softcap> // D == head size
|
5 |
| -#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) |
| 5 | +#ifndef GGML_USE_HIP |
6 | 6 | __launch_bounds__(D, 1)
|
7 |
| -#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) |
| 7 | +#endif // GGML_USE_HIP |
8 | 8 | static __global__ void flash_attn_vec_ext_f16(
|
9 | 9 | const char * __restrict__ Q,
|
10 | 10 | const char * __restrict__ K,
|
@@ -48,6 +48,12 @@ static __global__ void flash_attn_vec_ext_f16(
|
48 | 48 | NO_DEVICE_CODE;
|
49 | 49 | return;
|
50 | 50 | }
|
| 51 | +#if !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) |
| 52 | + if (ncols > 1) { |
| 53 | + NO_DEVICE_CODE; |
| 54 | + return; |
| 55 | + } |
| 56 | +#endif // !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) |
51 | 57 |
|
52 | 58 | //In this kernel Q, K, V are matrices while i, j, k are matrix indices.
|
53 | 59 |
|
@@ -91,6 +97,13 @@ static __global__ void flash_attn_vec_ext_f16(
|
91 | 97 | kqsum_shared[j][threadIdx.x] = 0.0f;
|
92 | 98 | }
|
93 | 99 | }
|
| 100 | + |
| 101 | + __shared__ half maskh_shared[ncols*D]; |
| 102 | +#pragma unroll |
| 103 | + for (int j = 0; j < ncols; ++j) { |
| 104 | + maskh_shared[j*D + tid] = 0.0f; |
| 105 | + } |
| 106 | + |
94 | 107 | __syncthreads();
|
95 | 108 |
|
96 | 109 | // Convert Q to half2 (f16 K) or q8_1 (quantized K) and store in registers:
|
@@ -175,6 +188,35 @@ static __global__ void flash_attn_vec_ext_f16(
|
175 | 188 | for (int k_VKQ_0 = blockIdx.y*D; k_VKQ_0 < ne11; k_VKQ_0 += gridDim.y*D) {
|
176 | 189 | // Calculate KQ tile and keep track of new maximum KQ values:
|
177 | 190 |
|
| 191 | + if (mask) { |
| 192 | +#pragma unroll |
| 193 | + for (int j = 0; j < ncols; ++j) { |
| 194 | + maskh_shared[j*D + tid] = slopeh*maskh[j*ne11 + k_VKQ_0 + tid]; |
| 195 | + } |
| 196 | + |
| 197 | + __syncthreads(); |
| 198 | + |
| 199 | + // When using multiple parallel sequences in llama.cpp, some KV slices can be fully masked out. |
| 200 | + // In such cases, skip the KV slice. |
| 201 | + // On AMD __all_sync would not work correctly because it assumes a warp size of 64. |
| 202 | +#ifndef GGML_USE_HIP |
| 203 | + bool skip = true; |
| 204 | +#pragma unroll |
| 205 | + for (int j = 0; j < ncols; ++j) { |
| 206 | +#pragma unroll |
| 207 | + for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) { |
| 208 | + const int i = i0 + threadIdx.x; |
| 209 | + |
| 210 | + const float2 tmp = __half22float2(((const half2 *) maskh_shared)[j*(D/2) + i]); |
| 211 | + skip = skip && isinf(tmp.x) && isinf(tmp.y); |
| 212 | + } |
| 213 | + } |
| 214 | + if (__all_sync(0xFFFFFFFF, skip)) { |
| 215 | + continue; |
| 216 | + } |
| 217 | +#endif // GGML_USE_HIP |
| 218 | + } |
| 219 | + |
178 | 220 | // For unknown reasons using a half array of size 1 for kqmax_new causes a performance regression,
|
179 | 221 | // see https://github.com/ggerganov/llama.cpp/pull/7061 .
|
180 | 222 | // Therefore this variable is defined twice but only used once (so that the compiler can optimize out the unused variable).
|
@@ -202,7 +244,7 @@ static __global__ void flash_attn_vec_ext_f16(
|
202 | 244 | sum = logit_softcap*tanhf(sum);
|
203 | 245 | }
|
204 | 246 |
|
205 |
| - sum += mask ? slopeh*maskh[j*ne11 + k_VKQ_0 + i_KQ] : __float2half(0.0f); |
| 247 | + sum += maskh_shared[j*D + i_KQ]; |
206 | 248 |
|
207 | 249 | if (ncols == 1) {
|
208 | 250 | kqmax_new = ggml_cuda_hmax(kqmax_new, sum);
|
@@ -335,7 +377,9 @@ void ggml_cuda_flash_attn_ext_vec_f16_case(ggml_backend_cuda_context & ctx, ggml
|
335 | 377 | float logit_softcap;
|
336 | 378 | memcpy(&logit_softcap, (const float *) KQV->op_params + 2, sizeof(float));
|
337 | 379 |
|
338 |
| - if (Q->ne[1] == 1) { |
| 380 | + const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc; |
| 381 | + |
| 382 | + if (Q->ne[1] == 1 || GGML_CUDA_CC_IS_NVIDIA(cc)) { |
339 | 383 | constexpr int cols_per_block = 1;
|
340 | 384 | if (logit_softcap == 0.0f) {
|
341 | 385 | constexpr bool use_logit_softcap = false;
|
|
0 commit comments