Skip to content

Commit 2d38b6e

Browse files
authored
CANN: Add the basic supports of Flash Attention kernel (#13627)
* cann: add the basic FA support * cann: update the readme * cann: update the FlashAttention with PSEShift * cann: update the input parameters in FA * cann: update the alibi with max_bias * cann: add the constrints of softcap * cann: update the docs CANN.md * cann: update the docs CANN.md * cann: fix typo of CANN.md * cann: add some comments and update the CANN.md * cann: update the CANN.md * cann: update the inner precise for fusedInferAttention * cann: update the constraints of flash_attn_ext on ggml-cann.cpp * cann: clean the whitespace * cann: clean the whitespace * cann: add a new endline
1 parent e121edc commit 2d38b6e

File tree

9 files changed

+392
-0
lines changed

9 files changed

+392
-0
lines changed

docs/backend/CANN.md

100644100755
Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -280,6 +280,15 @@ cmake --build build --config release
280280
### **GitHub contribution**:
281281
Please add the **[CANN]** prefix/tag in issues/PRs titles to help the CANN-team check/address them without delay.
282282

283+
## Updates
284+
### Basic Flash Attention Support
285+
The basic FA kernel with aclnnops has been added in aclnn_ops.cpp.
286+
Currently, the FA only supports the cases with FP16 KV tensors and NO logit softcap.
287+
Since the aclnn interface for flash attention cannot support the logit softcap, we will only update the quantized version in the future.
288+
289+
Authors from Peking University: Bizhao Shi (bshi@pku.edu.cn), Yuxin Yang (yxyang@pku.edu.cn), Ruiyang Ma (ruiyang@stu.pku.edu.cn), and Guojie Luo (gluo@pku.edu.cn).
290+
291+
We would like to thank Tuo Dai, Shanni Li, and all of the project maintainers from Huawei Technologies Co., Ltd for their help during the code development and pull request.
283292

284293
## TODO
285294
- Support more models and data types.

ggml/src/ggml-cann/CMakeLists.txt

100644100755
File mode changed.

ggml/src/ggml-cann/Doxyfile

100644100755
File mode changed.

ggml/src/ggml-cann/acl_tensor.cpp

100644100755
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,8 @@ aclDataType ggml_cann_type_mapping(ggml_type type) {
3131
return ACL_FLOAT;
3232
case GGML_TYPE_F16:
3333
return ACL_FLOAT16;
34+
case GGML_TYPE_BF16:
35+
return ACL_BF16;
3436
case GGML_TYPE_I8:
3537
return ACL_INT8;
3638
case GGML_TYPE_I16:

ggml/src/ggml-cann/acl_tensor.h

100644100755
File mode changed.

ggml/src/ggml-cann/aclnn_ops.cpp

100644100755
Lines changed: 330 additions & 0 deletions
Large diffs are not rendered by default.

ggml/src/ggml-cann/aclnn_ops.h

100644100755
Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -714,6 +714,21 @@ void ggml_cann_count_equal(ggml_backend_cann_context& ctx, ggml_tensor* dst);
714714
*/
715715
void ggml_cann_step(ggml_backend_cann_context& ctx, ggml_tensor* dst);
716716

717+
/**
718+
* @brief Performs the Flash Attention extended operator using the CANN backend.
719+
*
720+
* @details This function implements the memory-efficient Flash Attention algorithm
721+
* for computing scaled dot-product attention with hardware acceleration.
722+
* The result is stored in the destination tensor `dst`.
723+
*
724+
* This operation is accelerated using the CANN backend to improve runtime performance.
725+
*
726+
* @param ctx The CANN context used for operations.
727+
* @param dst The destination tensor where the result will be stored.
728+
* dst->op is expected to be `GGML_OP_FLASH_ATTN_EXT`.
729+
*/
730+
void ggml_cann_flash_attn_ext(ggml_backend_cann_context& ctx, ggml_tensor* dst);
731+
717732
/*
718733
* @brief A generic wrapper for ACL resources with custom deleter support.
719734
*/

ggml/src/ggml-cann/common.h

100644100755
File mode changed.

ggml/src/ggml-cann/ggml-cann.cpp

100644100755
Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636
#include "ggml-backend-impl.h"
3737
#include "ggml-cann/aclnn_ops.h"
3838
#include "ggml-cann/common.h"
39+
#include "ggml.h"
3940

4041
#define GGML_COMMON_DECL_C
4142

@@ -1748,6 +1749,9 @@ static bool ggml_cann_compute_forward(ggml_backend_cann_context& ctx,
17481749
case GGML_OP_COUNT_EQUAL:
17491750
ggml_cann_count_equal(ctx, dst);
17501751
break;
1752+
case GGML_OP_FLASH_ATTN_EXT:
1753+
ggml_cann_flash_attn_ext(ctx, dst);
1754+
break;
17511755
default:
17521756
return false;
17531757
}
@@ -2177,6 +2181,38 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev,
21772181
case GGML_OP_PAD_REFLECT_1D:
21782182
case GGML_OP_COUNT_EQUAL:
21792183
return true;
2184+
case GGML_OP_FLASH_ATTN_EXT:{
2185+
// derived from [ggml-cuda.cu]
2186+
if(op->src[1]->type != GGML_TYPE_F16 || op->src[2]->type != GGML_TYPE_F16){
2187+
return false;
2188+
}
2189+
if(op->src[1]->type != GGML_TYPE_F16 && op->src[1]->type != GGML_TYPE_F32 && op->src[1]->type != GGML_TYPE_BF16){
2190+
return false;
2191+
}
2192+
if(op->type != GGML_TYPE_F16 && op->type != GGML_TYPE_F32 && op->type != GGML_TYPE_BF16){
2193+
return false;
2194+
}
2195+
if (op->src[1]->ne[0] != op->src[2]->ne[0]) {
2196+
// different head sizes of K and V are not supported yet
2197+
return false;
2198+
}
2199+
if (op->src[0]->ne[0] == 192) {
2200+
return false;
2201+
}
2202+
if (op->src[0]->ne[0] == 576) {
2203+
// DeepSeek MLA
2204+
return false;
2205+
}
2206+
if (op->src[0]->ne[3] != 1) {
2207+
return false;
2208+
}
2209+
float logitSoftcap = 0.0f;
2210+
memcpy(&logitSoftcap, (float*)op->op_params + 2, sizeof(float));
2211+
if(logitSoftcap != 0.0f) {
2212+
return false;
2213+
}
2214+
return true;
2215+
}
21802216
default:
21812217
return false;
21822218
}

0 commit comments

Comments
 (0)