From 72df31dfc215430a3aca12bc7e13a2b9acd63e8a Mon Sep 17 00:00:00 2001 From: Bizhao Shi Date: Wed, 14 May 2025 22:40:53 +0800 Subject: [PATCH 01/16] cann: add the basic FA support --- ggml/src/ggml-cann/aclnn_ops.cpp | 160 +++++++++++++++++++++++++++++++ ggml/src/ggml-cann/aclnn_ops.h | 17 ++++ ggml/src/ggml-cann/ggml-cann.cpp | 33 +++++++ 3 files changed, 210 insertions(+) diff --git a/ggml/src/ggml-cann/aclnn_ops.cpp b/ggml/src/ggml-cann/aclnn_ops.cpp index 67c0223c010a1..c1b7fc1409efb 100644 --- a/ggml/src/ggml-cann/aclnn_ops.cpp +++ b/ggml/src/ggml-cann/aclnn_ops.cpp @@ -2587,3 +2587,163 @@ void ggml_cann_step(ggml_backend_cann_context& ctx, ggml_tensor* dst){ ggml_cann_release_resources(ctx, acl_src, acl_dst, alpha); } + +void ggml_cann_flash_attn_ext(ggml_backend_cann_context& ctx, ggml_tensor* dst){ + + ggml_tensor* src0 = dst->src[0]; // q, fp32 + ggml_tensor* src1 = dst->src[1]; // k, fp16 + ggml_tensor* src2 = dst->src[2]; // v, fp16 + ggml_tensor* src3 = dst->src[3]; // mask, fp16 + + size_t faElemSize = sizeof(uint16_t); + + // Step 1: cast the src0 (Query) to fp16 + aclTensor* acl_src0_f16_tensor = nullptr; + + ggml_cann_pool_alloc src0_f16_allocator(ctx.pool()); + void* src0_f16_buffer = nullptr; + + if(src0->type != GGML_TYPE_F16){ + aclTensor* acl_src0_f32_tensor = ggml_cann_create_tensor(src0); + + src0_f16_allocator.alloc(ggml_nelements(src0) * faElemSize); + src0_f16_buffer = src0_f16_allocator.get(); + + int64_t* src0_f16_ne = src0->ne; + size_t src0_f16_nb[GGML_MAX_DIMS]; + src0_f16_nb[0] = sizeof(uint16_t); + for(int i = 1; i < GGML_MAX_DIMS; ++i){ + src0_f16_nb[i] = src0_f16_nb[i - 1] * src0_f16_ne[i - 1]; + } + + acl_src0_f16_tensor = ggml_cann_create_tensor( + src0_f16_buffer, ACL_FLOAT16, faElemSize, + src0_f16_ne, src0_f16_nb, GGML_MAX_DIMS + ); + aclnn_cast(ctx, acl_src0_f32_tensor, acl_src0_f16_tensor, ACL_FLOAT16); + ggml_cann_release_resources(ctx, acl_src0_f32_tensor); + }else{ + acl_src0_f16_tensor = ggml_cann_create_tensor(src0); + } + + // Step 2: genetates mask with ACL_BOOL + size_t maskElemSize = sizeof(char); + ggml_cann_pool_alloc src3_bool_allocator(ctx.pool()); + src3_bool_allocator.alloc(ggml_nelements(src3) * maskElemSize); + void* src3_bool_buffer = src3_bool_allocator.get(); + + int64_t* src3_bool_ne = src3->ne; + size_t src3_bool_nb[GGML_MAX_DIMS]; + src3_bool_nb[0] = maskElemSize; + for(int i = 1; i < GGML_MAX_DIMS; ++i){ + src3_bool_nb[i] = src3_bool_nb[i - 1] * src3_bool_ne[i - 1]; + } + + aclTensor* acl_mask_f16_tensor = ggml_cann_create_tensor(src3); + aclTensor* acl_mask_bool_tensor = ggml_cann_create_tensor( + src3_bool_buffer, ACL_BOOL, maskElemSize, + src3_bool_ne, src3_bool_nb, GGML_MAX_DIMS); + + GGML_CANN_CALL_ACLNN_OP(ctx, IsNegInf, acl_mask_f16_tensor, acl_mask_bool_tensor); + ggml_cann_release_resources(ctx, acl_mask_f16_tensor); + + // Step 3: generates the output tensor directly from FA kernel + ggml_cann_pool_alloc out_f16_allocator(ctx.pool()); + out_f16_allocator.alloc(ggml_nelements(dst) * faElemSize); + void* out_f16_buffer = out_f16_allocator.get(); + + int64_t* out_f16_ne = src0->ne; + size_t out_f16_nb[GGML_MAX_DIMS]; + out_f16_nb[0] = faElemSize; + for(int i = 1; i < GGML_MAX_DIMS; ++i){ + out_f16_nb[i] = out_f16_nb[i - 1] * out_f16_ne[i - 1]; + } + + aclTensor* acl_out_f16_tensor = ggml_cann_create_tensor( + out_f16_buffer, ACL_FLOAT16, faElemSize, + out_f16_ne, out_f16_nb, GGML_MAX_DIMS + ); + + // Step 4: Performs the f16 Flash Attention kernel + + int kvTensorNum = 1; + aclTensor* acl_q_tensor = acl_src0_f16_tensor; + aclTensor* acl_k_tensors[] = {ggml_cann_create_tensor(src1)}; + aclTensor* acl_v_tensors[] = {ggml_cann_create_tensor(src2)}; + auto acl_k_tensor_list = aclCreateTensorList(acl_k_tensors, kvTensorNum); + auto acl_v_tensor_list = aclCreateTensorList(acl_v_tensors, kvTensorNum); + aclTensor* acl_out_tensor = acl_out_f16_tensor; + + + int64_t numHeads = src0->ne[2]; // N + int64_t numKeyValueHeads = src1->ne[2]; + double scaleValue = 1 / sqrt(src0->ne[0]); // 1/sqrt(d) + int64_t preTokens = 65535; + int64_t nextTokens = 65535; + char layout[5] = {'B', 'N', 'S', 'D', 0}; + int64_t sparseMode = 0; + int64_t innerPrecise = 1; + int64_t blockSize = 0; + int64_t antiquantMode = 0; + bool softmaxLseFlag = false; + int64_t keyAntiquantMode = 0; + int64_t valueAntiquantMode = 0; + + // Refer to https://gitee.com/ascend/cann-ops-adv/blob/master/docs/FusedInferAttentionScoreV2.md + + GGML_CANN_CALL_ACLNN_OP(ctx, FusedInferAttentionScoreV2, + acl_q_tensor, acl_k_tensor_list, acl_v_tensor_list, // q, k, v + nullptr, acl_mask_bool_tensor, // pse, mask + nullptr, nullptr, // actSeqLen, actSeqLenkv + nullptr, nullptr, // deqScale1, quantScale1 + nullptr, nullptr, nullptr, // deqScale2, quantScale2, quantOffset2 + nullptr, nullptr, // antiquantScale, antiquantOffset + nullptr, // blockTable + nullptr, nullptr, // qPadSize, kvPadSize + nullptr, nullptr, // kAntiquantScale, kAntiQuantOffset + nullptr, nullptr, // vAntiquantScale, vAntiQuantOffset + nullptr, nullptr, nullptr, // kSharedPrefix, vSharedPrefix, actSharedLen + numHeads, scaleValue, // heads, scaleValue + preTokens, nextTokens, // preTokens, nextTokens + layout, // inputLayout + numKeyValueHeads, // numKVHeads + sparseMode, innerPrecise, // sparseMode, innerPrecise + blockSize, antiquantMode, // blockSize, antiquantMode + softmaxLseFlag, // softmaxLseFlag + keyAntiquantMode, valueAntiquantMode, // keyAntiqMode, valueAntiqMode + acl_out_tensor, // attentionOut + nullptr // softmaxLse + ); + + // Step 5: post-processing, permute and cast to f32 + int64_t new_dim[] = {0, 2, 1, 3}; + aclTensor* acl_dst_tensor = ggml_cann_create_tensor(dst); + + if(dst->type != GGML_TYPE_F16){ + ggml_cann_pool_alloc perm_out_f16_allocator(ctx.pool()); + perm_out_f16_allocator.alloc(ggml_nelements(dst) * faElemSize); + void* perm_out_f16_buffer = perm_out_f16_allocator.get(); + + int64_t* perm_out_f16_ne = dst->ne; + size_t perm_out_f16_nb[GGML_MAX_DIMS]; + perm_out_f16_nb[0] = faElemSize; + for(int i = 1; i < GGML_MAX_DIMS; ++i){ + perm_out_f16_nb[i] = perm_out_f16_nb[i - 1] * perm_out_f16_ne[i - 1]; + } + aclTensor* acl_perm_out_f16_tensor = ggml_cann_create_tensor( + perm_out_f16_buffer, ACL_FLOAT16, faElemSize, + perm_out_f16_ne, perm_out_f16_nb, GGML_MAX_DIMS); + aclnn_permute(ctx, acl_out_tensor, acl_perm_out_f16_tensor, new_dim, GGML_MAX_DIMS); + aclnn_cast(ctx, + acl_perm_out_f16_tensor, acl_dst_tensor, ggml_cann_type_mapping(dst->type)); + ggml_cann_release_resources(ctx, acl_perm_out_f16_tensor); + }else{ + // only need to permute + aclnn_permute(ctx, acl_out_tensor, acl_dst_tensor, new_dim, GGML_MAX_DIMS); + } + + ggml_cann_release_resources(ctx, acl_q_tensor, acl_k_tensor_list, acl_v_tensor_list, + acl_mask_bool_tensor, acl_out_f16_tensor, + acl_dst_tensor); + +} \ No newline at end of file diff --git a/ggml/src/ggml-cann/aclnn_ops.h b/ggml/src/ggml-cann/aclnn_ops.h index 462351542e546..a4fedc29cb680 100644 --- a/ggml/src/ggml-cann/aclnn_ops.h +++ b/ggml/src/ggml-cann/aclnn_ops.h @@ -45,6 +45,8 @@ #include #include #include +#include +#include #include "acl_tensor.h" #include "common.h" @@ -714,6 +716,21 @@ void ggml_cann_count_equal(ggml_backend_cann_context& ctx, ggml_tensor* dst); */ void ggml_cann_step(ggml_backend_cann_context& ctx, ggml_tensor* dst); +/** + * @brief Performs the Flash Attention extended operator using the CANN backend. + * + * @details This function implements the memory-efficient Flash Attention algorithm + * for computing scaled dot-product attention with hardware acceleration. + * The result is stored in the destination tensor `dst`. + * + * This operation is accelerated using the CANN backend to improve runtime performance. + * + * @param ctx The CANN context used for operations. + * @param dst The destination tensor where the result will be stored. + * dst->op is expected to be `GGML_OP_FLASH_ATTN_EXT`. + */ +void ggml_cann_flash_attn_ext(ggml_backend_cann_context& ctx, ggml_tensor* dst); + /* * @brief A generic wrapper for ACL resources with custom deleter support. */ diff --git a/ggml/src/ggml-cann/ggml-cann.cpp b/ggml/src/ggml-cann/ggml-cann.cpp index e2617b06e9c39..f4fd563556c9b 100644 --- a/ggml/src/ggml-cann/ggml-cann.cpp +++ b/ggml/src/ggml-cann/ggml-cann.cpp @@ -1747,6 +1747,9 @@ static bool ggml_cann_compute_forward(ggml_backend_cann_context& ctx, case GGML_OP_COUNT_EQUAL: ggml_cann_count_equal(ctx, dst); break; + case GGML_OP_FLASH_ATTN_EXT: + ggml_cann_flash_attn_ext(ctx, dst); + break; default: return false; } @@ -2161,6 +2164,36 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev, case GGML_OP_PAD_REFLECT_1D: case GGML_OP_COUNT_EQUAL: return true; + case GGML_OP_FLASH_ATTN_EXT:{ + // copy from [ggml-cuda.cu] + if (op->src[1]->ne[0] != op->src[2]->ne[0]) { + // different head sizes of K and V are not supported yet + return false; + } + if (op->src[0]->ne[0] == 192) { + return false; + } + if (op->src[0]->ne[0] == 576) { + // DeepSeek MLA + return false; + } + if (op->src[0]->ne[3] != 1) { + return false; + } + if (op->src[1]->type == GGML_TYPE_BF16 || op->src[2]->type == GGML_TYPE_BF16) { + return false; + } + if (op->src[0]->ne[0] == 64 && op->src[1]->type == GGML_TYPE_F16) { + return true; + } + if (op->src[0]->ne[0] == 128) { + return true; + } + if (op->src[0]->ne[0] == 256 && op->src[1]->type == GGML_TYPE_F16 && op->src[2]->type == GGML_TYPE_F16) { + return true; + } + return op->src[1]->type == GGML_TYPE_F16 && op->src[2]->type == GGML_TYPE_F16; + } default: return false; } From 3a731825fc78932d16d005e3fc0e2691d11b5a4d Mon Sep 17 00:00:00 2001 From: Bizhao Shi Date: Thu, 15 May 2025 14:16:41 +0800 Subject: [PATCH 02/16] cann: update the readme --- docs/backend/CANN.md | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/docs/backend/CANN.md b/docs/backend/CANN.md index 23f10175a6b2d..9bd2a9127eee6 100644 --- a/docs/backend/CANN.md +++ b/docs/backend/CANN.md @@ -258,6 +258,11 @@ cmake --build build --config release ### **GitHub contribution**: Please add the **[CANN]** prefix/tag in issues/PRs titles to help the CANN-team check/address them without delay. +### Basic Flash Attention Support +The basic FA kernel with aclnnops has been added in aclnn_ops.cpp. + +Authors: Bizhao Shi (bshi@pku.edu.cn), Yuxin Yang (yxyang@pku.edu.cn), Ruiyang Ma (ruiyang@stu.pku.edu.cn), and Guojie Luo (gluo@pku.edu.cn). + ## TODO - Support more models and data types. From 6a39d6382858b0a26428fc42b3792c5e8363e17e Mon Sep 17 00:00:00 2001 From: Bizhao Shi Date: Sun, 18 May 2025 22:38:51 +0800 Subject: [PATCH 03/16] cann: update the FlashAttention with PSEShift --- ggml/src/ggml-cann/CMakeLists.txt | 0 ggml/src/ggml-cann/Doxyfile | 0 ggml/src/ggml-cann/acl_tensor.cpp | 2 + ggml/src/ggml-cann/acl_tensor.h | 0 ggml/src/ggml-cann/aclnn_ops.cpp | 614 +++++++++++++++++++++++++++--- ggml/src/ggml-cann/aclnn_ops.h | 0 ggml/src/ggml-cann/common.h | 0 ggml/src/ggml-cann/ggml-cann.cpp | 0 ggml/src/ggml-cann/ifa.py | 43 +++ 9 files changed, 609 insertions(+), 50 deletions(-) mode change 100644 => 100755 ggml/src/ggml-cann/CMakeLists.txt mode change 100644 => 100755 ggml/src/ggml-cann/Doxyfile mode change 100644 => 100755 ggml/src/ggml-cann/acl_tensor.cpp mode change 100644 => 100755 ggml/src/ggml-cann/acl_tensor.h mode change 100644 => 100755 ggml/src/ggml-cann/aclnn_ops.cpp mode change 100644 => 100755 ggml/src/ggml-cann/aclnn_ops.h mode change 100644 => 100755 ggml/src/ggml-cann/common.h mode change 100644 => 100755 ggml/src/ggml-cann/ggml-cann.cpp create mode 100644 ggml/src/ggml-cann/ifa.py diff --git a/ggml/src/ggml-cann/CMakeLists.txt b/ggml/src/ggml-cann/CMakeLists.txt old mode 100644 new mode 100755 diff --git a/ggml/src/ggml-cann/Doxyfile b/ggml/src/ggml-cann/Doxyfile old mode 100644 new mode 100755 diff --git a/ggml/src/ggml-cann/acl_tensor.cpp b/ggml/src/ggml-cann/acl_tensor.cpp old mode 100644 new mode 100755 index f5462c5a18e37..f311864d486f7 --- a/ggml/src/ggml-cann/acl_tensor.cpp +++ b/ggml/src/ggml-cann/acl_tensor.cpp @@ -31,6 +31,8 @@ aclDataType ggml_cann_type_mapping(ggml_type type) { return ACL_FLOAT; case GGML_TYPE_F16: return ACL_FLOAT16; + case GGML_TYPE_BF16: + return ACL_BF16; case GGML_TYPE_I8: return ACL_INT8; case GGML_TYPE_I16: diff --git a/ggml/src/ggml-cann/acl_tensor.h b/ggml/src/ggml-cann/acl_tensor.h old mode 100644 new mode 100755 diff --git a/ggml/src/ggml-cann/aclnn_ops.cpp b/ggml/src/ggml-cann/aclnn_ops.cpp old mode 100644 new mode 100755 index c1b7fc1409efb..cef33ee71e8ba --- a/ggml/src/ggml-cann/aclnn_ops.cpp +++ b/ggml/src/ggml-cann/aclnn_ops.cpp @@ -72,6 +72,15 @@ #include #include +#include +#include +#include +#include + +#include "aclnnop/aclnn_flash_attention_score.h" +#include "aclnnop/aclnn_logical_not.h" + +#include "ggml-cann/acl_tensor.h" #include "ggml-impl.h" #define GGML_COMMON_DECL_C @@ -2588,26 +2597,108 @@ void ggml_cann_step(ggml_backend_cann_context& ctx, ggml_tensor* dst){ ggml_cann_release_resources(ctx, acl_src, acl_dst, alpha); } +#ifdef DEBUG +static int cnt = 0; + +static void PrintOutResultShort(int64_t ne[GGML_MAX_DIMS], void** deviceAddr, std::string s) { + auto size = ne[0] * ne[1] * ne[2] * ne[3]; + std::vector resultData(size, 0); + auto ret = aclrtMemcpy(resultData.data(), resultData.size() * sizeof(resultData[0]), + *deviceAddr, size * sizeof(resultData[0]), ACL_MEMCPY_DEVICE_TO_HOST); + // CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("copy result from device to host failed. ERROR: %d\n", ret); return); + + // 打开文件 + std::string filename = "output_acl_short_" + std::to_string(cnt) + "_" + s + ".txt"; + cnt++; + std::ofstream outFile(filename); + + // 将数据写入文件 + for(size_t i = 0; i < size; ++i){ + outFile << GGML_FP16_TO_FP32(resultData[i]) << " "; + if(i > 0 && i % ne[0] == 0){ + outFile << "\n"; + } + } + outFile << std::endl << std::endl; + // 关闭文件 + outFile.close(); +} + + +static void PrintOutResultChar(int64_t ne[GGML_MAX_DIMS], void** deviceAddr, std::string s) { + auto size = ne[0] * ne[1] * ne[2] * ne[3]; + std::vector resultData(size, 0); + auto ret = aclrtMemcpy(resultData.data(), resultData.size() * sizeof(resultData[0]), + *deviceAddr, size * sizeof(resultData[0]), ACL_MEMCPY_DEVICE_TO_HOST); + // CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("copy result from device to host failed. ERROR: %d\n", ret); return); + + // 打开文件 + std::string filename = "output_acl_char_" + std::to_string(cnt) + "_" + s + ".txt"; + cnt++; + std::ofstream outFile(filename); + + // 将数据写入文件 + for(size_t i = 0; i < size; ++i){ + outFile << int(resultData[i]) << " "; + if(i > 0 && i % ne[0] == 0){ + outFile << "\n"; + } + } + outFile << std::endl << std::endl; + // 关闭文件 + outFile.close(); +} + +static void PrintOutResultFloat(int64_t ne[GGML_MAX_DIMS], void** deviceAddr, std::string s) { + auto size = ne[0] * ne[1] * ne[2] * ne[3]; + std::vector resultData(size, 0); + auto ret = aclrtMemcpy(resultData.data(), resultData.size() * sizeof(resultData[0]), + *deviceAddr, size * sizeof(resultData[0]), ACL_MEMCPY_DEVICE_TO_HOST); + // CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("copy result from device to host failed. ERROR: %d\n", ret); return); + + // 打开文件 + std::string filename = "output_acl_short_" + std::to_string(cnt) + "_" + s + ".txt"; + cnt++; + std::ofstream outFile(filename); + + // 将数据写入文件 + for(size_t i = 0; i < size; ++i){ + outFile << float(resultData[i]) << " "; + if(i > 0 && i % ne[0] == 0){ + outFile << "\n"; + } + } + outFile << std::endl << std::endl; + // 关闭文件 + outFile.close(); +} + + +#endif + void ggml_cann_flash_attn_ext(ggml_backend_cann_context& ctx, ggml_tensor* dst){ - + ggml_tensor* src0 = dst->src[0]; // q, fp32 ggml_tensor* src1 = dst->src[1]; // k, fp16 ggml_tensor* src2 = dst->src[2]; // v, fp16 ggml_tensor* src3 = dst->src[3]; // mask, fp16 size_t faElemSize = sizeof(uint16_t); - - // Step 1: cast the src0 (Query) to fp16 + auto faDataType = ACL_FLOAT16; //ACL_BF16; + aclTensor* acl_src0_f16_tensor = nullptr; - + aclTensor* acl_src1_f16_tensor = nullptr; + aclTensor* acl_src2_f16_tensor = nullptr; + aclTensor* acl_src3_f16_tensor = nullptr; + aclTensor* acl_dst_f16_tensor = nullptr; + + // Step 1: cast the src0 (Query) to fp16 ggml_cann_pool_alloc src0_f16_allocator(ctx.pool()); void* src0_f16_buffer = nullptr; - if(src0->type != GGML_TYPE_F16){ + if(ggml_cann_type_mapping(src0->type) != faDataType){ aclTensor* acl_src0_f32_tensor = ggml_cann_create_tensor(src0); - - src0_f16_allocator.alloc(ggml_nelements(src0) * faElemSize); - src0_f16_buffer = src0_f16_allocator.get(); + src0_f16_buffer = src0_f16_allocator.alloc(ggml_nelements(src0) * faElemSize); int64_t* src0_f16_ne = src0->ne; size_t src0_f16_nb[GGML_MAX_DIMS]; @@ -2617,40 +2708,26 @@ void ggml_cann_flash_attn_ext(ggml_backend_cann_context& ctx, ggml_tensor* dst){ } acl_src0_f16_tensor = ggml_cann_create_tensor( - src0_f16_buffer, ACL_FLOAT16, faElemSize, + src0_f16_buffer, faDataType, faElemSize, src0_f16_ne, src0_f16_nb, GGML_MAX_DIMS ); - aclnn_cast(ctx, acl_src0_f32_tensor, acl_src0_f16_tensor, ACL_FLOAT16); + aclnn_cast(ctx, acl_src0_f32_tensor, acl_src0_f16_tensor, faDataType); ggml_cann_release_resources(ctx, acl_src0_f32_tensor); }else{ acl_src0_f16_tensor = ggml_cann_create_tensor(src0); } - // Step 2: genetates mask with ACL_BOOL - size_t maskElemSize = sizeof(char); - ggml_cann_pool_alloc src3_bool_allocator(ctx.pool()); - src3_bool_allocator.alloc(ggml_nelements(src3) * maskElemSize); - void* src3_bool_buffer = src3_bool_allocator.get(); - - int64_t* src3_bool_ne = src3->ne; - size_t src3_bool_nb[GGML_MAX_DIMS]; - src3_bool_nb[0] = maskElemSize; - for(int i = 1; i < GGML_MAX_DIMS; ++i){ - src3_bool_nb[i] = src3_bool_nb[i - 1] * src3_bool_ne[i - 1]; - } - - aclTensor* acl_mask_f16_tensor = ggml_cann_create_tensor(src3); - aclTensor* acl_mask_bool_tensor = ggml_cann_create_tensor( - src3_bool_buffer, ACL_BOOL, maskElemSize, - src3_bool_ne, src3_bool_nb, GGML_MAX_DIMS); + acl_src1_f16_tensor = ggml_cann_create_tensor(src1); + acl_src2_f16_tensor = ggml_cann_create_tensor(src2); - GGML_CANN_CALL_ACLNN_OP(ctx, IsNegInf, acl_mask_f16_tensor, acl_mask_bool_tensor); - ggml_cann_release_resources(ctx, acl_mask_f16_tensor); +#ifdef DEBUG + PrintOutResultShort(src0->ne, &(src0_f16_buffer), "q-1"); + PrintOutResultShort(src1->ne, &(src1->data), "k-1"); + PrintOutResultShort(src2->ne, &(src2->data), "v-1"); +#endif - // Step 3: generates the output tensor directly from FA kernel ggml_cann_pool_alloc out_f16_allocator(ctx.pool()); - out_f16_allocator.alloc(ggml_nelements(dst) * faElemSize); - void* out_f16_buffer = out_f16_allocator.get(); + void* out_f16_buffer = out_f16_allocator.alloc(ggml_nelements(dst) * faElemSize); int64_t* out_f16_ne = src0->ne; size_t out_f16_nb[GGML_MAX_DIMS]; @@ -2659,21 +2736,134 @@ void ggml_cann_flash_attn_ext(ggml_backend_cann_context& ctx, ggml_tensor* dst){ out_f16_nb[i] = out_f16_nb[i - 1] * out_f16_ne[i - 1]; } - aclTensor* acl_out_f16_tensor = ggml_cann_create_tensor( - out_f16_buffer, ACL_FLOAT16, faElemSize, + acl_dst_f16_tensor = ggml_cann_create_tensor( + out_f16_buffer, faDataType, faElemSize, out_f16_ne, out_f16_nb, GGML_MAX_DIMS ); - // Step 4: Performs the f16 Flash Attention kernel +#ifdef DEBUG + PrintOutResultShort(src0->ne, &(src0_f16_buffer), "q-2"); + PrintOutResultShort(src1->ne, &(src1->data), "k-2"); + PrintOutResultShort(src2->ne, &(src2->data), "v-2"); + PrintOutResultShort(src0->ne, &(out_f16_buffer), "out-2"); +#endif + + aclTensor* acl_mask_f16_tensor = nullptr; + aclTensor* acl_mask_bool_tensor = nullptr; + aclTensor* bcast_pse_tensor = nullptr; + + int64_t bcast_pse_ne[GGML_MAX_DIMS]; + size_t bcast_pse_nb[GGML_MAX_DIMS]; + ggml_cann_pool_alloc bcast_pse_allocator(ctx.pool()); + void* bcast_pse_buffer = nullptr; + if(src3) + bcast_pse_buffer = + bcast_pse_allocator.alloc(ggml_nelements(src3) * src0->ne[2] * sizeof(uint16_t)); + + if(src3 != nullptr){ +#ifdef DEBUG + PrintOutResultShort(src3->ne, &(src3->data), "mask"); +#endif + // size_t maskElemSize = sizeof(char); + + // ggml_cann_pool_alloc src3_bool_allocator(ctx.pool()); + // void* src3_bool_buffer = src3_bool_allocator.alloc(ggml_nelements(src3) * maskElemSize); + // int64_t* src3_bool_ne = src3->ne; + // size_t src3_bool_nb[GGML_MAX_DIMS]; + // src3_bool_nb[0] = maskElemSize; + // for(int i = 1; i < GGML_MAX_DIMS; ++i){ + // src3_bool_nb[i] = src3_bool_nb[i - 1] * src3_bool_ne[i - 1]; + // } + + + // acl_mask_bool_tensor = ggml_cann_create_tensor( + // src3_bool_buffer, ACL_BOOL, maskElemSize, + // src3_bool_ne, src3_bool_nb, GGML_MAX_DIMS); + + // GGML_CANN_CALL_ACLNN_OP(ctx, IsNegInf, acl_mask_f16_tensor, acl_mask_bool_tensor); + // GGML_CANN_CALL_ACLNN_OP(ctx, InplaceLogicalNot, acl_mask_bool_tensor); + + + // PrintOutResultChar(src3->ne, &(src3_bool_buffer), "mask"); + + // broadcast pse + if(src0->ne[1] > 1){ + acl_mask_f16_tensor = ggml_cann_create_tensor(src3); + bcast_pse_ne[0] = src3->ne[0]; + bcast_pse_ne[1] = src3->ne[1]; + bcast_pse_ne[2] = src0->ne[2]; + bcast_pse_ne[3] = src3->ne[3]; + // int64_t bcast_pse_ne[GGML_MAX_DIMS] = {src3->ne[0], src3->ne[1], src0->ne[2], src3->ne[3]}; + // size_t bcast_pse_nb[GGML_MAX_DIMS]; + bcast_pse_nb[0] = sizeof(uint16_t); + + for(int i = 1; i < GGML_MAX_DIMS; ++i){ + bcast_pse_nb[i] = bcast_pse_nb[i - 1] * bcast_pse_ne[i - 1]; + } + + bcast_pse_tensor = ggml_cann_create_tensor( + bcast_pse_buffer, ACL_FLOAT16, sizeof(uint16_t), bcast_pse_ne, bcast_pse_nb, GGML_MAX_DIMS + ); + + int64_t repeats[] = {1, src0->ne[2], 1, 1}; + aclnn_repeat(ctx, acl_mask_f16_tensor, bcast_pse_tensor, repeats); +#ifdef DEBUG + PrintOutResultShort(bcast_pse_ne, &(src3->data), "repeat-1"); +#endif + ggml_cann_release_resources(ctx, acl_mask_f16_tensor); + }else{ + // ggml_cann_release_resources(ctx, acl_mask_f16_tensor); + int64_t trunc_pse_ne[GGML_MAX_DIMS] = {src3->ne[0], src0->ne[1], src3->ne[2], src3->ne[3]}; + size_t* trunc_pse_nb = src3->nb; + + aclTensor* acl_mask_f16_trunc_tensor = ggml_cann_create_tensor( + src3->data, ACL_FLOAT16, sizeof(uint16_t), trunc_pse_ne, trunc_pse_nb, GGML_MAX_DIMS + ); + + // bcast_pse_buffer = + // bcast_pse_allocator.alloc(src3->ne[0] * src0->ne[1] * src0->ne[2] * src3->ne[3] * sizeof(uint16_t)); + + bcast_pse_ne[0] = src3->ne[0]; + bcast_pse_ne[1] = src0->ne[1]; + bcast_pse_ne[2] = src0->ne[2]; + bcast_pse_ne[3] = src3->ne[3]; + + // int64_t bcast_pse_ne[GGML_MAX_DIMS] = {src3->ne[0], src0->ne[1], src0->ne[2], src3->ne[3]}; + // size_t bcast_pse_nb[GGML_MAX_DIMS]; + bcast_pse_nb[0] = sizeof(uint16_t); + + for(int i = 1; i < GGML_MAX_DIMS; ++i){ + bcast_pse_nb[i] = bcast_pse_nb[i - 1] * bcast_pse_ne[i - 1]; + } + + + bcast_pse_tensor = ggml_cann_create_tensor( + bcast_pse_buffer, ACL_FLOAT16, sizeof(uint16_t), bcast_pse_ne, bcast_pse_nb, GGML_MAX_DIMS + ); + int64_t repeats[] = {1, src0->ne[2], 1, 1}; + aclnn_repeat(ctx, acl_mask_f16_trunc_tensor, bcast_pse_tensor, repeats); +#ifdef DEBUG + PrintOutResultShort(bcast_pse_ne, &bcast_pse_buffer, "repeat-1"); +#endif + ggml_cann_release_resources(ctx, acl_mask_f16_trunc_tensor); + } + } + +#ifdef DEBUG + PrintOutResultShort(src0->ne, &(src0_f16_buffer), "q-3"); + PrintOutResultShort(src1->ne, &(src1->data), "k-3"); + PrintOutResultShort(src2->ne, &(src2->data), "v-3"); + PrintOutResultShort(src0->ne, &(out_f16_buffer), "out-3"); +#endif + + int kvTensorNum = 1; aclTensor* acl_q_tensor = acl_src0_f16_tensor; - aclTensor* acl_k_tensors[] = {ggml_cann_create_tensor(src1)}; - aclTensor* acl_v_tensors[] = {ggml_cann_create_tensor(src2)}; + aclTensor* acl_k_tensors[] = {acl_src1_f16_tensor}; + aclTensor* acl_v_tensors[] = {acl_src2_f16_tensor}; auto acl_k_tensor_list = aclCreateTensorList(acl_k_tensors, kvTensorNum); auto acl_v_tensor_list = aclCreateTensorList(acl_v_tensors, kvTensorNum); - aclTensor* acl_out_tensor = acl_out_f16_tensor; - int64_t numHeads = src0->ne[2]; // N int64_t numKeyValueHeads = src1->ne[2]; @@ -2682,18 +2872,28 @@ void ggml_cann_flash_attn_ext(ggml_backend_cann_context& ctx, ggml_tensor* dst){ int64_t nextTokens = 65535; char layout[5] = {'B', 'N', 'S', 'D', 0}; int64_t sparseMode = 0; - int64_t innerPrecise = 1; + int64_t innerPrecise = 2; int64_t blockSize = 0; int64_t antiquantMode = 0; bool softmaxLseFlag = false; int64_t keyAntiquantMode = 0; int64_t valueAntiquantMode = 0; + // Refer to https://gitee.com/ascend/cann-ops-adv/blob/master/docs/FusedInferAttentionScoreV2.md +#ifdef DEBUG + PrintOutResultShort(src0->ne, &(src0_f16_buffer), "q-4"); + PrintOutResultShort(src1->ne, &(src1->data), "k-4"); + PrintOutResultShort(src2->ne, &(src2->data), "v-4"); + PrintOutResultShort(src0->ne, &(out_f16_buffer), "out-4"); + if(src3) + PrintOutResultShort(bcast_pse_ne, &bcast_pse_buffer, "repeat-4"); +#endif + GGML_CANN_CALL_ACLNN_OP(ctx, FusedInferAttentionScoreV2, acl_q_tensor, acl_k_tensor_list, acl_v_tensor_list, // q, k, v - nullptr, acl_mask_bool_tensor, // pse, mask + bcast_pse_tensor, nullptr, // pse, mask nullptr, nullptr, // actSeqLen, actSeqLenkv nullptr, nullptr, // deqScale1, quantScale1 nullptr, nullptr, nullptr, // deqScale2, quantScale2, quantOffset2 @@ -2711,15 +2911,24 @@ void ggml_cann_flash_attn_ext(ggml_backend_cann_context& ctx, ggml_tensor* dst){ blockSize, antiquantMode, // blockSize, antiquantMode softmaxLseFlag, // softmaxLseFlag keyAntiquantMode, valueAntiquantMode, // keyAntiqMode, valueAntiqMode - acl_out_tensor, // attentionOut + acl_dst_f16_tensor, // attentionOut nullptr // softmaxLse ); +#ifdef DEBUG + PrintOutResultShort(src0->ne, &(src0_f16_buffer), "q-5"); + PrintOutResultShort(src1->ne, &(src1->data), "k-5"); + PrintOutResultShort(src2->ne, &(src2->data), "v-5"); + if(src3) + PrintOutResultShort(bcast_pse_ne, &bcast_pse_buffer, "repeat-5"); + PrintOutResultShort(out_f16_ne, &out_f16_buffer, "out-5"); +#endif + // Step 5: post-processing, permute and cast to f32 int64_t new_dim[] = {0, 2, 1, 3}; aclTensor* acl_dst_tensor = ggml_cann_create_tensor(dst); - if(dst->type != GGML_TYPE_F16){ + if(ggml_cann_type_mapping(dst->type) != faDataType){ ggml_cann_pool_alloc perm_out_f16_allocator(ctx.pool()); perm_out_f16_allocator.alloc(ggml_nelements(dst) * faElemSize); void* perm_out_f16_buffer = perm_out_f16_allocator.get(); @@ -2731,19 +2940,324 @@ void ggml_cann_flash_attn_ext(ggml_backend_cann_context& ctx, ggml_tensor* dst){ perm_out_f16_nb[i] = perm_out_f16_nb[i - 1] * perm_out_f16_ne[i - 1]; } aclTensor* acl_perm_out_f16_tensor = ggml_cann_create_tensor( - perm_out_f16_buffer, ACL_FLOAT16, faElemSize, + perm_out_f16_buffer, faDataType, faElemSize, perm_out_f16_ne, perm_out_f16_nb, GGML_MAX_DIMS); - aclnn_permute(ctx, acl_out_tensor, acl_perm_out_f16_tensor, new_dim, GGML_MAX_DIMS); + aclnn_permute(ctx, acl_dst_f16_tensor, acl_perm_out_f16_tensor, new_dim, GGML_MAX_DIMS); aclnn_cast(ctx, acl_perm_out_f16_tensor, acl_dst_tensor, ggml_cann_type_mapping(dst->type)); ggml_cann_release_resources(ctx, acl_perm_out_f16_tensor); }else{ // only need to permute - aclnn_permute(ctx, acl_out_tensor, acl_dst_tensor, new_dim, GGML_MAX_DIMS); + aclnn_permute(ctx, acl_dst_f16_tensor, acl_dst_tensor, new_dim, GGML_MAX_DIMS); } + ggml_cann_release_resources(ctx, acl_src0_f16_tensor, acl_src1_f16_tensor, acl_src2_f16_tensor, acl_dst_f16_tensor, acl_dst_tensor); + if(src3) + ggml_cann_release_resources(ctx, bcast_pse_tensor); +} + - ggml_cann_release_resources(ctx, acl_q_tensor, acl_k_tensor_list, acl_v_tensor_list, - acl_mask_bool_tensor, acl_out_f16_tensor, - acl_dst_tensor); -} \ No newline at end of file +// void ggml_cann_flash_attn_ext_archive(ggml_backend_cann_context& ctx, ggml_tensor* dst){ + + +// int id = 1; +// std::cout << std::endl; +// while(1){ +// auto ptr = dst->src[id - 1]; +// if(ptr != nullptr){ +// std::cout << "src " << id << ": " << ptr->name << " " << ggml_type_name(ptr->type) +// << " ne: " << ptr->ne[0] << "x" +// << ptr->ne[1] << "x" +// << ptr->ne[2] << "x" +// << ptr->ne[3] << " nb: " +// << ptr->nb[0] << "x" +// << ptr->nb[1] << "x" +// << ptr->nb[2] << "x" +// << ptr->nb[3] << '\n'; +// id++; +// }else{ +// break; +// } +// } + + + +// ggml_tensor* src0 = dst->src[0]; // q, fp32 +// ggml_tensor* src1 = dst->src[1]; // k, fp16 +// ggml_tensor* src2 = dst->src[2]; // v, fp16 +// ggml_tensor* src3 = dst->src[3]; // mask, fp16 + +// size_t faElemSize = sizeof(uint16_t); + +// auto faDataType = ACL_FLOAT16; //ACL_BF16; + +// // Step 1: cast the src0 (Query) to fp16 +// aclTensor* acl_src0_f16_tensor = nullptr; +// ggml_cann_pool_alloc src0_f16_allocator(ctx.pool()); +// void* src0_f16_buffer = nullptr; + +// if(ggml_cann_type_mapping(src0->type) != faDataType){ +// aclTensor* acl_src0_f32_tensor = ggml_cann_create_tensor(src0); + +// src0_f16_allocator.alloc(ggml_nelements(src0) * faElemSize); +// src0_f16_buffer = src0_f16_allocator.get(); + +// int64_t* src0_f16_ne = src0->ne; +// size_t src0_f16_nb[GGML_MAX_DIMS]; +// src0_f16_nb[0] = sizeof(uint16_t); +// for(int i = 1; i < GGML_MAX_DIMS; ++i){ +// src0_f16_nb[i] = src0_f16_nb[i - 1] * src0_f16_ne[i - 1]; +// } + +// acl_src0_f16_tensor = ggml_cann_create_tensor( +// src0_f16_buffer, faDataType, faElemSize, +// src0_f16_ne, src0_f16_nb, GGML_MAX_DIMS +// ); +// aclnn_cast(ctx, acl_src0_f32_tensor, acl_src0_f16_tensor, faDataType); +// ggml_cann_release_resources(ctx, acl_src0_f32_tensor); +// }else{ +// acl_src0_f16_tensor = ggml_cann_create_tensor(src0); +// } +// PrintOutResultShort(src0->ne, &(src0_f16_buffer), "q"); + +// // Step 2: genetates mask with ACL_BOOL +// aclTensor* acl_mask_f16_tensor = nullptr; +// aclTensor* acl_mask_bool_tensor = nullptr; +// aclTensor* bcast_pse_tensor = nullptr; + +// if(src3 != nullptr){ +// size_t maskElemSize = sizeof(char); +// ggml_cann_pool_alloc src3_bool_allocator(ctx.pool()); +// src3_bool_allocator.alloc(ggml_nelements(src3) * maskElemSize); +// void* src3_bool_buffer = src3_bool_allocator.get(); + +// int64_t* src3_bool_ne = src3->ne; +// size_t src3_bool_nb[GGML_MAX_DIMS]; +// src3_bool_nb[0] = maskElemSize; +// for(int i = 1; i < GGML_MAX_DIMS; ++i){ +// src3_bool_nb[i] = src3_bool_nb[i - 1] * src3_bool_ne[i - 1]; +// } + +// acl_mask_f16_tensor = ggml_cann_create_tensor(src3); +// acl_mask_bool_tensor = ggml_cann_create_tensor( +// src3_bool_buffer, ACL_BOOL, maskElemSize, +// src3_bool_ne, src3_bool_nb, GGML_MAX_DIMS); + +// // GGML_CANN_CALL_ACLNN_OP(ctx, IsNegInf, acl_mask_f16_tensor, acl_mask_bool_tensor); +// // GGML_CANN_CALL_ACLNN_OP(ctx, InplaceLogicalNot, acl_mask_bool_tensor); + +// PrintOutResultShort(src3->ne, &(src3->data), "mask"); +// // PrintOutResultChar(src3->ne, &(src3_bool_buffer), "mask"); + +// // broadcast pse +// if(src0->ne[1] > 1){ +// ggml_cann_pool_alloc bcast_pse_allocator(ctx.pool()); +// void* bcast_pse_buffer = +// bcast_pse_allocator.alloc(ggml_nelements(src3) * src0->ne[2] * sizeof(uint16_t)); + +// int64_t bcast_pse_ne[GGML_MAX_DIMS] = {src3->ne[0], src3->ne[1], src0->ne[2], src3->ne[3]}; +// size_t bcast_pse_nb[GGML_MAX_DIMS]; +// bcast_pse_nb[0] = sizeof(uint16_t); + +// for(int i = 1; i < GGML_MAX_DIMS; ++i){ +// bcast_pse_nb[i] = bcast_pse_nb[i - 1] * bcast_pse_ne[i - 1]; +// } + +// bcast_pse_tensor = ggml_cann_create_tensor( +// bcast_pse_buffer, ACL_FLOAT16, sizeof(uint16_t), bcast_pse_ne, bcast_pse_nb, GGML_MAX_DIMS +// ); + +// int64_t repeats[] = {1, src0->ne[2], 1, 1}; +// aclnn_repeat(ctx, acl_mask_f16_tensor, bcast_pse_tensor, repeats); +// PrintOutResultShort(bcast_pse_ne, &(src3->data), "mask"); +// }else{ +// // ggml_cann_release_resources(ctx, acl_mask_f16_tensor); +// int64_t trunc_pse_ne[GGML_MAX_DIMS] = {src3->ne[0], src0->ne[1], src3->ne[2], src3->ne[3]}; +// size_t* trunc_pse_nb = src3->nb; +// aclTensor* acl_mask_f16_trunc_tensor = ggml_cann_create_tensor( +// src3->data, ACL_FLOAT16, sizeof(uint16_t), trunc_pse_ne, trunc_pse_nb, GGML_MAX_DIMS +// ); +// ggml_cann_pool_alloc bcast_pse_allocator(ctx.pool()); +// void* bcast_pse_buffer = +// bcast_pse_allocator.alloc(src3->ne[0] * src0->ne[1] * src0->ne[2] * src3->ne[3] * sizeof(uint16_t)); + +// int64_t bcast_pse_ne[GGML_MAX_DIMS] = {src3->ne[0], src0->ne[1], src0->ne[2], src3->ne[3]}; +// size_t bcast_pse_nb[GGML_MAX_DIMS]; +// bcast_pse_nb[0] = sizeof(uint16_t); + +// for(int i = 1; i < GGML_MAX_DIMS; ++i){ +// bcast_pse_nb[i] = bcast_pse_nb[i - 1] * bcast_pse_ne[i - 1]; +// } + +// std::cout << bcast_pse_ne[0] << " " << bcast_pse_ne[1] << " " << bcast_pse_ne[2] << " " << bcast_pse_ne[3] << std::endl; +// std::cout << bcast_pse_nb[0] << " " << bcast_pse_nb[1] << " " << bcast_pse_nb[2] << " " << bcast_pse_nb[3] << std::endl; + + +// bcast_pse_tensor = ggml_cann_create_tensor( +// bcast_pse_buffer, ACL_FLOAT16, sizeof(uint16_t), bcast_pse_ne, bcast_pse_nb, GGML_MAX_DIMS +// ); +// int64_t repeats[] = {1, src0->ne[2], 1, 1}; +// aclnn_repeat(ctx, acl_mask_f16_trunc_tensor, bcast_pse_tensor, repeats); +// PrintOutResultShort(bcast_pse_ne, &bcast_pse_buffer, "repeat"); + +// // aclnn_muls(ctx, bcast_pse_tensor, sqrt(src0->ne[0]), nullptr, true); +// } +// } + +// // ggml_cann_release_resources(ctx, acl_mask_f16_tensor); + +// // Step 3: generates the output tensor directly from FA kernel +// ggml_cann_pool_alloc out_f16_allocator(ctx.pool()); +// out_f16_allocator.alloc(ggml_nelements(dst) * faElemSize); +// void* out_f16_buffer = out_f16_allocator.get(); + +// int64_t* out_f16_ne = src0->ne; +// size_t out_f16_nb[GGML_MAX_DIMS]; +// out_f16_nb[0] = faElemSize; +// for(int i = 1; i < GGML_MAX_DIMS; ++i){ +// out_f16_nb[i] = out_f16_nb[i - 1] * out_f16_ne[i - 1]; +// } + +// aclTensor* acl_out_f16_tensor = ggml_cann_create_tensor( +// out_f16_buffer, faDataType, faElemSize, +// out_f16_ne, out_f16_nb, GGML_MAX_DIMS +// ); + +// // Step 4: Performs the f16 Flash Attention kernel + +// int kvTensorNum = 1; +// aclTensor* acl_q_tensor = acl_src0_f16_tensor; + +// auto tmp_k = ggml_cann_create_tensor(src1); +// ggml_cann_pool_alloc k_f32_allocator(ctx.pool()); +// k_f32_allocator.alloc(ggml_nelements(src1) * sizeof(float)); +// void* k_f32_buffer = k_f32_allocator.get(); +// size_t k_f32_nb[GGML_MAX_DIMS]; +// for(int i = 0; i < GGML_MAX_DIMS; ++i){ +// k_f32_nb[i] = src1->nb[i] * 2; +// } + +// auto tmp_k_f32 = ggml_cann_create_tensor(k_f32_buffer, ACL_FLOAT, sizeof(float), +// src1->ne, k_f32_nb, GGML_MAX_DIMS); +// aclnn_cast(ctx, tmp_k, tmp_k_f32, ACL_FLOAT); + +// auto tmp_v = ggml_cann_create_tensor(src2); +// ggml_cann_pool_alloc v_f32_allocator(ctx.pool()); +// v_f32_allocator.alloc(ggml_nelements(src2) * sizeof(float)); +// void* v_f32_buffer = v_f32_allocator.get(); +// size_t v_f32_nb[GGML_MAX_DIMS]; +// for(int i = 0; i < GGML_MAX_DIMS; ++i){ +// v_f32_nb[i] = src1->nb[i] * 2; +// } +// auto tmp_v_f32 = ggml_cann_create_tensor(v_f32_buffer, ACL_FLOAT, sizeof(float), +// src2->ne, v_f32_nb, GGML_MAX_DIMS); +// aclnn_cast(ctx, tmp_v, tmp_v_f32, ACL_FLOAT); + +// PrintOutResultFloat(src1->ne, &k_f32_buffer, "k"); +// PrintOutResultFloat(src2->ne, &v_f32_buffer, "v"); + + +// aclTensor* acl_k_tensors[] = {tmp_k}; +// aclTensor* acl_v_tensors[] = {tmp_v}; + +// // aclTensor* acl_k_tensors[] = {ggml_cann_create_tensor(src1)}; +// // aclTensor* acl_v_tensors[] = {ggml_cann_create_tensor(src2)}; +// auto acl_k_tensor_list = aclCreateTensorList(acl_k_tensors, kvTensorNum); +// auto acl_v_tensor_list = aclCreateTensorList(acl_v_tensors, kvTensorNum); +// aclTensor* acl_out_tensor = acl_out_f16_tensor; + + +// int64_t numHeads = src0->ne[2]; // N +// int64_t numKeyValueHeads = src1->ne[2]; +// double scaleValue = 1 / sqrt(src0->ne[0]); // 1/sqrt(d) +// int64_t preTokens = 65535; +// int64_t nextTokens = 65535; +// char layout[5] = {'B', 'N', 'S', 'D', 0}; +// int64_t sparseMode = 0; +// int64_t innerPrecise = 2; +// int64_t blockSize = 0; +// int64_t antiquantMode = 0; +// bool softmaxLseFlag = false; +// int64_t keyAntiquantMode = 0; +// int64_t valueAntiquantMode = 0; + +// // Refer to https://gitee.com/ascend/cann-ops-adv/blob/master/docs/FusedInferAttentionScoreV2.md + +// GGML_CANN_CALL_ACLNN_OP(ctx, FusedInferAttentionScoreV2, +// acl_q_tensor, acl_k_tensor_list, acl_v_tensor_list, // q, k, v +// bcast_pse_tensor, nullptr, // pse, mask +// nullptr, nullptr, // actSeqLen, actSeqLenkv +// nullptr, nullptr, // deqScale1, quantScale1 +// nullptr, nullptr, nullptr, // deqScale2, quantScale2, quantOffset2 +// nullptr, nullptr, // antiquantScale, antiquantOffset +// nullptr, // blockTable +// nullptr, nullptr, // qPadSize, kvPadSize +// nullptr, nullptr, // kAntiquantScale, kAntiQuantOffset +// nullptr, nullptr, // vAntiquantScale, vAntiQuantOffset +// nullptr, nullptr, nullptr, // kSharedPrefix, vSharedPrefix, actSharedLen +// numHeads, scaleValue, // heads, scaleValue +// preTokens, nextTokens, // preTokens, nextTokens +// layout, // inputLayout +// numKeyValueHeads, // numKVHeads +// sparseMode, innerPrecise, // sparseMode, innerPrecise +// blockSize, antiquantMode, // blockSize, antiquantMode +// softmaxLseFlag, // softmaxLseFlag +// keyAntiquantMode, valueAntiquantMode, // keyAntiqMode, valueAntiqMode +// acl_out_tensor, // attentionOut +// nullptr // softmaxLse +// ); + +// ggml_cann_pool_alloc aux_out_f32_allocator(ctx.pool()); +// aux_out_f32_allocator.alloc(ggml_nelements(dst) * sizeof(float)); +// void* aux_out_f32_buffer = aux_out_f32_allocator.get(); +// int64_t* aux_out_f32_ne = out_f16_ne; +// size_t aux_out_f32_nb[GGML_MAX_DIMS]; +// for(int i = 0; i < GGML_MAX_DIMS; ++i){ +// aux_out_f32_nb[i] = out_f16_nb[i] * 2; +// } +// aclTensor* aux_out_f32_tensor = ggml_cann_create_tensor( +// aux_out_f32_buffer, ACL_FLOAT, sizeof(float), aux_out_f32_ne, aux_out_f32_nb, GGML_MAX_DIMS +// ); + +// aclnn_cast(ctx, +// acl_out_tensor, aux_out_f32_tensor, ggml_cann_type_mapping(dst->type)); + + + +// PrintOutResultFloat(out_f16_ne, &aux_out_f32_buffer, "fia-res-f32"); + +// // Step 5: post-processing, permute and cast to f32 +// int64_t new_dim[] = {0, 2, 1, 3}; +// aclTensor* acl_dst_tensor = ggml_cann_create_tensor(dst); + +// if(ggml_cann_type_mapping(dst->type) != faDataType){ +// ggml_cann_pool_alloc perm_out_f16_allocator(ctx.pool()); +// perm_out_f16_allocator.alloc(ggml_nelements(dst) * faElemSize); +// void* perm_out_f16_buffer = perm_out_f16_allocator.get(); + +// int64_t* perm_out_f16_ne = dst->ne; +// size_t perm_out_f16_nb[GGML_MAX_DIMS]; +// perm_out_f16_nb[0] = faElemSize; +// for(int i = 1; i < GGML_MAX_DIMS; ++i){ +// perm_out_f16_nb[i] = perm_out_f16_nb[i - 1] * perm_out_f16_ne[i - 1]; +// } +// aclTensor* acl_perm_out_f16_tensor = ggml_cann_create_tensor( +// perm_out_f16_buffer, faDataType, faElemSize, +// perm_out_f16_ne, perm_out_f16_nb, GGML_MAX_DIMS); +// // aclTensor* acl_perm_out_f16_tensor = ggml_cann_create_tensor( +// // perm_out_f16_buffer, ACL_FLOAT16, faElemSize, +// // perm_out_f16_ne, perm_out_f16_nb, GGML_MAX_DIMS); +// aclnn_permute(ctx, acl_out_tensor, acl_perm_out_f16_tensor, new_dim, GGML_MAX_DIMS); +// aclnn_cast(ctx, +// acl_perm_out_f16_tensor, acl_dst_tensor, ggml_cann_type_mapping(dst->type)); +// ggml_cann_release_resources(ctx, acl_perm_out_f16_tensor); +// }else{ +// // only need to permute +// aclnn_permute(ctx, acl_out_tensor, acl_dst_tensor, new_dim, GGML_MAX_DIMS); +// } + +// ggml_cann_release_resources(ctx, acl_q_tensor, acl_k_tensor_list, acl_v_tensor_list, +// acl_mask_bool_tensor, acl_out_f16_tensor, +// acl_dst_tensor); + +// } \ No newline at end of file diff --git a/ggml/src/ggml-cann/aclnn_ops.h b/ggml/src/ggml-cann/aclnn_ops.h old mode 100644 new mode 100755 diff --git a/ggml/src/ggml-cann/common.h b/ggml/src/ggml-cann/common.h old mode 100644 new mode 100755 diff --git a/ggml/src/ggml-cann/ggml-cann.cpp b/ggml/src/ggml-cann/ggml-cann.cpp old mode 100644 new mode 100755 diff --git a/ggml/src/ggml-cann/ifa.py b/ggml/src/ggml-cann/ifa.py new file mode 100644 index 0000000000000..b01f497ef7994 --- /dev/null +++ b/ggml/src/ggml-cann/ifa.py @@ -0,0 +1,43 @@ +# 单算子调用方式 +import torch +import torch_npu +import math + +def load_float_array_to_tensor(file_path, shape, dtype): + with open(file_path, 'r') as file: + # 读取文件内容并按空格分割 + data = file.read().strip().split() + # 将字符串转换为浮点数 + float_array = [float(num) for num in data] + # 转换为 PyTorch 张量 + tensor = torch.tensor(float_array, dtype=dtype).reshape(shape).npu() + return tensor + +batch = 1 +nhead_q = 4 +nhead_kv = nhead_q +seq_q = 1 +dims = 64 +seq_kv = 512 +layout="BNSD" + +scale_value = 1 / pow(dims, 0.5) + +q_tensor = load_float_array_to_tensor("/data/home/2101111451/pr/llama.cpp/output_acl_short_0_q.txt", + (batch, nhead_q, seq_q, dims), torch.float16) +k_tensor = load_float_array_to_tensor("/data/home/2101111451/pr/llama.cpp/output_acl_short_3_k.txt", + (batch, nhead_kv, seq_kv, dims), torch.float16) + +v_tensor = load_float_array_to_tensor("/data/home/2101111451/pr/llama.cpp/output_acl_short_4_v.txt", + (batch, nhead_kv, seq_kv, dims), torch.float16) + +pse_tensor = load_float_array_to_tensor("/data/home/2101111451/pr/llama.cpp/output_acl_short_1_mask.txt", + (1, 1, -1, seq_kv), torch.float16) + +print(q_tensor.shape, k_tensor.shape, v_tensor.shape, pse_tensor.shape) + +# 调用IFA算子 +out = torch_npu.npu_incre_flash_attention(q_tensor, k_tensor, v_tensor, pse_shift=pse_tensor, + num_heads=nhead_q, num_key_value_heads=nhead_kv, + input_layout=layout, scale_value=scale_value) + From 8a902b9875e40be7a2e766c2e816a2eff078b747 Mon Sep 17 00:00:00 2001 From: Bizhao Shi Date: Sun, 18 May 2025 23:08:14 +0800 Subject: [PATCH 04/16] cann: update the input parameters in FA --- ggml/src/ggml-cann/aclnn_ops.cpp | 361 ++----------------------------- 1 file changed, 16 insertions(+), 345 deletions(-) diff --git a/ggml/src/ggml-cann/aclnn_ops.cpp b/ggml/src/ggml-cann/aclnn_ops.cpp index cef33ee71e8ba..0e55de00d5b05 100755 --- a/ggml/src/ggml-cann/aclnn_ops.cpp +++ b/ggml/src/ggml-cann/aclnn_ops.cpp @@ -2683,6 +2683,17 @@ void ggml_cann_flash_attn_ext(ggml_backend_cann_context& ctx, ggml_tensor* dst){ ggml_tensor* src2 = dst->src[2]; // v, fp16 ggml_tensor* src3 = dst->src[3]; // mask, fp16 + float maxBias = 0.0; + float scaleValue = 1.0; + float logitSoftcap = 0.0; + memcpy(&scaleValue, (float*)dst->op_params + 0, sizeof(float)); + memcpy(&maxBias, (float*)dst->op_params + 1, sizeof(float)); + memcpy(&logitSoftcap, (float*)dst->op_params + 2, sizeof(float)); + + // if(logitSoftcap != 0.0f){ + // // call the non-fa implementation + // }else{ + size_t faElemSize = sizeof(uint16_t); auto faDataType = ACL_FLOAT16; //ACL_BF16; @@ -2748,8 +2759,6 @@ void ggml_cann_flash_attn_ext(ggml_backend_cann_context& ctx, ggml_tensor* dst){ PrintOutResultShort(src0->ne, &(out_f16_buffer), "out-2"); #endif - aclTensor* acl_mask_f16_tensor = nullptr; - aclTensor* acl_mask_bool_tensor = nullptr; aclTensor* bcast_pse_tensor = nullptr; int64_t bcast_pse_ne[GGML_MAX_DIMS]; @@ -2764,39 +2773,16 @@ void ggml_cann_flash_attn_ext(ggml_backend_cann_context& ctx, ggml_tensor* dst){ #ifdef DEBUG PrintOutResultShort(src3->ne, &(src3->data), "mask"); #endif - // size_t maskElemSize = sizeof(char); - - // ggml_cann_pool_alloc src3_bool_allocator(ctx.pool()); - // void* src3_bool_buffer = src3_bool_allocator.alloc(ggml_nelements(src3) * maskElemSize); - // int64_t* src3_bool_ne = src3->ne; - // size_t src3_bool_nb[GGML_MAX_DIMS]; - // src3_bool_nb[0] = maskElemSize; - // for(int i = 1; i < GGML_MAX_DIMS; ++i){ - // src3_bool_nb[i] = src3_bool_nb[i - 1] * src3_bool_ne[i - 1]; - // } - - - // acl_mask_bool_tensor = ggml_cann_create_tensor( - // src3_bool_buffer, ACL_BOOL, maskElemSize, - // src3_bool_ne, src3_bool_nb, GGML_MAX_DIMS); - - // GGML_CANN_CALL_ACLNN_OP(ctx, IsNegInf, acl_mask_f16_tensor, acl_mask_bool_tensor); - // GGML_CANN_CALL_ACLNN_OP(ctx, InplaceLogicalNot, acl_mask_bool_tensor); - - - // PrintOutResultChar(src3->ne, &(src3_bool_buffer), "mask"); - // broadcast pse if(src0->ne[1] > 1){ - acl_mask_f16_tensor = ggml_cann_create_tensor(src3); + aclTensor* acl_mask_f16_tensor = ggml_cann_create_tensor(src3); + bcast_pse_ne[0] = src3->ne[0]; bcast_pse_ne[1] = src3->ne[1]; bcast_pse_ne[2] = src0->ne[2]; bcast_pse_ne[3] = src3->ne[3]; - // int64_t bcast_pse_ne[GGML_MAX_DIMS] = {src3->ne[0], src3->ne[1], src0->ne[2], src3->ne[3]}; - // size_t bcast_pse_nb[GGML_MAX_DIMS]; - bcast_pse_nb[0] = sizeof(uint16_t); + bcast_pse_nb[0] = sizeof(uint16_t); for(int i = 1; i < GGML_MAX_DIMS; ++i){ bcast_pse_nb[i] = bcast_pse_nb[i - 1] * bcast_pse_ne[i - 1]; } @@ -2812,30 +2798,22 @@ void ggml_cann_flash_attn_ext(ggml_backend_cann_context& ctx, ggml_tensor* dst){ #endif ggml_cann_release_resources(ctx, acl_mask_f16_tensor); }else{ - // ggml_cann_release_resources(ctx, acl_mask_f16_tensor); int64_t trunc_pse_ne[GGML_MAX_DIMS] = {src3->ne[0], src0->ne[1], src3->ne[2], src3->ne[3]}; size_t* trunc_pse_nb = src3->nb; aclTensor* acl_mask_f16_trunc_tensor = ggml_cann_create_tensor( src3->data, ACL_FLOAT16, sizeof(uint16_t), trunc_pse_ne, trunc_pse_nb, GGML_MAX_DIMS ); - - // bcast_pse_buffer = - // bcast_pse_allocator.alloc(src3->ne[0] * src0->ne[1] * src0->ne[2] * src3->ne[3] * sizeof(uint16_t)); bcast_pse_ne[0] = src3->ne[0]; bcast_pse_ne[1] = src0->ne[1]; bcast_pse_ne[2] = src0->ne[2]; bcast_pse_ne[3] = src3->ne[3]; - // int64_t bcast_pse_ne[GGML_MAX_DIMS] = {src3->ne[0], src0->ne[1], src0->ne[2], src3->ne[3]}; - // size_t bcast_pse_nb[GGML_MAX_DIMS]; bcast_pse_nb[0] = sizeof(uint16_t); - for(int i = 1; i < GGML_MAX_DIMS; ++i){ bcast_pse_nb[i] = bcast_pse_nb[i - 1] * bcast_pse_ne[i - 1]; } - bcast_pse_tensor = ggml_cann_create_tensor( bcast_pse_buffer, ACL_FLOAT16, sizeof(uint16_t), bcast_pse_ne, bcast_pse_nb, GGML_MAX_DIMS @@ -2867,7 +2845,7 @@ void ggml_cann_flash_attn_ext(ggml_backend_cann_context& ctx, ggml_tensor* dst){ int64_t numHeads = src0->ne[2]; // N int64_t numKeyValueHeads = src1->ne[2]; - double scaleValue = 1 / sqrt(src0->ne[0]); // 1/sqrt(d) + // double scaleValue = 1 / sqrt(src0->ne[0]); // 1/sqrt(d) int64_t preTokens = 65535; int64_t nextTokens = 65535; char layout[5] = {'B', 'N', 'S', 'D', 0}; @@ -2953,311 +2931,4 @@ void ggml_cann_flash_attn_ext(ggml_backend_cann_context& ctx, ggml_tensor* dst){ ggml_cann_release_resources(ctx, acl_src0_f16_tensor, acl_src1_f16_tensor, acl_src2_f16_tensor, acl_dst_f16_tensor, acl_dst_tensor); if(src3) ggml_cann_release_resources(ctx, bcast_pse_tensor); -} - - - -// void ggml_cann_flash_attn_ext_archive(ggml_backend_cann_context& ctx, ggml_tensor* dst){ - - -// int id = 1; -// std::cout << std::endl; -// while(1){ -// auto ptr = dst->src[id - 1]; -// if(ptr != nullptr){ -// std::cout << "src " << id << ": " << ptr->name << " " << ggml_type_name(ptr->type) -// << " ne: " << ptr->ne[0] << "x" -// << ptr->ne[1] << "x" -// << ptr->ne[2] << "x" -// << ptr->ne[3] << " nb: " -// << ptr->nb[0] << "x" -// << ptr->nb[1] << "x" -// << ptr->nb[2] << "x" -// << ptr->nb[3] << '\n'; -// id++; -// }else{ -// break; -// } -// } - - - -// ggml_tensor* src0 = dst->src[0]; // q, fp32 -// ggml_tensor* src1 = dst->src[1]; // k, fp16 -// ggml_tensor* src2 = dst->src[2]; // v, fp16 -// ggml_tensor* src3 = dst->src[3]; // mask, fp16 - -// size_t faElemSize = sizeof(uint16_t); - -// auto faDataType = ACL_FLOAT16; //ACL_BF16; - -// // Step 1: cast the src0 (Query) to fp16 -// aclTensor* acl_src0_f16_tensor = nullptr; -// ggml_cann_pool_alloc src0_f16_allocator(ctx.pool()); -// void* src0_f16_buffer = nullptr; - -// if(ggml_cann_type_mapping(src0->type) != faDataType){ -// aclTensor* acl_src0_f32_tensor = ggml_cann_create_tensor(src0); - -// src0_f16_allocator.alloc(ggml_nelements(src0) * faElemSize); -// src0_f16_buffer = src0_f16_allocator.get(); - -// int64_t* src0_f16_ne = src0->ne; -// size_t src0_f16_nb[GGML_MAX_DIMS]; -// src0_f16_nb[0] = sizeof(uint16_t); -// for(int i = 1; i < GGML_MAX_DIMS; ++i){ -// src0_f16_nb[i] = src0_f16_nb[i - 1] * src0_f16_ne[i - 1]; -// } - -// acl_src0_f16_tensor = ggml_cann_create_tensor( -// src0_f16_buffer, faDataType, faElemSize, -// src0_f16_ne, src0_f16_nb, GGML_MAX_DIMS -// ); -// aclnn_cast(ctx, acl_src0_f32_tensor, acl_src0_f16_tensor, faDataType); -// ggml_cann_release_resources(ctx, acl_src0_f32_tensor); -// }else{ -// acl_src0_f16_tensor = ggml_cann_create_tensor(src0); -// } -// PrintOutResultShort(src0->ne, &(src0_f16_buffer), "q"); - -// // Step 2: genetates mask with ACL_BOOL -// aclTensor* acl_mask_f16_tensor = nullptr; -// aclTensor* acl_mask_bool_tensor = nullptr; -// aclTensor* bcast_pse_tensor = nullptr; - -// if(src3 != nullptr){ -// size_t maskElemSize = sizeof(char); -// ggml_cann_pool_alloc src3_bool_allocator(ctx.pool()); -// src3_bool_allocator.alloc(ggml_nelements(src3) * maskElemSize); -// void* src3_bool_buffer = src3_bool_allocator.get(); - -// int64_t* src3_bool_ne = src3->ne; -// size_t src3_bool_nb[GGML_MAX_DIMS]; -// src3_bool_nb[0] = maskElemSize; -// for(int i = 1; i < GGML_MAX_DIMS; ++i){ -// src3_bool_nb[i] = src3_bool_nb[i - 1] * src3_bool_ne[i - 1]; -// } - -// acl_mask_f16_tensor = ggml_cann_create_tensor(src3); -// acl_mask_bool_tensor = ggml_cann_create_tensor( -// src3_bool_buffer, ACL_BOOL, maskElemSize, -// src3_bool_ne, src3_bool_nb, GGML_MAX_DIMS); - -// // GGML_CANN_CALL_ACLNN_OP(ctx, IsNegInf, acl_mask_f16_tensor, acl_mask_bool_tensor); -// // GGML_CANN_CALL_ACLNN_OP(ctx, InplaceLogicalNot, acl_mask_bool_tensor); - -// PrintOutResultShort(src3->ne, &(src3->data), "mask"); -// // PrintOutResultChar(src3->ne, &(src3_bool_buffer), "mask"); - -// // broadcast pse -// if(src0->ne[1] > 1){ -// ggml_cann_pool_alloc bcast_pse_allocator(ctx.pool()); -// void* bcast_pse_buffer = -// bcast_pse_allocator.alloc(ggml_nelements(src3) * src0->ne[2] * sizeof(uint16_t)); - -// int64_t bcast_pse_ne[GGML_MAX_DIMS] = {src3->ne[0], src3->ne[1], src0->ne[2], src3->ne[3]}; -// size_t bcast_pse_nb[GGML_MAX_DIMS]; -// bcast_pse_nb[0] = sizeof(uint16_t); - -// for(int i = 1; i < GGML_MAX_DIMS; ++i){ -// bcast_pse_nb[i] = bcast_pse_nb[i - 1] * bcast_pse_ne[i - 1]; -// } - -// bcast_pse_tensor = ggml_cann_create_tensor( -// bcast_pse_buffer, ACL_FLOAT16, sizeof(uint16_t), bcast_pse_ne, bcast_pse_nb, GGML_MAX_DIMS -// ); - -// int64_t repeats[] = {1, src0->ne[2], 1, 1}; -// aclnn_repeat(ctx, acl_mask_f16_tensor, bcast_pse_tensor, repeats); -// PrintOutResultShort(bcast_pse_ne, &(src3->data), "mask"); -// }else{ -// // ggml_cann_release_resources(ctx, acl_mask_f16_tensor); -// int64_t trunc_pse_ne[GGML_MAX_DIMS] = {src3->ne[0], src0->ne[1], src3->ne[2], src3->ne[3]}; -// size_t* trunc_pse_nb = src3->nb; -// aclTensor* acl_mask_f16_trunc_tensor = ggml_cann_create_tensor( -// src3->data, ACL_FLOAT16, sizeof(uint16_t), trunc_pse_ne, trunc_pse_nb, GGML_MAX_DIMS -// ); -// ggml_cann_pool_alloc bcast_pse_allocator(ctx.pool()); -// void* bcast_pse_buffer = -// bcast_pse_allocator.alloc(src3->ne[0] * src0->ne[1] * src0->ne[2] * src3->ne[3] * sizeof(uint16_t)); - -// int64_t bcast_pse_ne[GGML_MAX_DIMS] = {src3->ne[0], src0->ne[1], src0->ne[2], src3->ne[3]}; -// size_t bcast_pse_nb[GGML_MAX_DIMS]; -// bcast_pse_nb[0] = sizeof(uint16_t); - -// for(int i = 1; i < GGML_MAX_DIMS; ++i){ -// bcast_pse_nb[i] = bcast_pse_nb[i - 1] * bcast_pse_ne[i - 1]; -// } - -// std::cout << bcast_pse_ne[0] << " " << bcast_pse_ne[1] << " " << bcast_pse_ne[2] << " " << bcast_pse_ne[3] << std::endl; -// std::cout << bcast_pse_nb[0] << " " << bcast_pse_nb[1] << " " << bcast_pse_nb[2] << " " << bcast_pse_nb[3] << std::endl; - - -// bcast_pse_tensor = ggml_cann_create_tensor( -// bcast_pse_buffer, ACL_FLOAT16, sizeof(uint16_t), bcast_pse_ne, bcast_pse_nb, GGML_MAX_DIMS -// ); -// int64_t repeats[] = {1, src0->ne[2], 1, 1}; -// aclnn_repeat(ctx, acl_mask_f16_trunc_tensor, bcast_pse_tensor, repeats); -// PrintOutResultShort(bcast_pse_ne, &bcast_pse_buffer, "repeat"); - -// // aclnn_muls(ctx, bcast_pse_tensor, sqrt(src0->ne[0]), nullptr, true); -// } -// } - -// // ggml_cann_release_resources(ctx, acl_mask_f16_tensor); - -// // Step 3: generates the output tensor directly from FA kernel -// ggml_cann_pool_alloc out_f16_allocator(ctx.pool()); -// out_f16_allocator.alloc(ggml_nelements(dst) * faElemSize); -// void* out_f16_buffer = out_f16_allocator.get(); - -// int64_t* out_f16_ne = src0->ne; -// size_t out_f16_nb[GGML_MAX_DIMS]; -// out_f16_nb[0] = faElemSize; -// for(int i = 1; i < GGML_MAX_DIMS; ++i){ -// out_f16_nb[i] = out_f16_nb[i - 1] * out_f16_ne[i - 1]; -// } - -// aclTensor* acl_out_f16_tensor = ggml_cann_create_tensor( -// out_f16_buffer, faDataType, faElemSize, -// out_f16_ne, out_f16_nb, GGML_MAX_DIMS -// ); - -// // Step 4: Performs the f16 Flash Attention kernel - -// int kvTensorNum = 1; -// aclTensor* acl_q_tensor = acl_src0_f16_tensor; - -// auto tmp_k = ggml_cann_create_tensor(src1); -// ggml_cann_pool_alloc k_f32_allocator(ctx.pool()); -// k_f32_allocator.alloc(ggml_nelements(src1) * sizeof(float)); -// void* k_f32_buffer = k_f32_allocator.get(); -// size_t k_f32_nb[GGML_MAX_DIMS]; -// for(int i = 0; i < GGML_MAX_DIMS; ++i){ -// k_f32_nb[i] = src1->nb[i] * 2; -// } - -// auto tmp_k_f32 = ggml_cann_create_tensor(k_f32_buffer, ACL_FLOAT, sizeof(float), -// src1->ne, k_f32_nb, GGML_MAX_DIMS); -// aclnn_cast(ctx, tmp_k, tmp_k_f32, ACL_FLOAT); - -// auto tmp_v = ggml_cann_create_tensor(src2); -// ggml_cann_pool_alloc v_f32_allocator(ctx.pool()); -// v_f32_allocator.alloc(ggml_nelements(src2) * sizeof(float)); -// void* v_f32_buffer = v_f32_allocator.get(); -// size_t v_f32_nb[GGML_MAX_DIMS]; -// for(int i = 0; i < GGML_MAX_DIMS; ++i){ -// v_f32_nb[i] = src1->nb[i] * 2; -// } -// auto tmp_v_f32 = ggml_cann_create_tensor(v_f32_buffer, ACL_FLOAT, sizeof(float), -// src2->ne, v_f32_nb, GGML_MAX_DIMS); -// aclnn_cast(ctx, tmp_v, tmp_v_f32, ACL_FLOAT); - -// PrintOutResultFloat(src1->ne, &k_f32_buffer, "k"); -// PrintOutResultFloat(src2->ne, &v_f32_buffer, "v"); - - -// aclTensor* acl_k_tensors[] = {tmp_k}; -// aclTensor* acl_v_tensors[] = {tmp_v}; - -// // aclTensor* acl_k_tensors[] = {ggml_cann_create_tensor(src1)}; -// // aclTensor* acl_v_tensors[] = {ggml_cann_create_tensor(src2)}; -// auto acl_k_tensor_list = aclCreateTensorList(acl_k_tensors, kvTensorNum); -// auto acl_v_tensor_list = aclCreateTensorList(acl_v_tensors, kvTensorNum); -// aclTensor* acl_out_tensor = acl_out_f16_tensor; - - -// int64_t numHeads = src0->ne[2]; // N -// int64_t numKeyValueHeads = src1->ne[2]; -// double scaleValue = 1 / sqrt(src0->ne[0]); // 1/sqrt(d) -// int64_t preTokens = 65535; -// int64_t nextTokens = 65535; -// char layout[5] = {'B', 'N', 'S', 'D', 0}; -// int64_t sparseMode = 0; -// int64_t innerPrecise = 2; -// int64_t blockSize = 0; -// int64_t antiquantMode = 0; -// bool softmaxLseFlag = false; -// int64_t keyAntiquantMode = 0; -// int64_t valueAntiquantMode = 0; - -// // Refer to https://gitee.com/ascend/cann-ops-adv/blob/master/docs/FusedInferAttentionScoreV2.md - -// GGML_CANN_CALL_ACLNN_OP(ctx, FusedInferAttentionScoreV2, -// acl_q_tensor, acl_k_tensor_list, acl_v_tensor_list, // q, k, v -// bcast_pse_tensor, nullptr, // pse, mask -// nullptr, nullptr, // actSeqLen, actSeqLenkv -// nullptr, nullptr, // deqScale1, quantScale1 -// nullptr, nullptr, nullptr, // deqScale2, quantScale2, quantOffset2 -// nullptr, nullptr, // antiquantScale, antiquantOffset -// nullptr, // blockTable -// nullptr, nullptr, // qPadSize, kvPadSize -// nullptr, nullptr, // kAntiquantScale, kAntiQuantOffset -// nullptr, nullptr, // vAntiquantScale, vAntiQuantOffset -// nullptr, nullptr, nullptr, // kSharedPrefix, vSharedPrefix, actSharedLen -// numHeads, scaleValue, // heads, scaleValue -// preTokens, nextTokens, // preTokens, nextTokens -// layout, // inputLayout -// numKeyValueHeads, // numKVHeads -// sparseMode, innerPrecise, // sparseMode, innerPrecise -// blockSize, antiquantMode, // blockSize, antiquantMode -// softmaxLseFlag, // softmaxLseFlag -// keyAntiquantMode, valueAntiquantMode, // keyAntiqMode, valueAntiqMode -// acl_out_tensor, // attentionOut -// nullptr // softmaxLse -// ); - -// ggml_cann_pool_alloc aux_out_f32_allocator(ctx.pool()); -// aux_out_f32_allocator.alloc(ggml_nelements(dst) * sizeof(float)); -// void* aux_out_f32_buffer = aux_out_f32_allocator.get(); -// int64_t* aux_out_f32_ne = out_f16_ne; -// size_t aux_out_f32_nb[GGML_MAX_DIMS]; -// for(int i = 0; i < GGML_MAX_DIMS; ++i){ -// aux_out_f32_nb[i] = out_f16_nb[i] * 2; -// } -// aclTensor* aux_out_f32_tensor = ggml_cann_create_tensor( -// aux_out_f32_buffer, ACL_FLOAT, sizeof(float), aux_out_f32_ne, aux_out_f32_nb, GGML_MAX_DIMS -// ); - -// aclnn_cast(ctx, -// acl_out_tensor, aux_out_f32_tensor, ggml_cann_type_mapping(dst->type)); - - - -// PrintOutResultFloat(out_f16_ne, &aux_out_f32_buffer, "fia-res-f32"); - -// // Step 5: post-processing, permute and cast to f32 -// int64_t new_dim[] = {0, 2, 1, 3}; -// aclTensor* acl_dst_tensor = ggml_cann_create_tensor(dst); - -// if(ggml_cann_type_mapping(dst->type) != faDataType){ -// ggml_cann_pool_alloc perm_out_f16_allocator(ctx.pool()); -// perm_out_f16_allocator.alloc(ggml_nelements(dst) * faElemSize); -// void* perm_out_f16_buffer = perm_out_f16_allocator.get(); - -// int64_t* perm_out_f16_ne = dst->ne; -// size_t perm_out_f16_nb[GGML_MAX_DIMS]; -// perm_out_f16_nb[0] = faElemSize; -// for(int i = 1; i < GGML_MAX_DIMS; ++i){ -// perm_out_f16_nb[i] = perm_out_f16_nb[i - 1] * perm_out_f16_ne[i - 1]; -// } -// aclTensor* acl_perm_out_f16_tensor = ggml_cann_create_tensor( -// perm_out_f16_buffer, faDataType, faElemSize, -// perm_out_f16_ne, perm_out_f16_nb, GGML_MAX_DIMS); -// // aclTensor* acl_perm_out_f16_tensor = ggml_cann_create_tensor( -// // perm_out_f16_buffer, ACL_FLOAT16, faElemSize, -// // perm_out_f16_ne, perm_out_f16_nb, GGML_MAX_DIMS); -// aclnn_permute(ctx, acl_out_tensor, acl_perm_out_f16_tensor, new_dim, GGML_MAX_DIMS); -// aclnn_cast(ctx, -// acl_perm_out_f16_tensor, acl_dst_tensor, ggml_cann_type_mapping(dst->type)); -// ggml_cann_release_resources(ctx, acl_perm_out_f16_tensor); -// }else{ -// // only need to permute -// aclnn_permute(ctx, acl_out_tensor, acl_dst_tensor, new_dim, GGML_MAX_DIMS); -// } - -// ggml_cann_release_resources(ctx, acl_q_tensor, acl_k_tensor_list, acl_v_tensor_list, -// acl_mask_bool_tensor, acl_out_f16_tensor, -// acl_dst_tensor); - -// } \ No newline at end of file +} \ No newline at end of file From f5e24a5c7d8ff739a0f9fa586f9a5f4aabeb328e Mon Sep 17 00:00:00 2001 From: Bizhao Shi Date: Mon, 19 May 2025 00:10:28 +0800 Subject: [PATCH 05/16] cann: update the alibi with max_bias --- ggml/src/ggml-cann/aclnn_ops.cpp | 108 +++++++++++++++++++++++++++++-- 1 file changed, 103 insertions(+), 5 deletions(-) diff --git a/ggml/src/ggml-cann/aclnn_ops.cpp b/ggml/src/ggml-cann/aclnn_ops.cpp index 0e55de00d5b05..7c18dd5748425 100755 --- a/ggml/src/ggml-cann/aclnn_ops.cpp +++ b/ggml/src/ggml-cann/aclnn_ops.cpp @@ -2690,10 +2690,6 @@ void ggml_cann_flash_attn_ext(ggml_backend_cann_context& ctx, ggml_tensor* dst){ memcpy(&maxBias, (float*)dst->op_params + 1, sizeof(float)); memcpy(&logitSoftcap, (float*)dst->op_params + 2, sizeof(float)); - // if(logitSoftcap != 0.0f){ - // // call the non-fa implementation - // }else{ - size_t faElemSize = sizeof(uint16_t); auto faDataType = ACL_FLOAT16; //ACL_BF16; @@ -2825,6 +2821,108 @@ void ggml_cann_flash_attn_ext(ggml_backend_cann_context& ctx, ggml_tensor* dst){ #endif ggml_cann_release_resources(ctx, acl_mask_f16_trunc_tensor); } + + if(maxBias != 0.0f){ + // alibi + const int64_t ne2_ne3 = src0->ne[2] * src0->ne[3]; + const int64_t n_head = src0->ne[2]; + const int n_heads_log2_floor = 1u << (uint32_t)floor(log2(n_head)); + float m0 = powf(2.0f, -(maxBias) / n_heads_log2_floor); + float m1 = powf(2.0f, -(maxBias / 2.0f) / n_heads_log2_floor); + // init arange + ggml_cann_pool_alloc arange_allocator(ctx.pool(), + ne2_ne3 * faElemSize); + void* tmp_arange_buffer = arange_allocator.get(); + + // arange1: [1, ..., n_heads_log2_floor+1) + float start = 1; + float stop = n_heads_log2_floor + 1; + float step = 1; + int64_t n_elements_arange = n_heads_log2_floor; + + int64_t tmp_arange1_ne[] = {n_heads_log2_floor}; + size_t tmp_arange1_nb[] = {faElemSize}; + aclTensor* tmp_arange1_tensor = ggml_cann_create_tensor( + tmp_arange_buffer, faDataType, faElemSize, + tmp_arange1_ne, tmp_arange1_nb, + GGML_MAX_DIMS - 3, ACL_FORMAT_ND); + + aclnn_arange(ctx, tmp_arange1_tensor, start, stop, step, n_elements_arange); + + aclTensor* tmp_arange2_tensor = nullptr; + if (n_heads_log2_floor < ne2_ne3) { + // arange2: [1, ..., 2 * (k - n_heads_log2_floor) + 1) + start = 1; + stop = 2 * (ne2_ne3 - n_heads_log2_floor) + 1; + step = 2; + n_elements_arange = ne2_ne3 - n_heads_log2_floor; + int64_t tmp_arange2_ne[] = {ne2_ne3 - n_heads_log2_floor}; + size_t tmp_arange2_nb[] = {faElemSize}; + + aclTensor* tmp_arange2_tensor = ggml_cann_create_tensor( + (char*)tmp_arange_buffer + + n_heads_log2_floor * faElemSize, + faDataType, faElemSize, + tmp_arange2_ne, tmp_arange2_nb, GGML_MAX_DIMS - 3, ACL_FORMAT_ND); + aclnn_arange(ctx, tmp_arange2_tensor, start, stop, step, + n_elements_arange); + } + + // init mk_base + ggml_cann_pool_alloc mk_base_allocator(ctx.pool(), + ne2_ne3 * faElemSize); + void* tmp_mk_base_buffer = mk_base_allocator.get(); + int64_t tmp_mk_base1_ne[] = {n_heads_log2_floor}; + size_t tmp_mk_base1_nb[] = {faElemSize}; + aclTensor* tmp_mk_base1_tensor = ggml_cann_create_tensor( + tmp_mk_base_buffer, faDataType, faElemSize, + tmp_mk_base1_ne, tmp_mk_base1_nb, + GGML_MAX_DIMS - 3, ACL_FORMAT_ND); + + aclnn_fill_scalar(ctx, m0, tmp_mk_base1_tensor); + + aclTensor* tmp_mk_base2_tensor = nullptr; + if (n_heads_log2_floor < ne2_ne3) { + int64_t tmp_mk_base2_ne[] = {ne2_ne3 - n_heads_log2_floor}; + size_t tmp_mk_base2_nb[] = {faElemSize}; + aclTensor* tmp_mk_base2_tensor = ggml_cann_create_tensor( + (char*)tmp_mk_base_buffer + + n_heads_log2_floor * faElemSize, + faDataType, faElemSize, + tmp_mk_base2_ne, tmp_mk_base2_nb, GGML_MAX_DIMS - 3, ACL_FORMAT_ND); + aclnn_fill_scalar(ctx, m1, tmp_mk_base2_tensor); + } + + // init mk + int64_t tmp_mk_base_ne[] = {ne2_ne3}; + size_t tmp_mk_base_nb[] = {faElemSize}; + aclTensor* tmp_mk_base_tensor = ggml_cann_create_tensor( + tmp_mk_base_buffer, faDataType, faElemSize, + tmp_mk_base_ne, tmp_mk_base_nb, + GGML_MAX_DIMS - 3, ACL_FORMAT_ND); + aclTensor* tmp_arange_tensor = ggml_cann_create_tensor( + tmp_arange_buffer, faDataType, faElemSize, + tmp_mk_base_ne, tmp_mk_base_nb, + GGML_MAX_DIMS - 3, ACL_FORMAT_ND); + aclnn_pow_tensor_tensor(ctx, tmp_mk_base_tensor, tmp_arange_tensor); + + // reshape mk + int64_t tmp_mk_ne[] = {1, 1, src0->ne[2], src0->ne[3]}; + size_t tmp_mk_nb[GGML_MAX_DIMS]; + tmp_mk_nb[0] = faElemSize; + for (int i = 1; i < GGML_MAX_DIMS; i++) { + tmp_mk_nb[i] = tmp_mk_nb[i - 1] * tmp_mk_ne[i - 1]; + } + aclTensor* tmp_mk_tensor = ggml_cann_create_tensor( + tmp_mk_base_buffer, faDataType, faElemSize, + tmp_mk_ne, tmp_mk_nb, GGML_MAX_DIMS, + ACL_FORMAT_ND); + GGML_CANN_CALL_ACLNN_OP(ctx, InplaceMul, bcast_pse_tensor, tmp_mk_tensor); + + ggml_cann_release_resources(ctx, tmp_arange1_tensor, tmp_arange2_tensor, + tmp_mk_base1_tensor, tmp_mk_base2_tensor, tmp_mk_base_tensor, + tmp_arange_tensor, tmp_mk_tensor); + } } #ifdef DEBUG @@ -2931,4 +3029,4 @@ void ggml_cann_flash_attn_ext(ggml_backend_cann_context& ctx, ggml_tensor* dst){ ggml_cann_release_resources(ctx, acl_src0_f16_tensor, acl_src1_f16_tensor, acl_src2_f16_tensor, acl_dst_f16_tensor, acl_dst_tensor); if(src3) ggml_cann_release_resources(ctx, bcast_pse_tensor); -} \ No newline at end of file +} From c8c2908bfc9cadc2abe0973521dedecd657dcf63 Mon Sep 17 00:00:00 2001 From: Bizhao Shi Date: Mon, 19 May 2025 13:12:29 +0800 Subject: [PATCH 06/16] cann: add the constrints of softcap --- ggml/src/ggml-cann/aclnn_ops.cpp | 669 +++++++++++++------------------ ggml/src/ggml-cann/ifa.py | 43 -- 2 files changed, 278 insertions(+), 434 deletions(-) delete mode 100644 ggml/src/ggml-cann/ifa.py diff --git a/ggml/src/ggml-cann/aclnn_ops.cpp b/ggml/src/ggml-cann/aclnn_ops.cpp index 7c18dd5748425..f6d80cb03e519 100755 --- a/ggml/src/ggml-cann/aclnn_ops.cpp +++ b/ggml/src/ggml-cann/aclnn_ops.cpp @@ -82,11 +82,13 @@ #include "ggml-cann/acl_tensor.h" #include "ggml-impl.h" +#include "ggml.h" #define GGML_COMMON_DECL_C #include "../ggml-common.h" + void bcast_shape(ggml_tensor * src0, ggml_tensor * src1, ggml_tensor * dst, aclTensor ** acl_src0, aclTensor ** acl_src1, aclTensor ** acl_dst) { GGML_ASSERT(ggml_are_same_shape(src0, dst) && ggml_can_repeat(src1, src0)); @@ -2597,85 +2599,6 @@ void ggml_cann_step(ggml_backend_cann_context& ctx, ggml_tensor* dst){ ggml_cann_release_resources(ctx, acl_src, acl_dst, alpha); } -#ifdef DEBUG -static int cnt = 0; - -static void PrintOutResultShort(int64_t ne[GGML_MAX_DIMS], void** deviceAddr, std::string s) { - auto size = ne[0] * ne[1] * ne[2] * ne[3]; - std::vector resultData(size, 0); - auto ret = aclrtMemcpy(resultData.data(), resultData.size() * sizeof(resultData[0]), - *deviceAddr, size * sizeof(resultData[0]), ACL_MEMCPY_DEVICE_TO_HOST); - // CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("copy result from device to host failed. ERROR: %d\n", ret); return); - - // 打开文件 - std::string filename = "output_acl_short_" + std::to_string(cnt) + "_" + s + ".txt"; - cnt++; - std::ofstream outFile(filename); - - // 将数据写入文件 - for(size_t i = 0; i < size; ++i){ - outFile << GGML_FP16_TO_FP32(resultData[i]) << " "; - if(i > 0 && i % ne[0] == 0){ - outFile << "\n"; - } - } - outFile << std::endl << std::endl; - // 关闭文件 - outFile.close(); -} - - -static void PrintOutResultChar(int64_t ne[GGML_MAX_DIMS], void** deviceAddr, std::string s) { - auto size = ne[0] * ne[1] * ne[2] * ne[3]; - std::vector resultData(size, 0); - auto ret = aclrtMemcpy(resultData.data(), resultData.size() * sizeof(resultData[0]), - *deviceAddr, size * sizeof(resultData[0]), ACL_MEMCPY_DEVICE_TO_HOST); - // CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("copy result from device to host failed. ERROR: %d\n", ret); return); - - // 打开文件 - std::string filename = "output_acl_char_" + std::to_string(cnt) + "_" + s + ".txt"; - cnt++; - std::ofstream outFile(filename); - - // 将数据写入文件 - for(size_t i = 0; i < size; ++i){ - outFile << int(resultData[i]) << " "; - if(i > 0 && i % ne[0] == 0){ - outFile << "\n"; - } - } - outFile << std::endl << std::endl; - // 关闭文件 - outFile.close(); -} - -static void PrintOutResultFloat(int64_t ne[GGML_MAX_DIMS], void** deviceAddr, std::string s) { - auto size = ne[0] * ne[1] * ne[2] * ne[3]; - std::vector resultData(size, 0); - auto ret = aclrtMemcpy(resultData.data(), resultData.size() * sizeof(resultData[0]), - *deviceAddr, size * sizeof(resultData[0]), ACL_MEMCPY_DEVICE_TO_HOST); - // CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("copy result from device to host failed. ERROR: %d\n", ret); return); - - // 打开文件 - std::string filename = "output_acl_short_" + std::to_string(cnt) + "_" + s + ".txt"; - cnt++; - std::ofstream outFile(filename); - - // 将数据写入文件 - for(size_t i = 0; i < size; ++i){ - outFile << float(resultData[i]) << " "; - if(i > 0 && i % ne[0] == 0){ - outFile << "\n"; - } - } - outFile << std::endl << std::endl; - // 关闭文件 - outFile.close(); -} - - -#endif - void ggml_cann_flash_attn_ext(ggml_backend_cann_context& ctx, ggml_tensor* dst){ ggml_tensor* src0 = dst->src[0]; // q, fp32 @@ -2683,350 +2606,314 @@ void ggml_cann_flash_attn_ext(ggml_backend_cann_context& ctx, ggml_tensor* dst){ ggml_tensor* src2 = dst->src[2]; // v, fp16 ggml_tensor* src3 = dst->src[3]; // mask, fp16 - float maxBias = 0.0; - float scaleValue = 1.0; - float logitSoftcap = 0.0; + float maxBias = 0.0f; + float scaleValue = 1.0f; + float logitSoftcap = 0.0f; memcpy(&scaleValue, (float*)dst->op_params + 0, sizeof(float)); memcpy(&maxBias, (float*)dst->op_params + 1, sizeof(float)); memcpy(&logitSoftcap, (float*)dst->op_params + 2, sizeof(float)); - - size_t faElemSize = sizeof(uint16_t); - auto faDataType = ACL_FLOAT16; //ACL_BF16; - - aclTensor* acl_src0_f16_tensor = nullptr; - aclTensor* acl_src1_f16_tensor = nullptr; - aclTensor* acl_src2_f16_tensor = nullptr; - aclTensor* acl_src3_f16_tensor = nullptr; - aclTensor* acl_dst_f16_tensor = nullptr; - - // Step 1: cast the src0 (Query) to fp16 - ggml_cann_pool_alloc src0_f16_allocator(ctx.pool()); - void* src0_f16_buffer = nullptr; - - if(ggml_cann_type_mapping(src0->type) != faDataType){ - aclTensor* acl_src0_f32_tensor = ggml_cann_create_tensor(src0); - src0_f16_buffer = src0_f16_allocator.alloc(ggml_nelements(src0) * faElemSize); - - int64_t* src0_f16_ne = src0->ne; - size_t src0_f16_nb[GGML_MAX_DIMS]; - src0_f16_nb[0] = sizeof(uint16_t); - for(int i = 1; i < GGML_MAX_DIMS; ++i){ - src0_f16_nb[i] = src0_f16_nb[i - 1] * src0_f16_ne[i - 1]; - } - acl_src0_f16_tensor = ggml_cann_create_tensor( - src0_f16_buffer, faDataType, faElemSize, - src0_f16_ne, src0_f16_nb, GGML_MAX_DIMS - ); - aclnn_cast(ctx, acl_src0_f32_tensor, acl_src0_f16_tensor, faDataType); - ggml_cann_release_resources(ctx, acl_src0_f32_tensor); - }else{ - acl_src0_f16_tensor = ggml_cann_create_tensor(src0); - } + if(logitSoftcap == 0.0f){ + size_t faElemSize = sizeof(uint16_t); + auto faDataType = ACL_FLOAT16; //ACL_BF16; - acl_src1_f16_tensor = ggml_cann_create_tensor(src1); - acl_src2_f16_tensor = ggml_cann_create_tensor(src2); - -#ifdef DEBUG - PrintOutResultShort(src0->ne, &(src0_f16_buffer), "q-1"); - PrintOutResultShort(src1->ne, &(src1->data), "k-1"); - PrintOutResultShort(src2->ne, &(src2->data), "v-1"); -#endif + aclTensor* acl_src0_f16_tensor = nullptr; + aclTensor* acl_src1_f16_tensor = nullptr; + aclTensor* acl_src2_f16_tensor = nullptr; + aclTensor* acl_src3_f16_tensor = nullptr; + aclTensor* acl_dst_f16_tensor = nullptr; - ggml_cann_pool_alloc out_f16_allocator(ctx.pool()); - void* out_f16_buffer = out_f16_allocator.alloc(ggml_nelements(dst) * faElemSize); + // Step 1: cast the src0 (Query) to fp16 + ggml_cann_pool_alloc src0_f16_allocator(ctx.pool()); + void* src0_f16_buffer = nullptr; - int64_t* out_f16_ne = src0->ne; - size_t out_f16_nb[GGML_MAX_DIMS]; - out_f16_nb[0] = faElemSize; - for(int i = 1; i < GGML_MAX_DIMS; ++i){ - out_f16_nb[i] = out_f16_nb[i - 1] * out_f16_ne[i - 1]; - } + if(ggml_cann_type_mapping(src0->type) != faDataType){ + aclTensor* acl_src0_f32_tensor = ggml_cann_create_tensor(src0); + src0_f16_buffer = src0_f16_allocator.alloc( + ggml_nelements(src0) * faElemSize); - acl_dst_f16_tensor = ggml_cann_create_tensor( - out_f16_buffer, faDataType, faElemSize, - out_f16_ne, out_f16_nb, GGML_MAX_DIMS - ); + int64_t* src0_f16_ne = src0->ne; + size_t src0_f16_nb[GGML_MAX_DIMS]; + src0_f16_nb[0] = sizeof(uint16_t); + for(int i = 1; i < GGML_MAX_DIMS; ++i){ + src0_f16_nb[i] = src0_f16_nb[i - 1] * src0_f16_ne[i - 1]; + } -#ifdef DEBUG - PrintOutResultShort(src0->ne, &(src0_f16_buffer), "q-2"); - PrintOutResultShort(src1->ne, &(src1->data), "k-2"); - PrintOutResultShort(src2->ne, &(src2->data), "v-2"); - PrintOutResultShort(src0->ne, &(out_f16_buffer), "out-2"); -#endif + acl_src0_f16_tensor = ggml_cann_create_tensor( + src0_f16_buffer, faDataType, faElemSize, + src0_f16_ne, src0_f16_nb, GGML_MAX_DIMS + ); + aclnn_cast(ctx, acl_src0_f32_tensor, acl_src0_f16_tensor, faDataType); + ggml_cann_release_resources(ctx, acl_src0_f32_tensor); + }else{ + acl_src0_f16_tensor = ggml_cann_create_tensor(src0); + } - aclTensor* bcast_pse_tensor = nullptr; + acl_src1_f16_tensor = ggml_cann_create_tensor(src1); + acl_src2_f16_tensor = ggml_cann_create_tensor(src2); - int64_t bcast_pse_ne[GGML_MAX_DIMS]; - size_t bcast_pse_nb[GGML_MAX_DIMS]; - ggml_cann_pool_alloc bcast_pse_allocator(ctx.pool()); - void* bcast_pse_buffer = nullptr; - if(src3) - bcast_pse_buffer = - bcast_pse_allocator.alloc(ggml_nelements(src3) * src0->ne[2] * sizeof(uint16_t)); + ggml_cann_pool_alloc out_f16_allocator(ctx.pool()); + void* out_f16_buffer = out_f16_allocator.alloc( + ggml_nelements(dst) * faElemSize); - if(src3 != nullptr){ -#ifdef DEBUG - PrintOutResultShort(src3->ne, &(src3->data), "mask"); -#endif - // broadcast pse - if(src0->ne[1] > 1){ - aclTensor* acl_mask_f16_tensor = ggml_cann_create_tensor(src3); + int64_t* out_f16_ne = src0->ne; + size_t out_f16_nb[GGML_MAX_DIMS]; + out_f16_nb[0] = faElemSize; + for(int i = 1; i < GGML_MAX_DIMS; ++i){ + out_f16_nb[i] = out_f16_nb[i - 1] * out_f16_ne[i - 1]; + } - bcast_pse_ne[0] = src3->ne[0]; - bcast_pse_ne[1] = src3->ne[1]; - bcast_pse_ne[2] = src0->ne[2]; - bcast_pse_ne[3] = src3->ne[3]; + acl_dst_f16_tensor = ggml_cann_create_tensor( + out_f16_buffer, faDataType, faElemSize, + out_f16_ne, out_f16_nb, GGML_MAX_DIMS + ); - bcast_pse_nb[0] = sizeof(uint16_t); - for(int i = 1; i < GGML_MAX_DIMS; ++i){ - bcast_pse_nb[i] = bcast_pse_nb[i - 1] * bcast_pse_ne[i - 1]; - } + aclTensor* bcast_pse_tensor = nullptr; + + int64_t bcast_pse_ne[GGML_MAX_DIMS]; + size_t bcast_pse_nb[GGML_MAX_DIMS]; + ggml_cann_pool_alloc bcast_pse_allocator(ctx.pool()); + void* bcast_pse_buffer = nullptr; + if(src3) + bcast_pse_buffer = bcast_pse_allocator.alloc( + ggml_nelements(src3) * src0->ne[2] * sizeof(uint16_t)); + + if(src3 != nullptr){ + // broadcast pse + if(src0->ne[1] > 1){ + aclTensor* acl_mask_f16_tensor = ggml_cann_create_tensor(src3); + + bcast_pse_ne[0] = src3->ne[0]; + bcast_pse_ne[1] = src3->ne[1]; + bcast_pse_ne[2] = src0->ne[2]; + bcast_pse_ne[3] = src3->ne[3]; + + bcast_pse_nb[0] = sizeof(uint16_t); + for(int i = 1; i < GGML_MAX_DIMS; ++i){ + bcast_pse_nb[i] = bcast_pse_nb[i - 1] * bcast_pse_ne[i - 1]; + } - bcast_pse_tensor = ggml_cann_create_tensor( - bcast_pse_buffer, ACL_FLOAT16, sizeof(uint16_t), bcast_pse_ne, bcast_pse_nb, GGML_MAX_DIMS - ); + bcast_pse_tensor = ggml_cann_create_tensor( + bcast_pse_buffer, ACL_FLOAT16, sizeof(uint16_t), + bcast_pse_ne, bcast_pse_nb, GGML_MAX_DIMS); - int64_t repeats[] = {1, src0->ne[2], 1, 1}; - aclnn_repeat(ctx, acl_mask_f16_tensor, bcast_pse_tensor, repeats); -#ifdef DEBUG - PrintOutResultShort(bcast_pse_ne, &(src3->data), "repeat-1"); -#endif - ggml_cann_release_resources(ctx, acl_mask_f16_tensor); - }else{ - int64_t trunc_pse_ne[GGML_MAX_DIMS] = {src3->ne[0], src0->ne[1], src3->ne[2], src3->ne[3]}; - size_t* trunc_pse_nb = src3->nb; + int64_t repeats[] = {1, src0->ne[2], 1, 1}; + aclnn_repeat(ctx, acl_mask_f16_tensor, bcast_pse_tensor, repeats); - aclTensor* acl_mask_f16_trunc_tensor = ggml_cann_create_tensor( - src3->data, ACL_FLOAT16, sizeof(uint16_t), trunc_pse_ne, trunc_pse_nb, GGML_MAX_DIMS - ); + ggml_cann_release_resources(ctx, acl_mask_f16_tensor); + }else{ + int64_t trunc_pse_ne[GGML_MAX_DIMS] = {src3->ne[0], src0->ne[1], src3->ne[2], src3->ne[3]}; + size_t* trunc_pse_nb = src3->nb; - bcast_pse_ne[0] = src3->ne[0]; - bcast_pse_ne[1] = src0->ne[1]; - bcast_pse_ne[2] = src0->ne[2]; - bcast_pse_ne[3] = src3->ne[3]; + aclTensor* acl_mask_f16_trunc_tensor = ggml_cann_create_tensor( + src3->data, ACL_FLOAT16, sizeof(uint16_t), + trunc_pse_ne, trunc_pse_nb, GGML_MAX_DIMS); - bcast_pse_nb[0] = sizeof(uint16_t); - for(int i = 1; i < GGML_MAX_DIMS; ++i){ - bcast_pse_nb[i] = bcast_pse_nb[i - 1] * bcast_pse_ne[i - 1]; - } + bcast_pse_ne[0] = src3->ne[0]; + bcast_pse_ne[1] = src0->ne[1]; + bcast_pse_ne[2] = src0->ne[2]; + bcast_pse_ne[3] = src3->ne[3]; - bcast_pse_tensor = ggml_cann_create_tensor( - bcast_pse_buffer, ACL_FLOAT16, sizeof(uint16_t), bcast_pse_ne, bcast_pse_nb, GGML_MAX_DIMS - ); - int64_t repeats[] = {1, src0->ne[2], 1, 1}; - aclnn_repeat(ctx, acl_mask_f16_trunc_tensor, bcast_pse_tensor, repeats); -#ifdef DEBUG - PrintOutResultShort(bcast_pse_ne, &bcast_pse_buffer, "repeat-1"); -#endif - ggml_cann_release_resources(ctx, acl_mask_f16_trunc_tensor); - } + bcast_pse_nb[0] = sizeof(uint16_t); + for(int i = 1; i < GGML_MAX_DIMS; ++i){ + bcast_pse_nb[i] = bcast_pse_nb[i - 1] * bcast_pse_ne[i - 1]; + } - if(maxBias != 0.0f){ - // alibi - const int64_t ne2_ne3 = src0->ne[2] * src0->ne[3]; - const int64_t n_head = src0->ne[2]; - const int n_heads_log2_floor = 1u << (uint32_t)floor(log2(n_head)); - float m0 = powf(2.0f, -(maxBias) / n_heads_log2_floor); - float m1 = powf(2.0f, -(maxBias / 2.0f) / n_heads_log2_floor); - // init arange - ggml_cann_pool_alloc arange_allocator(ctx.pool(), - ne2_ne3 * faElemSize); - void* tmp_arange_buffer = arange_allocator.get(); - - // arange1: [1, ..., n_heads_log2_floor+1) - float start = 1; - float stop = n_heads_log2_floor + 1; - float step = 1; - int64_t n_elements_arange = n_heads_log2_floor; - - int64_t tmp_arange1_ne[] = {n_heads_log2_floor}; - size_t tmp_arange1_nb[] = {faElemSize}; - aclTensor* tmp_arange1_tensor = ggml_cann_create_tensor( - tmp_arange_buffer, faDataType, faElemSize, - tmp_arange1_ne, tmp_arange1_nb, - GGML_MAX_DIMS - 3, ACL_FORMAT_ND); - - aclnn_arange(ctx, tmp_arange1_tensor, start, stop, step, n_elements_arange); - - aclTensor* tmp_arange2_tensor = nullptr; - if (n_heads_log2_floor < ne2_ne3) { - // arange2: [1, ..., 2 * (k - n_heads_log2_floor) + 1) - start = 1; - stop = 2 * (ne2_ne3 - n_heads_log2_floor) + 1; - step = 2; - n_elements_arange = ne2_ne3 - n_heads_log2_floor; - int64_t tmp_arange2_ne[] = {ne2_ne3 - n_heads_log2_floor}; - size_t tmp_arange2_nb[] = {faElemSize}; - - aclTensor* tmp_arange2_tensor = ggml_cann_create_tensor( - (char*)tmp_arange_buffer + - n_heads_log2_floor * faElemSize, - faDataType, faElemSize, - tmp_arange2_ne, tmp_arange2_nb, GGML_MAX_DIMS - 3, ACL_FORMAT_ND); - aclnn_arange(ctx, tmp_arange2_tensor, start, stop, step, - n_elements_arange); - } + bcast_pse_tensor = ggml_cann_create_tensor( + bcast_pse_buffer, ACL_FLOAT16, sizeof(uint16_t), + bcast_pse_ne, bcast_pse_nb, GGML_MAX_DIMS); - // init mk_base - ggml_cann_pool_alloc mk_base_allocator(ctx.pool(), - ne2_ne3 * faElemSize); - void* tmp_mk_base_buffer = mk_base_allocator.get(); - int64_t tmp_mk_base1_ne[] = {n_heads_log2_floor}; - size_t tmp_mk_base1_nb[] = {faElemSize}; - aclTensor* tmp_mk_base1_tensor = ggml_cann_create_tensor( - tmp_mk_base_buffer, faDataType, faElemSize, - tmp_mk_base1_ne, tmp_mk_base1_nb, - GGML_MAX_DIMS - 3, ACL_FORMAT_ND); - - aclnn_fill_scalar(ctx, m0, tmp_mk_base1_tensor); - - aclTensor* tmp_mk_base2_tensor = nullptr; - if (n_heads_log2_floor < ne2_ne3) { - int64_t tmp_mk_base2_ne[] = {ne2_ne3 - n_heads_log2_floor}; - size_t tmp_mk_base2_nb[] = {faElemSize}; - aclTensor* tmp_mk_base2_tensor = ggml_cann_create_tensor( - (char*)tmp_mk_base_buffer + - n_heads_log2_floor * faElemSize, - faDataType, faElemSize, - tmp_mk_base2_ne, tmp_mk_base2_nb, GGML_MAX_DIMS - 3, ACL_FORMAT_ND); - aclnn_fill_scalar(ctx, m1, tmp_mk_base2_tensor); - } + int64_t repeats[] = {1, src0->ne[2], 1, 1}; + aclnn_repeat(ctx, acl_mask_f16_trunc_tensor, bcast_pse_tensor, repeats); - // init mk - int64_t tmp_mk_base_ne[] = {ne2_ne3}; - size_t tmp_mk_base_nb[] = {faElemSize}; - aclTensor* tmp_mk_base_tensor = ggml_cann_create_tensor( - tmp_mk_base_buffer, faDataType, faElemSize, - tmp_mk_base_ne, tmp_mk_base_nb, - GGML_MAX_DIMS - 3, ACL_FORMAT_ND); - aclTensor* tmp_arange_tensor = ggml_cann_create_tensor( - tmp_arange_buffer, faDataType, faElemSize, - tmp_mk_base_ne, tmp_mk_base_nb, - GGML_MAX_DIMS - 3, ACL_FORMAT_ND); - aclnn_pow_tensor_tensor(ctx, tmp_mk_base_tensor, tmp_arange_tensor); - - // reshape mk - int64_t tmp_mk_ne[] = {1, 1, src0->ne[2], src0->ne[3]}; - size_t tmp_mk_nb[GGML_MAX_DIMS]; - tmp_mk_nb[0] = faElemSize; - for (int i = 1; i < GGML_MAX_DIMS; i++) { - tmp_mk_nb[i] = tmp_mk_nb[i - 1] * tmp_mk_ne[i - 1]; + ggml_cann_release_resources(ctx, acl_mask_f16_trunc_tensor); } - aclTensor* tmp_mk_tensor = ggml_cann_create_tensor( - tmp_mk_base_buffer, faDataType, faElemSize, - tmp_mk_ne, tmp_mk_nb, GGML_MAX_DIMS, - ACL_FORMAT_ND); - GGML_CANN_CALL_ACLNN_OP(ctx, InplaceMul, bcast_pse_tensor, tmp_mk_tensor); - - ggml_cann_release_resources(ctx, tmp_arange1_tensor, tmp_arange2_tensor, - tmp_mk_base1_tensor, tmp_mk_base2_tensor, tmp_mk_base_tensor, - tmp_arange_tensor, tmp_mk_tensor); - } - } - -#ifdef DEBUG - PrintOutResultShort(src0->ne, &(src0_f16_buffer), "q-3"); - PrintOutResultShort(src1->ne, &(src1->data), "k-3"); - PrintOutResultShort(src2->ne, &(src2->data), "v-3"); - PrintOutResultShort(src0->ne, &(out_f16_buffer), "out-3"); -#endif + if(maxBias != 0.0f){ + // alibi + const int64_t ne2_ne3 = src0->ne[2] * src0->ne[3]; + const int64_t n_head = src0->ne[2]; + const int n_heads_log2_floor = 1u << (uint32_t)floor(log2(n_head)); + float m0 = powf(2.0f, -(maxBias) / n_heads_log2_floor); + float m1 = powf(2.0f, -(maxBias / 2.0f) / n_heads_log2_floor); + // init arange + ggml_cann_pool_alloc arange_allocator(ctx.pool(), + ne2_ne3 * faElemSize); + void* tmp_arange_buffer = arange_allocator.get(); + + // arange1: [1, ..., n_heads_log2_floor+1) + float start = 1; + float stop = n_heads_log2_floor + 1; + float step = 1; + int64_t n_elements_arange = n_heads_log2_floor; + + int64_t tmp_arange1_ne[] = {n_heads_log2_floor}; + size_t tmp_arange1_nb[] = {faElemSize}; + aclTensor* tmp_arange1_tensor = ggml_cann_create_tensor( + tmp_arange_buffer, faDataType, faElemSize, + tmp_arange1_ne, tmp_arange1_nb, + GGML_MAX_DIMS - 3, ACL_FORMAT_ND); + + aclnn_arange(ctx, tmp_arange1_tensor, start, stop, step, n_elements_arange); + + aclTensor* tmp_arange2_tensor = nullptr; + if (n_heads_log2_floor < ne2_ne3) { + // arange2: [1, ..., 2 * (k - n_heads_log2_floor) + 1) + start = 1; + stop = 2 * (ne2_ne3 - n_heads_log2_floor) + 1; + step = 2; + n_elements_arange = ne2_ne3 - n_heads_log2_floor; + int64_t tmp_arange2_ne[] = {ne2_ne3 - n_heads_log2_floor}; + size_t tmp_arange2_nb[] = {faElemSize}; + + aclTensor* tmp_arange2_tensor = ggml_cann_create_tensor( + (char*)tmp_arange_buffer + + n_heads_log2_floor * faElemSize, + faDataType, faElemSize, + tmp_arange2_ne, tmp_arange2_nb, GGML_MAX_DIMS - 3, ACL_FORMAT_ND); + aclnn_arange(ctx, tmp_arange2_tensor, start, stop, step, + n_elements_arange); + } + // init mk_base + ggml_cann_pool_alloc mk_base_allocator(ctx.pool(), + ne2_ne3 * faElemSize); + void* tmp_mk_base_buffer = mk_base_allocator.get(); + int64_t tmp_mk_base1_ne[] = {n_heads_log2_floor}; + size_t tmp_mk_base1_nb[] = {faElemSize}; + aclTensor* tmp_mk_base1_tensor = ggml_cann_create_tensor( + tmp_mk_base_buffer, faDataType, faElemSize, + tmp_mk_base1_ne, tmp_mk_base1_nb, + GGML_MAX_DIMS - 3, ACL_FORMAT_ND); + + aclnn_fill_scalar(ctx, m0, tmp_mk_base1_tensor); + + aclTensor* tmp_mk_base2_tensor = nullptr; + if (n_heads_log2_floor < ne2_ne3) { + int64_t tmp_mk_base2_ne[] = {ne2_ne3 - n_heads_log2_floor}; + size_t tmp_mk_base2_nb[] = {faElemSize}; + aclTensor* tmp_mk_base2_tensor = ggml_cann_create_tensor( + (char*)tmp_mk_base_buffer + + n_heads_log2_floor * faElemSize, + faDataType, faElemSize, + tmp_mk_base2_ne, tmp_mk_base2_nb, GGML_MAX_DIMS - 3, ACL_FORMAT_ND); + aclnn_fill_scalar(ctx, m1, tmp_mk_base2_tensor); + } - int kvTensorNum = 1; - aclTensor* acl_q_tensor = acl_src0_f16_tensor; - aclTensor* acl_k_tensors[] = {acl_src1_f16_tensor}; - aclTensor* acl_v_tensors[] = {acl_src2_f16_tensor}; - auto acl_k_tensor_list = aclCreateTensorList(acl_k_tensors, kvTensorNum); - auto acl_v_tensor_list = aclCreateTensorList(acl_v_tensors, kvTensorNum); - - int64_t numHeads = src0->ne[2]; // N - int64_t numKeyValueHeads = src1->ne[2]; - // double scaleValue = 1 / sqrt(src0->ne[0]); // 1/sqrt(d) - int64_t preTokens = 65535; - int64_t nextTokens = 65535; - char layout[5] = {'B', 'N', 'S', 'D', 0}; - int64_t sparseMode = 0; - int64_t innerPrecise = 2; - int64_t blockSize = 0; - int64_t antiquantMode = 0; - bool softmaxLseFlag = false; - int64_t keyAntiquantMode = 0; - int64_t valueAntiquantMode = 0; - - - // Refer to https://gitee.com/ascend/cann-ops-adv/blob/master/docs/FusedInferAttentionScoreV2.md - -#ifdef DEBUG - PrintOutResultShort(src0->ne, &(src0_f16_buffer), "q-4"); - PrintOutResultShort(src1->ne, &(src1->data), "k-4"); - PrintOutResultShort(src2->ne, &(src2->data), "v-4"); - PrintOutResultShort(src0->ne, &(out_f16_buffer), "out-4"); - if(src3) - PrintOutResultShort(bcast_pse_ne, &bcast_pse_buffer, "repeat-4"); -#endif + // init mk + int64_t tmp_mk_base_ne[] = {ne2_ne3}; + size_t tmp_mk_base_nb[] = {faElemSize}; + aclTensor* tmp_mk_base_tensor = ggml_cann_create_tensor( + tmp_mk_base_buffer, faDataType, faElemSize, + tmp_mk_base_ne, tmp_mk_base_nb, + GGML_MAX_DIMS - 3, ACL_FORMAT_ND); + aclTensor* tmp_arange_tensor = ggml_cann_create_tensor( + tmp_arange_buffer, faDataType, faElemSize, + tmp_mk_base_ne, tmp_mk_base_nb, + GGML_MAX_DIMS - 3, ACL_FORMAT_ND); + aclnn_pow_tensor_tensor(ctx, tmp_mk_base_tensor, tmp_arange_tensor); + + // reshape mk + int64_t tmp_mk_ne[] = {1, 1, src0->ne[2], src0->ne[3]}; + size_t tmp_mk_nb[GGML_MAX_DIMS]; + tmp_mk_nb[0] = faElemSize; + for (int i = 1; i < GGML_MAX_DIMS; i++) { + tmp_mk_nb[i] = tmp_mk_nb[i - 1] * tmp_mk_ne[i - 1]; + } + aclTensor* tmp_mk_tensor = ggml_cann_create_tensor( + tmp_mk_base_buffer, faDataType, faElemSize, + tmp_mk_ne, tmp_mk_nb, GGML_MAX_DIMS, + ACL_FORMAT_ND); + GGML_CANN_CALL_ACLNN_OP(ctx, InplaceMul, bcast_pse_tensor, tmp_mk_tensor); + + ggml_cann_release_resources(ctx, tmp_arange1_tensor, tmp_arange2_tensor, + tmp_mk_base1_tensor, tmp_mk_base2_tensor, tmp_mk_base_tensor, + tmp_arange_tensor, tmp_mk_tensor); + } + } - GGML_CANN_CALL_ACLNN_OP(ctx, FusedInferAttentionScoreV2, - acl_q_tensor, acl_k_tensor_list, acl_v_tensor_list, // q, k, v - bcast_pse_tensor, nullptr, // pse, mask - nullptr, nullptr, // actSeqLen, actSeqLenkv - nullptr, nullptr, // deqScale1, quantScale1 - nullptr, nullptr, nullptr, // deqScale2, quantScale2, quantOffset2 - nullptr, nullptr, // antiquantScale, antiquantOffset - nullptr, // blockTable - nullptr, nullptr, // qPadSize, kvPadSize - nullptr, nullptr, // kAntiquantScale, kAntiQuantOffset - nullptr, nullptr, // vAntiquantScale, vAntiQuantOffset - nullptr, nullptr, nullptr, // kSharedPrefix, vSharedPrefix, actSharedLen - numHeads, scaleValue, // heads, scaleValue - preTokens, nextTokens, // preTokens, nextTokens - layout, // inputLayout - numKeyValueHeads, // numKVHeads - sparseMode, innerPrecise, // sparseMode, innerPrecise - blockSize, antiquantMode, // blockSize, antiquantMode - softmaxLseFlag, // softmaxLseFlag - keyAntiquantMode, valueAntiquantMode, // keyAntiqMode, valueAntiqMode - acl_dst_f16_tensor, // attentionOut - nullptr // softmaxLse - ); - -#ifdef DEBUG - PrintOutResultShort(src0->ne, &(src0_f16_buffer), "q-5"); - PrintOutResultShort(src1->ne, &(src1->data), "k-5"); - PrintOutResultShort(src2->ne, &(src2->data), "v-5"); - if(src3) - PrintOutResultShort(bcast_pse_ne, &bcast_pse_buffer, "repeat-5"); - PrintOutResultShort(out_f16_ne, &out_f16_buffer, "out-5"); -#endif + int kvTensorNum = 1; + aclTensor* acl_q_tensor = acl_src0_f16_tensor; + aclTensor* acl_k_tensors[] = {acl_src1_f16_tensor}; + aclTensor* acl_v_tensors[] = {acl_src2_f16_tensor}; + auto acl_k_tensor_list = aclCreateTensorList(acl_k_tensors, kvTensorNum); + auto acl_v_tensor_list = aclCreateTensorList(acl_v_tensors, kvTensorNum); + + int64_t numHeads = src0->ne[2]; // N + int64_t numKeyValueHeads = src1->ne[2]; + // double scaleValue = 1 / sqrt(src0->ne[0]); // 1/sqrt(d) + int64_t preTokens = 65535; + int64_t nextTokens = 65535; + char layout[5] = {'B', 'N', 'S', 'D', 0}; + int64_t sparseMode = 0; + int64_t innerPrecise = 2; + int64_t blockSize = 0; + int64_t antiquantMode = 0; + bool softmaxLseFlag = false; + int64_t keyAntiquantMode = 0; + int64_t valueAntiquantMode = 0; + + // Refer to https://gitee.com/ascend/cann-ops-adv/blob/master/docs/FusedInferAttentionScoreV2.md + + + GGML_CANN_CALL_ACLNN_OP(ctx, FusedInferAttentionScoreV2, + acl_q_tensor, acl_k_tensor_list, acl_v_tensor_list, // q, k, v + bcast_pse_tensor, nullptr, // pse, mask + nullptr, nullptr, // actSeqLen, actSeqLenkv + nullptr, nullptr, // deqScale1, quantScale1 + nullptr, nullptr, nullptr, // deqScale2, quantScale2, quantOffset2 + nullptr, nullptr, // antiquantScale, antiquantOffset + nullptr, // blockTable + nullptr, nullptr, // qPadSize, kvPadSize + nullptr, nullptr, // kAntiquantScale, kAntiQuantOffset + nullptr, nullptr, // vAntiquantScale, vAntiQuantOffset + nullptr, nullptr, nullptr, // kSharedPrefix, vSharedPrefix, actSharedLen + numHeads, scaleValue, // heads, scaleValue + preTokens, nextTokens, // preTokens, nextTokens + layout, // inputLayout + numKeyValueHeads, // numKVHeads + sparseMode, innerPrecise, // sparseMode, innerPrecise + blockSize, antiquantMode, // blockSize, antiquantMode + softmaxLseFlag, // softmaxLseFlag + keyAntiquantMode, valueAntiquantMode, // keyAntiqMode, valueAntiqMode + acl_dst_f16_tensor, // attentionOut + nullptr // softmaxLse + ); - // Step 5: post-processing, permute and cast to f32 - int64_t new_dim[] = {0, 2, 1, 3}; - aclTensor* acl_dst_tensor = ggml_cann_create_tensor(dst); + // Step 5: post-processing, permute and cast to f32 + int64_t new_dim[] = {0, 2, 1, 3}; + aclTensor* acl_dst_tensor = ggml_cann_create_tensor(dst); - if(ggml_cann_type_mapping(dst->type) != faDataType){ - ggml_cann_pool_alloc perm_out_f16_allocator(ctx.pool()); - perm_out_f16_allocator.alloc(ggml_nelements(dst) * faElemSize); - void* perm_out_f16_buffer = perm_out_f16_allocator.get(); + if(ggml_cann_type_mapping(dst->type) != faDataType){ + ggml_cann_pool_alloc perm_out_f16_allocator(ctx.pool()); + perm_out_f16_allocator.alloc(ggml_nelements(dst) * faElemSize); + void* perm_out_f16_buffer = perm_out_f16_allocator.get(); - int64_t* perm_out_f16_ne = dst->ne; - size_t perm_out_f16_nb[GGML_MAX_DIMS]; - perm_out_f16_nb[0] = faElemSize; - for(int i = 1; i < GGML_MAX_DIMS; ++i){ - perm_out_f16_nb[i] = perm_out_f16_nb[i - 1] * perm_out_f16_ne[i - 1]; + int64_t* perm_out_f16_ne = dst->ne; + size_t perm_out_f16_nb[GGML_MAX_DIMS]; + perm_out_f16_nb[0] = faElemSize; + for(int i = 1; i < GGML_MAX_DIMS; ++i){ + perm_out_f16_nb[i] = perm_out_f16_nb[i - 1] * perm_out_f16_ne[i - 1]; + } + aclTensor* acl_perm_out_f16_tensor = ggml_cann_create_tensor( + perm_out_f16_buffer, faDataType, faElemSize, + perm_out_f16_ne, perm_out_f16_nb, GGML_MAX_DIMS); + aclnn_permute(ctx, acl_dst_f16_tensor, acl_perm_out_f16_tensor, new_dim, GGML_MAX_DIMS); + aclnn_cast(ctx, + acl_perm_out_f16_tensor, acl_dst_tensor, ggml_cann_type_mapping(dst->type)); + ggml_cann_release_resources(ctx, acl_perm_out_f16_tensor); + }else{ + // only need to permute + aclnn_permute(ctx, acl_dst_f16_tensor, acl_dst_tensor, new_dim, GGML_MAX_DIMS); } - aclTensor* acl_perm_out_f16_tensor = ggml_cann_create_tensor( - perm_out_f16_buffer, faDataType, faElemSize, - perm_out_f16_ne, perm_out_f16_nb, GGML_MAX_DIMS); - aclnn_permute(ctx, acl_dst_f16_tensor, acl_perm_out_f16_tensor, new_dim, GGML_MAX_DIMS); - aclnn_cast(ctx, - acl_perm_out_f16_tensor, acl_dst_tensor, ggml_cann_type_mapping(dst->type)); - ggml_cann_release_resources(ctx, acl_perm_out_f16_tensor); + ggml_cann_release_resources(ctx, acl_src0_f16_tensor, + acl_src1_f16_tensor, + acl_src2_f16_tensor, + acl_dst_f16_tensor, + acl_dst_tensor); + if(src3) + ggml_cann_release_resources(ctx, bcast_pse_tensor); }else{ - // only need to permute - aclnn_permute(ctx, acl_dst_f16_tensor, acl_dst_tensor, new_dim, GGML_MAX_DIMS); + throw std::runtime_error("Function not implemented"); } - ggml_cann_release_resources(ctx, acl_src0_f16_tensor, acl_src1_f16_tensor, acl_src2_f16_tensor, acl_dst_f16_tensor, acl_dst_tensor); - if(src3) - ggml_cann_release_resources(ctx, bcast_pse_tensor); } diff --git a/ggml/src/ggml-cann/ifa.py b/ggml/src/ggml-cann/ifa.py deleted file mode 100644 index b01f497ef7994..0000000000000 --- a/ggml/src/ggml-cann/ifa.py +++ /dev/null @@ -1,43 +0,0 @@ -# 单算子调用方式 -import torch -import torch_npu -import math - -def load_float_array_to_tensor(file_path, shape, dtype): - with open(file_path, 'r') as file: - # 读取文件内容并按空格分割 - data = file.read().strip().split() - # 将字符串转换为浮点数 - float_array = [float(num) for num in data] - # 转换为 PyTorch 张量 - tensor = torch.tensor(float_array, dtype=dtype).reshape(shape).npu() - return tensor - -batch = 1 -nhead_q = 4 -nhead_kv = nhead_q -seq_q = 1 -dims = 64 -seq_kv = 512 -layout="BNSD" - -scale_value = 1 / pow(dims, 0.5) - -q_tensor = load_float_array_to_tensor("/data/home/2101111451/pr/llama.cpp/output_acl_short_0_q.txt", - (batch, nhead_q, seq_q, dims), torch.float16) -k_tensor = load_float_array_to_tensor("/data/home/2101111451/pr/llama.cpp/output_acl_short_3_k.txt", - (batch, nhead_kv, seq_kv, dims), torch.float16) - -v_tensor = load_float_array_to_tensor("/data/home/2101111451/pr/llama.cpp/output_acl_short_4_v.txt", - (batch, nhead_kv, seq_kv, dims), torch.float16) - -pse_tensor = load_float_array_to_tensor("/data/home/2101111451/pr/llama.cpp/output_acl_short_1_mask.txt", - (1, 1, -1, seq_kv), torch.float16) - -print(q_tensor.shape, k_tensor.shape, v_tensor.shape, pse_tensor.shape) - -# 调用IFA算子 -out = torch_npu.npu_incre_flash_attention(q_tensor, k_tensor, v_tensor, pse_shift=pse_tensor, - num_heads=nhead_q, num_key_value_heads=nhead_kv, - input_layout=layout, scale_value=scale_value) - From 47f2c64658ac3065b9ec55a883f67bafd055ebaf Mon Sep 17 00:00:00 2001 From: Bizhao Shi Date: Mon, 19 May 2025 13:42:33 +0800 Subject: [PATCH 07/16] cann: update the docs CANN.md --- docs/backend/CANN.md | 10 +++------- 1 file changed, 3 insertions(+), 7 deletions(-) mode change 100644 => 100755 docs/backend/CANN.md diff --git a/docs/backend/CANN.md b/docs/backend/CANN.md old mode 100644 new mode 100755 index 9bd2a9127eee6..ec19f1146549a --- a/docs/backend/CANN.md +++ b/docs/backend/CANN.md @@ -259,10 +259,6 @@ cmake --build build --config release Please add the **[CANN]** prefix/tag in issues/PRs titles to help the CANN-team check/address them without delay. ### Basic Flash Attention Support -The basic FA kernel with aclnnops has been added in aclnn_ops.cpp. - -Authors: Bizhao Shi (bshi@pku.edu.cn), Yuxin Yang (yxyang@pku.edu.cn), Ruiyang Ma (ruiyang@stu.pku.edu.cn), and Guojie Luo (gluo@pku.edu.cn). - - -## TODO -- Support more models and data types. +The basic FA kernel with aclnnops has been added in aclnn_ops.cpp. +Currently, the FA only supports the cases with FP16 KV tensors and NO logit softcap. +Since the aclnn interface for flash attention cannot support the logit softcap, we will only update the \ No newline at end of file From fb62f0158f7a3bd3128832b4e00d6ff4f33b7bfa Mon Sep 17 00:00:00 2001 From: Bizhao Shi Date: Mon, 19 May 2025 13:47:44 +0800 Subject: [PATCH 08/16] cann: update the docs CANN.md --- docs/backend/CANN.md | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/docs/backend/CANN.md b/docs/backend/CANN.md index ec19f1146549a..f9bf435209596 100755 --- a/docs/backend/CANN.md +++ b/docs/backend/CANN.md @@ -258,7 +258,15 @@ cmake --build build --config release ### **GitHub contribution**: Please add the **[CANN]** prefix/tag in issues/PRs titles to help the CANN-team check/address them without delay. +## Updates ### Basic Flash Attention Support The basic FA kernel with aclnnops has been added in aclnn_ops.cpp. Currently, the FA only supports the cases with FP16 KV tensors and NO logit softcap. -Since the aclnn interface for flash attention cannot support the logit softcap, we will only update the \ No newline at end of file +Since the aclnn interface for flash attention cannot support the logit softcap, we will only update the quantized version in the future. + +Authors from Peking University: Bizhao Shi (bshi@pku.edu.cn), Yuxin Yang (yxyang@pku.edu.cn), Ruiyang Ma (ruiyang@stu.pku.edu.cn), and Guojie Luo (gluo@pku.edu.cn). + +Thanks Tuo Dai and Shanni Li from Huawei Technologies Co., Ltd. + +## TODO +- Support more models and d \ No newline at end of file From b266beb203e75995145d42ef46ffd558d965b083 Mon Sep 17 00:00:00 2001 From: Bizhao Shi Date: Wed, 21 May 2025 15:35:11 +0800 Subject: [PATCH 09/16] cann: fix typo of CANN.md --- docs/backend/CANN.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/backend/CANN.md b/docs/backend/CANN.md index f9bf435209596..5d655b0605974 100755 --- a/docs/backend/CANN.md +++ b/docs/backend/CANN.md @@ -269,4 +269,4 @@ Authors from Peking University: Bizhao Shi (bshi@pku.edu.cn), Yuxin Yang (yxyang Thanks Tuo Dai and Shanni Li from Huawei Technologies Co., Ltd. ## TODO -- Support more models and d \ No newline at end of file +- Support more models and data types. From 8a112f0a2ba04ebbccc66e24e8a470a109b1cf78 Mon Sep 17 00:00:00 2001 From: Bizhao Shi Date: Wed, 21 May 2025 16:04:39 +0800 Subject: [PATCH 10/16] cann: add some comments and update the CANN.md --- ggml/src/ggml-cann/aclnn_ops.cpp | 46 +++++++++++++++++--------------- ggml/src/ggml-cann/aclnn_ops.h | 2 -- 2 files changed, 24 insertions(+), 24 deletions(-) diff --git a/ggml/src/ggml-cann/aclnn_ops.cpp b/ggml/src/ggml-cann/aclnn_ops.cpp index f6d80cb03e519..bbc636fe0b2f7 100755 --- a/ggml/src/ggml-cann/aclnn_ops.cpp +++ b/ggml/src/ggml-cann/aclnn_ops.cpp @@ -65,6 +65,7 @@ #include #include #include +#include #include #include @@ -72,17 +73,7 @@ #include #include -#include -#include -#include -#include - -#include "aclnnop/aclnn_flash_attention_score.h" -#include "aclnnop/aclnn_logical_not.h" - -#include "ggml-cann/acl_tensor.h" #include "ggml-impl.h" -#include "ggml.h" #define GGML_COMMON_DECL_C @@ -2623,7 +2614,7 @@ void ggml_cann_flash_attn_ext(ggml_backend_cann_context& ctx, ggml_tensor* dst){ aclTensor* acl_src3_f16_tensor = nullptr; aclTensor* acl_dst_f16_tensor = nullptr; - // Step 1: cast the src0 (Query) to fp16 + // Step 1: cast the src0 (Query) to fp16 if needed ggml_cann_pool_alloc src0_f16_allocator(ctx.pool()); void* src0_f16_buffer = nullptr; @@ -2649,6 +2640,9 @@ void ggml_cann_flash_attn_ext(ggml_backend_cann_context& ctx, ggml_tensor* dst){ acl_src0_f16_tensor = ggml_cann_create_tensor(src0); } + // Step 2: create the acl tensors for src1 (Key), src2 (Value), + // and the direct output from FusedInferAttention + acl_src1_f16_tensor = ggml_cann_create_tensor(src1); acl_src2_f16_tensor = ggml_cann_create_tensor(src2); @@ -2668,21 +2662,23 @@ void ggml_cann_flash_attn_ext(ggml_backend_cann_context& ctx, ggml_tensor* dst){ out_f16_ne, out_f16_nb, GGML_MAX_DIMS ); - aclTensor* bcast_pse_tensor = nullptr; + // Step 3: create the PSEShift tensor if needed + // this tensor is considered as mask (f16) in the llama.cpp + + aclTensor* bcast_pse_tensor = nullptr; int64_t bcast_pse_ne[GGML_MAX_DIMS]; size_t bcast_pse_nb[GGML_MAX_DIMS]; ggml_cann_pool_alloc bcast_pse_allocator(ctx.pool()); void* bcast_pse_buffer = nullptr; - if(src3) + + if(src3 != nullptr){ bcast_pse_buffer = bcast_pse_allocator.alloc( ggml_nelements(src3) * src0->ne[2] * sizeof(uint16_t)); - - if(src3 != nullptr){ - // broadcast pse + if(src0->ne[1] > 1){ + // Case 1: broadcast pse for prefill stage with multiple head aclTensor* acl_mask_f16_tensor = ggml_cann_create_tensor(src3); - bcast_pse_ne[0] = src3->ne[0]; bcast_pse_ne[1] = src3->ne[1]; bcast_pse_ne[2] = src0->ne[2]; @@ -2702,6 +2698,7 @@ void ggml_cann_flash_attn_ext(ggml_backend_cann_context& ctx, ggml_tensor* dst){ ggml_cann_release_resources(ctx, acl_mask_f16_tensor); }else{ + // Case 2: trunc the first row and broadcast pse for decode stage with multiple head int64_t trunc_pse_ne[GGML_MAX_DIMS] = {src3->ne[0], src0->ne[1], src3->ne[2], src3->ne[3]}; size_t* trunc_pse_nb = src3->nb; @@ -2729,6 +2726,7 @@ void ggml_cann_flash_attn_ext(ggml_backend_cann_context& ctx, ggml_tensor* dst){ ggml_cann_release_resources(ctx, acl_mask_f16_trunc_tensor); } + // Compute the slope if needed. Derived from ggml_cann_softmax(). if(maxBias != 0.0f){ // alibi const int64_t ne2_ne3 = src0->ne[2] * src0->ne[3]; @@ -2832,6 +2830,7 @@ void ggml_cann_flash_attn_ext(ggml_backend_cann_context& ctx, ggml_tensor* dst){ } } + // Step 4: set the inputs for FusedInferAttention. int kvTensorNum = 1; aclTensor* acl_q_tensor = acl_src0_f16_tensor; aclTensor* acl_k_tensors[] = {acl_src1_f16_tensor}; @@ -2853,9 +2852,9 @@ void ggml_cann_flash_attn_ext(ggml_backend_cann_context& ctx, ggml_tensor* dst){ int64_t keyAntiquantMode = 0; int64_t valueAntiquantMode = 0; + // Step 5: launch the FusedInferAttentionScoreV2 kernel. // Refer to https://gitee.com/ascend/cann-ops-adv/blob/master/docs/FusedInferAttentionScoreV2.md - - + GGML_CANN_CALL_ACLNN_OP(ctx, FusedInferAttentionScoreV2, acl_q_tensor, acl_k_tensor_list, acl_v_tensor_list, // q, k, v bcast_pse_tensor, nullptr, // pse, mask @@ -2880,7 +2879,8 @@ void ggml_cann_flash_attn_ext(ggml_backend_cann_context& ctx, ggml_tensor* dst){ nullptr // softmaxLse ); - // Step 5: post-processing, permute and cast to f32 + // Step 6: post-processing, permute and cast to f32 + int64_t new_dim[] = {0, 2, 1, 3}; aclTensor* acl_dst_tensor = ggml_cann_create_tensor(dst); @@ -2911,9 +2911,11 @@ void ggml_cann_flash_attn_ext(ggml_backend_cann_context& ctx, ggml_tensor* dst){ acl_src2_f16_tensor, acl_dst_f16_tensor, acl_dst_tensor); - if(src3) + if(src3 != nullptr){ ggml_cann_release_resources(ctx, bcast_pse_tensor); + } }else{ - throw std::runtime_error("Function not implemented"); + GGML_ABORT("Function not implemented"); } } + \ No newline at end of file diff --git a/ggml/src/ggml-cann/aclnn_ops.h b/ggml/src/ggml-cann/aclnn_ops.h index a4fedc29cb680..4de24d0016950 100755 --- a/ggml/src/ggml-cann/aclnn_ops.h +++ b/ggml/src/ggml-cann/aclnn_ops.h @@ -45,8 +45,6 @@ #include #include #include -#include -#include #include "acl_tensor.h" #include "common.h" From 1779e0085578efb5d9349e15bd7687bedce9e1cc Mon Sep 17 00:00:00 2001 From: Bizhao Shi Date: Wed, 21 May 2025 16:09:28 +0800 Subject: [PATCH 11/16] cann: update the CANN.md --- docs/backend/CANN.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/backend/CANN.md b/docs/backend/CANN.md index 5d655b0605974..5fda3aba9027f 100755 --- a/docs/backend/CANN.md +++ b/docs/backend/CANN.md @@ -266,7 +266,7 @@ Since the aclnn interface for flash attention cannot support the logit softcap, Authors from Peking University: Bizhao Shi (bshi@pku.edu.cn), Yuxin Yang (yxyang@pku.edu.cn), Ruiyang Ma (ruiyang@stu.pku.edu.cn), and Guojie Luo (gluo@pku.edu.cn). -Thanks Tuo Dai and Shanni Li from Huawei Technologies Co., Ltd. +We would like to thank Tuo Dai, Shanni Li, and all of the project maintainers from Huawei Technologies Co., Ltd for their help during the code development and pull request. ## TODO - Support more models and data types. From 092ccf68438a07d597d45cf8b9aa3152e1055ed1 Mon Sep 17 00:00:00 2001 From: Bizhao Shi Date: Wed, 21 May 2025 18:15:05 +0800 Subject: [PATCH 12/16] cann: update the inner precise for fusedInferAttention --- ggml/src/ggml-cann/aclnn_ops.cpp | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/ggml/src/ggml-cann/aclnn_ops.cpp b/ggml/src/ggml-cann/aclnn_ops.cpp index bbc636fe0b2f7..722e3c7c58f36 100755 --- a/ggml/src/ggml-cann/aclnn_ops.cpp +++ b/ggml/src/ggml-cann/aclnn_ops.cpp @@ -74,6 +74,7 @@ #include #include "ggml-impl.h" +#include "ggml.h" #define GGML_COMMON_DECL_C @@ -2611,7 +2612,6 @@ void ggml_cann_flash_attn_ext(ggml_backend_cann_context& ctx, ggml_tensor* dst){ aclTensor* acl_src0_f16_tensor = nullptr; aclTensor* acl_src1_f16_tensor = nullptr; aclTensor* acl_src2_f16_tensor = nullptr; - aclTensor* acl_src3_f16_tensor = nullptr; aclTensor* acl_dst_f16_tensor = nullptr; // Step 1: cast the src0 (Query) to fp16 if needed @@ -2845,7 +2845,7 @@ void ggml_cann_flash_attn_ext(ggml_backend_cann_context& ctx, ggml_tensor* dst){ int64_t nextTokens = 65535; char layout[5] = {'B', 'N', 'S', 'D', 0}; int64_t sparseMode = 0; - int64_t innerPrecise = 2; + int64_t innerPrecise = (src0->ne[1] == 1) ? 0 : 2; int64_t blockSize = 0; int64_t antiquantMode = 0; bool softmaxLseFlag = false; @@ -2915,7 +2915,6 @@ void ggml_cann_flash_attn_ext(ggml_backend_cann_context& ctx, ggml_tensor* dst){ ggml_cann_release_resources(ctx, bcast_pse_tensor); } }else{ - GGML_ABORT("Function not implemented"); + GGML_ABORT("Function is not implemented."); } -} - \ No newline at end of file +} \ No newline at end of file From c380305b294eef5daea4f677e5a12fb117b15dfd Mon Sep 17 00:00:00 2001 From: Bizhao Shi Date: Thu, 22 May 2025 10:37:59 +0800 Subject: [PATCH 13/16] cann: update the constraints of flash_attn_ext on ggml-cann.cpp --- ggml/src/ggml-cann/ggml-cann.cpp | 27 +++++++++++++++------------ 1 file changed, 15 insertions(+), 12 deletions(-) diff --git a/ggml/src/ggml-cann/ggml-cann.cpp b/ggml/src/ggml-cann/ggml-cann.cpp index f4fd563556c9b..f50622917be3d 100755 --- a/ggml/src/ggml-cann/ggml-cann.cpp +++ b/ggml/src/ggml-cann/ggml-cann.cpp @@ -36,6 +36,7 @@ #include "ggml-backend-impl.h" #include "ggml-cann/aclnn_ops.h" #include "ggml-cann/common.h" +#include "ggml.h" #define GGML_COMMON_DECL_C @@ -2165,7 +2166,16 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev, case GGML_OP_COUNT_EQUAL: return true; case GGML_OP_FLASH_ATTN_EXT:{ - // copy from [ggml-cuda.cu] + // derived from [ggml-cuda.cu] + if(op->src[1]->type != GGML_TYPE_F16 || op->src[2]->type != GGML_TYPE_F16){ + return false; + } + if(op->src[1]->type != GGML_TYPE_F16 && op->src[1]->type != GGML_TYPE_F32 && op->src[1]->type != GGML_TYPE_BF16){ + return false; + } + if(op->type != GGML_TYPE_F16 && op->type != GGML_TYPE_F32 && op->type != GGML_TYPE_BF16){ + return false; + } if (op->src[1]->ne[0] != op->src[2]->ne[0]) { // different head sizes of K and V are not supported yet return false; @@ -2180,19 +2190,12 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev, if (op->src[0]->ne[3] != 1) { return false; } - if (op->src[1]->type == GGML_TYPE_BF16 || op->src[2]->type == GGML_TYPE_BF16) { + float logitSoftcap = 0.0f; + memcpy(&logitSoftcap, (float*)op->op_params + 2, sizeof(float)); + if(logitSoftcap != 0.0f) { return false; } - if (op->src[0]->ne[0] == 64 && op->src[1]->type == GGML_TYPE_F16) { - return true; - } - if (op->src[0]->ne[0] == 128) { - return true; - } - if (op->src[0]->ne[0] == 256 && op->src[1]->type == GGML_TYPE_F16 && op->src[2]->type == GGML_TYPE_F16) { - return true; - } - return op->src[1]->type == GGML_TYPE_F16 && op->src[2]->type == GGML_TYPE_F16; + return true; } default: return false; From 3b084d5b3e75e414381f9986a9de095c45faa6f4 Mon Sep 17 00:00:00 2001 From: Bizhao Shi Date: Fri, 23 May 2025 17:41:21 +0800 Subject: [PATCH 14/16] cann: clean the whitespace --- ggml/src/ggml-cann/aclnn_ops.cpp | 47 ++++++++++++++++---------------- ggml/src/ggml-cann/aclnn_ops.h | 2 +- 2 files changed, 24 insertions(+), 25 deletions(-) diff --git a/ggml/src/ggml-cann/aclnn_ops.cpp b/ggml/src/ggml-cann/aclnn_ops.cpp index 247e04ef91be8..dd89920649256 100755 --- a/ggml/src/ggml-cann/aclnn_ops.cpp +++ b/ggml/src/ggml-cann/aclnn_ops.cpp @@ -2898,14 +2898,14 @@ void ggml_cann_flash_attn_ext(ggml_backend_cann_context& ctx, ggml_tensor* dst){ ggml_nelements(src0) * faElemSize); int64_t* src0_f16_ne = src0->ne; - size_t src0_f16_nb[GGML_MAX_DIMS]; + size_t src0_f16_nb[GGML_MAX_DIMS]; src0_f16_nb[0] = sizeof(uint16_t); for(int i = 1; i < GGML_MAX_DIMS; ++i){ src0_f16_nb[i] = src0_f16_nb[i - 1] * src0_f16_ne[i - 1]; } acl_src0_f16_tensor = ggml_cann_create_tensor( - src0_f16_buffer, faDataType, faElemSize, + src0_f16_buffer, faDataType, faElemSize, src0_f16_ne, src0_f16_nb, GGML_MAX_DIMS ); aclnn_cast(ctx, acl_src0_f32_tensor, acl_src0_f16_tensor, faDataType); @@ -2914,7 +2914,7 @@ void ggml_cann_flash_attn_ext(ggml_backend_cann_context& ctx, ggml_tensor* dst){ acl_src0_f16_tensor = ggml_cann_create_tensor(src0); } - // Step 2: create the acl tensors for src1 (Key), src2 (Value), + // Step 2: create the acl tensors for src1 (Key), src2 (Value), // and the direct output from FusedInferAttention acl_src1_f16_tensor = ggml_cann_create_tensor(src1); @@ -2932,24 +2932,23 @@ void ggml_cann_flash_attn_ext(ggml_backend_cann_context& ctx, ggml_tensor* dst){ } acl_dst_f16_tensor = ggml_cann_create_tensor( - out_f16_buffer, faDataType, faElemSize, + out_f16_buffer, faDataType, faElemSize, out_f16_ne, out_f16_nb, GGML_MAX_DIMS ); - // Step 3: create the PSEShift tensor if needed // this tensor is considered as mask (f16) in the llama.cpp - + aclTensor* bcast_pse_tensor = nullptr; int64_t bcast_pse_ne[GGML_MAX_DIMS]; size_t bcast_pse_nb[GGML_MAX_DIMS]; ggml_cann_pool_alloc bcast_pse_allocator(ctx.pool()); void* bcast_pse_buffer = nullptr; - + if(src3 != nullptr){ bcast_pse_buffer = bcast_pse_allocator.alloc( ggml_nelements(src3) * src0->ne[2] * sizeof(uint16_t)); - + if(src0->ne[1] > 1){ // Case 1: broadcast pse for prefill stage with multiple head aclTensor* acl_mask_f16_tensor = ggml_cann_create_tensor(src3); @@ -2964,7 +2963,7 @@ void ggml_cann_flash_attn_ext(ggml_backend_cann_context& ctx, ggml_tensor* dst){ } bcast_pse_tensor = ggml_cann_create_tensor( - bcast_pse_buffer, ACL_FLOAT16, sizeof(uint16_t), + bcast_pse_buffer, ACL_FLOAT16, sizeof(uint16_t), bcast_pse_ne, bcast_pse_nb, GGML_MAX_DIMS); int64_t repeats[] = {1, src0->ne[2], 1, 1}; @@ -2977,7 +2976,7 @@ void ggml_cann_flash_attn_ext(ggml_backend_cann_context& ctx, ggml_tensor* dst){ size_t* trunc_pse_nb = src3->nb; aclTensor* acl_mask_f16_trunc_tensor = ggml_cann_create_tensor( - src3->data, ACL_FLOAT16, sizeof(uint16_t), + src3->data, ACL_FLOAT16, sizeof(uint16_t), trunc_pse_ne, trunc_pse_nb, GGML_MAX_DIMS); bcast_pse_ne[0] = src3->ne[0]; @@ -2991,7 +2990,7 @@ void ggml_cann_flash_attn_ext(ggml_backend_cann_context& ctx, ggml_tensor* dst){ } bcast_pse_tensor = ggml_cann_create_tensor( - bcast_pse_buffer, ACL_FLOAT16, sizeof(uint16_t), + bcast_pse_buffer, ACL_FLOAT16, sizeof(uint16_t), bcast_pse_ne, bcast_pse_nb, GGML_MAX_DIMS); int64_t repeats[] = {1, src0->ne[2], 1, 1}; @@ -3007,8 +3006,8 @@ void ggml_cann_flash_attn_ext(ggml_backend_cann_context& ctx, ggml_tensor* dst){ const int64_t n_head = src0->ne[2]; const int n_heads_log2_floor = 1u << (uint32_t)floor(log2(n_head)); float m0 = powf(2.0f, -(maxBias) / n_heads_log2_floor); - float m1 = powf(2.0f, -(maxBias / 2.0f) / n_heads_log2_floor); - // init arange + float m1 = powf(2.0f, -(maxBias / 2.0f) / n_heads_log2_floor); + // init arange ggml_cann_pool_alloc arange_allocator(ctx.pool(), ne2_ne3 * faElemSize); void* tmp_arange_buffer = arange_allocator.get(); @@ -3076,11 +3075,11 @@ void ggml_cann_flash_attn_ext(ggml_backend_cann_context& ctx, ggml_tensor* dst){ int64_t tmp_mk_base_ne[] = {ne2_ne3}; size_t tmp_mk_base_nb[] = {faElemSize}; aclTensor* tmp_mk_base_tensor = ggml_cann_create_tensor( - tmp_mk_base_buffer, faDataType, faElemSize, + tmp_mk_base_buffer, faDataType, faElemSize, tmp_mk_base_ne, tmp_mk_base_nb, GGML_MAX_DIMS - 3, ACL_FORMAT_ND); aclTensor* tmp_arange_tensor = ggml_cann_create_tensor( - tmp_arange_buffer, faDataType, faElemSize, + tmp_arange_buffer, faDataType, faElemSize, tmp_mk_base_ne, tmp_mk_base_nb, GGML_MAX_DIMS - 3, ACL_FORMAT_ND); aclnn_pow_tensor_tensor(ctx, tmp_mk_base_tensor, tmp_arange_tensor); @@ -3095,12 +3094,12 @@ void ggml_cann_flash_attn_ext(ggml_backend_cann_context& ctx, ggml_tensor* dst){ aclTensor* tmp_mk_tensor = ggml_cann_create_tensor( tmp_mk_base_buffer, faDataType, faElemSize, tmp_mk_ne, tmp_mk_nb, GGML_MAX_DIMS, - ACL_FORMAT_ND); + ACL_FORMAT_ND); GGML_CANN_CALL_ACLNN_OP(ctx, InplaceMul, bcast_pse_tensor, tmp_mk_tensor); ggml_cann_release_resources(ctx, tmp_arange1_tensor, tmp_arange2_tensor, tmp_mk_base1_tensor, tmp_mk_base2_tensor, tmp_mk_base_tensor, - tmp_arange_tensor, tmp_mk_tensor); + tmp_arange_tensor, tmp_mk_tensor); } } @@ -3128,7 +3127,7 @@ void ggml_cann_flash_attn_ext(ggml_backend_cann_context& ctx, ggml_tensor* dst){ // Step 5: launch the FusedInferAttentionScoreV2 kernel. // Refer to https://gitee.com/ascend/cann-ops-adv/blob/master/docs/FusedInferAttentionScoreV2.md - + GGML_CANN_CALL_ACLNN_OP(ctx, FusedInferAttentionScoreV2, acl_q_tensor, acl_k_tensor_list, acl_v_tensor_list, // q, k, v bcast_pse_tensor, nullptr, // pse, mask @@ -3170,20 +3169,20 @@ void ggml_cann_flash_attn_ext(ggml_backend_cann_context& ctx, ggml_tensor* dst){ perm_out_f16_nb[i] = perm_out_f16_nb[i - 1] * perm_out_f16_ne[i - 1]; } aclTensor* acl_perm_out_f16_tensor = ggml_cann_create_tensor( - perm_out_f16_buffer, faDataType, faElemSize, + perm_out_f16_buffer, faDataType, faElemSize, perm_out_f16_ne, perm_out_f16_nb, GGML_MAX_DIMS); aclnn_permute(ctx, acl_dst_f16_tensor, acl_perm_out_f16_tensor, new_dim, GGML_MAX_DIMS); - aclnn_cast(ctx, + aclnn_cast(ctx, acl_perm_out_f16_tensor, acl_dst_tensor, ggml_cann_type_mapping(dst->type)); ggml_cann_release_resources(ctx, acl_perm_out_f16_tensor); }else{ // only need to permute aclnn_permute(ctx, acl_dst_f16_tensor, acl_dst_tensor, new_dim, GGML_MAX_DIMS); } - ggml_cann_release_resources(ctx, acl_src0_f16_tensor, - acl_src1_f16_tensor, - acl_src2_f16_tensor, - acl_dst_f16_tensor, + ggml_cann_release_resources(ctx, acl_src0_f16_tensor, + acl_src1_f16_tensor, + acl_src2_f16_tensor, + acl_dst_f16_tensor, acl_dst_tensor); if(src3 != nullptr){ ggml_cann_release_resources(ctx, bcast_pse_tensor); diff --git a/ggml/src/ggml-cann/aclnn_ops.h b/ggml/src/ggml-cann/aclnn_ops.h index 3873b8aa6cd72..80ce80baea02c 100755 --- a/ggml/src/ggml-cann/aclnn_ops.h +++ b/ggml/src/ggml-cann/aclnn_ops.h @@ -720,7 +720,7 @@ void ggml_cann_step(ggml_backend_cann_context& ctx, ggml_tensor* dst); * @details This function implements the memory-efficient Flash Attention algorithm * for computing scaled dot-product attention with hardware acceleration. * The result is stored in the destination tensor `dst`. - * + * * This operation is accelerated using the CANN backend to improve runtime performance. * * @param ctx The CANN context used for operations. From d23697b8a01cbb335d483d89e64dcada59081a82 Mon Sep 17 00:00:00 2001 From: Bizhao Shi Date: Fri, 23 May 2025 17:41:42 +0800 Subject: [PATCH 15/16] cann: clean the whitespace --- docs/backend/CANN.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/backend/CANN.md b/docs/backend/CANN.md index 52113aa885ea5..a5ba617ca7bab 100755 --- a/docs/backend/CANN.md +++ b/docs/backend/CANN.md @@ -282,8 +282,8 @@ Please add the **[CANN]** prefix/tag in issues/PRs titles to help the CANN-team ## Updates ### Basic Flash Attention Support -The basic FA kernel with aclnnops has been added in aclnn_ops.cpp. -Currently, the FA only supports the cases with FP16 KV tensors and NO logit softcap. +The basic FA kernel with aclnnops has been added in aclnn_ops.cpp. +Currently, the FA only supports the cases with FP16 KV tensors and NO logit softcap. Since the aclnn interface for flash attention cannot support the logit softcap, we will only update the quantized version in the future. Authors from Peking University: Bizhao Shi (bshi@pku.edu.cn), Yuxin Yang (yxyang@pku.edu.cn), Ruiyang Ma (ruiyang@stu.pku.edu.cn), and Guojie Luo (gluo@pku.edu.cn). From 8a7829b7b8f9101f459ba774798bd18d8898fe8e Mon Sep 17 00:00:00 2001 From: Bizhao Shi Date: Fri, 23 May 2025 18:41:31 +0800 Subject: [PATCH 16/16] cann: add a new endline --- ggml/src/ggml-cann/aclnn_ops.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ggml/src/ggml-cann/aclnn_ops.cpp b/ggml/src/ggml-cann/aclnn_ops.cpp index dd89920649256..437ece2d4a3cf 100755 --- a/ggml/src/ggml-cann/aclnn_ops.cpp +++ b/ggml/src/ggml-cann/aclnn_ops.cpp @@ -3190,4 +3190,4 @@ void ggml_cann_flash_attn_ext(ggml_backend_cann_context& ctx, ggml_tensor* dst){ }else{ GGML_ABORT("Function is not implemented."); } -} \ No newline at end of file +}