Skip to content

[CANN]: add the basic supports of Flash Attention kernel #13627

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 18 commits into from
May 26, 2025
Merged
Show file tree
Hide file tree
Changes from 12 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions docs/backend/CANN.md
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -258,6 +258,15 @@ cmake --build build --config release
### **GitHub contribution**:
Please add the **[CANN]** prefix/tag in issues/PRs titles to help the CANN-team check/address them without delay.

## Updates
### Basic Flash Attention Support
The basic FA kernel with aclnnops has been added in aclnn_ops.cpp.
Currently, the FA only supports the cases with FP16 KV tensors and NO logit softcap.
Since the aclnn interface for flash attention cannot support the logit softcap, we will only update the quantized version in the future.

Authors from Peking University: Bizhao Shi (bshi@pku.edu.cn), Yuxin Yang (yxyang@pku.edu.cn), Ruiyang Ma (ruiyang@stu.pku.edu.cn), and Guojie Luo (gluo@pku.edu.cn).

We would like to thank Tuo Dai, Shanni Li, and all of the project maintainers from Huawei Technologies Co., Ltd for their help during the code development and pull request.

## TODO
- Support more models and data types.
Empty file modified ggml/src/ggml-cann/CMakeLists.txt
100644 → 100755
Empty file.
Empty file modified ggml/src/ggml-cann/Doxyfile
100644 → 100755
Empty file.
2 changes: 2 additions & 0 deletions ggml/src/ggml-cann/acl_tensor.cpp
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@ aclDataType ggml_cann_type_mapping(ggml_type type) {
return ACL_FLOAT;
case GGML_TYPE_F16:
return ACL_FLOAT16;
case GGML_TYPE_BF16:
return ACL_BF16;
case GGML_TYPE_I8:
return ACL_INT8;
case GGML_TYPE_I16:
Expand Down
Empty file modified ggml/src/ggml-cann/acl_tensor.h
100644 → 100755
Empty file.
331 changes: 331 additions & 0 deletions ggml/src/ggml-cann/aclnn_ops.cpp
100644 → 100755

Large diffs are not rendered by default.

15 changes: 15 additions & 0 deletions ggml/src/ggml-cann/aclnn_ops.h
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -714,6 +714,21 @@ void ggml_cann_count_equal(ggml_backend_cann_context& ctx, ggml_tensor* dst);
*/
void ggml_cann_step(ggml_backend_cann_context& ctx, ggml_tensor* dst);

/**
* @brief Performs the Flash Attention extended operator using the CANN backend.
*
* @details This function implements the memory-efficient Flash Attention algorithm
* for computing scaled dot-product attention with hardware acceleration.
* The result is stored in the destination tensor `dst`.
*
* This operation is accelerated using the CANN backend to improve runtime performance.
*
* @param ctx The CANN context used for operations.
* @param dst The destination tensor where the result will be stored.
* dst->op is expected to be `GGML_OP_FLASH_ATTN_EXT`.
*/
void ggml_cann_flash_attn_ext(ggml_backend_cann_context& ctx, ggml_tensor* dst);

/*
* @brief A generic wrapper for ACL resources with custom deleter support.
*/
Expand Down
Empty file modified ggml/src/ggml-cann/common.h
100644 → 100755
Empty file.
33 changes: 33 additions & 0 deletions ggml/src/ggml-cann/ggml-cann.cpp
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -1747,6 +1747,9 @@ static bool ggml_cann_compute_forward(ggml_backend_cann_context& ctx,
case GGML_OP_COUNT_EQUAL:
ggml_cann_count_equal(ctx, dst);
break;
case GGML_OP_FLASH_ATTN_EXT:
ggml_cann_flash_attn_ext(ctx, dst);
break;
default:
return false;
}
Expand Down Expand Up @@ -2161,6 +2164,36 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev,
case GGML_OP_PAD_REFLECT_1D:
case GGML_OP_COUNT_EQUAL:
return true;
case GGML_OP_FLASH_ATTN_EXT:{
// copy from [ggml-cuda.cu]
if (op->src[1]->ne[0] != op->src[2]->ne[0]) {
// different head sizes of K and V are not supported yet
return false;
}
if (op->src[0]->ne[0] == 192) {
return false;
}
if (op->src[0]->ne[0] == 576) {
// DeepSeek MLA
return false;
}
if (op->src[0]->ne[3] != 1) {
return false;
}
if (op->src[1]->type == GGML_TYPE_BF16 || op->src[2]->type == GGML_TYPE_BF16) {
return false;
}
if (op->src[0]->ne[0] == 64 && op->src[1]->type == GGML_TYPE_F16) {
return true;
}
if (op->src[0]->ne[0] == 128) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think there's a logical issue here. Currently, it seems that when op->src[0]->ne[0] == 128, the code allows kv to have a data type like q4/q8, implying that this case is supported. However, quantized formats are actually not supported at the moment. I believe the logic should be adjusted accordingly to reflect this:

if (op->src[0]->ne[0] != 128) {
	return false;
}

Could you please help confirm the logic?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We have updated the if-else logic to pass all of the tests.

return true;
}
if (op->src[0]->ne[0] == 256 && op->src[1]->type == GGML_TYPE_F16 && op->src[2]->type == GGML_TYPE_F16) {
return true;
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems that the current FA doesn't support cases where logitSoftcap is not equal to 0, so we should add a check to ensure logitSoftcap equals 0 here, as shown in the code below.

float logitSoftcap = 0.0f;
memcpy(&logitSoftcap,  (float*)op->op_params + 2, sizeof(float));
if(logitSoftcap != 0.0f) {
	return false;
}

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for your comment. We have added it.

return op->src[1]->type == GGML_TYPE_F16 && op->src[2]->type == GGML_TYPE_F16;
}
default:
return false;
}
Expand Down