From b4c169fdd5eaf9faaacf3b61a781694ca9254541 Mon Sep 17 00:00:00 2001 From: juk Date: Wed, 2 Apr 2025 17:43:53 +0100 Subject: [PATCH 01/23] Initial commit with all but the MLA graph code done --- common/arg.cpp | 7 ++ common/common.cpp | 1 + common/common.h | 1 + convert_hf_to_gguf.py | 22 ++++++ examples/server/README.md | 1 + gguf-py/gguf/constants.py | 6 ++ gguf-py/gguf/tensor_mapping.py | 8 +++ include/llama.h | 1 + src/llama-arch.cpp | 21 ++---- src/llama-arch.h | 2 + src/llama-context.cpp | 8 +++ src/llama-cparams.h | 1 + src/llama-kv-cache.cpp | 23 +++++-- src/llama-model.cpp | 121 +++++++++++++++++++++++++++++++++ src/llama-model.h | 2 + 15 files changed, 201 insertions(+), 24 deletions(-) diff --git a/common/arg.cpp b/common/arg.cpp index fa22e86cd14e6..238db672dd1ec 100644 --- a/common/arg.cpp +++ b/common/arg.cpp @@ -1346,6 +1346,13 @@ common_params_context common_params_parser_init(common_params & params, llama_ex params.flash_attn = true; } ).set_env("LLAMA_ARG_FLASH_ATTN")); + add_opt(common_arg( + {"-mla", "--mla-attn"}, + string_format("enable Multi-head Latent Attention (default: %s)", params.mla_attn ? "enabled" : "disabled"), + [](common_params & params) { + params.mla_attn = true; + } + ).set_env("LLAMA_ARG_MLA_ATTN")); add_opt(common_arg( {"-p", "--prompt"}, "PROMPT", "prompt to start generation with; for system message, use -sys", diff --git a/common/common.cpp b/common/common.cpp index d4882c5123cce..d3ab321487637 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -1098,6 +1098,7 @@ struct llama_context_params common_context_params_to_llama(const common_params & cparams.cb_eval_user_data = params.cb_eval_user_data; cparams.offload_kqv = !params.no_kv_offload; cparams.flash_attn = params.flash_attn; + cparams.mla_attn = params.mla_attn; cparams.no_perf = params.no_perf; if (params.reranking) { diff --git a/common/common.h b/common/common.h index 725b5123d24f9..cd38a646dcfad 100644 --- a/common/common.h +++ b/common/common.h @@ -319,6 +319,7 @@ struct common_params { bool simple_io = false; // improves compatibility with subprocesses and limited consoles bool cont_batching = true; // insert new sequences for decoding on-the-fly bool flash_attn = false; // flash attention + bool mla_attn = false; // MLA attention for deepseek2 bool no_perf = false; // disable performance metrics bool ctx_shift = true; // context shift on inifinite text generation diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index cfe94deaf76ef..7bb95669d48dc 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -330,6 +330,7 @@ def prepare_tensors(self): gguf.MODEL_TENSOR.TIME_MIX_LERP_FUSED, gguf.MODEL_TENSOR.POSNET_NORM1, gguf.MODEL_TENSOR.POSNET_NORM2, + gguf.MODEL_TENSOR.ATTN_K_B, ) ) or not new_name.endswith(".weight") @@ -4414,6 +4415,27 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter return [] return [(self.map_tensor_name(name), data_torch)] + if name.endswith("kv_b_proj.weight"): + name_kb = name.replace("kv_b_proj", "k_b_proj") + name_vb = name.replace("kv_b_proj", "v_b_proj") + + n_head_kv = self.hparams["num_key_value_heads"] + v_head_dim = self.hparams["v_head_dim"] + qk_nope_head_dim = self.hparams["qk_nope_head_dim"] + + assert data_torch.shape[0] == n_head_kv * (v_head_dim + qk_nope_head_dim) + + kv_b = data_torch.view(n_head_kv, v_head_dim + qk_nope_head_dim, data_torch.shape[-1]) + k_b, v_b = torch.split(kv_b, [qk_nope_head_dim, v_head_dim], dim=1) + k_b = k_b.transpose(1, 2) + k_b = k_b.reshape(n_head_kv * data_torch.shape[-1], qk_nope_head_dim) + v_b = v_b.reshape(n_head_kv * v_head_dim, data_torch.shape[-1]) + + return [ + (self.map_tensor_name(name), data_torch), + (self.map_tensor_name(name_kb), k_b), + (self.map_tensor_name(name_vb), v_b) + ] def prepare_tensors(self): super().prepare_tensors() diff --git a/examples/server/README.md b/examples/server/README.md index a2a0903261e31..043c725d8d548 100644 --- a/examples/server/README.md +++ b/examples/server/README.md @@ -46,6 +46,7 @@ The project is under active development, and we are [looking for feedback and co | `-ub, --ubatch-size N` | physical maximum batch size (default: 512)
(env: LLAMA_ARG_UBATCH) | | `--keep N` | number of tokens to keep from the initial prompt (default: 0, -1 = all) | | `-fa, --flash-attn` | enable Flash Attention (default: disabled)
(env: LLAMA_ARG_FLASH_ATTN) | +| `-mla, --mla-attn` | enable Multi-head Latent Attention (default: disabled)
(env: LLAMA_ARG_MLA_ATTN) | | `--no-perf` | disable internal libllama performance timings (default: false)
(env: LLAMA_ARG_NO_PERF) | | `-e, --escape` | process escapes sequences (\n, \r, \t, \', \", \\) (default: true) | | `--no-escape` | do not process escape sequences | diff --git a/gguf-py/gguf/constants.py b/gguf-py/gguf/constants.py index 3a52cfd1e39ac..c5a96e4fc432e 100644 --- a/gguf-py/gguf/constants.py +++ b/gguf-py/gguf/constants.py @@ -377,6 +377,8 @@ class MODEL_TENSOR(IntEnum): ATTN_Q_B = auto() ATTN_KV_A_MQA = auto() ATTN_KV_B = auto() + ATTN_K_B = auto() + ATTN_V_B = auto() ATTN_Q_A_NORM = auto() ATTN_KV_A_NORM = auto() FFN_SUB_NORM = auto() @@ -581,6 +583,8 @@ class MODEL_TENSOR(IntEnum): MODEL_TENSOR.ATTN_Q_B: "blk.{bid}.attn_q_b", MODEL_TENSOR.ATTN_KV_A_MQA: "blk.{bid}.attn_kv_a_mqa", MODEL_TENSOR.ATTN_KV_B: "blk.{bid}.attn_kv_b", + MODEL_TENSOR.ATTN_K_B: "blk.{bid}.attn_k_b", + MODEL_TENSOR.ATTN_V_B: "blk.{bid}.attn_v_b", MODEL_TENSOR.ATTN_Q_A_NORM: "blk.{bid}.attn_q_a_norm", MODEL_TENSOR.ATTN_KV_A_NORM: "blk.{bid}.attn_kv_a_norm", MODEL_TENSOR.ATTN_SUB_NORM: "blk.{bid}.attn_sub_norm", @@ -1451,6 +1455,8 @@ class MODEL_TENSOR(IntEnum): MODEL_TENSOR.ATTN_Q_B, MODEL_TENSOR.ATTN_KV_A_MQA, MODEL_TENSOR.ATTN_KV_B, + MODEL_TENSOR.ATTN_K_B, + MODEL_TENSOR.ATTN_V_B, MODEL_TENSOR.ATTN_Q_A_NORM, MODEL_TENSOR.ATTN_KV_A_NORM, MODEL_TENSOR.ATTN_OUT, diff --git a/gguf-py/gguf/tensor_mapping.py b/gguf-py/gguf/tensor_mapping.py index 50bef12e3dbe7..00733e59fe6d5 100644 --- a/gguf-py/gguf/tensor_mapping.py +++ b/gguf-py/gguf/tensor_mapping.py @@ -656,6 +656,14 @@ class TensorNameMap: "model.layers.{bid}.self_attn.kv_b_proj", # deepseek2 ), + MODEL_TENSOR.ATTN_K_B: ( + "model.layers.{bid}.self_attn.k_b_proj", # deepseek2 (MLA specific) + ), + + MODEL_TENSOR.ATTN_V_B: ( + "model.layers.{bid}.self_attn.v_b_proj", # deepseek2 (MLA specific) + ), + MODEL_TENSOR.ATTN_Q_A_NORM: ( "model.layers.{bid}.self_attn.q_a_layernorm", # deepseek2 ), diff --git a/include/llama.h b/include/llama.h index fca2b034ba270..2627efb0bd96a 100644 --- a/include/llama.h +++ b/include/llama.h @@ -355,6 +355,7 @@ extern "C" { bool embeddings; // if true, extract embeddings (together with logits) bool offload_kqv; // whether to offload the KQV ops (including the KV cache) to GPU bool flash_attn; // whether to use flash attention [EXPERIMENTAL] + bool mla_attn; // MLA attention for deepseek2 bool no_perf; // whether to measure performance timings // Abort callback diff --git a/src/llama-arch.cpp b/src/llama-arch.cpp index 047782e7d0fc8..4bf8ebe85016b 100644 --- a/src/llama-arch.cpp +++ b/src/llama-arch.cpp @@ -1030,6 +1030,8 @@ static const std::map> LLM_TENSOR_N { LLM_TENSOR_ATTN_Q_B, "blk.%d.attn_q_b" }, { LLM_TENSOR_ATTN_KV_A_MQA, "blk.%d.attn_kv_a_mqa" }, { LLM_TENSOR_ATTN_KV_B, "blk.%d.attn_kv_b" }, + { LLM_TENSOR_ATTN_K_B, "blk.%d.attn_k_b" }, + { LLM_TENSOR_ATTN_V_B, "blk.%d.attn_v_b" }, { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" }, @@ -1471,23 +1473,8 @@ static const std::map LLM_TENSOR_INFOS = { {LLM_TENSOR_ATTN_Q_B, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, {LLM_TENSOR_ATTN_KV_A_MQA, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, {LLM_TENSOR_ATTN_KV_B, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, - {LLM_TENSOR_DEC_ATTN_Q, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, - {LLM_TENSOR_DEC_ATTN_K, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, - {LLM_TENSOR_ATTN_Q, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, - {LLM_TENSOR_ATTN_K, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, - {LLM_TENSOR_ATTN_V, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, - {LLM_TENSOR_ATTN_QKV, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, - {LLM_TENSOR_ATTN_OUT, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, - {LLM_TENSOR_FFN_GATE, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, - {LLM_TENSOR_FFN_DOWN, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, - {LLM_TENSOR_FFN_UP, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, - {LLM_TENSOR_FFN_DOWN_SHEXP, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, - {LLM_TENSOR_FFN_GATE_SHEXP, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, - {LLM_TENSOR_FFN_UP_SHEXP, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, - {LLM_TENSOR_ATTN_Q_A, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, - {LLM_TENSOR_ATTN_Q_B, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, - {LLM_TENSOR_ATTN_KV_A_MQA, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, - {LLM_TENSOR_ATTN_KV_B, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_ATTN_K_B, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_ATTN_V_B, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, {LLM_TENSOR_DEC_ATTN_Q, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, {LLM_TENSOR_DEC_ATTN_K, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, {LLM_TENSOR_DEC_ATTN_V, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, diff --git a/src/llama-arch.h b/src/llama-arch.h index 297cfa4dae571..9ed22f33eafb6 100644 --- a/src/llama-arch.h +++ b/src/llama-arch.h @@ -299,6 +299,8 @@ enum llm_tensor { LLM_TENSOR_ATTN_Q_B, LLM_TENSOR_ATTN_KV_A_MQA, LLM_TENSOR_ATTN_KV_B, + LLM_TENSOR_ATTN_K_B, + LLM_TENSOR_ATTN_V_B, LLM_TENSOR_ATTN_Q_A_NORM, LLM_TENSOR_ATTN_KV_A_NORM, LLM_TENSOR_ATTN_SUB_NORM, diff --git a/src/llama-context.cpp b/src/llama-context.cpp index 3927079432d94..a957aaff3d479 100644 --- a/src/llama-context.cpp +++ b/src/llama-context.cpp @@ -37,6 +37,7 @@ llama_context::llama_context( cparams.embeddings = params.embeddings; cparams.offload_kqv = params.offload_kqv; cparams.flash_attn = params.flash_attn; + cparams.mla_attn = params.mla_attn; cparams.no_perf = params.no_perf; cparams.pooling_type = params.pooling_type; cparams.warmup = false; @@ -104,6 +105,7 @@ llama_context::llama_context( LLAMA_LOG_INFO("%s: n_ubatch = %u\n", __func__, cparams.n_ubatch); LLAMA_LOG_INFO("%s: causal_attn = %d\n", __func__, cparams.causal_attn); LLAMA_LOG_INFO("%s: flash_attn = %d\n", __func__, cparams.flash_attn); + LLAMA_LOG_INFO("%s: mla_attn = %d\n", __func__, cparams.mla_attn); LLAMA_LOG_INFO("%s: freq_base = %.1f\n", __func__, cparams.rope_freq_base); LLAMA_LOG_INFO("%s: freq_scale = %g\n", __func__, cparams.rope_freq_scale); @@ -2243,6 +2245,7 @@ llama_context_params llama_context_default_params() { /*.embeddings =*/ false, /*.offload_kqv =*/ true, /*.flash_attn =*/ false, + /*.mla_attn =*/ false, /*.no_perf =*/ true, /*.abort_callback =*/ nullptr, /*.abort_callback_data =*/ nullptr, @@ -2274,6 +2277,11 @@ llama_context * llama_init_from_model( params.flash_attn = false; } + if (params.mla_attn && model->arch != LLM_ARCH_DEEPSEEK2) { + LLAMA_LOG_WARN("%s: mla_attn is only compatible with Deepseek2 - forcing off\n", __func__); + params.mla_attn = false; + } + if (ggml_is_quantized(params.type_v) && !params.flash_attn) { LLAMA_LOG_ERROR("%s: V cache quantization requires flash_attn\n", __func__); return nullptr; diff --git a/src/llama-cparams.h b/src/llama-cparams.h index 30e550f023a9e..f2309bc8eef67 100644 --- a/src/llama-cparams.h +++ b/src/llama-cparams.h @@ -28,6 +28,7 @@ struct llama_cparams { bool causal_attn; bool offload_kqv; bool flash_attn; + bool mla_attn; bool no_perf; bool warmup; diff --git a/src/llama-kv-cache.cpp b/src/llama-kv-cache.cpp index 7ba546c10ff74..d0faaa3064c7c 100644 --- a/src/llama-kv-cache.cpp +++ b/src/llama-kv-cache.cpp @@ -27,7 +27,7 @@ bool llama_kv_cache_unified::init( recurrent = llama_model_is_recurrent(&model); v_trans = !recurrent && !cparams.flash_attn; - can_shift = !recurrent && model.arch != LLM_ARCH_DEEPSEEK2; // not supported due to MLA + can_shift = !recurrent && model.arch != LLM_ARCH_DEEPSEEK2; // TODO: support DEEPSEEK2 context shifting LLAMA_LOG_INFO("%s: kv_size = %d, offload = %d, type_k = '%s', type_v = '%s', n_layer = %d, can_shift = %d\n", __func__, kv_size, offload, ggml_type_name(type_k), ggml_type_name(type_v), n_layer, can_shift); @@ -71,8 +71,17 @@ bool llama_kv_cache_unified::init( v_l.reserve(n_layer); for (int i = 0; i < n_layer; i++) { - const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(i) + hparams.n_embd_k_s(); - const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(i) + hparams.n_embd_v_s(); + int64_t n_embd_k; + int64_t n_embd_v; + + // note: deepseek with MLA option converts into MQA (ie: GQA with 1 group) + if (cparams.mla_attn) { + n_embd_k = hparams.n_lora_kv + hparams.n_rot; + n_embd_v = hparams.n_lora_kv; + } else { + n_embd_k = hparams.n_embd_k_gqa(i) + hparams.n_embd_k_s(); + n_embd_v = hparams.n_embd_v_gqa(i) + hparams.n_embd_v_s(); + } const char * dev_name = "CPU"; @@ -86,8 +95,8 @@ bool llama_kv_cache_unified::init( buft = ggml_backend_cpu_buffer_type(); } - LLAMA_LOG_DEBUG("%s: layer %3d: n_embd_k_gqa = %d, n_embd_v_gqa = %d, dev = %s\n", __func__, - i, n_embd_k_gqa, n_embd_v_gqa, dev_name); + LLAMA_LOG_DEBUG("%s: layer %3d: n_embd_k = %d, n_embd_v = %d, dev = %s\n", __func__, + i, n_embd_k, n_embd_v, dev_name); ggml_context * ctx = ctx_for_buft(buft); if (!ctx) { @@ -95,8 +104,8 @@ bool llama_kv_cache_unified::init( return false; } - ggml_tensor * k = ggml_new_tensor_1d(ctx, type_k, n_embd_k_gqa*kv_size); - ggml_tensor * v = ggml_new_tensor_1d(ctx, type_v, n_embd_v_gqa*kv_size); + ggml_tensor * k = ggml_new_tensor_1d(ctx, type_k, n_embd_k*kv_size); + ggml_tensor * v = ggml_new_tensor_1d(ctx, type_v, n_embd_v*kv_size); ggml_format_name(k, "cache_k_l%d", i); ggml_format_name(v, "cache_v_l%d", i); k_l.push_back(k); diff --git a/src/llama-model.cpp b/src/llama-model.cpp index ca6e3ab2caeb1..41744e3fda965 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -3072,6 +3072,8 @@ bool llama_model::load_tensors(llama_model_loader & ml) { layer.wkv_a_mqa = create_tensor(tn(LLM_TENSOR_ATTN_KV_A_MQA, "weight", i), {n_embd, kv_lora_rank + (n_embd_head_qk_rope)}, 0); layer.wkv_b = create_tensor(tn(LLM_TENSOR_ATTN_KV_B, "weight", i), {kv_lora_rank, n_head * (n_embd_head_qk_nope + n_embd_head_v)}, 0); + layer.wk_b = create_tensor(tn(LLM_TENSOR_ATTN_K_B, "weight", i), {n_embd_head_qk_nope, n_head * kv_lora_rank}, 0); + layer.wv_b = create_tensor(tn(LLM_TENSOR_ATTN_V_B, "weight", i), {kv_lora_rank, n_head * n_embd_head_v}, 0); layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), { n_head * ( n_embd_head_v), n_embd}, 0); layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); @@ -9533,6 +9535,124 @@ struct llm_build_deepseek2 : public llm_graph_context { LLM_NORM_RMS, il); cb(cur, "attn_norm", il); + // self_attention + { + ggml_tensor * q = NULL; + if (!is_lite) { + q = ggml_mul_mat(ctx0, model.layers[il].wq_a, cur); + cb(q, "q", il); + + q = build_norm(q, + model.layers[il].attn_q_a_norm, nullptr, + LLM_NORM_RMS, il); + cb(q, "q", il); + + q = ggml_mul_mat(ctx0, model.layers[il].wq_b, q); + cb(q, "q", il); + } else { + q = ggml_mul_mat(ctx0, model.layers[il].wq, cur); + cb(q, "q", il); + } + + // split into {n_head * n_embd_head_qk_nope, n_tokens} + ggml_tensor * q_nope = ggml_view_3d(ctx0, q, n_embd_head_qk_nope, n_head, n_tokens, + ggml_row_size(q->type, hparams.n_embd_head_k), + ggml_row_size(q->type, hparams.n_embd_head_k * n_head), + 0); + cb(q_nope, "q_nope", il); + + // and {n_head * n_embd_head_qk_rope, n_tokens} + ggml_tensor * q_pe = ggml_view_3d(ctx0, q, n_embd_head_qk_rope, n_head, n_tokens, + ggml_row_size(q->type, hparams.n_embd_head_k), + ggml_row_size(q->type, hparams.n_embd_head_k * n_head), + ggml_row_size(q->type, n_embd_head_qk_nope)); + cb(q_pe, "q_pe", il); + + ggml_tensor * kv_pe_compresseed = ggml_mul_mat(ctx0, model.layers[il].wkv_a_mqa, cur); + cb(kv_pe_compresseed, "kv_pe_compresseed", il); + + // split into {kv_lora_rank, n_tokens} + ggml_tensor * kv_compressed = ggml_view_2d(ctx0, kv_pe_compresseed, kv_lora_rank, n_tokens, + kv_pe_compresseed->nb[1], + 0); + cb(kv_compressed, "kv_compressed", il); + + // and {n_embd_head_qk_rope, n_tokens} + ggml_tensor * k_pe = ggml_view_3d(ctx0, kv_pe_compresseed, n_embd_head_qk_rope, 1, n_tokens, + kv_pe_compresseed->nb[1], + kv_pe_compresseed->nb[1], + ggml_row_size(kv_pe_compresseed->type, kv_lora_rank)); + cb(k_pe, "k_pe", il); + + // TODO: the CUDA backend used to not support non-cont. RoPE, investigate removing this + q_pe = ggml_cont(ctx0, q_pe); + q_pe = ggml_rope_ext(ctx0, q_pe, inp_pos, nullptr, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor_scaled, beta_fast, beta_slow + ); + cb(q_pe, "q_pe", il); + + // TODO: the CUDA backend used to not support non-cont. RoPE, investigate removing this + k_pe = ggml_cont(ctx0, k_pe); + k_pe = ggml_rope_ext(ctx0, k_pe, inp_pos, nullptr, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor_scaled, beta_fast, beta_slow + ); + cb(k_pe, "k_pe", il); + + // TODO: the CUDA backend used to not support non-cont. (RMS) norm, investigate removing ggml_cont + kv_compressed = ggml_cont(ctx0, kv_compressed); + kv_compressed = build_norm(kv_compressed, + model.layers[il].attn_kv_a_norm, nullptr, + LLM_NORM_RMS, il); + cb(kv_compressed, "kv_compressed", il); + + // note: deepseek with MLA option converts into MQA (ie: GQA with 1 group) + if (cparams.mla_attn) { + // TODO: later + } + + // note: deepseek without MLA option converts into MHA + } else { + ggml_tensor * kv = ggml_mul_mat(ctx0, model.layers[il].wkv_b, kv_compressed); + cb(kv, "kv", il); + + // split into {n_head * n_embd_head_qk_nope, n_tokens} + ggml_tensor * k_nope = ggml_view_3d(ctx0, kv, n_embd_head_qk_nope, n_head, n_tokens, + ggml_row_size(kv->type, n_embd_head_qk_nope + hparams.n_embd_head_v), + ggml_row_size(kv->type, n_head * (n_embd_head_qk_nope + hparams.n_embd_head_v)), + 0); + cb(k_nope, "k_nope", il); + + // and {n_head * n_embd_head_v, n_tokens} + ggml_tensor * v_states = ggml_view_3d(ctx0, kv, hparams.n_embd_head_v, n_head, n_tokens, + ggml_row_size(kv->type, (n_embd_head_qk_nope + hparams.n_embd_head_v)), + ggml_row_size(kv->type, (n_embd_head_qk_nope + hparams.n_embd_head_v)*n_head), + ggml_row_size(kv->type, (n_embd_head_qk_nope))); + cb(v_states, "v_states", il); + + v_states = ggml_cont(ctx0, v_states); + cb(v_states, "v_states", il); + + v_states = ggml_view_2d(ctx0, v_states, hparams.n_embd_head_v * n_head, n_tokens, + ggml_row_size(kv->type, hparams.n_embd_head_v * n_head), + 0); + cb(v_states, "v_states", il); + + ggml_tensor * q_states = ggml_concat(ctx0, q_nope, q_pe, 0); + cb(q_states, "q_states", il); + + ggml_tensor * k_states = ggml_concat(ctx0, k_nope, ggml_repeat(ctx0, k_pe, q_pe), 0); + cb(k_states, "k_states", il); + + cur = build_attn(inp_attn, gf, + model.layers[il].wo, NULL, + q_states, k_states, v_states, nullptr, kq_scale, il); + } + + } + +#if 0 // self_attention { ggml_tensor * q = NULL; @@ -9645,6 +9765,7 @@ struct llm_build_deepseek2 : public llm_graph_context { model.layers[il].wo, NULL, q_states, k_states, v_states, nullptr, kq_scale, il); } +#endif if (il == n_layer - 1) { // skip computing output for unused tokens diff --git a/src/llama-model.h b/src/llama-model.h index 91e6e8725acd2..77b4b0e1bc24e 100644 --- a/src/llama-model.h +++ b/src/llama-model.h @@ -169,6 +169,8 @@ struct llama_layer { struct ggml_tensor * wq_b = nullptr; struct ggml_tensor * wkv_a_mqa = nullptr; struct ggml_tensor * wkv_b = nullptr; + struct ggml_tensor * wk_b = nullptr; + struct ggml_tensor * wv_b = nullptr; struct ggml_tensor * wq_cross = nullptr; struct ggml_tensor * wk_cross = nullptr; struct ggml_tensor * wv_cross = nullptr; From 10207b4a6412f0edf7f7e8cfeb0529ab166a5677 Mon Sep 17 00:00:00 2001 From: juk Date: Wed, 2 Apr 2025 17:59:12 +0100 Subject: [PATCH 02/23] Fixes --- src/llama-kv-cache.cpp | 2 +- src/llama-model.cpp | 121 +---------------------------------------- 2 files changed, 4 insertions(+), 119 deletions(-) diff --git a/src/llama-kv-cache.cpp b/src/llama-kv-cache.cpp index d0faaa3064c7c..a736cb710f9a2 100644 --- a/src/llama-kv-cache.cpp +++ b/src/llama-kv-cache.cpp @@ -95,7 +95,7 @@ bool llama_kv_cache_unified::init( buft = ggml_backend_cpu_buffer_type(); } - LLAMA_LOG_DEBUG("%s: layer %3d: n_embd_k = %d, n_embd_v = %d, dev = %s\n", __func__, + LLAMA_LOG_DEBUG("%s: layer %3d: n_embd_k = %lld, n_embd_v = %lld, dev = %s\n", __func__, i, n_embd_k, n_embd_v, dev_name); ggml_context * ctx = ctx_for_buft(buft); diff --git a/src/llama-model.cpp b/src/llama-model.cpp index 41744e3fda965..ab738395b0c40 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -9607,13 +9607,13 @@ struct llm_build_deepseek2 : public llm_graph_context { LLM_NORM_RMS, il); cb(kv_compressed, "kv_compressed", il); - // note: deepseek with MLA option converts into MQA (ie: GQA with 1 group) + if (cparams.mla_attn) { + // note: deepseek with MLA option converts into MQA (ie: GQA with 1 group) // TODO: later - } - // note: deepseek without MLA option converts into MHA } else { + // note: deepseek without MLA option converts into MHA ggml_tensor * kv = ggml_mul_mat(ctx0, model.layers[il].wkv_b, kv_compressed); cb(kv, "kv", il); @@ -9652,121 +9652,6 @@ struct llm_build_deepseek2 : public llm_graph_context { } -#if 0 - // self_attention - { - ggml_tensor * q = NULL; - if (!is_lite) { - // {n_embd, q_lora_rank} * {n_embd, n_tokens} -> {q_lora_rank, n_tokens} - q = ggml_mul_mat(ctx0, model.layers[il].wq_a, cur); - cb(q, "q", il); - - q = build_norm(q, - model.layers[il].attn_q_a_norm, NULL, - LLM_NORM_RMS, il); - cb(q, "q", il); - - // {q_lora_rank, n_head * hparams.n_embd_head_k} * {q_lora_rank, n_tokens} -> {n_head * hparams.n_embd_head_k, n_tokens} - q = ggml_mul_mat(ctx0, model.layers[il].wq_b, q); - cb(q, "q", il); - } else { - q = ggml_mul_mat(ctx0, model.layers[il].wq, cur); - cb(q, "q", il); - } - - // split into {n_head * n_embd_head_qk_nope, n_tokens} - ggml_tensor * q_nope = ggml_view_3d(ctx0, q, n_embd_head_qk_nope, n_head, n_tokens, - ggml_row_size(q->type, hparams.n_embd_head_k), - ggml_row_size(q->type, hparams.n_embd_head_k * n_head), - 0); - cb(q_nope, "q_nope", il); - - // and {n_head * n_embd_head_qk_rope, n_tokens} - ggml_tensor * q_pe = ggml_view_3d(ctx0, q, n_embd_head_qk_rope, n_head, n_tokens, - ggml_row_size(q->type, hparams.n_embd_head_k), - ggml_row_size(q->type, hparams.n_embd_head_k * n_head), - ggml_row_size(q->type, n_embd_head_qk_nope)); - cb(q_pe, "q_pe", il); - - // {n_embd, kv_lora_rank + n_embd_head_qk_rope} * {n_embd, n_tokens} -> {kv_lora_rank + n_embd_head_qk_rope, n_tokens} - ggml_tensor * kv_pe_compresseed = ggml_mul_mat(ctx0, model.layers[il].wkv_a_mqa, cur); - cb(kv_pe_compresseed, "kv_pe_compresseed", il); - - // split into {kv_lora_rank, n_tokens} - ggml_tensor * kv_compressed = ggml_view_2d(ctx0, kv_pe_compresseed, kv_lora_rank, n_tokens, - kv_pe_compresseed->nb[1], - 0); - cb(kv_compressed, "kv_compressed", il); - - // and {n_embd_head_qk_rope, n_tokens} - ggml_tensor * k_pe = ggml_view_3d(ctx0, kv_pe_compresseed, n_embd_head_qk_rope, 1, n_tokens, - kv_pe_compresseed->nb[1], - kv_pe_compresseed->nb[1], - ggml_row_size(kv_pe_compresseed->type, kv_lora_rank)); - cb(k_pe, "k_pe", il); - - // TODO: the CUDA backend used to not support non-cont. (RMS) norm, investigate removing ggml_cont - kv_compressed = ggml_cont(ctx0, kv_compressed); - kv_compressed = build_norm(kv_compressed, - model.layers[il].attn_kv_a_norm, NULL, - LLM_NORM_RMS, il); - cb(kv_compressed, "kv_compressed", il); - - // {kv_lora_rank, n_head * (n_embd_head_qk_nope + n_embd_head_v)} * {kv_lora_rank, n_tokens} -> {n_head * (n_embd_head_qk_nope + n_embd_head_v), n_tokens} - ggml_tensor * kv = ggml_mul_mat(ctx0, model.layers[il].wkv_b, kv_compressed); - cb(kv, "kv", il); - - // split into {n_head * n_embd_head_qk_nope, n_tokens} - ggml_tensor * k_nope = ggml_view_3d(ctx0, kv, n_embd_head_qk_nope, n_head, n_tokens, - ggml_row_size(kv->type, n_embd_head_qk_nope + hparams.n_embd_head_v), - ggml_row_size(kv->type, n_head * (n_embd_head_qk_nope + hparams.n_embd_head_v)), - 0); - cb(k_nope, "k_nope", il); - - // and {n_head * n_embd_head_v, n_tokens} - ggml_tensor * v_states = ggml_view_3d(ctx0, kv, hparams.n_embd_head_v, n_head, n_tokens, - ggml_row_size(kv->type, (n_embd_head_qk_nope + hparams.n_embd_head_v)), - ggml_row_size(kv->type, (n_embd_head_qk_nope + hparams.n_embd_head_v)*n_head), - ggml_row_size(kv->type, (n_embd_head_qk_nope))); - cb(v_states, "v_states", il); - - v_states = ggml_cont(ctx0, v_states); - cb(v_states, "v_states", il); - - v_states = ggml_view_2d(ctx0, v_states, hparams.n_embd_head_v * n_head, n_tokens, - ggml_row_size(kv->type, hparams.n_embd_head_v * n_head), - 0); - cb(v_states, "v_states", il); - - q_pe = ggml_cont(ctx0, q_pe); // TODO: the CUDA backend used to not support non-cont. RoPE, investigate removing this - q_pe = ggml_rope_ext( - ctx0, q_pe, inp_pos, nullptr, - n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, - ext_factor, attn_factor_scaled, beta_fast, beta_slow - ); - cb(q_pe, "q_pe", il); - - // shared RoPE key - k_pe = ggml_cont(ctx0, k_pe); // TODO: the CUDA backend used to not support non-cont. RoPE, investigate removing this - k_pe = ggml_rope_ext( - ctx0, k_pe, inp_pos, nullptr, - n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, - ext_factor, attn_factor_scaled, beta_fast, beta_slow - ); - cb(k_pe, "k_pe", il); - - ggml_tensor * q_states = ggml_concat(ctx0, q_nope, q_pe, 0); - cb(q_states, "q_states", il); - - ggml_tensor * k_states = ggml_concat(ctx0, k_nope, ggml_repeat(ctx0, k_pe, q_pe), 0); - cb(k_states, "k_states", il); - - cur = build_attn(inp_attn, gf, - model.layers[il].wo, NULL, - q_states, k_states, v_states, nullptr, kq_scale, il); - } -#endif - if (il == n_layer - 1) { // skip computing output for unused tokens ggml_tensor * inp_out_ids = build_inp_out_ids(); From ea3c05bb3ddad5fda0a99540b94c612809324f57 Mon Sep 17 00:00:00 2001 From: juk Date: Wed, 2 Apr 2025 18:04:44 +0100 Subject: [PATCH 03/23] Just make `uint32_t n_embd_k` and `uint32_t n_embd_v` --- src/llama-kv-cache.cpp | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/llama-kv-cache.cpp b/src/llama-kv-cache.cpp index a736cb710f9a2..50651758c9c75 100644 --- a/src/llama-kv-cache.cpp +++ b/src/llama-kv-cache.cpp @@ -71,8 +71,8 @@ bool llama_kv_cache_unified::init( v_l.reserve(n_layer); for (int i = 0; i < n_layer; i++) { - int64_t n_embd_k; - int64_t n_embd_v; + uint32_t n_embd_k; + uint32_t n_embd_v; // note: deepseek with MLA option converts into MQA (ie: GQA with 1 group) if (cparams.mla_attn) { @@ -95,7 +95,7 @@ bool llama_kv_cache_unified::init( buft = ggml_backend_cpu_buffer_type(); } - LLAMA_LOG_DEBUG("%s: layer %3d: n_embd_k = %lld, n_embd_v = %lld, dev = %s\n", __func__, + LLAMA_LOG_DEBUG("%s: layer %3d: n_embd_k = %d, n_embd_v = %d, dev = %s\n", __func__, i, n_embd_k, n_embd_v, dev_name); ggml_context * ctx = ctx_for_buft(buft); From 1f604a7ab3fde6d428fc14bbf9d0ecad894a580f Mon Sep 17 00:00:00 2001 From: juk Date: Wed, 2 Apr 2025 19:56:08 +0100 Subject: [PATCH 04/23] First working version --- src/llama-graph.cpp | 138 +++++++++++++++++++++++++++++++++++++++++ src/llama-kv-cache.cpp | 6 +- src/llama-model.cpp | 37 ++++++++++- 3 files changed, 175 insertions(+), 6 deletions(-) diff --git a/src/llama-graph.cpp b/src/llama-graph.cpp index cec203df49268..6b4c933f82531 100644 --- a/src/llama-graph.cpp +++ b/src/llama-graph.cpp @@ -1415,6 +1415,144 @@ ggml_tensor * llm_graph_context::build_attn( return cur; } +ggml_tensor * llm_graph_context::build_attn_mla( + llm_graph_input_attn_kv_unified * inp, + ggml_cgraph * gf, + ggml_tensor * wv_b, + ggml_tensor * wo, + ggml_tensor * q_cur, + ggml_tensor * k_cur, + ggml_tensor * v_cur, + float kq_scale, + int il) const { + // these nodes are added to the graph together so that they are not reordered + // by doing so, the number of splits in the graph is reduced + ggml_build_forward_expand(gf, q_cur); + ggml_build_forward_expand(gf, k_cur); + ggml_build_forward_expand(gf, v_cur); + + const llama_kv_cache_unified * kv_self = static_cast(memory); + const auto & n_ctx = cparams.n_ctx; + + const auto kv_lora_rank = hparams.n_lora_kv; + + // note: deepseek with MLA option converts into MQA with larger n_ebed (ie: GQA with 1 group) + const int64_t n_embd_k_compressed = kv_lora_rank + hparams.n_rot; + const int64_t n_embd_v_compressed = kv_lora_rank; + + // note: this is the smaller n_ebed what we get after decompression + const int64_t n_embd_head_v = hparams.n_embd_head_v; + + // note: llm_build_deepseek2 passes as: {n_embd, n_tokens, n_head} + const auto n_tokens = q_cur->ne[1]; + const auto n_head = q_cur->ne[2]; + + // store to KV cache + { + const auto kv_head = kv_self->head; + + GGML_ASSERT(kv_self->size == n_ctx); + + ggml_tensor * k_cache_view = ggml_view_1d(ctx0, kv_self->k_l[il], + n_tokens*n_embd_k_compressed, + ggml_row_size(kv_self->k_l[il]->type, n_embd_k_compressed)*kv_head); + //cb(k_cache_view, "k_cache_view", il); + + // note: storing RoPE-ed version of K in the KV cache + ggml_build_forward_expand(gf, ggml_cpy(ctx0, k_cur, k_cache_view)); + + v_cur = ggml_reshape_2d(ctx0, v_cur, n_embd_v_compressed, n_tokens); + + // note: for deepseek MLA the V cache just holds a transposed copy of the K cache + ggml_tensor * v_cache_view = ggml_view_2d(ctx0, kv_self->v_l[il], + n_tokens, n_embd_v_compressed, + ( n_ctx)*ggml_element_size(kv_self->v_l[il]), + (kv_head)*ggml_element_size(kv_self->v_l[il])); + + v_cur = ggml_transpose(ctx0, v_cur); + //cb(v_cache_view, "v_cache_view", il); + + ggml_build_forward_expand(gf, ggml_cpy(ctx0, v_cur, v_cache_view)); + } + + // NOTE: We can't pass this to `build_attn_mha()` due to: + // - There's no way to decompress as `build_attn_mha()` applies the kqv_merged and cont inside. + // - We lose the performance gain from using 2D views when applying MQA (ie: GQA with 1 group). + // - TODO: Consider refactoring `build_attn_mha()` and/or adding optimised `build_attn_mqa()` version. + + const auto & kq_mask = inp->get_kq_mask(); + + const auto n_kv = kv_self->n; + + ggml_tensor * k_cache = ggml_view_2d(ctx0, kv_self->k_l[il], + n_embd_k_compressed, n_kv, + ggml_row_size(kv_self->k_l[il]->type, n_embd_k_compressed), + 0); + cb(k_cache, "k_cache", il); + + struct ggml_tensor * v_cache_trans = ggml_view_2d(ctx0, kv_self->v_l[il], + n_kv, n_embd_v_compressed, + ggml_element_size(kv_self->v_l[il])*n_ctx, + 0); + cb(v_cache_trans, "v_cache_trans", il); + + ggml_tensor * q_states = ggml_view_2d(ctx0, q_cur, + n_embd_k_compressed, n_tokens*n_head, + ggml_row_size(q_cur->type, n_embd_k_compressed), + 0); + cb(q_states, "q_states_view", il); + + ggml_tensor * kq = ggml_mul_mat(ctx0, k_cache, q_states); + cb(kq, "kq", il); + + kq = ggml_view_3d(ctx0, kq, n_kv, n_tokens, n_head, + ggml_row_size(kq->type, n_kv), + ggml_row_size(kq->type, n_kv)*n_tokens, + 0); + cb(kq, "kq_view", il); + + ggml_tensor * kq_soft_max = ggml_soft_max_ext(ctx0, kq, kq_mask, kq_scale, hparams.f_max_alibi_bias); + cb(kq_soft_max, "kq_soft_max", il); + + kq_soft_max = ggml_view_2d(ctx0, kq_soft_max, + n_kv, n_tokens*n_head, + ggml_row_size(kq_soft_max->type, n_kv), + 0); + cb(kq_soft_max, "kq_soft_max_view", il); + + ggml_tensor * kqv_compressed = ggml_mul_mat(ctx0, v_cache_trans, kq_soft_max); + cb(kqv_compressed, "kqv_compressed,", il); + + kqv_compressed = ggml_view_3d(ctx0, kqv_compressed, + n_embd_v_compressed, n_tokens, n_head, + ggml_row_size(kqv_compressed->type, n_embd_v_compressed), + ggml_row_size(kqv_compressed->type, n_embd_v_compressed)*n_tokens, + 0); + cb(kqv_compressed, "kqv_compressed_view", il); + + ggml_tensor * wv_b_view = ggml_view_3d(wv_b, + n_embd_v_compressed, n_embd_head_v, n_head, + ggml_row_size(wv_b, n_embd_v_compressed), + ggml_row_size(wv_b, n_embd_v_compressed)*n_embd_head_v, + 0); + cb(wv_b_view, "wv_b_view", il); + + ggml_tensor * kqv = ggml_mul_mat(ctx0, wv_b_view, kqv_compressed); + cb(kqv, "kqv", il); + + kqv = ggml_permute(ctx0, kqv, 0, 2, 1, 3); + cb(kqv, "kqv_merged", il); + + ggml_tensor * cur = ggml_cont_2d(ctx0, kqv, n_embd_head_v*n_head, n_tokens); + cb(cur, "kqv_cont", il); + + ggml_build_forward_expand(gf, cur); + + cur = build_lora_mm(wo, cur); + + return cur; +} + llm_graph_input_attn_cross * llm_graph_context::build_attn_inp_cross() const { auto inp = std::make_unique(cross); diff --git a/src/llama-kv-cache.cpp b/src/llama-kv-cache.cpp index 50651758c9c75..e1855a11c65aa 100644 --- a/src/llama-kv-cache.cpp +++ b/src/llama-kv-cache.cpp @@ -71,8 +71,8 @@ bool llama_kv_cache_unified::init( v_l.reserve(n_layer); for (int i = 0; i < n_layer; i++) { - uint32_t n_embd_k; - uint32_t n_embd_v; + int64_t n_embd_k; + int64_t n_embd_v; // note: deepseek with MLA option converts into MQA (ie: GQA with 1 group) if (cparams.mla_attn) { @@ -95,7 +95,7 @@ bool llama_kv_cache_unified::init( buft = ggml_backend_cpu_buffer_type(); } - LLAMA_LOG_DEBUG("%s: layer %3d: n_embd_k = %d, n_embd_v = %d, dev = %s\n", __func__, + LLAMA_LOG_DEBUG("%s: layer %3ld: n_embd_k = %ld, n_embd_v = %d, dev = %s\n", __func__, i, n_embd_k, n_embd_v, dev_name); ggml_context * ctx = ctx_for_buft(buft); diff --git a/src/llama-model.cpp b/src/llama-model.cpp index ab738395b0c40..bd7ee407f546d 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -9524,7 +9524,7 @@ struct llm_build_deepseek2 : public llm_graph_context { // inp_pos - contains the positions ggml_tensor * inp_pos = build_inp_pos(); - auto * inp_attn = build_attn_inp_kv_unified(); + auto * inp_attn = build_attn_inp_kv_mla(); for (int il = 0; il < n_layer; ++il) { ggml_tensor * inpSA = inpL; @@ -9607,13 +9607,44 @@ struct llm_build_deepseek2 : public llm_graph_context { LLM_NORM_RMS, il); cb(kv_compressed, "kv_compressed", il); - if (cparams.mla_attn) { // note: deepseek with MLA option converts into MQA (ie: GQA with 1 group) - // TODO: later + q_nope = ggml_permute(ctx0, q_nope, 0, 2, 1, 3); + cb(q_nope, "q_nope_perm", il); + + q_pe = ggml_permute(ctx0, q_pe, 0, 2, 1, 3); + cb(q_pe, "q_pe_perm", il); + + k_pe = ggml_view_2d(ctx0, k_pe, n_embd_head_qk_rope, n_tokens, + ggml_row_size(k_pe->type, n_embd_head_qk_rope), + 0); + cb(k_pe, "k_pe_view", il); + + ggml_tensor * wk_b = ggml_view_3d(ctx0, model.layers[il].wk_b, n_embd_head_qk_nope, kv_lora_rank, n_head, + ggml_row_size(model.layers[il].wk_b->type, n_embd_head_qk_nope), + ggml_row_size(model.layers[il].wk_b->type, kv_lora_rank * n_embd_head_qk_nope), + 0); + cb(wk_b, "wk_b", il); + + ggml_tensor * q_nope_absorbed = ggml_mul_mat(ctx0, wk_b, q_nope); + cb(q_nope_absorbed, "q_nope_absorbed", il); + + ggml_tensor * q_states = ggml_concat(ctx0, q_nope_absorbed, q_pe, 0); + cb(q_states, "q_states", il); + + ggml_tensor * k_states = ggml_concat(ctx0, kv_compressed, k_pe, 0); + cb(k_states, "k_states", il); + + ggml_tensor * v_states = kv_compressed; + cb(v_states, "v_states", il); + + cur = build_attn_mla(inp_attn, gf, + model.layers[il].wv_b, model.layers[il].wo, + q_states, k_states, v_states, kq_scale, il); } else { // note: deepseek without MLA option converts into MHA + ggml_tensor * kv = ggml_mul_mat(ctx0, model.layers[il].wkv_b, kv_compressed); cb(kv, "kv", il); From 1de077b32d6ab93cab5979e9641b88b734c59b91 Mon Sep 17 00:00:00 2001 From: juk Date: Wed, 2 Apr 2025 19:58:02 +0100 Subject: [PATCH 05/23] Fixed return bug in `DeepseekV2Model` --- convert_hf_to_gguf.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index 7bb95669d48dc..21e8399e3ab8b 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -4414,7 +4414,6 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter else: return [] - return [(self.map_tensor_name(name), data_torch)] if name.endswith("kv_b_proj.weight"): name_kb = name.replace("kv_b_proj", "k_b_proj") name_vb = name.replace("kv_b_proj", "v_b_proj") @@ -4437,6 +4436,8 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter (self.map_tensor_name(name_vb), v_b) ] + return [(self.map_tensor_name(name), data_torch)] + def prepare_tensors(self): super().prepare_tensors() From 7f92e7b6c64dbcdab0808e8915fbcfab7c748a24 Mon Sep 17 00:00:00 2001 From: juk Date: Wed, 2 Apr 2025 20:09:09 +0100 Subject: [PATCH 06/23] Minor fixes --- src/llama-graph.h | 11 +++++++++++ src/llama-kv-cache.cpp | 4 +++- src/llama-model.cpp | 2 +- 3 files changed, 15 insertions(+), 2 deletions(-) diff --git a/src/llama-graph.h b/src/llama-graph.h index bdf19ed015e35..c190cc949c702 100644 --- a/src/llama-graph.h +++ b/src/llama-graph.h @@ -523,6 +523,17 @@ struct llm_graph_context { float kq_scale, int il) const; + ggml_tensor * build_attn_mla( + llm_graph_input_attn_kv_unified * inp, + ggml_cgraph * gf, + ggml_tensor * wv_b, + ggml_tensor * wo, + ggml_tensor * q_cur, // [n_embd_k, n_tokens, n_head] + ggml_tensor * k_cur, // [n_embd_k, n_tokens] + ggml_tensor * v_cur, // [n_embd_v, n_tokens] + float kq_scale, + int il) const; + llm_graph_input_attn_cross * build_attn_inp_cross() const; ggml_tensor * build_attn( diff --git a/src/llama-kv-cache.cpp b/src/llama-kv-cache.cpp index e1855a11c65aa..dfa6c6ac6ef96 100644 --- a/src/llama-kv-cache.cpp +++ b/src/llama-kv-cache.cpp @@ -11,6 +11,8 @@ #include #include +#include + llama_kv_cache_unified::llama_kv_cache_unified(const llama_hparams & hparams, callbacks cbs) : hparams(hparams), cbs(std::move(cbs)) { } @@ -95,7 +97,7 @@ bool llama_kv_cache_unified::init( buft = ggml_backend_cpu_buffer_type(); } - LLAMA_LOG_DEBUG("%s: layer %3ld: n_embd_k = %ld, n_embd_v = %d, dev = %s\n", __func__, + LLAMA_LOG_DEBUG("%s: layer %3d: n_embd_k = %" PRId64 ", n_embd_v = %" PRId64 ", dev = %s\n", __func__, i, n_embd_k, n_embd_v, dev_name); ggml_context * ctx = ctx_for_buft(buft); diff --git a/src/llama-model.cpp b/src/llama-model.cpp index bd7ee407f546d..d66d969620483 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -9524,7 +9524,7 @@ struct llm_build_deepseek2 : public llm_graph_context { // inp_pos - contains the positions ggml_tensor * inp_pos = build_inp_pos(); - auto * inp_attn = build_attn_inp_kv_mla(); + auto * inp_attn = llm_graph_input_attn_kv_unified(); for (int il = 0; il < n_layer; ++il) { ggml_tensor * inpSA = inpL; From 319e3efb0fa0027a21de344e91b22e9b0048a0b5 Mon Sep 17 00:00:00 2001 From: juk Date: Wed, 2 Apr 2025 20:13:22 +0100 Subject: [PATCH 07/23] More fixes --- src/llama-graph.cpp | 4 ++-- src/llama-model.cpp | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/llama-graph.cpp b/src/llama-graph.cpp index 6b4c933f82531..74a2d1d50bb18 100644 --- a/src/llama-graph.cpp +++ b/src/llama-graph.cpp @@ -1532,8 +1532,8 @@ ggml_tensor * llm_graph_context::build_attn_mla( ggml_tensor * wv_b_view = ggml_view_3d(wv_b, n_embd_v_compressed, n_embd_head_v, n_head, - ggml_row_size(wv_b, n_embd_v_compressed), - ggml_row_size(wv_b, n_embd_v_compressed)*n_embd_head_v, + ggml_row_size(wv_b->type, n_embd_v_compressed), + ggml_row_size(wv_b->type, n_embd_v_compressed)*n_embd_head_v, 0); cb(wv_b_view, "wv_b_view", il); diff --git a/src/llama-model.cpp b/src/llama-model.cpp index d66d969620483..30e5efd2784e7 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -9524,7 +9524,7 @@ struct llm_build_deepseek2 : public llm_graph_context { // inp_pos - contains the positions ggml_tensor * inp_pos = build_inp_pos(); - auto * inp_attn = llm_graph_input_attn_kv_unified(); + auto * inp_attn = build_attn_inp_kv_unified(); for (int il = 0; il < n_layer; ++il) { ggml_tensor * inpSA = inpL; From ee4b38935a305206c820c144705f1bd48f4c73ee Mon Sep 17 00:00:00 2001 From: juk Date: Wed, 2 Apr 2025 20:18:57 +0100 Subject: [PATCH 08/23] Renamed `wv_b` to `wv_decompress` to avoid confusion with `_b` biases --- src/llama-graph.cpp | 12 ++++++------ src/llama-graph.h | 2 +- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/src/llama-graph.cpp b/src/llama-graph.cpp index 74a2d1d50bb18..fbac8934533cf 100644 --- a/src/llama-graph.cpp +++ b/src/llama-graph.cpp @@ -1418,7 +1418,7 @@ ggml_tensor * llm_graph_context::build_attn( ggml_tensor * llm_graph_context::build_attn_mla( llm_graph_input_attn_kv_unified * inp, ggml_cgraph * gf, - ggml_tensor * wv_b, + ggml_tensor * wv_decompress, ggml_tensor * wo, ggml_tensor * q_cur, ggml_tensor * k_cur, @@ -1530,14 +1530,14 @@ ggml_tensor * llm_graph_context::build_attn_mla( 0); cb(kqv_compressed, "kqv_compressed_view", il); - ggml_tensor * wv_b_view = ggml_view_3d(wv_b, + ggml_tensor * wv_decompress_view = ggml_view_3d(ctx0, wv_decompress, n_embd_v_compressed, n_embd_head_v, n_head, - ggml_row_size(wv_b->type, n_embd_v_compressed), - ggml_row_size(wv_b->type, n_embd_v_compressed)*n_embd_head_v, + ggml_row_size(wv_decompress->type, n_embd_v_compressed), + ggml_row_size(wv_decompress->type, n_embd_v_compressed)*n_embd_head_v, 0); - cb(wv_b_view, "wv_b_view", il); + cb(wv_decompress_view, "wv_decompress_view", il); - ggml_tensor * kqv = ggml_mul_mat(ctx0, wv_b_view, kqv_compressed); + ggml_tensor * kqv = ggml_mul_mat(ctx0, wv_decompress_view, kqv_compressed); cb(kqv, "kqv", il); kqv = ggml_permute(ctx0, kqv, 0, 2, 1, 3); diff --git a/src/llama-graph.h b/src/llama-graph.h index c190cc949c702..b0e6cfdc49d3c 100644 --- a/src/llama-graph.h +++ b/src/llama-graph.h @@ -526,7 +526,7 @@ struct llm_graph_context { ggml_tensor * build_attn_mla( llm_graph_input_attn_kv_unified * inp, ggml_cgraph * gf, - ggml_tensor * wv_b, + ggml_tensor * wv_decompress, ggml_tensor * wo, ggml_tensor * q_cur, // [n_embd_k, n_tokens, n_head] ggml_tensor * k_cur, // [n_embd_k, n_tokens] From c00cd9e2c71f21663200b1d1a1a07a134e64bd58 Mon Sep 17 00:00:00 2001 From: juk Date: Wed, 2 Apr 2025 20:24:07 +0100 Subject: [PATCH 09/23] Better `_compressed` variable names --- src/llama-graph.cpp | 19 +++++++++---------- 1 file changed, 9 insertions(+), 10 deletions(-) diff --git a/src/llama-graph.cpp b/src/llama-graph.cpp index fbac8934533cf..5875dde8ec204 100644 --- a/src/llama-graph.cpp +++ b/src/llama-graph.cpp @@ -1443,7 +1443,7 @@ ggml_tensor * llm_graph_context::build_attn_mla( // note: this is the smaller n_ebed what we get after decompression const int64_t n_embd_head_v = hparams.n_embd_head_v; - // note: llm_build_deepseek2 passes as: {n_embd, n_tokens, n_head} + // note: call from llm_build_deepseek2 passes as: {n_embd, n_tokens, n_head} const auto n_tokens = q_cur->ne[1]; const auto n_head = q_cur->ne[2]; @@ -1463,7 +1463,6 @@ ggml_tensor * llm_graph_context::build_attn_mla( v_cur = ggml_reshape_2d(ctx0, v_cur, n_embd_v_compressed, n_tokens); - // note: for deepseek MLA the V cache just holds a transposed copy of the K cache ggml_tensor * v_cache_view = ggml_view_2d(ctx0, kv_self->v_l[il], n_tokens, n_embd_v_compressed, ( n_ctx)*ggml_element_size(kv_self->v_l[il]), @@ -1484,25 +1483,25 @@ ggml_tensor * llm_graph_context::build_attn_mla( const auto n_kv = kv_self->n; - ggml_tensor * k_cache = ggml_view_2d(ctx0, kv_self->k_l[il], + ggml_tensor * k_compressed = ggml_view_2d(ctx0, kv_self->k_l[il], n_embd_k_compressed, n_kv, ggml_row_size(kv_self->k_l[il]->type, n_embd_k_compressed), 0); - cb(k_cache, "k_cache", il); + cb(k_compressed, "k_compressed", il); - struct ggml_tensor * v_cache_trans = ggml_view_2d(ctx0, kv_self->v_l[il], + struct ggml_tensor * v_compressed_trans = ggml_view_2d(ctx0, kv_self->v_l[il], n_kv, n_embd_v_compressed, ggml_element_size(kv_self->v_l[il])*n_ctx, 0); - cb(v_cache_trans, "v_cache_trans", il); + cb(v_compressed_trans, "v_compressed_trans", il); - ggml_tensor * q_states = ggml_view_2d(ctx0, q_cur, + ggml_tensor * q_compressed = ggml_view_2d(ctx0, q_cur, n_embd_k_compressed, n_tokens*n_head, ggml_row_size(q_cur->type, n_embd_k_compressed), 0); - cb(q_states, "q_states_view", il); + cb(q_compressed, "q_compressed", il); - ggml_tensor * kq = ggml_mul_mat(ctx0, k_cache, q_states); + ggml_tensor * kq = ggml_mul_mat(ctx0, k_compressed, q_compressed); cb(kq, "kq", il); kq = ggml_view_3d(ctx0, kq, n_kv, n_tokens, n_head, @@ -1520,7 +1519,7 @@ ggml_tensor * llm_graph_context::build_attn_mla( 0); cb(kq_soft_max, "kq_soft_max_view", il); - ggml_tensor * kqv_compressed = ggml_mul_mat(ctx0, v_cache_trans, kq_soft_max); + ggml_tensor * kqv_compressed = ggml_mul_mat(ctx0, v_compressed_trans, kq_soft_max); cb(kqv_compressed, "kqv_compressed,", il); kqv_compressed = ggml_view_3d(ctx0, kqv_compressed, From 55ad3a7323c8087427cfc5bfb35b42670949d181 Mon Sep 17 00:00:00 2001 From: juk Date: Wed, 2 Apr 2025 20:35:21 +0100 Subject: [PATCH 10/23] Minor comment and variable name fixes --- src/llama-graph.cpp | 2 +- src/llama-graph.h | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/llama-graph.cpp b/src/llama-graph.cpp index 5875dde8ec204..5e35e046c43a4 100644 --- a/src/llama-graph.cpp +++ b/src/llama-graph.cpp @@ -1443,7 +1443,7 @@ ggml_tensor * llm_graph_context::build_attn_mla( // note: this is the smaller n_ebed what we get after decompression const int64_t n_embd_head_v = hparams.n_embd_head_v; - // note: call from llm_build_deepseek2 passes as: {n_embd, n_tokens, n_head} + // note: call from llm_build_deepseek2() passes as: {n_embd, n_tokens, n_head} const auto n_tokens = q_cur->ne[1]; const auto n_head = q_cur->ne[2]; diff --git a/src/llama-graph.h b/src/llama-graph.h index b0e6cfdc49d3c..57de929a42c67 100644 --- a/src/llama-graph.h +++ b/src/llama-graph.h @@ -531,8 +531,8 @@ struct llm_graph_context { ggml_tensor * q_cur, // [n_embd_k, n_tokens, n_head] ggml_tensor * k_cur, // [n_embd_k, n_tokens] ggml_tensor * v_cur, // [n_embd_v, n_tokens] - float kq_scale, - int il) const; + float kq_scale, + int il) const; llm_graph_input_attn_cross * build_attn_inp_cross() const; From 0c86f5645b203f197b5c2508b6ad62a67134c724 Mon Sep 17 00:00:00 2001 From: juk Date: Wed, 2 Apr 2025 21:11:55 +0100 Subject: [PATCH 11/23] Moved `build_attn_mla` to better location --- src/llama-graph.cpp | 120 ++++++++++++++++++++++---------------------- src/llama-graph.h | 22 ++++---- 2 files changed, 71 insertions(+), 71 deletions(-) diff --git a/src/llama-graph.cpp b/src/llama-graph.cpp index 5e35e046c43a4..d59536e21ad73 100644 --- a/src/llama-graph.cpp +++ b/src/llama-graph.cpp @@ -1415,6 +1415,66 @@ ggml_tensor * llm_graph_context::build_attn( return cur; } +llm_graph_input_attn_cross * llm_graph_context::build_attn_inp_cross() const { + auto inp = std::make_unique(cross); + + const int32_t n_enc = !cross->v_embd.empty() ? cross->n_enc : hparams.n_ctx_train; + + inp->cross_kq_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_enc, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD)); + ggml_set_input(inp->cross_kq_mask); + + inp->cross_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->cross_kq_mask, GGML_TYPE_F16) : inp->cross_kq_mask; + + return (llm_graph_input_attn_cross *) res->add_input(std::move(inp)); +} + +ggml_tensor * llm_graph_context::build_attn( + llm_graph_input_attn_cross * inp, + ggml_cgraph * gf, + ggml_tensor * wo, + ggml_tensor * wo_b, + ggml_tensor * q_cur, + ggml_tensor * k_cur, + ggml_tensor * v_cur, + ggml_tensor * kq_b, + float kq_scale, + int il) const { + // these nodes are added to the graph together so that they are not reordered + // by doing so, the number of splits in the graph is reduced + ggml_build_forward_expand(gf, q_cur); + ggml_build_forward_expand(gf, k_cur); + ggml_build_forward_expand(gf, v_cur); + + const auto & kq_mask = inp->get_kq_mask_cross(); + + ggml_tensor * q = ggml_permute(ctx0, q_cur, 0, 2, 1, 3); + //cb(q, "q", il); + + ggml_tensor * k = ggml_permute(ctx0, k_cur, 0, 2, 1, 3); + //cb(k, "k", il); + + ggml_tensor * v = ggml_permute(ctx0, v_cur, 0, 2, 1, 3); + //cb(k, "v", il); + + ggml_tensor * cur = build_attn_mha(gf, q, k, v, kq_b, kq_mask, false, kq_scale); + + cb(cur, "kqv_out", il); + + if (wo) { + cur = build_lora_mm(wo, cur); + } + + if (wo_b) { + //cb(cur, "kqv_wo", il); + } + + if (wo_b) { + cur = ggml_add(ctx0, cur, wo_b); + } + + return cur; +} + ggml_tensor * llm_graph_context::build_attn_mla( llm_graph_input_attn_kv_unified * inp, ggml_cgraph * gf, @@ -1552,66 +1612,6 @@ ggml_tensor * llm_graph_context::build_attn_mla( return cur; } -llm_graph_input_attn_cross * llm_graph_context::build_attn_inp_cross() const { - auto inp = std::make_unique(cross); - - const int32_t n_enc = !cross->v_embd.empty() ? cross->n_enc : hparams.n_ctx_train; - - inp->cross_kq_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_enc, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD)); - ggml_set_input(inp->cross_kq_mask); - - inp->cross_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->cross_kq_mask, GGML_TYPE_F16) : inp->cross_kq_mask; - - return (llm_graph_input_attn_cross *) res->add_input(std::move(inp)); -} - -ggml_tensor * llm_graph_context::build_attn( - llm_graph_input_attn_cross * inp, - ggml_cgraph * gf, - ggml_tensor * wo, - ggml_tensor * wo_b, - ggml_tensor * q_cur, - ggml_tensor * k_cur, - ggml_tensor * v_cur, - ggml_tensor * kq_b, - float kq_scale, - int il) const { - // these nodes are added to the graph together so that they are not reordered - // by doing so, the number of splits in the graph is reduced - ggml_build_forward_expand(gf, q_cur); - ggml_build_forward_expand(gf, k_cur); - ggml_build_forward_expand(gf, v_cur); - - const auto & kq_mask = inp->get_kq_mask_cross(); - - ggml_tensor * q = ggml_permute(ctx0, q_cur, 0, 2, 1, 3); - //cb(q, "q", il); - - ggml_tensor * k = ggml_permute(ctx0, k_cur, 0, 2, 1, 3); - //cb(k, "k", il); - - ggml_tensor * v = ggml_permute(ctx0, v_cur, 0, 2, 1, 3); - //cb(k, "v", il); - - ggml_tensor * cur = build_attn_mha(gf, q, k, v, kq_b, kq_mask, false, kq_scale); - - cb(cur, "kqv_out", il); - - if (wo) { - cur = build_lora_mm(wo, cur); - } - - if (wo_b) { - //cb(cur, "kqv_wo", il); - } - - if (wo_b) { - cur = ggml_add(ctx0, cur, wo_b); - } - - return cur; -} - ggml_tensor * llm_graph_context::build_copy_mask_state( ggml_cgraph * gf, ggml_tensor * s, diff --git a/src/llama-graph.h b/src/llama-graph.h index 57de929a42c67..54dfbeef2090f 100644 --- a/src/llama-graph.h +++ b/src/llama-graph.h @@ -523,17 +523,6 @@ struct llm_graph_context { float kq_scale, int il) const; - ggml_tensor * build_attn_mla( - llm_graph_input_attn_kv_unified * inp, - ggml_cgraph * gf, - ggml_tensor * wv_decompress, - ggml_tensor * wo, - ggml_tensor * q_cur, // [n_embd_k, n_tokens, n_head] - ggml_tensor * k_cur, // [n_embd_k, n_tokens] - ggml_tensor * v_cur, // [n_embd_v, n_tokens] - float kq_scale, - int il) const; - llm_graph_input_attn_cross * build_attn_inp_cross() const; ggml_tensor * build_attn( @@ -548,6 +537,17 @@ struct llm_graph_context { float kq_scale, int il) const; + ggml_tensor * build_attn_mla( + llm_graph_input_attn_kv_unified * inp, + ggml_cgraph * gf, + ggml_tensor * wv_decompress, + ggml_tensor * wo, + ggml_tensor * q_cur, // [n_embd_k, n_tokens, n_head] + ggml_tensor * k_cur, // [n_embd_k, n_tokens] + ggml_tensor * v_cur, // [n_embd_v, n_tokens] + float kq_scale, + int il) const; + // // recurrent // From b0c8a43286d0f30f07f89bff87c94347b3a31dc1 Mon Sep 17 00:00:00 2001 From: juk Date: Wed, 2 Apr 2025 21:25:04 +0100 Subject: [PATCH 12/23] Removed `gguf.MODEL_TENSOR.ATTN_K_B` from `prepare_tensors()` for now --- convert_hf_to_gguf.py | 1 - 1 file changed, 1 deletion(-) diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index 21e8399e3ab8b..a191e8a898781 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -330,7 +330,6 @@ def prepare_tensors(self): gguf.MODEL_TENSOR.TIME_MIX_LERP_FUSED, gguf.MODEL_TENSOR.POSNET_NORM1, gguf.MODEL_TENSOR.POSNET_NORM2, - gguf.MODEL_TENSOR.ATTN_K_B, ) ) or not new_name.endswith(".weight") From 8c329bca6afd25629ab577bfae7bac72b39ff6d9 Mon Sep 17 00:00:00 2001 From: juk Date: Wed, 2 Apr 2025 22:04:31 +0100 Subject: [PATCH 13/23] Bumped `wkv_b` and `wk_b` to use F32. --- src/llama-model.cpp | 28 ++++++++++++++++++---------- 1 file changed, 18 insertions(+), 10 deletions(-) diff --git a/src/llama-model.cpp b/src/llama-model.cpp index 30e5efd2784e7..10a27d7d2c676 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -9627,7 +9627,11 @@ struct llm_build_deepseek2 : public llm_graph_context { 0); cb(wk_b, "wk_b", il); + // note: this operation *MUST* use F32 (or have `wk_b` stored as F32 or BF16 in the GGUF) ggml_tensor * q_nope_absorbed = ggml_mul_mat(ctx0, wk_b, q_nope); + if (wk_b->type != GGML_TYPE_F32 && wk_b->type != GGML_TYPE_BF16) { + ggml_mul_mat_set_prec(q_nope_absorbed, GGML_PREC_F32); + } cb(q_nope_absorbed, "q_nope_absorbed", il); ggml_tensor * q_states = ggml_concat(ctx0, q_nope_absorbed, q_pe, 0); @@ -9645,28 +9649,32 @@ struct llm_build_deepseek2 : public llm_graph_context { } else { // note: deepseek without MLA option converts into MHA - ggml_tensor * kv = ggml_mul_mat(ctx0, model.layers[il].wkv_b, kv_compressed); - cb(kv, "kv", il); + // note: this operation *MUST* use F32 (or have `wkv_b` stored as F32 or BF16 in the GGUF) + ggml_tensor * kv_decompressed = ggml_mul_mat(ctx0, model.layers[il].wkv_b, kv_compressed); + if (model.layers[il].wkv_b->type != GGML_TYPE_F32 && model.layers[il].wkv_b->type != GGML_TYPE_BF16) { + ggml_mul_mat_set_prec(kv_decompressed, GGML_PREC_F32); + } + cb(kv_decompressed, "kv_decompressed", il); // split into {n_head * n_embd_head_qk_nope, n_tokens} - ggml_tensor * k_nope = ggml_view_3d(ctx0, kv, n_embd_head_qk_nope, n_head, n_tokens, - ggml_row_size(kv->type, n_embd_head_qk_nope + hparams.n_embd_head_v), - ggml_row_size(kv->type, n_head * (n_embd_head_qk_nope + hparams.n_embd_head_v)), + ggml_tensor * k_nope = ggml_view_3d(ctx0, kv_decompressed, n_embd_head_qk_nope, n_head, n_tokens, + ggml_row_size(kv_decompressed->type, n_embd_head_qk_nope + hparams.n_embd_head_v), + ggml_row_size(kv_decompressed->type, n_head * (n_embd_head_qk_nope + hparams.n_embd_head_v)), 0); cb(k_nope, "k_nope", il); // and {n_head * n_embd_head_v, n_tokens} - ggml_tensor * v_states = ggml_view_3d(ctx0, kv, hparams.n_embd_head_v, n_head, n_tokens, - ggml_row_size(kv->type, (n_embd_head_qk_nope + hparams.n_embd_head_v)), - ggml_row_size(kv->type, (n_embd_head_qk_nope + hparams.n_embd_head_v)*n_head), - ggml_row_size(kv->type, (n_embd_head_qk_nope))); + ggml_tensor * v_states = ggml_view_3d(ctx0, kv_decompressed, hparams.n_embd_head_v, n_head, n_tokens, + ggml_row_size(kv_decompressed->type, (n_embd_head_qk_nope + hparams.n_embd_head_v)), + ggml_row_size(kv_decompressed->type, (n_embd_head_qk_nope + hparams.n_embd_head_v)*n_head), + ggml_row_size(kv_decompressed->type, (n_embd_head_qk_nope))); cb(v_states, "v_states", il); v_states = ggml_cont(ctx0, v_states); cb(v_states, "v_states", il); v_states = ggml_view_2d(ctx0, v_states, hparams.n_embd_head_v * n_head, n_tokens, - ggml_row_size(kv->type, hparams.n_embd_head_v * n_head), + ggml_row_size(v_states->type, hparams.n_embd_head_v * n_head), 0); cb(v_states, "v_states", il); From 68302eeeb7a82d4fe408df009d7639ce147340c4 Mon Sep 17 00:00:00 2001 From: juk Date: Wed, 2 Apr 2025 22:10:53 +0100 Subject: [PATCH 14/23] Use `ggml_mul_mat_set_prec` `GGML_PREC_F32` by default for now --- src/llama-model.cpp | 43 ++++++++++++++++++++++++++----------------- 1 file changed, 26 insertions(+), 17 deletions(-) diff --git a/src/llama-model.cpp b/src/llama-model.cpp index 10a27d7d2c676..a6126604e3a6b 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -9514,6 +9514,7 @@ struct llm_build_deepseek2 : public llm_graph_context { const uint32_t n_embd_head_qk_rope = hparams.n_rot; const uint32_t n_embd_head_qk_nope = hparams.n_embd_head_k - hparams.n_rot; const uint32_t kv_lora_rank = hparams.n_lora_kv; + const uint32_t n_embd_head_v = hparams.n_embd_head_v; ggml_tensor * cur; ggml_tensor * inpL; @@ -9555,14 +9556,16 @@ struct llm_build_deepseek2 : public llm_graph_context { } // split into {n_head * n_embd_head_qk_nope, n_tokens} - ggml_tensor * q_nope = ggml_view_3d(ctx0, q, n_embd_head_qk_nope, n_head, n_tokens, + ggml_tensor * q_nope = ggml_view_3d(ctx0, q, + n_embd_head_qk_nope, n_head, n_tokens, ggml_row_size(q->type, hparams.n_embd_head_k), ggml_row_size(q->type, hparams.n_embd_head_k * n_head), 0); cb(q_nope, "q_nope", il); // and {n_head * n_embd_head_qk_rope, n_tokens} - ggml_tensor * q_pe = ggml_view_3d(ctx0, q, n_embd_head_qk_rope, n_head, n_tokens, + ggml_tensor * q_pe = ggml_view_3d(ctx0, q, + n_embd_head_qk_rope, n_head, n_tokens, ggml_row_size(q->type, hparams.n_embd_head_k), ggml_row_size(q->type, hparams.n_embd_head_k * n_head), ggml_row_size(q->type, n_embd_head_qk_nope)); @@ -9578,7 +9581,8 @@ struct llm_build_deepseek2 : public llm_graph_context { cb(kv_compressed, "kv_compressed", il); // and {n_embd_head_qk_rope, n_tokens} - ggml_tensor * k_pe = ggml_view_3d(ctx0, kv_pe_compresseed, n_embd_head_qk_rope, 1, n_tokens, + ggml_tensor * k_pe = ggml_view_3d(ctx0, kv_pe_compresseed, + n_embd_head_qk_rope, 1, n_tokens, kv_pe_compresseed->nb[1], kv_pe_compresseed->nb[1], ggml_row_size(kv_pe_compresseed->type, kv_lora_rank)); @@ -9616,12 +9620,14 @@ struct llm_build_deepseek2 : public llm_graph_context { q_pe = ggml_permute(ctx0, q_pe, 0, 2, 1, 3); cb(q_pe, "q_pe_perm", il); - k_pe = ggml_view_2d(ctx0, k_pe, n_embd_head_qk_rope, n_tokens, + k_pe = ggml_view_2d(ctx0, k_pe, + n_embd_head_qk_rope, n_tokens, ggml_row_size(k_pe->type, n_embd_head_qk_rope), 0); cb(k_pe, "k_pe_view", il); - ggml_tensor * wk_b = ggml_view_3d(ctx0, model.layers[il].wk_b, n_embd_head_qk_nope, kv_lora_rank, n_head, + ggml_tensor * wk_b = ggml_view_3d(ctx0, model.layers[il].wk_b, + n_embd_head_qk_nope, kv_lora_rank, n_head, ggml_row_size(model.layers[il].wk_b->type, n_embd_head_qk_nope), ggml_row_size(model.layers[il].wk_b->type, kv_lora_rank * n_embd_head_qk_nope), 0); @@ -9629,9 +9635,9 @@ struct llm_build_deepseek2 : public llm_graph_context { // note: this operation *MUST* use F32 (or have `wk_b` stored as F32 or BF16 in the GGUF) ggml_tensor * q_nope_absorbed = ggml_mul_mat(ctx0, wk_b, q_nope); - if (wk_b->type != GGML_TYPE_F32 && wk_b->type != GGML_TYPE_BF16) { + //if (wk_b->type != GGML_TYPE_F32 && wk_b->type != GGML_TYPE_BF16) { ggml_mul_mat_set_prec(q_nope_absorbed, GGML_PREC_F32); - } + //} cb(q_nope_absorbed, "q_nope_absorbed", il); ggml_tensor * q_states = ggml_concat(ctx0, q_nope_absorbed, q_pe, 0); @@ -9651,30 +9657,33 @@ struct llm_build_deepseek2 : public llm_graph_context { // note: this operation *MUST* use F32 (or have `wkv_b` stored as F32 or BF16 in the GGUF) ggml_tensor * kv_decompressed = ggml_mul_mat(ctx0, model.layers[il].wkv_b, kv_compressed); - if (model.layers[il].wkv_b->type != GGML_TYPE_F32 && model.layers[il].wkv_b->type != GGML_TYPE_BF16) { + //if (model.layers[il].wkv_b->type != GGML_TYPE_F32 && model.layers[il].wkv_b->type != GGML_TYPE_BF16) { ggml_mul_mat_set_prec(kv_decompressed, GGML_PREC_F32); - } + //} cb(kv_decompressed, "kv_decompressed", il); // split into {n_head * n_embd_head_qk_nope, n_tokens} - ggml_tensor * k_nope = ggml_view_3d(ctx0, kv_decompressed, n_embd_head_qk_nope, n_head, n_tokens, - ggml_row_size(kv_decompressed->type, n_embd_head_qk_nope + hparams.n_embd_head_v), - ggml_row_size(kv_decompressed->type, n_head * (n_embd_head_qk_nope + hparams.n_embd_head_v)), + ggml_tensor * k_nope = ggml_view_3d(ctx0, kv_decompressed, + n_embd_head_qk_nope, n_head, n_tokens, + ggml_row_size(kv_decompressed->type, n_embd_head_qk_nope + n_embd_head_v), + ggml_row_size(kv_decompressed->type, n_head * (n_embd_head_qk_nope + n_embd_head_v)), 0); cb(k_nope, "k_nope", il); // and {n_head * n_embd_head_v, n_tokens} - ggml_tensor * v_states = ggml_view_3d(ctx0, kv_decompressed, hparams.n_embd_head_v, n_head, n_tokens, - ggml_row_size(kv_decompressed->type, (n_embd_head_qk_nope + hparams.n_embd_head_v)), - ggml_row_size(kv_decompressed->type, (n_embd_head_qk_nope + hparams.n_embd_head_v)*n_head), + ggml_tensor * v_states = ggml_view_3d(ctx0, kv_decompressed, + n_embd_head_v, n_head, n_tokens, + ggml_row_size(kv_decompressed->type, (n_embd_head_qk_nope + n_embd_head_v)), + ggml_row_size(kv_decompressed->type, (n_embd_head_qk_nope + n_embd_head_v)*n_head), ggml_row_size(kv_decompressed->type, (n_embd_head_qk_nope))); cb(v_states, "v_states", il); v_states = ggml_cont(ctx0, v_states); cb(v_states, "v_states", il); - v_states = ggml_view_2d(ctx0, v_states, hparams.n_embd_head_v * n_head, n_tokens, - ggml_row_size(v_states->type, hparams.n_embd_head_v * n_head), + v_states = ggml_view_2d(ctx0, v_states, + n_embd_head_v * n_head, n_tokens, + ggml_row_size(v_states->type, n_embd_head_v * n_head), 0); cb(v_states, "v_states", il); From 937a48d539d12bfd43cd5acc1993a0173ccc179c Mon Sep 17 00:00:00 2001 From: juk Date: Wed, 2 Apr 2025 22:34:16 +0100 Subject: [PATCH 15/23] Better/shorter variable names and more tidying up of code --- src/llama-graph.cpp | 76 ++++++++++++++++++++++----------------------- src/llama-graph.h | 2 +- src/llama-model.cpp | 58 ++++++++++++++++------------------ 3 files changed, 66 insertions(+), 70 deletions(-) diff --git a/src/llama-graph.cpp b/src/llama-graph.cpp index d59536e21ad73..e748b5f22c1bc 100644 --- a/src/llama-graph.cpp +++ b/src/llama-graph.cpp @@ -1478,7 +1478,7 @@ ggml_tensor * llm_graph_context::build_attn( ggml_tensor * llm_graph_context::build_attn_mla( llm_graph_input_attn_kv_unified * inp, ggml_cgraph * gf, - ggml_tensor * wv_decompress, + ggml_tensor * wv_b, ggml_tensor * wo, ggml_tensor * q_cur, ggml_tensor * k_cur, @@ -1497,8 +1497,8 @@ ggml_tensor * llm_graph_context::build_attn_mla( const auto kv_lora_rank = hparams.n_lora_kv; // note: deepseek with MLA option converts into MQA with larger n_ebed (ie: GQA with 1 group) - const int64_t n_embd_k_compressed = kv_lora_rank + hparams.n_rot; - const int64_t n_embd_v_compressed = kv_lora_rank; + const int64_t n_embd_k_cmpr = kv_lora_rank + hparams.n_rot; + const int64_t n_embd_v_cmpr = kv_lora_rank; // note: this is the smaller n_ebed what we get after decompression const int64_t n_embd_head_v = hparams.n_embd_head_v; @@ -1514,17 +1514,17 @@ ggml_tensor * llm_graph_context::build_attn_mla( GGML_ASSERT(kv_self->size == n_ctx); ggml_tensor * k_cache_view = ggml_view_1d(ctx0, kv_self->k_l[il], - n_tokens*n_embd_k_compressed, - ggml_row_size(kv_self->k_l[il]->type, n_embd_k_compressed)*kv_head); + n_tokens*n_embd_k_cmpr, + ggml_row_size(kv_self->k_l[il]->type, n_embd_k_cmpr)*kv_head); //cb(k_cache_view, "k_cache_view", il); // note: storing RoPE-ed version of K in the KV cache ggml_build_forward_expand(gf, ggml_cpy(ctx0, k_cur, k_cache_view)); - v_cur = ggml_reshape_2d(ctx0, v_cur, n_embd_v_compressed, n_tokens); + v_cur = ggml_reshape_2d(ctx0, v_cur, n_embd_v_cmpr, n_tokens); ggml_tensor * v_cache_view = ggml_view_2d(ctx0, kv_self->v_l[il], - n_tokens, n_embd_v_compressed, + n_tokens, n_embd_v_cmpr, ( n_ctx)*ggml_element_size(kv_self->v_l[il]), (kv_head)*ggml_element_size(kv_self->v_l[il])); @@ -1543,34 +1543,34 @@ ggml_tensor * llm_graph_context::build_attn_mla( const auto n_kv = kv_self->n; - ggml_tensor * k_compressed = ggml_view_2d(ctx0, kv_self->k_l[il], - n_embd_k_compressed, n_kv, - ggml_row_size(kv_self->k_l[il]->type, n_embd_k_compressed), + ggml_tensor * k_cmpr = ggml_view_2d(ctx0, kv_self->k_l[il], + n_embd_k_cmpr, n_kv, + ggml_row_size(kv_self->k_l[il]->type, n_embd_k_cmpr), 0); - cb(k_compressed, "k_compressed", il); + cb(k_cmpr, "k_cmpr", il); - struct ggml_tensor * v_compressed_trans = ggml_view_2d(ctx0, kv_self->v_l[il], - n_kv, n_embd_v_compressed, + struct ggml_tensor * v_cmpr_trans = ggml_view_2d(ctx0, kv_self->v_l[il], + n_kv, n_embd_v_cmpr, ggml_element_size(kv_self->v_l[il])*n_ctx, 0); - cb(v_compressed_trans, "v_compressed_trans", il); + cb(v_cmpr_trans, "v_cmpr_trans", il); - ggml_tensor * q_compressed = ggml_view_2d(ctx0, q_cur, - n_embd_k_compressed, n_tokens*n_head, - ggml_row_size(q_cur->type, n_embd_k_compressed), + ggml_tensor * q_cmpr = ggml_view_2d(ctx0, q_cur, + n_embd_k_cmpr, n_tokens*n_head, + ggml_row_size(q_cur->type, n_embd_k_cmpr), 0); - cb(q_compressed, "q_compressed", il); + cb(q_cmpr, "q_cmpr", il); - ggml_tensor * kq = ggml_mul_mat(ctx0, k_compressed, q_compressed); - cb(kq, "kq", il); + ggml_tensor * kq_cmpr = ggml_mul_mat(ctx0, k_cmpr, q_cmpr); + cb(kq_cmpr, "kq_cmpr", il); - kq = ggml_view_3d(ctx0, kq, n_kv, n_tokens, n_head, - ggml_row_size(kq->type, n_kv), - ggml_row_size(kq->type, n_kv)*n_tokens, + kq_cmpr = ggml_view_3d(ctx0, kq_cmpr, n_kv, n_tokens, n_head, + ggml_row_size(kq_cmpr->type, n_kv), + ggml_row_size(kq_cmpr->type, n_kv)*n_tokens, 0); - cb(kq, "kq_view", il); + cb(kq_cmpr, "kq_view", il); - ggml_tensor * kq_soft_max = ggml_soft_max_ext(ctx0, kq, kq_mask, kq_scale, hparams.f_max_alibi_bias); + ggml_tensor * kq_soft_max = ggml_soft_max_ext(ctx0, kq_cmpr, kq_mask, kq_scale, hparams.f_max_alibi_bias); cb(kq_soft_max, "kq_soft_max", il); kq_soft_max = ggml_view_2d(ctx0, kq_soft_max, @@ -1579,24 +1579,24 @@ ggml_tensor * llm_graph_context::build_attn_mla( 0); cb(kq_soft_max, "kq_soft_max_view", il); - ggml_tensor * kqv_compressed = ggml_mul_mat(ctx0, v_compressed_trans, kq_soft_max); - cb(kqv_compressed, "kqv_compressed,", il); + ggml_tensor * kqv_cmpr = ggml_mul_mat(ctx0, v_cmpr_trans, kq_soft_max); + cb(kqv_cmpr, "kqv_cmpr,", il); - kqv_compressed = ggml_view_3d(ctx0, kqv_compressed, - n_embd_v_compressed, n_tokens, n_head, - ggml_row_size(kqv_compressed->type, n_embd_v_compressed), - ggml_row_size(kqv_compressed->type, n_embd_v_compressed)*n_tokens, + kqv_cmpr = ggml_view_3d(ctx0, kqv_cmpr, + n_embd_v_cmpr, n_tokens, n_head, + ggml_row_size(kqv_cmpr->type, n_embd_v_cmpr), + ggml_row_size(kqv_cmpr->type, n_embd_v_cmpr)*n_tokens, 0); - cb(kqv_compressed, "kqv_compressed_view", il); + cb(kqv_cmpr, "kqv_cmpr_view", il); - ggml_tensor * wv_decompress_view = ggml_view_3d(ctx0, wv_decompress, - n_embd_v_compressed, n_embd_head_v, n_head, - ggml_row_size(wv_decompress->type, n_embd_v_compressed), - ggml_row_size(wv_decompress->type, n_embd_v_compressed)*n_embd_head_v, + ggml_tensor * wv_b_view = ggml_view_3d(ctx0, wv_b, + n_embd_v_cmpr, n_embd_head_v, n_head, + ggml_row_size(wv_b->type, n_embd_v_cmpr), + ggml_row_size(wv_b->type, n_embd_v_cmpr)*n_embd_head_v, 0); - cb(wv_decompress_view, "wv_decompress_view", il); + cb(wv_b_view, "wv_b_view", il); - ggml_tensor * kqv = ggml_mul_mat(ctx0, wv_decompress_view, kqv_compressed); + ggml_tensor * kqv = ggml_mul_mat(ctx0, wv_b_view, kqv_cmpr); cb(kqv, "kqv", il); kqv = ggml_permute(ctx0, kqv, 0, 2, 1, 3); diff --git a/src/llama-graph.h b/src/llama-graph.h index 54dfbeef2090f..fb9f1b6a68196 100644 --- a/src/llama-graph.h +++ b/src/llama-graph.h @@ -540,7 +540,7 @@ struct llm_graph_context { ggml_tensor * build_attn_mla( llm_graph_input_attn_kv_unified * inp, ggml_cgraph * gf, - ggml_tensor * wv_decompress, + ggml_tensor * wv_b, ggml_tensor * wo, ggml_tensor * q_cur, // [n_embd_k, n_tokens, n_head] ggml_tensor * k_cur, // [n_embd_k, n_tokens] diff --git a/src/llama-model.cpp b/src/llama-model.cpp index a6126604e3a6b..9f62696ad2b88 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -9571,21 +9571,21 @@ struct llm_build_deepseek2 : public llm_graph_context { ggml_row_size(q->type, n_embd_head_qk_nope)); cb(q_pe, "q_pe", il); - ggml_tensor * kv_pe_compresseed = ggml_mul_mat(ctx0, model.layers[il].wkv_a_mqa, cur); - cb(kv_pe_compresseed, "kv_pe_compresseed", il); + ggml_tensor * kv_pe_cmprresseed = ggml_mul_mat(ctx0, model.layers[il].wkv_a_mqa, cur); + cb(kv_pe_cmprresseed, "kv_pe_cmprresseed", il); // split into {kv_lora_rank, n_tokens} - ggml_tensor * kv_compressed = ggml_view_2d(ctx0, kv_pe_compresseed, kv_lora_rank, n_tokens, - kv_pe_compresseed->nb[1], + ggml_tensor * kv_cmpr = ggml_view_2d(ctx0, kv_pe_cmprresseed, kv_lora_rank, n_tokens, + kv_pe_cmprresseed->nb[1], 0); - cb(kv_compressed, "kv_compressed", il); + cb(kv_cmpr, "kv_cmpr", il); // and {n_embd_head_qk_rope, n_tokens} - ggml_tensor * k_pe = ggml_view_3d(ctx0, kv_pe_compresseed, + ggml_tensor * k_pe = ggml_view_3d(ctx0, kv_pe_cmprresseed, n_embd_head_qk_rope, 1, n_tokens, - kv_pe_compresseed->nb[1], - kv_pe_compresseed->nb[1], - ggml_row_size(kv_pe_compresseed->type, kv_lora_rank)); + kv_pe_cmprresseed->nb[1], + kv_pe_cmprresseed->nb[1], + ggml_row_size(kv_pe_cmprresseed->type, kv_lora_rank)); cb(k_pe, "k_pe", il); // TODO: the CUDA backend used to not support non-cont. RoPE, investigate removing this @@ -9605,11 +9605,11 @@ struct llm_build_deepseek2 : public llm_graph_context { cb(k_pe, "k_pe", il); // TODO: the CUDA backend used to not support non-cont. (RMS) norm, investigate removing ggml_cont - kv_compressed = ggml_cont(ctx0, kv_compressed); - kv_compressed = build_norm(kv_compressed, + kv_cmpr = ggml_cont(ctx0, kv_cmpr); + kv_cmpr = build_norm(kv_cmpr, model.layers[il].attn_kv_a_norm, nullptr, LLM_NORM_RMS, il); - cb(kv_compressed, "kv_compressed", il); + cb(kv_cmpr, "kv_cmpr", il); if (cparams.mla_attn) { // note: deepseek with MLA option converts into MQA (ie: GQA with 1 group) @@ -9633,20 +9633,18 @@ struct llm_build_deepseek2 : public llm_graph_context { 0); cb(wk_b, "wk_b", il); - // note: this operation *MUST* use F32 (or have `wk_b` stored as F32 or BF16 in the GGUF) + // note: this operation *MUST* use F32 or it will cause gibberish output ggml_tensor * q_nope_absorbed = ggml_mul_mat(ctx0, wk_b, q_nope); - //if (wk_b->type != GGML_TYPE_F32 && wk_b->type != GGML_TYPE_BF16) { - ggml_mul_mat_set_prec(q_nope_absorbed, GGML_PREC_F32); - //} + ggml_mul_mat_set_prec(q_nope_absorbed, GGML_PREC_F32); cb(q_nope_absorbed, "q_nope_absorbed", il); ggml_tensor * q_states = ggml_concat(ctx0, q_nope_absorbed, q_pe, 0); cb(q_states, "q_states", il); - ggml_tensor * k_states = ggml_concat(ctx0, kv_compressed, k_pe, 0); + ggml_tensor * k_states = ggml_concat(ctx0, kv_cmpr, k_pe, 0); cb(k_states, "k_states", il); - ggml_tensor * v_states = kv_compressed; + ggml_tensor * v_states = kv_cmpr; cb(v_states, "v_states", il); cur = build_attn_mla(inp_attn, gf, @@ -9655,27 +9653,25 @@ struct llm_build_deepseek2 : public llm_graph_context { } else { // note: deepseek without MLA option converts into MHA - // note: this operation *MUST* use F32 (or have `wkv_b` stored as F32 or BF16 in the GGUF) - ggml_tensor * kv_decompressed = ggml_mul_mat(ctx0, model.layers[il].wkv_b, kv_compressed); - //if (model.layers[il].wkv_b->type != GGML_TYPE_F32 && model.layers[il].wkv_b->type != GGML_TYPE_BF16) { - ggml_mul_mat_set_prec(kv_decompressed, GGML_PREC_F32); - //} - cb(kv_decompressed, "kv_decompressed", il); + // note: this operation *MUST* use F32 or it will cause gibberish output + ggml_tensor * kv = ggml_mul_mat(ctx0, model.layers[il].wkv_b, kv_cmpr); + ggml_mul_mat_set_prec(kv, GGML_PREC_F32); + cb(kv, "kv", il); // split into {n_head * n_embd_head_qk_nope, n_tokens} - ggml_tensor * k_nope = ggml_view_3d(ctx0, kv_decompressed, + ggml_tensor * k_nope = ggml_view_3d(ctx0, kv, n_embd_head_qk_nope, n_head, n_tokens, - ggml_row_size(kv_decompressed->type, n_embd_head_qk_nope + n_embd_head_v), - ggml_row_size(kv_decompressed->type, n_head * (n_embd_head_qk_nope + n_embd_head_v)), + ggml_row_size(kv->type, n_embd_head_qk_nope + n_embd_head_v), + ggml_row_size(kv->type, n_head * (n_embd_head_qk_nope + n_embd_head_v)), 0); cb(k_nope, "k_nope", il); // and {n_head * n_embd_head_v, n_tokens} - ggml_tensor * v_states = ggml_view_3d(ctx0, kv_decompressed, + ggml_tensor * v_states = ggml_view_3d(ctx0, kv, n_embd_head_v, n_head, n_tokens, - ggml_row_size(kv_decompressed->type, (n_embd_head_qk_nope + n_embd_head_v)), - ggml_row_size(kv_decompressed->type, (n_embd_head_qk_nope + n_embd_head_v)*n_head), - ggml_row_size(kv_decompressed->type, (n_embd_head_qk_nope))); + ggml_row_size(kv->type, (n_embd_head_qk_nope + n_embd_head_v)), + ggml_row_size(kv->type, (n_embd_head_qk_nope + n_embd_head_v)*n_head), + ggml_row_size(kv->type, (n_embd_head_qk_nope))); cb(v_states, "v_states", il); v_states = ggml_cont(ctx0, v_states); From 1fd0aab3aa638de99280256e2a1d4b51df05ddb0 Mon Sep 17 00:00:00 2001 From: juk Date: Wed, 2 Apr 2025 22:39:18 +0100 Subject: [PATCH 16/23] Fixed `kv_cmpr_pe` name --- src/llama-model.cpp | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/src/llama-model.cpp b/src/llama-model.cpp index 9f62696ad2b88..65f79ee2b8da1 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -9571,21 +9571,21 @@ struct llm_build_deepseek2 : public llm_graph_context { ggml_row_size(q->type, n_embd_head_qk_nope)); cb(q_pe, "q_pe", il); - ggml_tensor * kv_pe_cmprresseed = ggml_mul_mat(ctx0, model.layers[il].wkv_a_mqa, cur); - cb(kv_pe_cmprresseed, "kv_pe_cmprresseed", il); + ggml_tensor * kv_cmpr_pe = ggml_mul_mat(ctx0, model.layers[il].wkv_a_mqa, cur); + cb(kv_cmpr_pe, "kv_cmpr_pe", il); // split into {kv_lora_rank, n_tokens} - ggml_tensor * kv_cmpr = ggml_view_2d(ctx0, kv_pe_cmprresseed, kv_lora_rank, n_tokens, - kv_pe_cmprresseed->nb[1], + ggml_tensor * kv_cmpr = ggml_view_2d(ctx0, kv_cmpr_pe, kv_lora_rank, n_tokens, + kv_cmpr_pe->nb[1], 0); cb(kv_cmpr, "kv_cmpr", il); // and {n_embd_head_qk_rope, n_tokens} - ggml_tensor * k_pe = ggml_view_3d(ctx0, kv_pe_cmprresseed, + ggml_tensor * k_pe = ggml_view_3d(ctx0, kv_cmpr_pe, n_embd_head_qk_rope, 1, n_tokens, - kv_pe_cmprresseed->nb[1], - kv_pe_cmprresseed->nb[1], - ggml_row_size(kv_pe_cmprresseed->type, kv_lora_rank)); + kv_cmpr_pe->nb[1], + kv_cmpr_pe->nb[1], + ggml_row_size(kv_cmpr_pe->type, kv_lora_rank)); cb(k_pe, "k_pe", il); // TODO: the CUDA backend used to not support non-cont. RoPE, investigate removing this From 4fb439f6ff1037c25c96cfa400027e2311e91418 Mon Sep 17 00:00:00 2001 From: juk Date: Wed, 2 Apr 2025 22:44:12 +0100 Subject: [PATCH 17/23] Added `n_embd_head_k` as constant --- src/llama-model.cpp | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/src/llama-model.cpp b/src/llama-model.cpp index 65f79ee2b8da1..e671fb9cb8a30 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -9511,10 +9511,11 @@ struct llm_build_deepseek2 : public llm_graph_context { const float kq_scale = 1.0f*mscale*mscale/sqrtf(float(hparams.n_embd_head_k)); const float attn_factor_scaled = 1.0f / (1.0f + 0.1f * logf(1.0f / freq_scale)); + const uint32_t n_embd_head_k = hparams.n_embd_head_k; + const uint32_t n_embd_head_v = hparams.n_embd_head_v; const uint32_t n_embd_head_qk_rope = hparams.n_rot; const uint32_t n_embd_head_qk_nope = hparams.n_embd_head_k - hparams.n_rot; const uint32_t kv_lora_rank = hparams.n_lora_kv; - const uint32_t n_embd_head_v = hparams.n_embd_head_v; ggml_tensor * cur; ggml_tensor * inpL; @@ -9558,16 +9559,16 @@ struct llm_build_deepseek2 : public llm_graph_context { // split into {n_head * n_embd_head_qk_nope, n_tokens} ggml_tensor * q_nope = ggml_view_3d(ctx0, q, n_embd_head_qk_nope, n_head, n_tokens, - ggml_row_size(q->type, hparams.n_embd_head_k), - ggml_row_size(q->type, hparams.n_embd_head_k * n_head), + ggml_row_size(q->type, n_embd_head_k), + ggml_row_size(q->type, n_embd_head_k * n_head), 0); cb(q_nope, "q_nope", il); // and {n_head * n_embd_head_qk_rope, n_tokens} ggml_tensor * q_pe = ggml_view_3d(ctx0, q, n_embd_head_qk_rope, n_head, n_tokens, - ggml_row_size(q->type, hparams.n_embd_head_k), - ggml_row_size(q->type, hparams.n_embd_head_k * n_head), + ggml_row_size(q->type, n_embd_head_k), + ggml_row_size(q->type, n_embd_head_k * n_head), ggml_row_size(q->type, n_embd_head_qk_nope)); cb(q_pe, "q_pe", il); From f9a0ef4ada3205568c47ba613561801a28634621 Mon Sep 17 00:00:00 2001 From: juk Date: Thu, 3 Apr 2025 03:44:22 +0100 Subject: [PATCH 18/23] Fixed to use `build_attn_mha()` now --- src/llama-graph.cpp | 133 ++++++++++++++++++++------------------------ src/llama-graph.h | 11 ++-- src/llama-model.cpp | 14 ++++- 3 files changed, 77 insertions(+), 81 deletions(-) diff --git a/src/llama-graph.cpp b/src/llama-graph.cpp index e748b5f22c1bc..264e77d3c898c 100644 --- a/src/llama-graph.cpp +++ b/src/llama-graph.cpp @@ -1130,6 +1130,7 @@ ggml_tensor * llm_graph_context::build_attn_mha( ggml_tensor * v, ggml_tensor * kq_b, ggml_tensor * kq_mask, + ggml_tensor * v_mha_proj, bool v_trans, float kq_scale) const { //const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa(il); @@ -1141,11 +1142,18 @@ ggml_tensor * llm_graph_context::build_attn_mha( //const auto & n_embd_head_k = hparams.n_embd_head_k; //const auto & n_embd_head_v = hparams.n_embd_head_v; - const auto n_embd_head_v = v_trans ? v->ne[1] : v->ne[0]; + const auto n_embd = q->ne[0]; + const auto n_tokens = q->ne[1]; + const auto n_head = q->ne[2]; - const auto n_tokens = q->ne[1]; - const auto n_head = q->ne[2]; - const auto n_kv = k->ne[1]; + const auto n_kv = k->ne[1]; + const auto n_head_kv = k->ne[2]; + + // note: when using MLA, the final embedding size will be changed via v_mha_proj + const auto n_embd_head_v = v_mha_proj == nullptr ? v_trans ? v->ne[1] : v->ne[0] : v_mha_proj->ne[1]; + + GGML_ASSERT(k->ne[0] == q->ne[0] && "K and Q embedding size mismatch"); + GGML_ASSERT(k->ne[2] == v->ne[2] && "K and V number of heads mismatch"); ggml_tensor * cur; @@ -1164,12 +1172,29 @@ ggml_tensor * llm_graph_context::build_attn_mha( cur = ggml_reshape_2d(ctx0, cur, n_embd_head_v*n_head, n_tokens); } else { + + // for MQA (ie: GQA with 1 group) we don't need to use a batched matrix multiply + if (n_head_kv == 1) { + q = ggml_view_2d(ctx0, q, + n_embd, n_tokens*n_head, + ggml_row_size(q->type, n_embd), + 0); + } + ggml_tensor * kq = ggml_mul_mat(ctx0, k, q); // note: this op tends to require high floating point range // while for some models F16 is enough, for others it is not, so we default to F32 here ggml_mul_mat_set_prec(kq, GGML_PREC_F32); + if (n_head_kv == 1) { + kq = ggml_view_3d(ctx0, kq, + n_kv, n_tokens, n_head, + ggml_row_size(kq->type, n_kv), + ggml_row_size(kq->type, n_kv)*n_tokens, + 0); + } + if (arch == LLM_ARCH_GROK) { // need to do the following: // multiply by attn_output_multiplyer of 0.08838834764831845 @@ -1200,6 +1225,11 @@ ggml_tensor * llm_graph_context::build_attn_mha( ggml_tensor * kqv = ggml_mul_mat(ctx0, v, kq); + // for deepseek MLA we need to "decompress" from MQA back to MHA + if (v_mha_proj) { + kqv = ggml_mul_mat(ctx0, v_mha_proj, kqv); + } + ggml_tensor * kqv_merged = ggml_permute(ctx0, kqv, 0, 2, 1, 3); cur = ggml_cont_2d(ctx0, kqv_merged, n_embd_head_v*n_head, n_tokens); @@ -1258,7 +1288,7 @@ ggml_tensor * llm_graph_context::build_attn( ggml_tensor * v = ggml_permute(ctx0, v_cur, 0, 2, 1, 3); //cb(k, "v", il); - ggml_tensor * cur = build_attn_mha(gf, q, k, v, kq_b, kq_mask, false, kq_scale); + ggml_tensor * cur = build_attn_mha(gf, q, k, v, kq_b, kq_mask, nullptr, false, kq_scale); cb(cur, "kqv_out", il); @@ -1397,7 +1427,7 @@ ggml_tensor * llm_graph_context::build_attn( ggml_element_size(kv_self->v_l[il])*n_ctx*n_embd_head_v, 0); - ggml_tensor * cur = build_attn_mha(gf, q, k, v, kq_b, kq_mask, v_trans, kq_scale); + ggml_tensor * cur = build_attn_mha(gf, q, k, v, kq_b, kq_mask, nullptr, v_trans, kq_scale); cb(cur, "kqv_out", il); if (wo) { @@ -1456,7 +1486,7 @@ ggml_tensor * llm_graph_context::build_attn( ggml_tensor * v = ggml_permute(ctx0, v_cur, 0, 2, 1, 3); //cb(k, "v", il); - ggml_tensor * cur = build_attn_mha(gf, q, k, v, kq_b, kq_mask, false, kq_scale); + ggml_tensor * cur = build_attn_mha(gf, q, k, v, kq_b, kq_mask, nullptr, false, kq_scale); cb(cur, "kqv_out", il); @@ -1478,11 +1508,13 @@ ggml_tensor * llm_graph_context::build_attn( ggml_tensor * llm_graph_context::build_attn_mla( llm_graph_input_attn_kv_unified * inp, ggml_cgraph * gf, - ggml_tensor * wv_b, ggml_tensor * wo, + ggml_tensor * wo_b, + ggml_tensor * v_mha_proj, ggml_tensor * q_cur, ggml_tensor * k_cur, ggml_tensor * v_cur, + ggml_tensor * kq_b, float kq_scale, int il) const { // these nodes are added to the graph together so that they are not reordered @@ -1500,12 +1532,8 @@ ggml_tensor * llm_graph_context::build_attn_mla( const int64_t n_embd_k_cmpr = kv_lora_rank + hparams.n_rot; const int64_t n_embd_v_cmpr = kv_lora_rank; - // note: this is the smaller n_ebed what we get after decompression - const int64_t n_embd_head_v = hparams.n_embd_head_v; - // note: call from llm_build_deepseek2() passes as: {n_embd, n_tokens, n_head} const auto n_tokens = q_cur->ne[1]; - const auto n_head = q_cur->ne[2]; // store to KV cache { @@ -1534,80 +1562,38 @@ ggml_tensor * llm_graph_context::build_attn_mla( ggml_build_forward_expand(gf, ggml_cpy(ctx0, v_cur, v_cache_view)); } - // NOTE: We can't pass this to `build_attn_mha()` due to: - // - There's no way to decompress as `build_attn_mha()` applies the kqv_merged and cont inside. - // - We lose the performance gain from using 2D views when applying MQA (ie: GQA with 1 group). - // - TODO: Consider refactoring `build_attn_mha()` and/or adding optimised `build_attn_mqa()` version. - const auto & kq_mask = inp->get_kq_mask(); const auto n_kv = kv_self->n; - ggml_tensor * k_cmpr = ggml_view_2d(ctx0, kv_self->k_l[il], - n_embd_k_cmpr, n_kv, + ggml_tensor * k = ggml_view_3d(ctx0, kv_self->k_l[il], + n_embd_k_cmpr, n_kv, 1, + ggml_row_size(kv_self->k_l[il]->type, n_embd_k_cmpr), ggml_row_size(kv_self->k_l[il]->type, n_embd_k_cmpr), 0); - cb(k_cmpr, "k_cmpr", il); + //cb(k, "k", il); - struct ggml_tensor * v_cmpr_trans = ggml_view_2d(ctx0, kv_self->v_l[il], - n_kv, n_embd_v_cmpr, + struct ggml_tensor * v = ggml_view_3d(ctx0, kv_self->v_l[il], + n_kv, n_embd_v_cmpr, 1, + ggml_element_size(kv_self->v_l[il])*n_ctx, ggml_element_size(kv_self->v_l[il])*n_ctx, 0); - cb(v_cmpr_trans, "v_cmpr_trans", il); - - ggml_tensor * q_cmpr = ggml_view_2d(ctx0, q_cur, - n_embd_k_cmpr, n_tokens*n_head, - ggml_row_size(q_cur->type, n_embd_k_cmpr), - 0); - cb(q_cmpr, "q_cmpr", il); - - ggml_tensor * kq_cmpr = ggml_mul_mat(ctx0, k_cmpr, q_cmpr); - cb(kq_cmpr, "kq_cmpr", il); - - kq_cmpr = ggml_view_3d(ctx0, kq_cmpr, n_kv, n_tokens, n_head, - ggml_row_size(kq_cmpr->type, n_kv), - ggml_row_size(kq_cmpr->type, n_kv)*n_tokens, - 0); - cb(kq_cmpr, "kq_view", il); - - ggml_tensor * kq_soft_max = ggml_soft_max_ext(ctx0, kq_cmpr, kq_mask, kq_scale, hparams.f_max_alibi_bias); - cb(kq_soft_max, "kq_soft_max", il); - - kq_soft_max = ggml_view_2d(ctx0, kq_soft_max, - n_kv, n_tokens*n_head, - ggml_row_size(kq_soft_max->type, n_kv), - 0); - cb(kq_soft_max, "kq_soft_max_view", il); - - ggml_tensor * kqv_cmpr = ggml_mul_mat(ctx0, v_cmpr_trans, kq_soft_max); - cb(kqv_cmpr, "kqv_cmpr,", il); - - kqv_cmpr = ggml_view_3d(ctx0, kqv_cmpr, - n_embd_v_cmpr, n_tokens, n_head, - ggml_row_size(kqv_cmpr->type, n_embd_v_cmpr), - ggml_row_size(kqv_cmpr->type, n_embd_v_cmpr)*n_tokens, - 0); - cb(kqv_cmpr, "kqv_cmpr_view", il); - - ggml_tensor * wv_b_view = ggml_view_3d(ctx0, wv_b, - n_embd_v_cmpr, n_embd_head_v, n_head, - ggml_row_size(wv_b->type, n_embd_v_cmpr), - ggml_row_size(wv_b->type, n_embd_v_cmpr)*n_embd_head_v, - 0); - cb(wv_b_view, "wv_b_view", il); - - ggml_tensor * kqv = ggml_mul_mat(ctx0, wv_b_view, kqv_cmpr); - cb(kqv, "kqv", il); + //cb(v, "v", il); - kqv = ggml_permute(ctx0, kqv, 0, 2, 1, 3); - cb(kqv, "kqv_merged", il); + ggml_tensor * cur = build_attn_mha(gf, q_cur, k, v, kq_b, kq_mask, v_mha_proj, true, kq_scale); + cb(cur, "kqv_out", il); - ggml_tensor * cur = ggml_cont_2d(ctx0, kqv, n_embd_head_v*n_head, n_tokens); - cb(cur, "kqv_cont", il); + if (wo) { + cur = build_lora_mm(wo, cur); + } - ggml_build_forward_expand(gf, cur); + if (wo_b) { + //cb(cur, "kqv_wo", il); + } - cur = build_lora_mm(wo, cur); + if (wo_b) { + cur = ggml_add(ctx0, cur, wo_b); + } return cur; } @@ -1762,4 +1748,3 @@ void llm_graph_context::build_pooling( ggml_build_forward_expand(gf, cur); } - diff --git a/src/llama-graph.h b/src/llama-graph.h index fb9f1b6a68196..a52fb5631e123 100644 --- a/src/llama-graph.h +++ b/src/llama-graph.h @@ -492,6 +492,7 @@ struct llm_graph_context { ggml_tensor * v, // [n_embd_head_v, n_tokens, n_head_v] (v_trans == false) ggml_tensor * kq_b, ggml_tensor * kq_mask, + ggml_tensor * v_mha_proj, bool v_trans, float kq_scale) const; @@ -540,11 +541,13 @@ struct llm_graph_context { ggml_tensor * build_attn_mla( llm_graph_input_attn_kv_unified * inp, ggml_cgraph * gf, - ggml_tensor * wv_b, ggml_tensor * wo, - ggml_tensor * q_cur, // [n_embd_k, n_tokens, n_head] - ggml_tensor * k_cur, // [n_embd_k, n_tokens] - ggml_tensor * v_cur, // [n_embd_v, n_tokens] + ggml_tensor * wo_b, + ggml_tensor * v_mha_proj, + ggml_tensor * q_cur, // [n_embd_head_q, n_tokens, n_head_q] + ggml_tensor * k_cur, // [n_embd_head_k, n_tokens] + ggml_tensor * v_cur, // [n_embd_head_v, n_tokens] + ggml_tensor * kq_b, float kq_scale, int il) const; diff --git a/src/llama-model.cpp b/src/llama-model.cpp index e671fb9cb8a30..b11fbda09359b 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -9634,7 +9634,8 @@ struct llm_build_deepseek2 : public llm_graph_context { 0); cb(wk_b, "wk_b", il); - // note: this operation *MUST* use F32 or it will cause gibberish output + // note: this operation *MUST* use F32 or it will cause gibberish output, as this + // effectively does the KQ multiplication here instead of in build_attn_mha() ggml_tensor * q_nope_absorbed = ggml_mul_mat(ctx0, wk_b, q_nope); ggml_mul_mat_set_prec(q_nope_absorbed, GGML_PREC_F32); cb(q_nope_absorbed, "q_nope_absorbed", il); @@ -9648,9 +9649,16 @@ struct llm_build_deepseek2 : public llm_graph_context { ggml_tensor * v_states = kv_cmpr; cb(v_states, "v_states", il); + ggml_tensor * wv_b = ggml_view_3d(ctx0, model.layers[il].wv_b, + kv_lora_rank, n_embd_head_v, n_head, + ggml_row_size(model.layers[il].wk_b->type, kv_lora_rank), + ggml_row_size(model.layers[il].wk_b->type, kv_lora_rank * n_embd_head_v), + 0); + cb(wk_b, "wv_b", il); + cur = build_attn_mla(inp_attn, gf, - model.layers[il].wv_b, model.layers[il].wo, - q_states, k_states, v_states, kq_scale, il); + model.layers[il].wo, NULL, wv_b, + q_states, k_states, v_states, nullptr, kq_scale, il); } else { // note: deepseek without MLA option converts into MHA From 5fe402aa6ed507a4aa00d74a84be6f1fb11b8a50 Mon Sep 17 00:00:00 2001 From: juk Date: Thu, 3 Apr 2025 05:15:00 +0100 Subject: [PATCH 19/23] `mla_attn` on then not `flash_attn` so we can run `-fa` for draft models --- src/llama-context.cpp | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/src/llama-context.cpp b/src/llama-context.cpp index a957aaff3d479..79a96a30d67eb 100644 --- a/src/llama-context.cpp +++ b/src/llama-context.cpp @@ -2282,6 +2282,11 @@ llama_context * llama_init_from_model( params.mla_attn = false; } + if (params.mla_attn && params.flash_attn) { + LLAMA_LOG_WARN("%s: mla_attn is not compatible with flash_attn - forcing off\n", __func__); + params.flash_attn = false; + } + if (ggml_is_quantized(params.type_v) && !params.flash_attn) { LLAMA_LOG_ERROR("%s: V cache quantization requires flash_attn\n", __func__); return nullptr; From 9b862f98c87ded97f4ba93711f1e06d7fb9f85f9 Mon Sep 17 00:00:00 2001 From: juk Date: Thu, 3 Apr 2025 05:51:29 +0100 Subject: [PATCH 20/23] "flash_attn is not compatible with mla_attn" --> flash_attn off --- src/llama-context.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/llama-context.cpp b/src/llama-context.cpp index 79a96a30d67eb..e4951543db18f 100644 --- a/src/llama-context.cpp +++ b/src/llama-context.cpp @@ -2282,8 +2282,8 @@ llama_context * llama_init_from_model( params.mla_attn = false; } - if (params.mla_attn && params.flash_attn) { - LLAMA_LOG_WARN("%s: mla_attn is not compatible with flash_attn - forcing off\n", __func__); + if (params.flash_attn && params.mla_attn) { + LLAMA_LOG_WARN("%s: flash_attn is not compatible with mla_attn - forcing off\n", __func__); params.flash_attn = false; } From 8e23e0da2ee7d1e124aff3779ab08605abac7d45 Mon Sep 17 00:00:00 2001 From: juk Date: Thu, 3 Apr 2025 14:27:30 +0100 Subject: [PATCH 21/23] Fixed subtle bug caused by `-mla` for speculative models --- src/llama-kv-cache.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/llama-kv-cache.cpp b/src/llama-kv-cache.cpp index dfa6c6ac6ef96..a54207fb31e8d 100644 --- a/src/llama-kv-cache.cpp +++ b/src/llama-kv-cache.cpp @@ -77,7 +77,7 @@ bool llama_kv_cache_unified::init( int64_t n_embd_v; // note: deepseek with MLA option converts into MQA (ie: GQA with 1 group) - if (cparams.mla_attn) { + if (cparams.mla_attn && model.arch == LLM_ARCH_DEEPSEEK2) { n_embd_k = hparams.n_lora_kv + hparams.n_rot; n_embd_v = hparams.n_lora_kv; } else { From b3840861a22ebd9721d309865f04df227f5da468 Mon Sep 17 00:00:00 2001 From: juk Date: Fri, 4 Apr 2025 23:54:47 +0100 Subject: [PATCH 22/23] Removed need for `v_b_proj` storing. Tidied all ggml_row_size for quants --- convert_hf_to_gguf.py | 13 +++--- gguf-py/gguf/constants.py | 9 ++--- gguf-py/gguf/tensor_mapping.py | 8 +--- src/llama-arch.cpp | 6 +-- src/llama-arch.h | 3 +- src/llama-context.cpp | 19 +++++---- src/llama-kv-cache.cpp | 3 +- src/llama-model.cpp | 74 +++++++++++++++++----------------- src/llama-model.h | 37 +++++++++-------- 9 files changed, 81 insertions(+), 91 deletions(-) diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index a191e8a898781..daff8f87d6e77 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -4414,8 +4414,7 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter return [] if name.endswith("kv_b_proj.weight"): - name_kb = name.replace("kv_b_proj", "k_b_proj") - name_vb = name.replace("kv_b_proj", "v_b_proj") + name_kb = name.replace("kv_b_proj", "k_b_proj_trans") n_head_kv = self.hparams["num_key_value_heads"] v_head_dim = self.hparams["v_head_dim"] @@ -4424,15 +4423,13 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter assert data_torch.shape[0] == n_head_kv * (v_head_dim + qk_nope_head_dim) kv_b = data_torch.view(n_head_kv, v_head_dim + qk_nope_head_dim, data_torch.shape[-1]) - k_b, v_b = torch.split(kv_b, [qk_nope_head_dim, v_head_dim], dim=1) - k_b = k_b.transpose(1, 2) - k_b = k_b.reshape(n_head_kv * data_torch.shape[-1], qk_nope_head_dim) - v_b = v_b.reshape(n_head_kv * v_head_dim, data_torch.shape[-1]) + k_b = kv_b[:, :qk_nope_head_dim, :] + k_b_trans = k_b.transpose(1, 2) + k_b_trans = k_b_trans.reshape(n_head_kv * data_torch.shape[-1], qk_nope_head_dim) return [ (self.map_tensor_name(name), data_torch), - (self.map_tensor_name(name_kb), k_b), - (self.map_tensor_name(name_vb), v_b) + (self.map_tensor_name(name_kb), k_b_trans), ] return [(self.map_tensor_name(name), data_torch)] diff --git a/gguf-py/gguf/constants.py b/gguf-py/gguf/constants.py index c5a96e4fc432e..a422f4ea326ec 100644 --- a/gguf-py/gguf/constants.py +++ b/gguf-py/gguf/constants.py @@ -377,8 +377,7 @@ class MODEL_TENSOR(IntEnum): ATTN_Q_B = auto() ATTN_KV_A_MQA = auto() ATTN_KV_B = auto() - ATTN_K_B = auto() - ATTN_V_B = auto() + ATTN_K_B_TRANS = auto() ATTN_Q_A_NORM = auto() ATTN_KV_A_NORM = auto() FFN_SUB_NORM = auto() @@ -583,8 +582,7 @@ class MODEL_TENSOR(IntEnum): MODEL_TENSOR.ATTN_Q_B: "blk.{bid}.attn_q_b", MODEL_TENSOR.ATTN_KV_A_MQA: "blk.{bid}.attn_kv_a_mqa", MODEL_TENSOR.ATTN_KV_B: "blk.{bid}.attn_kv_b", - MODEL_TENSOR.ATTN_K_B: "blk.{bid}.attn_k_b", - MODEL_TENSOR.ATTN_V_B: "blk.{bid}.attn_v_b", + MODEL_TENSOR.ATTN_K_B_TRANS: "blk.{bid}.attn_k_b_trans", MODEL_TENSOR.ATTN_Q_A_NORM: "blk.{bid}.attn_q_a_norm", MODEL_TENSOR.ATTN_KV_A_NORM: "blk.{bid}.attn_kv_a_norm", MODEL_TENSOR.ATTN_SUB_NORM: "blk.{bid}.attn_sub_norm", @@ -1455,8 +1453,7 @@ class MODEL_TENSOR(IntEnum): MODEL_TENSOR.ATTN_Q_B, MODEL_TENSOR.ATTN_KV_A_MQA, MODEL_TENSOR.ATTN_KV_B, - MODEL_TENSOR.ATTN_K_B, - MODEL_TENSOR.ATTN_V_B, + MODEL_TENSOR.ATTN_K_B_TRANS, MODEL_TENSOR.ATTN_Q_A_NORM, MODEL_TENSOR.ATTN_KV_A_NORM, MODEL_TENSOR.ATTN_OUT, diff --git a/gguf-py/gguf/tensor_mapping.py b/gguf-py/gguf/tensor_mapping.py index 00733e59fe6d5..0f95e73761c97 100644 --- a/gguf-py/gguf/tensor_mapping.py +++ b/gguf-py/gguf/tensor_mapping.py @@ -656,12 +656,8 @@ class TensorNameMap: "model.layers.{bid}.self_attn.kv_b_proj", # deepseek2 ), - MODEL_TENSOR.ATTN_K_B: ( - "model.layers.{bid}.self_attn.k_b_proj", # deepseek2 (MLA specific) - ), - - MODEL_TENSOR.ATTN_V_B: ( - "model.layers.{bid}.self_attn.v_b_proj", # deepseek2 (MLA specific) + MODEL_TENSOR.ATTN_K_B_TRANS: ( + "model.layers.{bid}.self_attn.k_b_proj_trans", # deepseek2 (MLA specific) ), MODEL_TENSOR.ATTN_Q_A_NORM: ( diff --git a/src/llama-arch.cpp b/src/llama-arch.cpp index 4bf8ebe85016b..c4b052ce80b5e 100644 --- a/src/llama-arch.cpp +++ b/src/llama-arch.cpp @@ -1030,8 +1030,7 @@ static const std::map> LLM_TENSOR_N { LLM_TENSOR_ATTN_Q_B, "blk.%d.attn_q_b" }, { LLM_TENSOR_ATTN_KV_A_MQA, "blk.%d.attn_kv_a_mqa" }, { LLM_TENSOR_ATTN_KV_B, "blk.%d.attn_kv_b" }, - { LLM_TENSOR_ATTN_K_B, "blk.%d.attn_k_b" }, - { LLM_TENSOR_ATTN_V_B, "blk.%d.attn_v_b" }, + { LLM_TENSOR_ATTN_K_B_TRANS, "blk.%d.attn_k_b_trans" }, { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" }, @@ -1473,8 +1472,7 @@ static const std::map LLM_TENSOR_INFOS = { {LLM_TENSOR_ATTN_Q_B, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, {LLM_TENSOR_ATTN_KV_A_MQA, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, {LLM_TENSOR_ATTN_KV_B, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, - {LLM_TENSOR_ATTN_K_B, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, - {LLM_TENSOR_ATTN_V_B, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_ATTN_K_B_TRANS, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, {LLM_TENSOR_DEC_ATTN_Q, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, {LLM_TENSOR_DEC_ATTN_K, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, {LLM_TENSOR_DEC_ATTN_V, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, diff --git a/src/llama-arch.h b/src/llama-arch.h index 9ed22f33eafb6..523c1457d7c73 100644 --- a/src/llama-arch.h +++ b/src/llama-arch.h @@ -299,8 +299,7 @@ enum llm_tensor { LLM_TENSOR_ATTN_Q_B, LLM_TENSOR_ATTN_KV_A_MQA, LLM_TENSOR_ATTN_KV_B, - LLM_TENSOR_ATTN_K_B, - LLM_TENSOR_ATTN_V_B, + LLM_TENSOR_ATTN_K_B_TRANS, LLM_TENSOR_ATTN_Q_A_NORM, LLM_TENSOR_ATTN_KV_A_NORM, LLM_TENSOR_ATTN_SUB_NORM, diff --git a/src/llama-context.cpp b/src/llama-context.cpp index e4951543db18f..e7e440047d7c7 100644 --- a/src/llama-context.cpp +++ b/src/llama-context.cpp @@ -2277,14 +2277,17 @@ llama_context * llama_init_from_model( params.flash_attn = false; } - if (params.mla_attn && model->arch != LLM_ARCH_DEEPSEEK2) { - LLAMA_LOG_WARN("%s: mla_attn is only compatible with Deepseek2 - forcing off\n", __func__); - params.mla_attn = false; - } - - if (params.flash_attn && params.mla_attn) { - LLAMA_LOG_WARN("%s: flash_attn is not compatible with mla_attn - forcing off\n", __func__); - params.flash_attn = false; + if (params.mla_attn) { + if (model->arch != LLM_ARCH_DEEPSEEK2) { + LLAMA_LOG_WARN("%s: mla_attn is only compatible with Deepseek2 - forcing off\n", __func__); + params.mla_attn = false; + } else if (model->layers[0].wk_b_trans == nullptr) { + LLAMA_LOG_WARN("%s: mla_attn requires a gguf with the new 'attn_k_b_trans' tensor - forcing off\n", __func__); + params.mla_attn = false; + } else if (params.flash_attn) { + LLAMA_LOG_WARN("%s: flash_attn is not compatible with mla_attn - forcing off\n", __func__); + params.flash_attn = false; + } } if (ggml_is_quantized(params.type_v) && !params.flash_attn) { diff --git a/src/llama-kv-cache.cpp b/src/llama-kv-cache.cpp index a54207fb31e8d..fbe80b0453d40 100644 --- a/src/llama-kv-cache.cpp +++ b/src/llama-kv-cache.cpp @@ -76,8 +76,9 @@ bool llama_kv_cache_unified::init( int64_t n_embd_k; int64_t n_embd_v; - // note: deepseek with MLA option converts into MQA (ie: GQA with 1 group) + // note: be sure to check model.arch or this will cause a bug if used with a non-MLA draft model! if (cparams.mla_attn && model.arch == LLM_ARCH_DEEPSEEK2) { + // note: deepseek2 with MLA option converts into MQA (ie: GQA with 1 group) n_embd_k = hparams.n_lora_kv + hparams.n_rot; n_embd_v = hparams.n_lora_kv; } else { diff --git a/src/llama-model.cpp b/src/llama-model.cpp index b11fbda09359b..9a0d9d0b27343 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -3070,11 +3070,10 @@ bool llama_model::load_tensors(llama_model_loader & ml) { layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd_k_gqa}, 0); } - layer.wkv_a_mqa = create_tensor(tn(LLM_TENSOR_ATTN_KV_A_MQA, "weight", i), {n_embd, kv_lora_rank + (n_embd_head_qk_rope)}, 0); - layer.wkv_b = create_tensor(tn(LLM_TENSOR_ATTN_KV_B, "weight", i), {kv_lora_rank, n_head * (n_embd_head_qk_nope + n_embd_head_v)}, 0); - layer.wk_b = create_tensor(tn(LLM_TENSOR_ATTN_K_B, "weight", i), {n_embd_head_qk_nope, n_head * kv_lora_rank}, 0); - layer.wv_b = create_tensor(tn(LLM_TENSOR_ATTN_V_B, "weight", i), {kv_lora_rank, n_head * n_embd_head_v}, 0); - layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), { n_head * ( n_embd_head_v), n_embd}, 0); + layer.wkv_a_mqa = create_tensor(tn(LLM_TENSOR_ATTN_KV_A_MQA, "weight", i), {n_embd, kv_lora_rank + (n_embd_head_qk_rope)}, 0); + layer.wkv_b = create_tensor(tn(LLM_TENSOR_ATTN_KV_B, "weight", i), {kv_lora_rank, n_head * (n_embd_head_qk_nope + n_embd_head_v)}, 0); + layer.wk_b_trans = create_tensor(tn(LLM_TENSOR_ATTN_K_B_TRANS, "weight", i), {n_embd_head_qk_nope, n_head * kv_lora_rank}, TENSOR_NOT_REQUIRED); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), { n_head * ( n_embd_head_v), n_embd}, 0); layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); @@ -9556,19 +9555,19 @@ struct llm_build_deepseek2 : public llm_graph_context { cb(q, "q", il); } - // split into {n_head * n_embd_head_qk_nope, n_tokens} + // split into {n_embd_head_qk_nope, n_head, n_tokens} ggml_tensor * q_nope = ggml_view_3d(ctx0, q, n_embd_head_qk_nope, n_head, n_tokens, ggml_row_size(q->type, n_embd_head_k), - ggml_row_size(q->type, n_embd_head_k * n_head), + ggml_row_size(q->type, n_embd_head_k) * n_head, 0); cb(q_nope, "q_nope", il); - // and {n_head * n_embd_head_qk_rope, n_tokens} + // and {n_embd_head_qk_rope, n_head, n_tokens} ggml_tensor * q_pe = ggml_view_3d(ctx0, q, n_embd_head_qk_rope, n_head, n_tokens, ggml_row_size(q->type, n_embd_head_k), - ggml_row_size(q->type, n_embd_head_k * n_head), + ggml_row_size(q->type, n_embd_head_k) * n_head, ggml_row_size(q->type, n_embd_head_qk_nope)); cb(q_pe, "q_pe", il); @@ -9576,16 +9575,17 @@ struct llm_build_deepseek2 : public llm_graph_context { cb(kv_cmpr_pe, "kv_cmpr_pe", il); // split into {kv_lora_rank, n_tokens} - ggml_tensor * kv_cmpr = ggml_view_2d(ctx0, kv_cmpr_pe, kv_lora_rank, n_tokens, - kv_cmpr_pe->nb[1], + ggml_tensor * kv_cmpr = ggml_view_2d(ctx0, kv_cmpr_pe, + kv_lora_rank, n_tokens, + ggml_row_size(kv_cmpr_pe->type, kv_lora_rank + n_embd_head_qk_rope), 0); cb(kv_cmpr, "kv_cmpr", il); - // and {n_embd_head_qk_rope, n_tokens} + // and {n_embd_head_qk_rope, 1, n_tokens} ggml_tensor * k_pe = ggml_view_3d(ctx0, kv_cmpr_pe, n_embd_head_qk_rope, 1, n_tokens, - kv_cmpr_pe->nb[1], - kv_cmpr_pe->nb[1], + ggml_row_size(kv_cmpr_pe->type, kv_lora_rank + n_embd_head_qk_rope), + ggml_row_size(kv_cmpr_pe->type, kv_lora_rank + n_embd_head_qk_rope), ggml_row_size(kv_cmpr_pe->type, kv_lora_rank)); cb(k_pe, "k_pe", il); @@ -9613,7 +9613,7 @@ struct llm_build_deepseek2 : public llm_graph_context { cb(kv_cmpr, "kv_cmpr", il); if (cparams.mla_attn) { - // note: deepseek with MLA option converts into MQA (ie: GQA with 1 group) + GGML_ASSERT(model.layers[il].wk_b_trans != nullptr); // should not get here, see: llama_init_from_model() q_nope = ggml_permute(ctx0, q_nope, 0, 2, 1, 3); cb(q_nope, "q_nope_perm", il); @@ -9627,16 +9627,15 @@ struct llm_build_deepseek2 : public llm_graph_context { 0); cb(k_pe, "k_pe_view", il); - ggml_tensor * wk_b = ggml_view_3d(ctx0, model.layers[il].wk_b, + ggml_tensor * wk_b_trans = ggml_view_3d(ctx0, model.layers[il].wk_b_trans, n_embd_head_qk_nope, kv_lora_rank, n_head, - ggml_row_size(model.layers[il].wk_b->type, n_embd_head_qk_nope), - ggml_row_size(model.layers[il].wk_b->type, kv_lora_rank * n_embd_head_qk_nope), + ggml_row_size(model.layers[il].wk_b_trans->type, n_embd_head_qk_nope), + ggml_row_size(model.layers[il].wk_b_trans->type, n_embd_head_qk_nope) * kv_lora_rank, 0); - cb(wk_b, "wk_b", il); + cb(wk_b_trans, "wk_b_trans", il); - // note: this operation *MUST* use F32 or it will cause gibberish output, as this - // effectively does the KQ multiplication here instead of in build_attn_mha() - ggml_tensor * q_nope_absorbed = ggml_mul_mat(ctx0, wk_b, q_nope); + // note: this operation seems to need F32 precision, needs further investigation/testing + ggml_tensor * q_nope_absorbed = ggml_mul_mat(ctx0, wk_b_trans, q_nope); ggml_mul_mat_set_prec(q_nope_absorbed, GGML_PREC_F32); cb(q_nope_absorbed, "q_nope_absorbed", il); @@ -9649,38 +9648,38 @@ struct llm_build_deepseek2 : public llm_graph_context { ggml_tensor * v_states = kv_cmpr; cb(v_states, "v_states", il); - ggml_tensor * wv_b = ggml_view_3d(ctx0, model.layers[il].wv_b, + // {n_embd_head_v, n_head, n_tokens} + ggml_tensor * wv_b = ggml_view_3d(ctx0, model.layers[il].wkv_b, kv_lora_rank, n_embd_head_v, n_head, - ggml_row_size(model.layers[il].wk_b->type, kv_lora_rank), - ggml_row_size(model.layers[il].wk_b->type, kv_lora_rank * n_embd_head_v), - 0); - cb(wk_b, "wv_b", il); + ggml_row_size(model.layers[il].wkv_b->type, kv_lora_rank), + ggml_row_size(model.layers[il].wkv_b->type, kv_lora_rank) * (n_embd_head_qk_nope + n_embd_head_v), + ggml_row_size(model.layers[il].wkv_b->type, kv_lora_rank) * n_embd_head_qk_nope); + cb(wv_b, "wv_b", il); + // note: deepseek2 with MLA option converts into MQA (ie: GQA with 1 group) cur = build_attn_mla(inp_attn, gf, model.layers[il].wo, NULL, wv_b, q_states, k_states, v_states, nullptr, kq_scale, il); } else { - // note: deepseek without MLA option converts into MHA - - // note: this operation *MUST* use F32 or it will cause gibberish output + // note: this operation seems to need F32 precision, needs further investigation/testing ggml_tensor * kv = ggml_mul_mat(ctx0, model.layers[il].wkv_b, kv_cmpr); ggml_mul_mat_set_prec(kv, GGML_PREC_F32); cb(kv, "kv", il); - // split into {n_head * n_embd_head_qk_nope, n_tokens} + // split into {n_embd_head_qk_nope, n_head, n_tokens} ggml_tensor * k_nope = ggml_view_3d(ctx0, kv, n_embd_head_qk_nope, n_head, n_tokens, ggml_row_size(kv->type, n_embd_head_qk_nope + n_embd_head_v), - ggml_row_size(kv->type, n_head * (n_embd_head_qk_nope + n_embd_head_v)), + ggml_row_size(kv->type, n_embd_head_qk_nope + n_embd_head_v) * n_head, 0); cb(k_nope, "k_nope", il); - // and {n_head * n_embd_head_v, n_tokens} + // and {n_embd_head_v, n_head, n_tokens} ggml_tensor * v_states = ggml_view_3d(ctx0, kv, n_embd_head_v, n_head, n_tokens, - ggml_row_size(kv->type, (n_embd_head_qk_nope + n_embd_head_v)), - ggml_row_size(kv->type, (n_embd_head_qk_nope + n_embd_head_v)*n_head), - ggml_row_size(kv->type, (n_embd_head_qk_nope))); + ggml_row_size(kv->type, n_embd_head_qk_nope + n_embd_head_v), + ggml_row_size(kv->type, n_embd_head_qk_nope + n_embd_head_v) * n_head, + ggml_row_size(kv->type, n_embd_head_qk_nope)); cb(v_states, "v_states", il); v_states = ggml_cont(ctx0, v_states); @@ -9688,7 +9687,7 @@ struct llm_build_deepseek2 : public llm_graph_context { v_states = ggml_view_2d(ctx0, v_states, n_embd_head_v * n_head, n_tokens, - ggml_row_size(v_states->type, n_embd_head_v * n_head), + ggml_row_size(v_states->type, n_embd_head_v) * n_head, 0); cb(v_states, "v_states", il); @@ -9698,6 +9697,7 @@ struct llm_build_deepseek2 : public llm_graph_context { ggml_tensor * k_states = ggml_concat(ctx0, k_nope, ggml_repeat(ctx0, k_pe, q_pe), 0); cb(k_states, "k_states", il); + // note: deepseek2 without MLA option converts into MHA (ie: GQA with full n_head groups) cur = build_attn(inp_attn, gf, model.layers[il].wo, NULL, q_states, k_states, v_states, nullptr, kq_scale, il); diff --git a/src/llama-model.h b/src/llama-model.h index 77b4b0e1bc24e..9e6439c6996b6 100644 --- a/src/llama-model.h +++ b/src/llama-model.h @@ -160,25 +160,24 @@ struct llama_layer { struct ggml_tensor * attn_norm_enc = nullptr; // attention - struct ggml_tensor * wq = nullptr; - struct ggml_tensor * wk = nullptr; - struct ggml_tensor * wv = nullptr; - struct ggml_tensor * wo = nullptr; - struct ggml_tensor * wqkv = nullptr; - struct ggml_tensor * wq_a = nullptr; - struct ggml_tensor * wq_b = nullptr; - struct ggml_tensor * wkv_a_mqa = nullptr; - struct ggml_tensor * wkv_b = nullptr; - struct ggml_tensor * wk_b = nullptr; - struct ggml_tensor * wv_b = nullptr; - struct ggml_tensor * wq_cross = nullptr; - struct ggml_tensor * wk_cross = nullptr; - struct ggml_tensor * wv_cross = nullptr; - struct ggml_tensor * wo_cross = nullptr; - struct ggml_tensor * wq_enc = nullptr; - struct ggml_tensor * wk_enc = nullptr; - struct ggml_tensor * wv_enc = nullptr; - struct ggml_tensor * wo_enc = nullptr; + struct ggml_tensor * wq = nullptr; + struct ggml_tensor * wk = nullptr; + struct ggml_tensor * wv = nullptr; + struct ggml_tensor * wo = nullptr; + struct ggml_tensor * wqkv = nullptr; + struct ggml_tensor * wq_a = nullptr; + struct ggml_tensor * wq_b = nullptr; + struct ggml_tensor * wkv_a_mqa = nullptr; + struct ggml_tensor * wkv_b = nullptr; + struct ggml_tensor * wk_b_trans = nullptr; + struct ggml_tensor * wq_cross = nullptr; + struct ggml_tensor * wk_cross = nullptr; + struct ggml_tensor * wv_cross = nullptr; + struct ggml_tensor * wo_cross = nullptr; + struct ggml_tensor * wq_enc = nullptr; + struct ggml_tensor * wk_enc = nullptr; + struct ggml_tensor * wv_enc = nullptr; + struct ggml_tensor * wo_enc = nullptr; // attention bias struct ggml_tensor * bq = nullptr; From 5dbf99c38b246fc7170ac4ff61a878bb06f7d00d Mon Sep 17 00:00:00 2001 From: juk Date: Sat, 5 Apr 2025 00:37:57 +0100 Subject: [PATCH 23/23] Removed both calls to `ggml_mul_mat_set_prec` for MLA and non-MLA cases --- src/llama-model.cpp | 4 ---- 1 file changed, 4 deletions(-) diff --git a/src/llama-model.cpp b/src/llama-model.cpp index 9a0d9d0b27343..a6f7faf103298 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -9634,9 +9634,7 @@ struct llm_build_deepseek2 : public llm_graph_context { 0); cb(wk_b_trans, "wk_b_trans", il); - // note: this operation seems to need F32 precision, needs further investigation/testing ggml_tensor * q_nope_absorbed = ggml_mul_mat(ctx0, wk_b_trans, q_nope); - ggml_mul_mat_set_prec(q_nope_absorbed, GGML_PREC_F32); cb(q_nope_absorbed, "q_nope_absorbed", il); ggml_tensor * q_states = ggml_concat(ctx0, q_nope_absorbed, q_pe, 0); @@ -9661,9 +9659,7 @@ struct llm_build_deepseek2 : public llm_graph_context { model.layers[il].wo, NULL, wv_b, q_states, k_states, v_states, nullptr, kq_scale, il); } else { - // note: this operation seems to need F32 precision, needs further investigation/testing ggml_tensor * kv = ggml_mul_mat(ctx0, model.layers[il].wkv_b, kv_cmpr); - ggml_mul_mat_set_prec(kv, GGML_PREC_F32); cb(kv, "kv", il); // split into {n_embd_head_qk_nope, n_head, n_tokens}