From 3e6d1e4e2d63dd74e8087b62e493ece4fdd3421a Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Tue, 8 Apr 2025 13:22:34 +0300 Subject: [PATCH 1/3] ggml : FA supports F32 V --- ggml/src/ggml-cpu/ops.cpp | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/ggml/src/ggml-cpu/ops.cpp b/ggml/src/ggml-cpu/ops.cpp index 7a8d5ac6fd9d0..33ba8806fb613 100644 --- a/ggml/src/ggml-cpu/ops.cpp +++ b/ggml/src/ggml-cpu/ops.cpp @@ -6721,8 +6721,8 @@ static void ggml_compute_forward_flash_attn_ext_f16( ggml_vec_dot_t const kq_vec_dot = ggml_get_type_traits_cpu(k->type)->vec_dot; ggml_to_float_t const v_to_float = ggml_get_type_traits(v->type)->to_float; - GGML_ASSERT(q_to_vec_dot && "fattn: unsupported K-type"); - GGML_ASSERT(v_to_float && "fattn: unsupported V-type"); + GGML_ASSERT( q_to_vec_dot && "fattn: unsupported K-type"); + GGML_ASSERT(v->type == GGML_TYPE_F32 || v_to_float && "fattn: unsupported V-type"); // loop over n_batch and n_head for (int ir = ir0; ir < ir1; ++ir) { @@ -6818,10 +6818,14 @@ static void ggml_compute_forward_flash_attn_ext_f16( vs = expf(s - M); } - v_to_float(v_data, V32, DV); - // V += v*expf(s - M) - ggml_vec_mad_f32(DV, VKQ32, V32, vs); + if (v_to_float) { + v_to_float(v_data, V32, DV); + ggml_vec_mad_f32(DV, VKQ32, V32, vs); + } else { + // V is F32 + ggml_vec_mad_f32(DV, VKQ32, (const float *) v_data, vs); + } } S = S*ms + vs; // scale and increment sum with partial sum From 7cb9ae059ae495ad82a839d2adce136de0a9a97f Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Tue, 8 Apr 2025 13:23:14 +0300 Subject: [PATCH 2/3] graph : cast KV to F16 when the KV cache is not used ggml-ci --- examples/server_embd.py | 2 +- src/llama-graph.cpp | 9 +++++++++ 2 files changed, 10 insertions(+), 1 deletion(-) diff --git a/examples/server_embd.py b/examples/server_embd.py index 0e34c6ceab9ca..f8b0ffecd8f47 100644 --- a/examples/server_embd.py +++ b/examples/server_embd.py @@ -15,7 +15,7 @@ async def main(): model_url = "http://127.0.0.1:6900" responses: list[requests.Response] = await asyncio.gather(*[requests_post_async( url= f"{model_url}/embedding", - json= {"content": str(0)*1024} + json= {"content": "a "*1022} ) for i in range(n)]) for response in responses: diff --git a/src/llama-graph.cpp b/src/llama-graph.cpp index c3469177e091c..cd955d63bc390 100644 --- a/src/llama-graph.cpp +++ b/src/llama-graph.cpp @@ -1215,6 +1215,15 @@ ggml_tensor * llm_graph_context::build_attn_mha( v = ggml_transpose(ctx0, v); } + // this can happen when KV cache is not used (e.g. an embedding model with non-causal attn) + if (k->type == GGML_TYPE_F32) { + k = ggml_cast(ctx0, k, GGML_TYPE_F16); + } + + if (v->type == GGML_TYPE_F32) { + v = ggml_cast(ctx0, v, GGML_TYPE_F16); + } + cur = ggml_flash_attn_ext(ctx0, q, k, v, kq_mask, kq_scale, hparams.f_max_alibi_bias, hparams.attn_soft_cap ? hparams.f_attn_logit_softcapping : 0.0f); From 997b1b42a455e4dd5e9d75a559a73a4796e0715c Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Tue, 8 Apr 2025 13:41:26 +0300 Subject: [PATCH 3/3] server : add test that exercises embeddings with FA enabled ggml-ci --- examples/server/tests/unit/test_embedding.py | 20 ++++++++++++++++++++ examples/server/tests/utils.py | 15 +++++++++++++++ ggml/src/ggml-cpu/ops.cpp | 4 ++-- ggml/src/ggml-metal/ggml-metal.m | 5 +++++ 4 files changed, 42 insertions(+), 2 deletions(-) diff --git a/examples/server/tests/unit/test_embedding.py b/examples/server/tests/unit/test_embedding.py index 8b0eb42b0926f..0feb452ccfcd4 100644 --- a/examples/server/tests/unit/test_embedding.py +++ b/examples/server/tests/unit/test_embedding.py @@ -49,6 +49,26 @@ def test_embedding_multiple(): assert len(d['embedding']) > 1 +def test_embedding_multiple_with_fa(): + server = ServerPreset.bert_bge_small_with_fa() + server.pooling = 'last' + server.start() + # one of these should trigger the FA branch (i.e. context size % 256 == 0) + res = server.make_request("POST", "/v1/embeddings", data={ + "input": [ + "a "*253, + "b "*254, + "c "*255, + "d "*256, + ], + }) + assert res.status_code == 200 + assert len(res.body['data']) == 4 + for d in res.body['data']: + assert 'embedding' in d + assert len(d['embedding']) > 1 + + @pytest.mark.parametrize( "input,is_multi_prompt", [ diff --git a/examples/server/tests/utils.py b/examples/server/tests/utils.py index 30aa8660950a1..4dc2062a8e5b9 100644 --- a/examples/server/tests/utils.py +++ b/examples/server/tests/utils.py @@ -323,6 +323,21 @@ def bert_bge_small() -> ServerProcess: server.server_embeddings = True return server + @staticmethod + def bert_bge_small_with_fa() -> ServerProcess: + server = ServerProcess() + server.model_hf_repo = "ggml-org/models" + server.model_hf_file = "bert-bge-small/ggml-model-f16.gguf" + server.model_alias = "bert-bge-small" + server.n_ctx = 1024 + server.n_batch = 300 + server.n_ubatch = 300 + server.n_slots = 2 + server.fa = True + server.seed = 42 + server.server_embeddings = True + return server + @staticmethod def tinyllama_infill() -> ServerProcess: server = ServerProcess() diff --git a/ggml/src/ggml-cpu/ops.cpp b/ggml/src/ggml-cpu/ops.cpp index 33ba8806fb613..f63656be54f5c 100644 --- a/ggml/src/ggml-cpu/ops.cpp +++ b/ggml/src/ggml-cpu/ops.cpp @@ -6721,8 +6721,8 @@ static void ggml_compute_forward_flash_attn_ext_f16( ggml_vec_dot_t const kq_vec_dot = ggml_get_type_traits_cpu(k->type)->vec_dot; ggml_to_float_t const v_to_float = ggml_get_type_traits(v->type)->to_float; - GGML_ASSERT( q_to_vec_dot && "fattn: unsupported K-type"); - GGML_ASSERT(v->type == GGML_TYPE_F32 || v_to_float && "fattn: unsupported V-type"); + GGML_ASSERT(( q_to_vec_dot) && "fattn: unsupported K-type"); + GGML_ASSERT((v->type == GGML_TYPE_F32 || v_to_float ) && "fattn: unsupported V-type"); // loop over n_batch and n_head for (int ir = ir0; ir < ir1; ++ir) { diff --git a/ggml/src/ggml-metal/ggml-metal.m b/ggml/src/ggml-metal/ggml-metal.m index 456e1fd994c40..f226826020a5a 100644 --- a/ggml/src/ggml-metal/ggml-metal.m +++ b/ggml/src/ggml-metal/ggml-metal.m @@ -1345,6 +1345,11 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex case GGML_OP_ARANGE: return true; case GGML_OP_FLASH_ATTN_EXT: + if (op->src[0]->ne[0] == 32) { + // head size == 32 (e.g. bert-bge-small) + // TODO: not sure if it is worth adding kernels for this size + return false; + } if (op->src[1]->type != op->src[2]->type) { return false; }