From 773b6e3912d7db212c124c3c7ce2a98bee3775cc Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sun, 25 May 2025 16:41:16 +0300 Subject: [PATCH 01/14] kv-cache : simplify the "struct llama_kv_cache" interface ggml-ci --- include/llama.h | 7 +- src/llama-batch.cpp | 51 +++- src/llama-batch.h | 31 ++- src/llama-context.cpp | 178 ++++++------- src/llama-context.h | 4 + src/llama-kv-cache.cpp | 557 +++++++++++++++++++++++++++------------- src/llama-kv-cache.h | 201 +++++---------- src/llama-kv-cells.h | 17 ++ src/llama-memory.h | 27 ++ src/llama-model.cpp | 2 +- tools/server/server.cpp | 11 +- 11 files changed, 647 insertions(+), 439 deletions(-) diff --git a/include/llama.h b/include/llama.h index 01762bea2bf96..f0a1061dec960 100644 --- a/include/llama.h +++ b/include/llama.h @@ -259,9 +259,9 @@ extern "C" { llama_token * token; float * embd; llama_pos * pos; - int32_t * n_seq_id; - llama_seq_id ** seq_id; - int8_t * logits; // TODO: rename this to "output" + int32_t * n_seq_id; // TODO: remove, should belong to only 1 sequence + llama_seq_id ** seq_id; // TODO: become llama_seq_id * seq_id; + int8_t * logits; // TODO: rename this to "output" } llama_batch; enum llama_model_kv_override_type { @@ -698,6 +698,7 @@ extern "C" { LLAMA_API bool llama_kv_self_can_shift(const struct llama_context * ctx); // Apply the KV cache updates (such as K-shifts, defragmentation, etc.) + // TODO: deprecate and always update the cache lazily LLAMA_API void llama_kv_self_update(struct llama_context * ctx); // diff --git a/src/llama-batch.cpp b/src/llama-batch.cpp index b98e3256c390d..ac6bbef57d881 100644 --- a/src/llama-batch.cpp +++ b/src/llama-batch.cpp @@ -4,6 +4,21 @@ #include #include +void llama_ubatch::update() { + if (equal_seqs) { + // TODO: for now don't compute min/max for recurrent batches since we don't need this. + // the batches will be refactored anyway, so we'll fix this later + return; + } + + for (uint32_t i = 0; i < n_tokens; ++i) { + const llama_seq_id s = seq_id[i][0]; + + seq_pos_min[s] = seq_pos_min[s] == -1 ? pos[i] : std::min(seq_pos_min[s], pos[i]); + seq_pos_max[s] = seq_pos_max[s] == -1 ? pos[i] : std::max(seq_pos_max[s], pos[i]); + } +} + llama_ubatch llama_sbatch::reserve_ubatch(size_t n_ubatch, bool has_embd) { // clear empty sequences // the previous ubatch is assumed to be gone, @@ -15,24 +30,33 @@ llama_ubatch llama_sbatch::reserve_ubatch(size_t n_ubatch, bool has_embd) { break; } } - ubatch_token.resize(!has_embd ? n_ubatch : 0); - ubatch_embd.resize(has_embd ? n_embd * n_ubatch : 0); - ubatch_pos.resize(n_ubatch); - ubatch_n_seq_id.resize(n_ubatch); - ubatch_seq_id.resize(n_ubatch); - ubatch_output.resize(n_ubatch); + + udatas.push_back({}); + + auto & udata = udatas.back(); + + udata.token.resize(!has_embd ? n_ubatch : 0); + udata.embd.resize(has_embd ? n_embd * n_ubatch : 0); + udata.pos.resize(n_ubatch); + udata.n_seq_id.resize(n_ubatch); + udata.seq_id.resize(n_ubatch); + udata.output.resize(n_ubatch); + llama_ubatch ubatch = { /*equal_seqs =*/ true, /*n_tokens =*/ 0, /*n_seq_tokens =*/ 0, /*n_seqs =*/ 0, - /*token =*/ !has_embd ? ubatch_token.data() : nullptr, - /*embd =*/ has_embd ? ubatch_embd.data() : nullptr, - /*pos =*/ ubatch_pos.data(), - /*n_seq_id =*/ ubatch_n_seq_id.data(), - /*seq_id =*/ ubatch_seq_id.data(), - /*output =*/ ubatch_output.data(), + /*seq_pos_min =*/ {-1}, + /*seq_pos_max =*/ {-1}, + /*token =*/ !has_embd ? udata.token.data() : nullptr, + /*embd =*/ has_embd ? udata.embd.data() : nullptr, + /*pos =*/ udata.pos.data(), + /*n_seq_id =*/ udata.n_seq_id.data(), + /*seq_id =*/ udata.seq_id.data(), + /*output =*/ udata.output.data(), }; + return ubatch; } @@ -148,6 +172,7 @@ llama_ubatch llama_sbatch::split_simple(size_t n_ubatch) { GGML_ASSERT(seq.size() == 1 && s.n_seq_id == 0); // don't mix with other splits add_seq_to_ubatch(ubatch, s, length); } + ubatch.update(); return ubatch; } @@ -175,6 +200,7 @@ llama_ubatch llama_sbatch::split_equal(size_t n_ubatch) { if (length + n_tokens_in_ubatch > n_ubatch) { break; } } } + ubatch.update(); return ubatch; } @@ -187,6 +213,7 @@ llama_ubatch llama_sbatch::split_seq(size_t n_ubatch) { GGML_ASSERT(s.n_seq_id > 0); // should not be mixed with simple splits add_seq_to_ubatch(ubatch, s, length); } + ubatch.update(); return ubatch; } diff --git a/src/llama-batch.h b/src/llama-batch.h index 6305051b62b79..bd65ec6a935b4 100644 --- a/src/llama-batch.h +++ b/src/llama-batch.h @@ -1,6 +1,7 @@ #pragma once #include "llama.h" +#include "llama-cparams.h" #include #include @@ -8,18 +9,23 @@ // very similar to llama_batch, // but has more metadata about sequences struct llama_ubatch { + void update(); + bool equal_seqs; // TODO: whole_seqs for embeddings? - uint32_t n_tokens; // total tokens (n_seq_tokens * n_seqs) + uint32_t n_tokens; // total tokens (n_seq_tokens * n_seqs) uint32_t n_seq_tokens; // tokens per sequence uint32_t n_seqs; + llama_pos seq_pos_min[LLAMA_MAX_PARALLEL_SEQUENCES]; // min position of each sequence + llama_pos seq_pos_max[LLAMA_MAX_PARALLEL_SEQUENCES]; // max position of each sequence + llama_token * token; // [n_tokens] float * embd; // [n_embd, n_tokens] llama_pos * pos; // [n_tokens] - int32_t * n_seq_id; // [n_seqs] - llama_seq_id ** seq_id; // [n_seqs] + int32_t * n_seq_id; // [n_seqs] // TODO: remove, should belong to only 1 sequence + llama_seq_id ** seq_id; // [n_seqs] // TODO: become llama_seq_id * seq_id; int8_t * output; // [n_tokens] }; @@ -49,13 +55,18 @@ struct llama_sbatch { const llama_batch * batch = nullptr; - // buffers for the ubatch - std::vector ubatch_token; - std::vector ubatch_embd; - std::vector ubatch_pos; - std::vector ubatch_n_seq_id; - std::vector ubatch_seq_id; - std::vector ubatch_output; + // buffers for the ubatches + // TODO: very hacky, this needs a complete rework + struct ubatch_data { + std::vector token; + std::vector embd; + std::vector pos; + std::vector n_seq_id; + std::vector seq_id; + std::vector output; + }; + + std::vector udatas; llama_ubatch reserve_ubatch(size_t n_ubatch, bool has_embd = false); diff --git a/src/llama-context.cpp b/src/llama-context.cpp index e153351af3809..534448758e730 100644 --- a/src/llama-context.cpp +++ b/src/llama-context.cpp @@ -259,15 +259,9 @@ llama_context::llama_context( // reserve worst-case graph if (!hparams.vocab_only && memory) { - const uint32_t n_seqs = 1; // TODO: worst-case number of sequences + const uint32_t n_seqs = cparams.n_seq_max; const uint32_t n_tokens = std::min(cparams.n_ctx, cparams.n_ubatch); - llama_token token = model.vocab.token_bos(); // not actually used by llama_build_graph, but required to choose between token and embedding inputs graph - - // restore later - // TODO: something cleaner - const auto n_outputs_save = n_outputs; - LLAMA_LOG_DEBUG("%s: worst-case: n_tokens = %d, n_seqs = %d, n_outputs = %d\n", __func__, n_tokens, n_seqs, n_outputs); int n_splits_pp = -1; @@ -285,17 +279,8 @@ llama_context::llama_context( // reserve pp graph first so that buffers are only allocated once { - llama_ubatch ubatch_pp = { true, n_tokens, n_tokens / n_seqs, n_seqs, &token, nullptr, nullptr, nullptr, nullptr, nullptr}; - - // max number of outputs - n_outputs = ubatch_pp.n_tokens; - - LLAMA_LOG_DEBUG("%s: reserving graph for n_tokens = %d, n_seqs = %d\n", __func__, ubatch_pp.n_tokens, ubatch_pp.n_seqs); - - auto * gf = graph_init(); - graph_build(ctx_compute.get(), gf, ubatch_pp, LLM_GRAPH_TYPE_DEFAULT); - - if (!ggml_backend_sched_reserve(sched.get(), gf)) { + auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens); + if (!gf) { throw std::runtime_error("failed to allocate compute pp buffers"); } @@ -305,16 +290,8 @@ llama_context::llama_context( // reserve with tg graph to get the number of splits and nodes { - llama_ubatch ubatch_tg = { true, 1, 1, n_seqs, &token, nullptr, nullptr, nullptr, nullptr, nullptr}; - - n_outputs = ubatch_tg.n_tokens; - - LLAMA_LOG_DEBUG("%s: reserving graph for n_tokens = %d, n_seqs = %d\n", __func__, ubatch_tg.n_tokens, ubatch_tg.n_seqs); - - auto * gf = graph_init(); - graph_build(ctx_compute.get(), gf, ubatch_tg, LLM_GRAPH_TYPE_DEFAULT); - - if (!ggml_backend_sched_reserve(sched.get(), gf)) { + auto * gf = graph_reserve(1, 1, 1); + if (!gf) { throw std::runtime_error("failed to allocate compute tg buffers"); } @@ -324,22 +301,12 @@ llama_context::llama_context( // reserve again with pp graph to avoid ggml-alloc reallocations during inference { - llama_ubatch ubatch_pp = { true, n_tokens, n_tokens / n_seqs, n_seqs, &token, nullptr, nullptr, nullptr, nullptr, nullptr}; - - n_outputs = ubatch_pp.n_tokens; - - LLAMA_LOG_DEBUG("%s: reserving graph for n_tokens = %d, n_seqs = %d\n", __func__, ubatch_pp.n_tokens, ubatch_pp.n_seqs); - - auto * gf = graph_init(); - graph_build(ctx_compute.get(), gf, ubatch_pp, LLM_GRAPH_TYPE_DEFAULT); - - if (!ggml_backend_sched_reserve(sched.get(), gf)) { + auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens); + if (!gf) { throw std::runtime_error("failed to allocate compute pp buffers"); } } - n_outputs = n_outputs_save; - for (size_t i = 0; i < backend_ptrs.size(); ++i) { ggml_backend_t backend = backend_ptrs[i]; ggml_backend_buffer_type_t buft = backend_buft[i]; @@ -454,33 +421,22 @@ const llama_kv_cache * llama_context::get_kv_self() const { } void llama_context::kv_self_update() { - bool need_reserve = false; + if (!memory) { + return; + } llama_kv_cache * kv_self = static_cast(memory.get()); - need_reserve = kv_self->update(*this); - - // reserve a worst case graph if needed - if (need_reserve) { - LLAMA_LOG_DEBUG("%s: reserving a worst case graph\n", __func__); - - // build worst-case graph - uint32_t n_seqs = 1; // TODO: worst-case number of sequences - uint32_t n_tokens = std::min(cparams.n_ctx, cparams.n_ubatch); - - // simulate full KV cache + if (kv_self->update(*this)) { + // if the KV cache did any computation, we have to reserve a new worst-case graph kv_self->set_full(); - llama_token token = model.vocab.token_bos(); // not actually used by llama_build_graph, but required to choose between token and embedding inputs graph - llama_ubatch ubatch = { true, n_tokens, n_tokens / n_seqs, n_seqs, &token, nullptr, nullptr, nullptr, nullptr, nullptr}; - - auto * gf = graph_init(); - graph_build(ctx_compute.get(), gf, ubatch, LLM_GRAPH_TYPE_DEFAULT); + const uint32_t n_seqs = cparams.n_seq_max; + const uint32_t n_tokens = std::min(cparams.n_ctx, cparams.n_ubatch); - // initialize scheduler with the worst-case graph - ggml_backend_sched_reset(sched.get()); - if (!ggml_backend_sched_reserve(sched.get(), gf)) { - LLAMA_LOG_ERROR("%s: failed to allocate compute buffers\n", __func__); + auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens); + if (!gf) { + LLAMA_LOG_ERROR("%s: failed to reserve graph after the KV cache update\n", __func__); } } } @@ -737,8 +693,6 @@ int llama_context::encode(llama_batch & inp_batch) { n_outputs = n_tokens; - //batch_manager->prepare(ubatch); - ggml_backend_sched_reset(sched.get()); ggml_backend_sched_set_eval_callback(sched.get(), cparams.cb_eval, cparams.cb_eval_user_data); @@ -889,8 +843,6 @@ int llama_context::decode(llama_batch & inp_batch) { const int64_t n_tokens_all = batch.n_tokens; const int64_t n_embd = hparams.n_embd; - llama_kv_cache_guard kv_guard(kv_self); - GGML_ASSERT((!batch.token && batch.embd) || (batch.token && !batch.embd)); // NOLINT // TODO: move the validation to the llama_batch_allocr @@ -936,7 +888,28 @@ int llama_context::decode(llama_batch & inp_batch) { n_outputs_all = 1; } - llama_sbatch sbatch = kv_self->sbatch_init(batch, /* logits_all */ n_outputs_all == n_tokens_all); + // handle any pending defrags/shifts + kv_self_update(); + + auto decode_state = kv_self->init(batch, cparams.n_ubatch, embd_pooled, /* logits_all */ n_outputs_all == n_tokens_all); + if (!decode_state) { + return -2; + } + + switch (decode_state->get_status()) { + case LLAMA_MEMORY_STATUS_SUCCESS: + { + } break; + case LLAMA_MEMORY_STATUS_FAILED_PREPARE: + { + // not a fatal error, we can re-try with a different batch + return 1; + } + case LLAMA_MEMORY_STATUS_FAILED_COMPUTE: + { + return -2; + } + } // reserve output buffer if (output_reserve(n_outputs_all) < n_outputs_all) { @@ -944,13 +917,10 @@ int llama_context::decode(llama_batch & inp_batch) { return -2; }; - // handle any pending defrags/shifts - kv_self_update(); - int64_t n_outputs_prev = 0; - while (sbatch.n_tokens > 0) { - llama_ubatch ubatch = kv_self->ubatch_next(sbatch, cparams.n_ubatch, embd_pooled); + while (const auto * ubatch_ptr = decode_state->next()) { + const auto & ubatch = *ubatch_ptr; // count the outputs in this u_batch { @@ -969,11 +939,6 @@ int llama_context::decode(llama_batch & inp_batch) { n_outputs = n_outputs_new; } - // find KV slot - if (!kv_self->find_slot(ubatch)) { - return 1; - } - ggml_backend_sched_reset(sched.get()); ggml_backend_sched_set_eval_callback(sched.get(), cparams.cb_eval, cparams.cb_eval_user_data); @@ -1084,9 +1049,6 @@ int llama_context::decode(llama_batch & inp_batch) { n_outputs_prev += n_outputs; } - // finalize the batch processing - kv_guard.commit(); - // set to total number of outputs in the batch, for use in llama_get_logits_ith n_outputs = n_outputs_all; @@ -1094,7 +1056,7 @@ int llama_context::decode(llama_batch & inp_batch) { { bool sorted_output = true; - auto & out_ids = sbatch.out_ids; + auto & out_ids = decode_state->out_ids(); GGML_ASSERT(out_ids.size() == (size_t) n_outputs_all); @@ -1254,6 +1216,39 @@ ggml_cgraph * llama_context::graph_init() { return ggml_new_graph_custom(ctx_compute.get(), graph_max_nodes(), false); } +ggml_cgraph * llama_context::graph_reserve(uint32_t n_tokens, uint32_t n_seqs, uint32_t n_outputs) { + LLAMA_LOG_DEBUG("%s: reserving a graph for ubatch with n_tokens = %4u, n_seqs = %2u, n_outputs = %4u\n", __func__, n_tokens, n_seqs, n_outputs); + + // store the n_outputs as it is, and restore it afterwards + // TODO: not sure if needed, might simplify in the future by removing this + const auto save_n_outputs = this->n_outputs; + + this->n_outputs = n_outputs; + + llama_token token = model.vocab.token_bos(); // not actually used by llama_build_graph, but required to choose between token and embedding inputs graph + llama_ubatch ubatch = { true, n_tokens, n_tokens / n_seqs, n_seqs, {-1}, {-1}, &token, nullptr, nullptr, nullptr, nullptr, nullptr}; + + auto * gf = graph_init(); + auto res = graph_build(ctx_compute.get(), gf, ubatch, LLM_GRAPH_TYPE_DEFAULT); + + this->n_outputs = save_n_outputs; + + if (!res) { + LLAMA_LOG_ERROR("%s: failed to build worst-case graph\n", __func__); + return nullptr; + } + + ggml_backend_sched_reset(sched.get()); + + // initialize scheduler with the specified graph + if (!ggml_backend_sched_reserve(sched.get(), gf)) { + LLAMA_LOG_ERROR("%s: failed to allocate compute buffers\n", __func__); + return nullptr; + } + + return gf; +} + llm_graph_result_ptr llama_context::graph_build( ggml_context * ctx, ggml_cgraph * gf, @@ -1951,7 +1946,6 @@ void llama_context::opt_epoch_iter( llama_kv_cache * kv_self = static_cast(memory.get()); kv_self->clear(); - llama_kv_cache_guard kv_guard(kv_self); for (uint32_t pos_ctx = 0; pos_ctx < n_ctx; pos_ctx += n_batch) { batch.n_tokens = n_batch; @@ -1974,7 +1968,11 @@ void llama_context::opt_epoch_iter( int64_t n_outputs_all = n_tokens_all; - llama_sbatch sbatch = kv_self->sbatch_init(batch, /*logits_all =*/ true); + auto decode_state = kv_self->init(batch, cparams.n_ubatch, embd_pooled, /* logits_all */ true); + if (!decode_state || decode_state->get_status() != LLAMA_MEMORY_STATUS_SUCCESS) { + LLAMA_LOG_ERROR("%s: could not initialize batch\n", __func__); + break; + } // reserve output buffer if (output_reserve(n_outputs_all) < n_outputs_all) { @@ -1982,18 +1980,12 @@ void llama_context::opt_epoch_iter( GGML_ABORT("TODO: handle this error"); }; - for (uint32_t pos_batch = 0; pos_batch < n_batch; pos_batch += n_ubatch) { - llama_ubatch ubatch = kv_self->ubatch_next(sbatch, cparams.n_ubatch, embd_pooled); + uint32_t pos_batch = 0; + while (const auto * ubatch_ptr = decode_state->next()) { + const auto & ubatch = *ubatch_ptr; n_outputs = ubatch.n_tokens; - // TODO: not sure if this is needed - if (!kv_self->find_slot(ubatch)) { - LLAMA_LOG_WARN("%s: failed to find KV cache slot for ubatch of size %d\n", __func__, ubatch.n_tokens); - - GGML_ABORT("TODO: handle this error"); - } - auto * gf = graph_init(); auto res = graph_build(ctx_compute.get(), gf, ubatch, LLM_GRAPH_TYPE_DEFAULT); @@ -2027,10 +2019,10 @@ void llama_context::opt_epoch_iter( callback(train, opt_ctx, dataset, result, idata_in_loop + (pos_ctx + pos_batch)/n_ubatch + 1, ndata_in_loop, t_loop_start); } ggml_free(ctx_compute_opt); + + pos_batch += ubatch.n_tokens; } } - - kv_guard.commit(); } void llama_context::opt_epoch( diff --git a/src/llama-context.h b/src/llama-context.h index c0ceacb10ce6f..db9d388d0d42f 100644 --- a/src/llama-context.h +++ b/src/llama-context.h @@ -47,6 +47,7 @@ struct llama_context { llama_kv_cache * get_kv_self(); const llama_kv_cache * get_kv_self() const; + // TODO: remove void kv_self_update(); enum llama_pooling_type pooling_type() const; @@ -184,6 +185,9 @@ struct llama_context { ggml_cgraph * gf, bool batched); + // reserve a graph + ggml_cgraph * graph_reserve(uint32_t n_tokens, uint32_t n_seqs, uint32_t n_outputs); + private: llm_graph_result_ptr graph_build( ggml_context * ctx, diff --git a/src/llama-kv-cache.cpp b/src/llama-kv-cache.cpp index 766f8d079afb2..1ef3c5b2017e5 100644 --- a/src/llama-kv-cache.cpp +++ b/src/llama-kv-cache.cpp @@ -17,6 +17,61 @@ // llama_kv_cache_unified // +class llama_kv_cache_unified_decode_state_t : public llama_memory_decode_state_i { +public: + llama_kv_cache_unified_decode_state_t(llama_memory_status status) : status(status) {} + + llama_kv_cache_unified_decode_state_t( + llama_memory_status status, + llama_kv_cache_unified * kv, + llama_sbatch sbatch, + std::vector heads, + std::vector ubatches) + : status(status), + kv(kv), + sbatch(std::move(sbatch)), + heads(std::move(heads)), + ubatches(std::move(ubatches)) { + } + + ~llama_kv_cache_unified_decode_state_t() = default; + + llama_ubatch * next() override { + assert(status == LLAMA_MEMORY_STATUS_SUCCESS); + + if (i_next >= ubatches.size()) { + return nullptr; + } + + kv->fill_slot(heads[i_next], ubatches[i_next]); + + return &ubatches[i_next++]; + } + + std::vector & out_ids() override { + assert(status == LLAMA_MEMORY_STATUS_SUCCESS); + + return sbatch.out_ids; + } + + llama_memory_status get_status() const override { + return status; + } + +private: + const llama_memory_status status; + + llama_kv_cache_unified * kv; + + llama_sbatch sbatch; + + // the index of the next ubatch to process + size_t i_next = 0; + + std::vector heads; + std::vector ubatches; +}; + uint32_t llama_kv_cache_unified::get_padding(const llama_cparams & cparams) { // the FA kernels require padding to avoid extra runtime boundary checks return cparams.flash_attn ? 256u : 32u; @@ -293,26 +348,77 @@ llama_pos llama_kv_cache_unified::seq_pos_max(llama_seq_id seq_id) const { return cells.seq_pos_max(seq_id); } -void llama_kv_cache_unified::restore() { - for (auto & state : recovery.states) { - cells.set(state.i, state.cells); +llama_memory_decode_state_ptr llama_kv_cache_unified::init( + const llama_batch & batch, + uint32_t n_ubatch, + bool embd_pooled, + bool logits_all) { + GGML_UNUSED(embd_pooled); + + auto sbatch = llama_sbatch(batch, hparams.n_embd, true, logits_all); + + std::vector ubatches; + while (sbatch.n_tokens > 0) { + ubatches.push_back(sbatch.split_simple(n_ubatch)); + } + + auto heads = prepare(ubatches); + if (heads.empty()) { + return std::make_unique(LLAMA_MEMORY_STATUS_FAILED_PREPARE); } - recovery.clear(); + return std::make_unique(LLAMA_MEMORY_STATUS_SUCCESS, + this, std::move(sbatch), std::move(heads), std::move(ubatches)); } -void llama_kv_cache_unified::commit() { - if (recovery.states.empty()) { - LLAMA_LOG_WARN("%s: the recovery information upon a commit was empty - might indicate a bug (ref: %s)\n", - __func__, "https://github.com/ggml-org/llama.cpp/pull/13194"); - return; +std::vector llama_kv_cache_unified::prepare(const std::vector & ubatches) { + std::vector res; + + struct state { + uint32_t head_old; // old position of the head, before placing the ubatch + uint32_t head_new; // new position of the head, after placing the ubatch + + llama_kv_cells_unified cells; // copy of the old cells, before placing the ubatch + }; + + // remember the old state of the cells so we can restore it in the end + std::vector states; + + bool success = true; + + for (const auto & ubatch : ubatches) { + // only find a suitable slot for the ubatch. don't modify the cells yet + const int32_t head_new = find_slot(ubatch); + if (head_new < 0) { + success = false; + break; + } + + // remeber the position that we found + res.push_back(head_new); + + // store the old state of the cells in the recovery stack + states.push_back({head, (uint32_t) head_new, cells.cp(head_new, ubatch.n_tokens)}); + + // now emplace the ubatch + fill_slot(head_new, ubatch); + } + + // iterate backwards and restore the cells to their original state + for (auto it = states.rbegin(); it != states.rend(); ++it) { + cells.set(it->head_new, it->cells); + head = it->head_old; + } + + if (!success) { + return {}; } - recovery.clear(); + return res; } bool llama_kv_cache_unified::update(llama_context & lctx) { - bool need_reserve = false; + bool updated = false; auto * sched = lctx.get_sched(); @@ -330,14 +436,24 @@ bool llama_kv_cache_unified::update(llama_context & lctx) { auto * gf = lctx.graph_init(); auto res = build_graph_shift(lctx.get_cparams(), lctx.get_ctx_compute(), gf); + if (!res) { + LLAMA_LOG_ERROR("%s: failed to build graph for K-shift\n", __func__); + return updated; + } - ggml_backend_sched_alloc_graph(sched, gf); + if (!ggml_backend_sched_alloc_graph(sched, gf)) { + LLAMA_LOG_ERROR("%s: failed to allocate compute graph for K-shift\n", __func__); + return updated; + } res->set_inputs(nullptr); - lctx.graph_compute(gf, false); + if (lctx.graph_compute(gf, false) != GGML_STATUS_SUCCESS) { + LLAMA_LOG_ERROR("%s: failed to compute K-shift\n", __func__); + return updated; + } - need_reserve = true; + updated = true; } cells.reset_shift(); @@ -352,20 +468,30 @@ bool llama_kv_cache_unified::update(llama_context & lctx) { auto * gf = lctx.graph_init(); auto res = build_graph_defrag(lctx.get_cparams(), lctx.get_ctx_compute(), gf); + if (!res) { + LLAMA_LOG_ERROR("%s: failed to build graph for defrag\n", __func__); + return updated; + } - ggml_backend_sched_alloc_graph(sched, gf); + if (!ggml_backend_sched_alloc_graph(sched, gf)) { + LLAMA_LOG_ERROR("%s: failed to allocate compute graph for defrag\n", __func__); + return updated; + } res->set_inputs(nullptr); - lctx.graph_compute(gf, false); + if (lctx.graph_compute(gf, false) != GGML_STATUS_SUCCESS) { + LLAMA_LOG_ERROR("%s: failed to compute defrag\n", __func__); + return updated; + } - need_reserve = true; + updated = true; } do_defrag = false; } - return need_reserve; + return updated; } void llama_kv_cache_unified::defrag_sched(float thold) { @@ -392,40 +518,33 @@ void llama_kv_cache_unified::set_full() { head = 0; } -llama_sbatch llama_kv_cache_unified::sbatch_init(const llama_batch & batch, bool logits_all) { - return llama_sbatch(batch, hparams.n_embd, true, logits_all); -} - -llama_ubatch llama_kv_cache_unified::ubatch_next(llama_sbatch & sbatch, uint32_t n_ubatch, bool embd_pooled) const { - GGML_UNUSED(embd_pooled); - return sbatch.split_simple(n_ubatch); -} - -bool llama_kv_cache_unified::find_slot(const llama_ubatch & ubatch) { +int32_t llama_kv_cache_unified::find_slot(const llama_ubatch & ubatch) const { const uint32_t n_tokens = ubatch.n_tokens; + uint32_t head_cur = this->head; + // if we have enough unused cells before the current head -> // better to start searching from the beginning of the cache, hoping to fill it - if (head > cells.get_used() + 2*ubatch.n_tokens) { - head = 0; + if (head_cur > cells.get_used() + 2*ubatch.n_tokens) { + head_cur = 0; } // otherwise, one cell per token. if (n_tokens > cells.size()) { LLAMA_LOG_ERROR("%s: n_tokens = %d > size = %u\n", __func__, n_tokens, cells.size()); - return false; + return -1; } //#define FIND_SLOT_DEBUG 1 #if FIND_SLOT_DEBUG - LLAMA_LOG_WARN("begin: n = %5d, used = %5d, head = %5d, n_swa = %5d\n", n, used, head, n_swa); + LLAMA_LOG_WARN("begin: n = %5d, used = %5d, head = %5d, n_swa = %5d\n", n, cells.get_used(), head, n_swa); // for debugging { std::string ss; if (n_swa > 0) { - for (uint32_t i = 0; i < size; ++i) { + for (uint32_t i = 0; i < cells.size(); ++i) { if (cells.is_empty(i)) { ss += '.'; } else { @@ -443,18 +562,31 @@ bool llama_kv_cache_unified::find_slot(const llama_ubatch & ubatch) { uint32_t n_tested = 0; while (true) { - if (head + n_tokens > cells.size()) { - n_tested += cells.size() - head; - head = 0; + if (head_cur + n_tokens > cells.size()) { + n_tested += cells.size() - head_cur; + head_cur = 0; continue; } bool found = true; for (uint32_t i = 0; i < n_tokens; i++) { - // TODO: improve to accept cells that are masked by the SWA - if (!cells.is_empty(head + i)) { + const llama_seq_id seq_id = ubatch.seq_id[i][0]; + + // can we use this cell? either: + // - the cell is empty + // - the cell is occupied only by the same sequence, and the sequence is not masked + const bool can_use = + cells.is_empty(head_cur + i) || + ( + cells.pos_get(head_cur + i) <= ubatch.pos[i] && // causal mask + cells.seq_has(head_cur + i, seq_id) && // sequence mask + cells.seq_count(head_cur + i) == 1 && + is_masked_swa(cells.pos_get(head_cur + i), ubatch.seq_pos_min[seq_id]) // SWA mask + ); + + if (!can_use) { found = false; - head += i + 1; + head_cur += i + 1; n_tested += i + 1; break; } @@ -466,14 +598,23 @@ bool llama_kv_cache_unified::find_slot(const llama_ubatch & ubatch) { if (n_tested >= cells.size()) { //LLAMA_LOG_ERROR("%s: failed to find a slot for %d tokens\n", __func__, n_tokens); - return false; + return -1; } } - // store the old state of the cells in the recovery stack - recovery.states.push_back({head, cells.cp(head, n_tokens)}); + return head_cur; +} + +void llama_kv_cache_unified::fill_slot(uint32_t head_cur, const llama_ubatch & ubatch) { + head = head_cur; + + for (uint32_t i = 0; i < ubatch.n_tokens; ++i) { + if (!cells.is_empty(head + i)) { + cells.pos_chg(head + i, ubatch.pos[i]); + + continue; + } - for (uint32_t i = 0; i < n_tokens; ++i) { cells.pos_set(head + i, ubatch.pos[i]); for (int32_t j = 0; j < ubatch.n_seq_id[i]; j++) { @@ -482,15 +623,8 @@ bool llama_kv_cache_unified::find_slot(const llama_ubatch & ubatch) { } // a heuristic, to avoid attending the full cache if it is not yet utilized - // after enough generations, the benefit from this heuristic disappears - // if we start defragmenting the cache, the benefit from this will be more important + // as the cache gets filled, the benefit from this heuristic disappears n = std::min(cells.size(), std::max(n_pad, GGML_PAD(cells.used_max_p1(), n_pad))); - -#ifdef FIND_SLOT_DEBUG - LLAMA_LOG_WARN("end: n = %5d, used = %5d, head = %5d, n_swa = %5d\n", n, used, head, n_swa); -#endif - - return true; } bool llama_kv_cache_unified::get_can_shift() const { @@ -580,33 +714,6 @@ ggml_tensor * llama_kv_cache_unified::cpy_v(ggml_context * ctx, ggml_tensor * v_ return ggml_cpy(ctx, v_cur, v_view); } -void llama_kv_cache_unified::prune_swa(llama_seq_id seq_id, llama_pos pmin, llama_pos pmax) { - // no pruning is needed when the cache does not use SWA - GGML_ASSERT(swa_type != LLAMA_SWA_TYPE_NONE && "do not prune non-SWA cache"); - - int n_attended = 0; - - for (uint32_t i = 0; i < cells.size(); ++i) { - if (!cells.seq_has(i, seq_id)) { - continue; - } - - const llama_pos p0 = cells.pos_get(i); - - if (p0 <= pmin && !is_masked_swa(p0, pmin)) { - n_attended++; - } - - if (is_masked_swa(p0, pmax)) { - cells.seq_rm(i, seq_id); - } - } - - if (n_attended < std::min(n_swa, pmin)) { - LLAMA_LOG_WARN("%s: partial SWA cache detected - possible loss of information, pmin = %d, n_attended = %d, n_swa = %d\n", __func__, pmin, n_attended, n_swa); - } -} - void llama_kv_cache_unified::set_input_kq_mask(ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn) const { const int64_t n_tokens = ubatch->n_tokens; const int64_t n_seq_tokens = ubatch->n_seq_tokens; @@ -1362,12 +1469,13 @@ bool llama_kv_cache_unified::state_read_meta(llama_io_read_i & io, uint32_t cell batch.seq_id[i] = &dest_seq_id; } - if (!find_slot(batch)) { + const auto head_cur = find_slot(batch); + if (head_cur < 0) { LLAMA_LOG_ERROR("%s: failed to find available cells in kv cache\n", __func__); return false; } - commit(); + fill_slot(head_cur, batch); // DEBUG CHECK: kv.head should be our first cell, kv.head + cell_count - 1 should be our last cell (verify seq_id and pos values) // Assume that this is one contiguous block of cells @@ -1425,10 +1533,12 @@ bool llama_kv_cache_unified::state_read_data(llama_io_read_i & io, uint32_t cell LLAMA_LOG_ERROR("%s: mismatched layer count (%u instead of %u)\n", __func__, n_layer, (uint32_t) layers.size()); return false; } + if (cell_count > cells.size()) { LLAMA_LOG_ERROR("%s: not enough cells in kv cache to restore state (%u > %u)\n", __func__, cell_count, cells.size()); return false; } + if (this->v_trans != (bool) v_trans) { LLAMA_LOG_ERROR("%s: incompatible V transposition\n", __func__); return false; @@ -1543,6 +1653,65 @@ bool llama_kv_cache_unified::state_read_data(llama_io_read_i & io, uint32_t cell // llama_kv_cache_unified_iswa // +class llama_kv_cache_unified_iswa_decode_state_t : public llama_memory_decode_state_i { +public: + llama_kv_cache_unified_iswa_decode_state_t(llama_memory_status status) : status(status) {} + + llama_kv_cache_unified_iswa_decode_state_t( + llama_memory_status status, + llama_kv_cache_unified_iswa * kv, + llama_sbatch sbatch, + std::vector heads_base, + std::vector heads_swa, + std::vector ubatches) + : status(status), + kv(kv), + sbatch(std::move(sbatch)), + heads_base(std::move(heads_base)), + heads_swa(std::move(heads_swa)), + ubatches(std::move(ubatches)) { + } + + ~llama_kv_cache_unified_iswa_decode_state_t() = default; + + llama_ubatch * next() override { + assert(status == LLAMA_MEMORY_STATUS_SUCCESS); + + if (i_next >= ubatches.size()) { + return nullptr; + } + + kv->get_kv_base()->fill_slot(heads_base[i_next], ubatches[i_next]); + kv->get_kv_swa ()->fill_slot(heads_swa [i_next], ubatches[i_next]); + + return &ubatches[i_next++]; + } + + std::vector & out_ids() override { + assert(status == LLAMA_MEMORY_STATUS_SUCCESS); + + return sbatch.out_ids; + } + + llama_memory_status get_status() const override { + return status; + } + +private: + const llama_memory_status status; + + llama_kv_cache_unified_iswa * kv; + + llama_sbatch sbatch; + + // the index of the next ubatch to process + size_t i_next = 0; + + std::vector heads_base; + std::vector heads_swa; + std::vector ubatches; +}; + llama_kv_cache_unified_iswa::llama_kv_cache_unified_iswa( const llama_model & model, ggml_type type_k, @@ -1552,22 +1721,21 @@ llama_kv_cache_unified_iswa::llama_kv_cache_unified_iswa( bool swa_full, uint32_t kv_size, uint32_t n_seq_max, - uint32_t n_batch, + uint32_t n_ubatch, uint32_t n_pad) : hparams(model.hparams) { llama_kv_cache_unified::layer_filter_cb filter_base = [&](int32_t il) { return !model.hparams.is_swa(il); }; llama_kv_cache_unified::layer_filter_cb filter_swa = [&](int32_t il) { return model.hparams.is_swa(il); }; const uint32_t size_base = kv_size; - uint32_t size_swa = std::min(size_base, GGML_PAD(hparams.n_swa*n_seq_max + n_batch, n_pad)); + uint32_t size_swa = std::min(size_base, GGML_PAD(hparams.n_swa*n_seq_max + n_ubatch, n_pad)); - // when using full-size SWA cache, we set the SWA cache size to be equal to the base cache size and disable pruning + // when using full-size SWA cache, we set the SWA cache size to be equal to the base cache size if (swa_full) { LLAMA_LOG_WARN("%s: using full-size SWA cache (ref: %s)\n", __func__, "https://github.com/ggml-org/llama.cpp/pull/13194#issuecomment-2868343055"); size_swa = size_base; - do_prune = false; } LLAMA_LOG_INFO("%s: creating non-SWA KV cache, size = %u cells\n", __func__, size_base); @@ -1628,31 +1796,40 @@ llama_pos llama_kv_cache_unified_iswa::seq_pos_max(llama_seq_id seq_id) const { return kv_swa->seq_pos_max(seq_id); } -void llama_kv_cache_unified_iswa::restore() { - kv_base->restore(); - kv_swa ->restore(); -} +llama_memory_decode_state_ptr llama_kv_cache_unified_iswa::init(const llama_batch & batch, uint32_t n_ubatch, bool embd_pooled, bool logits_all) { + GGML_UNUSED(embd_pooled); -void llama_kv_cache_unified_iswa::commit() { - kv_base->commit(); - kv_swa ->commit(); + auto sbatch = llama_sbatch(batch, hparams.n_embd, true, logits_all); - // slide the attention window, forgetting/pruning old tokens that are outside the window - if (do_prune) { - for (const auto & [seq_id, entry] : pending.pos) { - kv_swa->prune_swa(seq_id, entry.pmin, entry.pmax); - } + std::vector ubatches; + + while (sbatch.n_tokens > 0) { + auto ubatch = sbatch.split_simple(n_ubatch); + + ubatches.push_back(ubatch); + } + + auto heads_base = kv_base->prepare(ubatches); + if (heads_base.empty()) { + return std::make_unique(LLAMA_MEMORY_STATUS_FAILED_PREPARE); + } + auto heads_swa = kv_swa->prepare(ubatches); + if (heads_swa.empty()) { + return std::make_unique(LLAMA_MEMORY_STATUS_FAILED_PREPARE); } - pending.clear(); + assert(heads_base.size() == heads_swa.size()); + + return std::make_unique(LLAMA_MEMORY_STATUS_SUCCESS, + this, std::move(sbatch), std::move(heads_base), std::move(heads_swa), std::move(ubatches)); } bool llama_kv_cache_unified_iswa::update(llama_context & lctx) { - bool res = true; + bool res = false; - res = res & kv_base->update(lctx); - res = res & kv_swa ->update(lctx); + res = res | kv_base->update(lctx); + res = res | kv_swa ->update(lctx); return res; } @@ -1667,43 +1844,6 @@ void llama_kv_cache_unified_iswa::set_full() { kv_swa ->set_full(); } -llama_sbatch llama_kv_cache_unified_iswa::sbatch_init(const llama_batch & batch, bool logits_all) { - pending.clear(); - - if (do_prune) { - for (int i = 0; i < batch.n_tokens; ++i) { - for (int s = 0; s < batch.n_seq_id[i]; ++s) { - const llama_seq_id seq_id = batch.seq_id[i][s]; - const llama_pos pos = batch.pos[i]; - - if (pending.pos.find(seq_id) == pending.pos.end()) { - pending.pos[seq_id].pmin = pos; - pending.pos[seq_id].pmax = pos; - } else { - pending.pos[seq_id].pmin = std::min(pending.pos[seq_id].pmin, pos); - pending.pos[seq_id].pmax = std::max(pending.pos[seq_id].pmax, pos); - } - } - } - } - - return llama_sbatch(batch, hparams.n_embd, true, logits_all); -} - -llama_ubatch llama_kv_cache_unified_iswa::ubatch_next(llama_sbatch & sbatch, uint32_t n_ubatch, bool embd_pooled) const { - GGML_UNUSED(embd_pooled); - return sbatch.split_simple(n_ubatch); -} - -bool llama_kv_cache_unified_iswa::find_slot(const llama_ubatch & batch) { - bool res = true; - - res = res & kv_base->find_slot(batch); - res = res & kv_swa ->find_slot(batch); - - return res; -} - bool llama_kv_cache_unified_iswa::get_can_shift() const { return kv_base->get_size() == kv_swa->get_size(); } @@ -1730,6 +1870,52 @@ llama_kv_cache_unified * llama_kv_cache_unified_iswa::get_kv_swa() const { // llama_kv_cache_recurrent // +class llama_kv_cache_recurrent_decode_state_t : public llama_memory_decode_state_i { +public: + llama_kv_cache_recurrent_decode_state_t(llama_memory_status status) : status(status) {} + + llama_kv_cache_recurrent_decode_state_t( + llama_memory_status status, + llama_kv_cache_recurrent * kv, + llama_sbatch sbatch, + std::vector ubatches) : status(status), kv(kv), sbatch(std::move(sbatch)), ubatches(std::move(ubatches)) {} + + ~llama_kv_cache_recurrent_decode_state_t() override = default; + + llama_ubatch * next() override { + assert(status == LLAMA_MEMORY_STATUS_SUCCESS); + + if (i_next >= ubatches.size()) { + return nullptr; + } + + kv->find_slot(ubatches[i_next]); + + return &ubatches[i_next++]; + } + + std::vector & out_ids() override { + assert(status == LLAMA_MEMORY_STATUS_SUCCESS); + + return sbatch.out_ids; + } + + llama_memory_status get_status() const override { + return status; + } + +private: + const llama_memory_status status; + + llama_kv_cache_recurrent * kv; + + llama_sbatch sbatch; + + size_t i_next = 0; + + std::vector ubatches; +}; + llama_kv_cache_recurrent::llama_kv_cache_recurrent( const llama_model & model, ggml_type type_k, @@ -2071,20 +2257,69 @@ llama_pos llama_kv_cache_recurrent::seq_pos_max(llama_seq_id seq_id) const { return result; } -void llama_kv_cache_recurrent::restore() { - if (pending.ranges.empty()) { - return; +llama_memory_decode_state_ptr llama_kv_cache_recurrent::init(const llama_batch & batch, uint32_t n_ubatch, bool embd_pooled, bool logits_all) { + GGML_UNUSED(embd_pooled); + + auto sbatch = llama_sbatch(batch, hparams.n_embd, false, logits_all); + + std::vector ubatches; + + while (sbatch.n_tokens > 0) { + llama_ubatch ubatch; + + if (embd_pooled) { + // Pooled embeddings cannot be split across ubatches (yet) + ubatch = sbatch.split_seq(n_ubatch); + } else { + ubatch = sbatch.split_equal(n_ubatch); + } + + ubatches.push_back(ubatch); } - seq_rm(-1, -1, -1); + if (!prepare(ubatches)) { + return std::make_unique(LLAMA_MEMORY_STATUS_FAILED_PREPARE); + } + + return std::make_unique(LLAMA_MEMORY_STATUS_SUCCESS, this, std::move(sbatch), std::move(ubatches)); } -void llama_kv_cache_recurrent::commit() { - pending.ranges.clear(); +bool llama_kv_cache_recurrent::prepare(const std::vector & ubatches) { + // simply remember the full state + // TODO: optimize + auto org_cells = cells; + auto org_used = used; + auto org_head = head; + + bool success = true; + + // TODO: here we have to verify that all ubatches can fit in the cells + // however, the current implementation is broken because it relies on s_copy() and s_mask() to update the cells + // during the compute of each ubatch. to reproduce, uncomment the following loop and run: + // + // $ llama-parallel -m ./mamba-130m/ggml-model-f16.gguf -np 5 -ns 8 + // + // recovery from failures when the batch does not fit in the KV cache will not work correctly until this is fixed + // + GGML_UNUSED(ubatches); + //for (const auto & ubatch : ubatches) { + // if (!find_slot(ubatch)) { + // success = false; + // break; + // } + //} + + // restore the original state + cells = std::move(org_cells); + used = org_used; + head = org_head; + + return success; } -bool llama_kv_cache_recurrent::update(llama_context & ctx) { - GGML_UNUSED(ctx); +bool llama_kv_cache_recurrent::update(llama_context & lctx) { + GGML_UNUSED(lctx); + // noop return false; } @@ -2098,23 +2333,7 @@ void llama_kv_cache_recurrent::set_full() { head = 0; } -llama_sbatch llama_kv_cache_recurrent::sbatch_init( - const llama_batch & batch, - bool logits_all) { - return llama_sbatch(batch, hparams.n_embd, false, logits_all); -} - -llama_ubatch llama_kv_cache_recurrent::ubatch_next(llama_sbatch & sbatch, uint32_t n_ubatch, bool embd_pooled) const { - if (embd_pooled) { - // Pooled embeddings cannot be split across ubatches (yet) - return sbatch.split_seq(n_ubatch); - } - - return sbatch.split_equal(n_ubatch); -} - -bool llama_kv_cache_recurrent::find_slot( - const llama_ubatch & ubatch) { +bool llama_kv_cache_recurrent::find_slot(const llama_ubatch & ubatch) { const uint32_t n_tokens = ubatch.n_tokens; const uint32_t n_seqs = ubatch.n_seqs; @@ -2332,18 +2551,6 @@ float llama_kv_cache_recurrent::s_mask(int i) const { return res; } -uint32_t llama_kv_cache_recurrent::cell_max() const { - for (uint32_t i = size; i > 0; --i) { - const kv_cell & cell = cells[i - 1]; - - if (cell.pos >= 0 && !cell.is_empty()) { - return i; - } - } - - return 0; -} - size_t llama_kv_cache_recurrent::total_size() const { size_t size = 0; for (const auto & buf : bufs) { @@ -2558,11 +2765,11 @@ bool llama_kv_cache_recurrent::state_read_meta(llama_io_read_i & io, uint32_t ce } batch.n_seq_id[0] = 1; batch.seq_id[0] = &dest_seq_id; + if (!find_slot(batch)) { LLAMA_LOG_ERROR("%s: failed to find available cells in kv cache\n", __func__); return false; } - commit(); // DEBUG CHECK: kv.head should be our first cell, kv.head + cell_count - 1 should be our last cell (verify seq_id and pos values) // Assume that this is one contiguous block of cells diff --git a/src/llama-kv-cache.h b/src/llama-kv-cache.h index ce6261e45a6e1..ac8fb7588b62c 100644 --- a/src/llama-kv-cache.h +++ b/src/llama-kv-cache.h @@ -22,14 +22,18 @@ struct llama_context; struct llama_kv_cache : public llama_memory_i { virtual ~llama_kv_cache() = default; - // call if batch processing fails - restores the cache state - virtual void restore() = 0; - - // call after successful batch processing - clears any pending state - virtual void commit() = 0; + // split the input batch into a set of ubatches and verify that they can fit into the cache + // check the llama_memory_decode_state_i::get_status() for the result + virtual llama_memory_decode_state_ptr init( + const llama_batch & batch, + uint32_t n_ubatch, + bool embd_pooled, + bool logits_all) = 0; // process any pending defrag/shift/etc. operations // optionally call once before processing a new batch + // return true if any operations were performed + // will reserve a new worst-case graph if needed virtual bool update(llama_context & lctx) = 0; // schedule a defrag if the fragmentation threshold is exceeded. otherwise, do nothing @@ -39,23 +43,6 @@ struct llama_kv_cache : public llama_memory_i { // TODO: remove virtual void set_full() = 0; - // - // batch processing - // - - // ============================================================================================================= - // TODO: refactor and simplify this [TAG: KV_API] - - virtual llama_sbatch sbatch_init(const llama_batch & batch, bool logits_all) = 0; - - // different KV caches require different batch splitting strategies - virtual llama_ubatch ubatch_next(llama_sbatch & sbatch, uint32_t n_ubatch, bool embd_pooled) const = 0; - - // find an empty slot of size "n_tokens" in the cache - virtual bool find_slot(const llama_ubatch & batch) = 0; - - // ============================================================================================================= - // getters virtual bool get_can_shift() const = 0; @@ -69,25 +56,6 @@ struct llama_kv_cache : public llama_memory_i { virtual void state_read (llama_io_read_i & io, llama_seq_id seq_id = -1) = 0; }; -// -// llama_kv_cache_guard -// - -struct llama_kv_cache_guard { - llama_kv_cache_guard(llama_kv_cache * kv) : kv(kv) {} - - ~llama_kv_cache_guard() { - kv->restore(); - } - - void commit() { - kv->commit(); - } - -private: - llama_kv_cache * kv; -}; - // // llama_kv_cache_unified // @@ -133,23 +101,18 @@ class llama_kv_cache_unified : public llama_kv_cache { // llama_kv_cache // - void restore() override; - void commit() override; + llama_memory_decode_state_ptr init( + const llama_batch & batch, + uint32_t n_ubatch, + bool embd_pooled, + bool logits_all) override; - bool update(llama_context & ctx) override; + bool update(llama_context & lctx) override; void defrag_sched(float thold) override; void set_full() override; - llama_sbatch sbatch_init(const llama_batch & batch, bool logits_all) override; - llama_ubatch ubatch_next(llama_sbatch & sbatch, uint32_t n_ubatch, bool embd_pooled) const override; - - // updates the cache head - // Note: On success, it's important that cache.head points - // to the first cell of the slot. - bool find_slot(const llama_ubatch & batch) override; - bool get_can_shift() const override; // state write/load @@ -172,7 +135,16 @@ class llama_kv_cache_unified : public llama_kv_cache { ggml_tensor * cpy_k(ggml_context * ctx, ggml_tensor * k_cur, int32_t il) const; ggml_tensor * cpy_v(ggml_context * ctx, ggml_tensor * v_cur, int32_t il) const; - void prune_swa(llama_seq_id seq_id, llama_pos pmin, llama_pos pmax); + // return empty vector on failure + std::vector prepare(const std::vector & ubatches); + + // return the cell position where we can insert the ubatch + // return -1 on failure to find a contiguous slot of kv cells + int32_t find_slot(const llama_ubatch & ubatch) const; + + // emplace the ubatch context into cells [head_cur, head_cur + ubatch.n_tokens) + // updates head = head_cur + void fill_slot(uint32_t head_cur, const llama_ubatch & ubatch); void set_input_kq_mask (ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn) const; void set_input_k_shift (ggml_tensor * dst) const; @@ -220,24 +192,6 @@ class llama_kv_cache_unified : public llama_kv_cache { // model layer id -> KV cache layer id std::unordered_map map_layer_ids; - // recovery information used to restore the KV cells to their original state in case of a failure - // TODO: do not store as a state in the llama_kv_cache object, instead return upon batch preparation - // to achieve that, first need to refactor the llama_kv_cache interface [TAG: KV_API] - struct { - void clear() { - states.clear(); - } - - struct state { - uint32_t i; - - llama_kv_cells_unified cells; - }; - - // stack with the partial states before each ubatch - std::vector states; - } recovery; - // defrag struct { std::vector ids; @@ -285,7 +239,7 @@ class llama_kv_cache_unified : public llama_kv_cache { // utilizes two instances of llama_kv_cache_unified // the first instance is for the non-SWA layers of the model and the second instance is for the SWA layers -// upon successful commit, the SWA cache removes old tokens outside the n_swa window +// upon successful processing of the batch, the SWA cache removes old tokens outside the n_swa window class llama_kv_cache_unified_iswa : public llama_kv_cache { public: @@ -298,7 +252,7 @@ class llama_kv_cache_unified_iswa : public llama_kv_cache { bool swa_full, uint32_t kv_size, uint32_t n_seq_max, - uint32_t n_batch, + uint32_t n_ubatch, uint32_t n_pad); ~llama_kv_cache_unified_iswa() = default; @@ -322,20 +276,18 @@ class llama_kv_cache_unified_iswa : public llama_kv_cache { // llama_kv_cache // - void restore() override; - void commit() override; + llama_memory_decode_state_ptr init( + const llama_batch & batch, + uint32_t n_ubatch, + bool embd_pooled, + bool logits_all) override; - bool update(llama_context & ctx) override; + bool update(llama_context & lctx) override; void defrag_sched(float thold) override; void set_full() override; - llama_sbatch sbatch_init(const llama_batch & batch, bool logits_all) override; - llama_ubatch ubatch_next(llama_sbatch & sbatch, uint32_t n_ubatch, bool embd_pooled) const override; - - bool find_slot(const llama_ubatch & batch) override; - bool get_can_shift() const override; // state write/load @@ -353,22 +305,6 @@ class llama_kv_cache_unified_iswa : public llama_kv_cache { private: const llama_hparams & hparams; - bool do_prune = true; - - struct { - struct entry { - llama_pos pmin; - llama_pos pmax; - }; - - void clear() { - pos.clear(); - } - - // used to perform SWA pruning of old tokens - std::unordered_map pos; - } pending; - std::unique_ptr kv_base; std::unique_ptr kv_swa; }; @@ -379,26 +315,6 @@ class llama_kv_cache_unified_iswa : public llama_kv_cache { class llama_kv_cache_recurrent : public llama_kv_cache { public: - struct kv_cell { - llama_pos pos = -1; - int32_t src = -1; // used to copy states - int32_t tail = -1; - - std::set seq_id; - - bool has_seq_id(const llama_seq_id & id) const { - return seq_id.find(id) != seq_id.end(); - } - - bool is_empty() const { - return seq_id.empty(); - } - - bool is_same_seq(const kv_cell & other) const { - return seq_id == other.seq_id; - } - }; - llama_kv_cache_recurrent( const llama_model & model, ggml_type type_k, @@ -428,19 +344,22 @@ class llama_kv_cache_recurrent : public llama_kv_cache { // llama_kv_cache // - void restore() override; - void commit() override; + llama_memory_decode_state_ptr init( + const llama_batch & batch, + uint32_t n_ubatch, + bool embd_pooled, + bool logits_all) override; - bool update(llama_context & ctx) override; + bool update(llama_context & lctx) override; void defrag_sched(float thold) override; void set_full() override; - llama_sbatch sbatch_init(const llama_batch & batch, bool logits_all) override; - llama_ubatch ubatch_next(llama_sbatch & sbatch, uint32_t n_ubatch, bool embd_pooled) const override; + bool prepare(const std::vector & ubatches); - bool find_slot(const llama_ubatch & batch) override; + // find a contiguous slot of kv cells and emplace the ubatch there + bool find_slot(const llama_ubatch & ubatch); bool get_can_shift() const override; @@ -460,6 +379,27 @@ class llama_kv_cache_recurrent : public llama_kv_cache { // computed before each graph build uint32_t n = 0; + // TODO: optimize for recurrent state needs + struct kv_cell { + llama_pos pos = -1; + int32_t src = -1; // used to copy states + int32_t tail = -1; + + std::set seq_id; + + bool has_seq_id(const llama_seq_id & id) const { + return seq_id.find(id) != seq_id.end(); + } + + bool is_empty() const { + return seq_id.empty(); + } + + bool is_same_seq(const kv_cell & other) const { + return seq_id == other.seq_id; + } + }; + std::vector cells; std::vector k_l; // per layer @@ -469,26 +409,11 @@ class llama_kv_cache_recurrent : public llama_kv_cache { //const llama_model & model; const llama_hparams & hparams; - // commit/restore cache - // TODO: rework for recurrent cache - struct slot_range { - uint32_t c0 = 0; // note: these are cell indices, not sequence positions - uint32_t c1 = 0; - }; - - // pending cell updates that are not yet committed - struct { - std::vector ranges; - } pending; - const uint32_t n_seq_max = 1; std::vector ctxs; std::vector bufs; - // find how many cells are currently in use - uint32_t cell_max() const; - size_t total_size() const; size_t size_k_bytes() const; diff --git a/src/llama-kv-cells.h b/src/llama-kv-cells.h index dbbd03fcba281..5c6b8d0dcdc78 100644 --- a/src/llama-kv-cells.h +++ b/src/llama-kv-cells.h @@ -196,6 +196,13 @@ class llama_kv_cells_unified { return false; } + int seq_count(uint32_t i) const { + assert(i < pos.size()); + assert(pos[i] != -1); + + return seq[i].count(); + } + bool seq_has(uint32_t i, llama_seq_id seq_id) const { assert(i < pos.size()); assert(seq_id >= 0); @@ -274,6 +281,16 @@ class llama_kv_cells_unified { used.insert(i); } + // change the position of a non-empty cell + // does not modify "has_shift" + // note: call only if the cell is not empty + void pos_chg(uint32_t i, llama_pos p) { + assert(i < pos.size()); + assert(pos[i] != -1); + + pos[i] = p; + } + // pos[i] = pos[i] + d // sets "has_shift" to true // note: call only if the cell is not empty diff --git a/src/llama-memory.h b/src/llama-memory.h index a2d250434affa..44a45da9f5891 100644 --- a/src/llama-memory.h +++ b/src/llama-memory.h @@ -2,6 +2,11 @@ #include "llama.h" +#include +#include + +struct llama_ubatch; + struct llama_memory_params { // kv cache ggml_type type_k; @@ -30,3 +35,25 @@ class llama_memory_i { virtual bool get_can_edit() const = 0; }; + +enum llama_memory_status { + LLAMA_MEMORY_STATUS_SUCCESS = 0, + LLAMA_MEMORY_STATUS_FAILED_PREPARE, + LLAMA_MEMORY_STATUS_FAILED_COMPUTE, +}; + +class llama_memory_decode_state_i { +public: + virtual ~llama_memory_decode_state_i() = default; + + // consume the next ubatch from the decode state + // return nullptr if we are done + virtual llama_ubatch * next() = 0; + + // TODO: this might get reworked in the future when refactoring llama_batch + virtual std::vector & out_ids() = 0; + + virtual llama_memory_status get_status() const = 0; +}; + +using llama_memory_decode_state_ptr = std::unique_ptr; diff --git a/src/llama-model.cpp b/src/llama-model.cpp index 3f1f6c9bf3b06..2be5b8f9f1c48 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -13230,7 +13230,7 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params, params.swa_full, cparams.n_ctx, cparams.n_seq_max, - cparams.n_batch, + cparams.n_ubatch, padding); } else { GGML_ASSERT(!hparams.is_swa_any()); diff --git a/tools/server/server.cpp b/tools/server/server.cpp index 5d03dc3dc790a..96682cc07cb64 100644 --- a/tools/server/server.cpp +++ b/tools/server/server.cpp @@ -2016,11 +2016,6 @@ struct server_context { params_base.n_cache_reuse = 0; SRV_WRN("%s\n", "cache_reuse is not supported by this context, it will be disabled"); } - - if (!params_base.speculative.model.path.empty()) { - SRV_ERR("%s\n", "err: speculative decode is not supported by this context"); - return false; - } } return true; @@ -3213,8 +3208,10 @@ struct server_context { slot.cache_tokens.clear(); // TODO: not needed, will be cleared later via "keep_first()" } - if (slot.n_past > 0 && slot.n_past < (int) slot.cache_tokens.size()) { - if (llama_kv_self_seq_pos_min(ctx, slot.id) > 0) { + if (slot.n_past > 0 && slot.n_past + 32 < (int) slot.cache_tokens.size()) { + const auto pos_min = llama_kv_self_seq_pos_min(ctx, slot.id); + if (pos_min > 0) { + SLT_WRN(slot, "n_past = %d, cache_tokens.size() = %d, seq_id = %d, pos_min = %d\n", slot.n_past, (int) slot.cache_tokens.size(), slot.id, pos_min); SLT_WRN(slot, "forcing full prompt re-processing due to lack of cache data (likely due to SWA, see %s)\n", "https://github.com/ggml-org/llama.cpp/pull/13194#issuecomment-2868343055"); slot.n_past = 0; From 9fc50dcdcbe3ab56425700ee15d3aad121b8ddec Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sun, 25 May 2025 17:12:58 +0300 Subject: [PATCH 02/14] kv-cache : revert the (n_swa + n_ubatch) change (for next PR) ggml-ci --- src/llama-kv-cache.cpp | 4 ++-- src/llama-kv-cache.h | 2 +- src/llama-model.cpp | 2 +- tools/server/server.cpp | 10 ++++++++-- 4 files changed, 12 insertions(+), 6 deletions(-) diff --git a/src/llama-kv-cache.cpp b/src/llama-kv-cache.cpp index 1ef3c5b2017e5..862e3f8ee35c5 100644 --- a/src/llama-kv-cache.cpp +++ b/src/llama-kv-cache.cpp @@ -1721,14 +1721,14 @@ llama_kv_cache_unified_iswa::llama_kv_cache_unified_iswa( bool swa_full, uint32_t kv_size, uint32_t n_seq_max, - uint32_t n_ubatch, + uint32_t n_batch, uint32_t n_pad) : hparams(model.hparams) { llama_kv_cache_unified::layer_filter_cb filter_base = [&](int32_t il) { return !model.hparams.is_swa(il); }; llama_kv_cache_unified::layer_filter_cb filter_swa = [&](int32_t il) { return model.hparams.is_swa(il); }; const uint32_t size_base = kv_size; - uint32_t size_swa = std::min(size_base, GGML_PAD(hparams.n_swa*n_seq_max + n_ubatch, n_pad)); + uint32_t size_swa = std::min(size_base, GGML_PAD(hparams.n_swa*n_seq_max + n_batch, n_pad)); // when using full-size SWA cache, we set the SWA cache size to be equal to the base cache size if (swa_full) { diff --git a/src/llama-kv-cache.h b/src/llama-kv-cache.h index ac8fb7588b62c..a966e2c3a56a2 100644 --- a/src/llama-kv-cache.h +++ b/src/llama-kv-cache.h @@ -252,7 +252,7 @@ class llama_kv_cache_unified_iswa : public llama_kv_cache { bool swa_full, uint32_t kv_size, uint32_t n_seq_max, - uint32_t n_ubatch, + uint32_t n_batch, uint32_t n_pad); ~llama_kv_cache_unified_iswa() = default; diff --git a/src/llama-model.cpp b/src/llama-model.cpp index 2be5b8f9f1c48..3f1f6c9bf3b06 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -13230,7 +13230,7 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params, params.swa_full, cparams.n_ctx, cparams.n_seq_max, - cparams.n_ubatch, + cparams.n_batch, padding); } else { GGML_ASSERT(!hparams.is_swa_any()); diff --git a/tools/server/server.cpp b/tools/server/server.cpp index 96682cc07cb64..91b73afa7c794 100644 --- a/tools/server/server.cpp +++ b/tools/server/server.cpp @@ -2016,6 +2016,11 @@ struct server_context { params_base.n_cache_reuse = 0; SRV_WRN("%s\n", "cache_reuse is not supported by this context, it will be disabled"); } + + if (!params_base.speculative.model.path.empty()) { + SRV_ERR("%s\n", "err: speculative decode is not supported by this context"); + return false; + } } return true; @@ -3208,7 +3213,7 @@ struct server_context { slot.cache_tokens.clear(); // TODO: not needed, will be cleared later via "keep_first()" } - if (slot.n_past > 0 && slot.n_past + 32 < (int) slot.cache_tokens.size()) { + if (slot.n_past > 0 && slot.n_past < (int) slot.cache_tokens.size()) { const auto pos_min = llama_kv_self_seq_pos_min(ctx, slot.id); if (pos_min > 0) { SLT_WRN(slot, "n_past = %d, cache_tokens.size() = %d, seq_id = %d, pos_min = %d\n", slot.n_past, (int) slot.cache_tokens.size(), slot.id, pos_min); @@ -3422,10 +3427,11 @@ struct server_context { // retry with half the batch size to try to find a free slot in the KV cache n_batch /= 2; - i -= n_batch; SRV_WRN("failed to find free space in the KV cache, retrying with smaller batch size - try increasing it via the context size or enable defragmentation, i = %d, n_batch = %d, ret = %d\n", i, n_batch, ret); + i -= n_batch; + continue; // continue loop of n_batch } From c2c35917d0a6fce406820fd2876196658a247cf2 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sun, 25 May 2025 17:17:38 +0300 Subject: [PATCH 03/14] kv-cache : some comments ggml-ci --- src/llama-context.h | 2 +- src/llama-kv-cache.h | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/src/llama-context.h b/src/llama-context.h index db9d388d0d42f..2de7368293b12 100644 --- a/src/llama-context.h +++ b/src/llama-context.h @@ -185,7 +185,7 @@ struct llama_context { ggml_cgraph * gf, bool batched); - // reserve a graph + // reserve a graph with a dummy ubatch of the specified size ggml_cgraph * graph_reserve(uint32_t n_tokens, uint32_t n_seqs, uint32_t n_outputs); private: diff --git a/src/llama-kv-cache.h b/src/llama-kv-cache.h index a966e2c3a56a2..2ba61d62427d7 100644 --- a/src/llama-kv-cache.h +++ b/src/llama-kv-cache.h @@ -135,6 +135,7 @@ class llama_kv_cache_unified : public llama_kv_cache { ggml_tensor * cpy_k(ggml_context * ctx, ggml_tensor * k_cur, int32_t il) const; ggml_tensor * cpy_v(ggml_context * ctx, ggml_tensor * v_cur, int32_t il) const; + // find places for the provided ubatches in the cache, returns the head locations // return empty vector on failure std::vector prepare(const std::vector & ubatches); From 885678201d3a211938dca309d39102d86ce52ff2 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sun, 25 May 2025 17:41:05 +0300 Subject: [PATCH 04/14] context : fix graph reserve for multiple sequences ggml-ci --- src/llama-context.cpp | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/src/llama-context.cpp b/src/llama-context.cpp index 534448758e730..ea2d718f49f93 100644 --- a/src/llama-context.cpp +++ b/src/llama-context.cpp @@ -1219,6 +1219,13 @@ ggml_cgraph * llama_context::graph_init() { ggml_cgraph * llama_context::graph_reserve(uint32_t n_tokens, uint32_t n_seqs, uint32_t n_outputs) { LLAMA_LOG_DEBUG("%s: reserving a graph for ubatch with n_tokens = %4u, n_seqs = %2u, n_outputs = %4u\n", __func__, n_tokens, n_seqs, n_outputs); + if (n_tokens % n_seqs != 0) { + n_tokens = (n_tokens / n_seqs) * n_seqs; + n_outputs = std::min(n_outputs, n_tokens); + + LLAMA_LOG_DEBUG("%s: making n_tokens a multiple of n_seqs - n_tokens = %u, n_seqs = %u, n_outputs = %u\n", __func__, n_tokens, n_seqs, n_outputs); + } + // store the n_outputs as it is, and restore it afterwards // TODO: not sure if needed, might simplify in the future by removing this const auto save_n_outputs = this->n_outputs; From bffb9d4a15db370db55f5f009bf4b8738b359114 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sun, 25 May 2025 17:55:07 +0300 Subject: [PATCH 05/14] kv-cache : fix typo [no ci] --- src/llama-kv-cache.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/llama-kv-cache.cpp b/src/llama-kv-cache.cpp index 862e3f8ee35c5..0562fe189d6f6 100644 --- a/src/llama-kv-cache.cpp +++ b/src/llama-kv-cache.cpp @@ -574,7 +574,7 @@ int32_t llama_kv_cache_unified::find_slot(const llama_ubatch & ubatch) const { // can we use this cell? either: // - the cell is empty - // - the cell is occupied only by the same sequence, and the sequence is not masked + // - the cell is occupied only by the same sequence, and the pos is masked const bool can_use = cells.is_empty(head_cur + i) || ( @@ -2285,7 +2285,7 @@ llama_memory_decode_state_ptr llama_kv_cache_recurrent::init(const llama_batch & } bool llama_kv_cache_recurrent::prepare(const std::vector & ubatches) { - // simply remember the full state + // simply remember the full state because it is very small for this type of cache // TODO: optimize auto org_cells = cells; auto org_used = used; From 32cc9eab24e58f90f3b22721db72a920f486585e Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sun, 25 May 2025 21:58:57 +0300 Subject: [PATCH 06/14] kv-cache : fix find_slot() logic for free slots ggml-ci --- src/llama-kv-cache.cpp | 11 +++++++---- src/llama-kv-cache.h | 2 -- 2 files changed, 7 insertions(+), 6 deletions(-) diff --git a/src/llama-kv-cache.cpp b/src/llama-kv-cache.cpp index 0562fe189d6f6..9423ed958780a 100644 --- a/src/llama-kv-cache.cpp +++ b/src/llama-kv-cache.cpp @@ -570,6 +570,7 @@ int32_t llama_kv_cache_unified::find_slot(const llama_ubatch & ubatch) const { bool found = true; for (uint32_t i = 0; i < n_tokens; i++) { + const llama_pos pos = ubatch.pos[i]; const llama_seq_id seq_id = ubatch.seq_id[i][0]; // can we use this cell? either: @@ -578,10 +579,12 @@ int32_t llama_kv_cache_unified::find_slot(const llama_ubatch & ubatch) const { const bool can_use = cells.is_empty(head_cur + i) || ( - cells.pos_get(head_cur + i) <= ubatch.pos[i] && // causal mask - cells.seq_has(head_cur + i, seq_id) && // sequence mask - cells.seq_count(head_cur + i) == 1 && - is_masked_swa(cells.pos_get(head_cur + i), ubatch.seq_pos_min[seq_id]) // SWA mask + cells.seq_has (head_cur + i, seq_id) && // sequence mask + cells.seq_count(head_cur + i) == 1 && + ( + cells.pos_get (head_cur + i) >= pos || // causal mask + is_masked_swa(cells.pos_get(head_cur + i), ubatch.seq_pos_min[seq_id]) // SWA mask + ) ); if (!can_use) { diff --git a/src/llama-kv-cache.h b/src/llama-kv-cache.h index 2ba61d62427d7..f1ba7cba390e2 100644 --- a/src/llama-kv-cache.h +++ b/src/llama-kv-cache.h @@ -33,7 +33,6 @@ struct llama_kv_cache : public llama_memory_i { // process any pending defrag/shift/etc. operations // optionally call once before processing a new batch // return true if any operations were performed - // will reserve a new worst-case graph if needed virtual bool update(llama_context & lctx) = 0; // schedule a defrag if the fragmentation threshold is exceeded. otherwise, do nothing @@ -240,7 +239,6 @@ class llama_kv_cache_unified : public llama_kv_cache { // utilizes two instances of llama_kv_cache_unified // the first instance is for the non-SWA layers of the model and the second instance is for the SWA layers -// upon successful processing of the batch, the SWA cache removes old tokens outside the n_swa window class llama_kv_cache_unified_iswa : public llama_kv_cache { public: From f97de9b784f043b319feecff62284219555734e1 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Mon, 26 May 2025 11:38:49 +0300 Subject: [PATCH 07/14] llama : add TODO for deprecating the defrag API in the future --- include/llama.h | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/include/llama.h b/include/llama.h index f0a1061dec960..29677d74207a3 100644 --- a/include/llama.h +++ b/include/llama.h @@ -692,13 +692,14 @@ extern "C" { // This will be applied: // - lazily on next llama_decode() // - explicitly with llama_kv_self_update() + // TODO: deprecate and always update the cache lazily [TAG: API_KV_NO_DEFRAG] LLAMA_API void llama_kv_self_defrag(struct llama_context * ctx); // Check if the context supports KV cache shifting LLAMA_API bool llama_kv_self_can_shift(const struct llama_context * ctx); // Apply the KV cache updates (such as K-shifts, defragmentation, etc.) - // TODO: deprecate and always update the cache lazily + // TODO: deprecate and always update the cache lazily [TAG: API_KV_NO_DEFRAG] LLAMA_API void llama_kv_self_update(struct llama_context * ctx); // From 7764d91497d853f2c6c255e3ae0daa39e94ab2df Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Tue, 27 May 2025 16:10:15 +0300 Subject: [PATCH 08/14] kv-cache : improve find_slot() using min/max seq pos info ggml-ci --- src/llama-batch.cpp | 20 ---------------- src/llama-batch.h | 6 ----- src/llama-context.cpp | 2 +- src/llama-kv-cache.cpp | 52 ++++++++++++++++++++++++++++++------------ src/llama-kv-cells.h | 41 ++++++++++++++++++++------------- 5 files changed, 63 insertions(+), 58 deletions(-) diff --git a/src/llama-batch.cpp b/src/llama-batch.cpp index ac6bbef57d881..6a19a243118d3 100644 --- a/src/llama-batch.cpp +++ b/src/llama-batch.cpp @@ -4,21 +4,6 @@ #include #include -void llama_ubatch::update() { - if (equal_seqs) { - // TODO: for now don't compute min/max for recurrent batches since we don't need this. - // the batches will be refactored anyway, so we'll fix this later - return; - } - - for (uint32_t i = 0; i < n_tokens; ++i) { - const llama_seq_id s = seq_id[i][0]; - - seq_pos_min[s] = seq_pos_min[s] == -1 ? pos[i] : std::min(seq_pos_min[s], pos[i]); - seq_pos_max[s] = seq_pos_max[s] == -1 ? pos[i] : std::max(seq_pos_max[s], pos[i]); - } -} - llama_ubatch llama_sbatch::reserve_ubatch(size_t n_ubatch, bool has_embd) { // clear empty sequences // the previous ubatch is assumed to be gone, @@ -47,8 +32,6 @@ llama_ubatch llama_sbatch::reserve_ubatch(size_t n_ubatch, bool has_embd) { /*n_tokens =*/ 0, /*n_seq_tokens =*/ 0, /*n_seqs =*/ 0, - /*seq_pos_min =*/ {-1}, - /*seq_pos_max =*/ {-1}, /*token =*/ !has_embd ? udata.token.data() : nullptr, /*embd =*/ has_embd ? udata.embd.data() : nullptr, /*pos =*/ udata.pos.data(), @@ -172,7 +155,6 @@ llama_ubatch llama_sbatch::split_simple(size_t n_ubatch) { GGML_ASSERT(seq.size() == 1 && s.n_seq_id == 0); // don't mix with other splits add_seq_to_ubatch(ubatch, s, length); } - ubatch.update(); return ubatch; } @@ -200,7 +182,6 @@ llama_ubatch llama_sbatch::split_equal(size_t n_ubatch) { if (length + n_tokens_in_ubatch > n_ubatch) { break; } } } - ubatch.update(); return ubatch; } @@ -213,7 +194,6 @@ llama_ubatch llama_sbatch::split_seq(size_t n_ubatch) { GGML_ASSERT(s.n_seq_id > 0); // should not be mixed with simple splits add_seq_to_ubatch(ubatch, s, length); } - ubatch.update(); return ubatch; } diff --git a/src/llama-batch.h b/src/llama-batch.h index bd65ec6a935b4..b8260b94fd2d0 100644 --- a/src/llama-batch.h +++ b/src/llama-batch.h @@ -1,7 +1,6 @@ #pragma once #include "llama.h" -#include "llama-cparams.h" #include #include @@ -9,8 +8,6 @@ // very similar to llama_batch, // but has more metadata about sequences struct llama_ubatch { - void update(); - bool equal_seqs; // TODO: whole_seqs for embeddings? @@ -18,9 +15,6 @@ struct llama_ubatch { uint32_t n_seq_tokens; // tokens per sequence uint32_t n_seqs; - llama_pos seq_pos_min[LLAMA_MAX_PARALLEL_SEQUENCES]; // min position of each sequence - llama_pos seq_pos_max[LLAMA_MAX_PARALLEL_SEQUENCES]; // max position of each sequence - llama_token * token; // [n_tokens] float * embd; // [n_embd, n_tokens] llama_pos * pos; // [n_tokens] diff --git a/src/llama-context.cpp b/src/llama-context.cpp index ea2d718f49f93..e3409158e3ab0 100644 --- a/src/llama-context.cpp +++ b/src/llama-context.cpp @@ -1233,7 +1233,7 @@ ggml_cgraph * llama_context::graph_reserve(uint32_t n_tokens, uint32_t n_seqs, u this->n_outputs = n_outputs; llama_token token = model.vocab.token_bos(); // not actually used by llama_build_graph, but required to choose between token and embedding inputs graph - llama_ubatch ubatch = { true, n_tokens, n_tokens / n_seqs, n_seqs, {-1}, {-1}, &token, nullptr, nullptr, nullptr, nullptr, nullptr}; + llama_ubatch ubatch = { true, n_tokens, n_tokens / n_seqs, n_seqs, &token, nullptr, nullptr, nullptr, nullptr, nullptr}; auto * gf = graph_init(); auto res = graph_build(ctx_compute.get(), gf, ubatch, LLM_GRAPH_TYPE_DEFAULT); diff --git a/src/llama-kv-cache.cpp b/src/llama-kv-cache.cpp index 9423ed958780a..f316070e92495 100644 --- a/src/llama-kv-cache.cpp +++ b/src/llama-kv-cache.cpp @@ -548,7 +548,7 @@ int32_t llama_kv_cache_unified::find_slot(const llama_ubatch & ubatch) const { if (cells.is_empty(i)) { ss += '.'; } else { - ss += 'x'; + ss += std::to_string(cells.seq_get(i)); } if (i%256 == 255) { ss += '\n'; @@ -557,6 +557,10 @@ int32_t llama_kv_cache_unified::find_slot(const llama_ubatch & ubatch) const { } LLAMA_LOG_WARN("\n%s\n", ss.c_str()); } + + LLAMA_LOG_WARN("kv_cells: n_swa = %4d, min[0] = %5d, max[0] = %5d\n", n_swa, cells.seq_pos_min(0), cells.seq_pos_max(0)); + LLAMA_LOG_WARN("kv_cells: n_swa = %4d, min[1] = %5d, max[1] = %5d\n", n_swa, cells.seq_pos_min(1), cells.seq_pos_max(1)); + LLAMA_LOG_WARN("kv_cells: n_swa = %4d, min[2] = %5d, max[2] = %5d\n", n_swa, cells.seq_pos_min(2), cells.seq_pos_max(2)); #endif uint32_t n_tested = 0; @@ -568,6 +572,12 @@ int32_t llama_kv_cache_unified::find_slot(const llama_ubatch & ubatch) const { continue; } + // keep track of what the minimum sequence positions would be if we accept the ubatch + llama_seq_id seq_pos_min[LLAMA_MAX_PARALLEL_SEQUENCES]; + for (int s = 0; s < LLAMA_MAX_PARALLEL_SEQUENCES; ++s) { + seq_pos_min[s] = cells.seq_pos_min(s); + } + bool found = true; for (uint32_t i = 0; i < n_tokens; i++) { const llama_pos pos = ubatch.pos[i]; @@ -575,17 +585,31 @@ int32_t llama_kv_cache_unified::find_slot(const llama_ubatch & ubatch) const { // can we use this cell? either: // - the cell is empty - // - the cell is occupied only by the same sequence, and the pos is masked - const bool can_use = - cells.is_empty(head_cur + i) || - ( - cells.seq_has (head_cur + i, seq_id) && // sequence mask - cells.seq_count(head_cur + i) == 1 && - ( - cells.pos_get (head_cur + i) >= pos || // causal mask - is_masked_swa(cells.pos_get(head_cur + i), ubatch.seq_pos_min[seq_id]) // SWA mask - ) - ); + // - the cell is occupied only by one sequence: + // - mask causally, if the sequence is the same as the one we are inserting + // - mask SWA, using current max pos for that sequence in the cache + // always insert in the cell with minimum pos + bool can_use = cells.is_empty(head_cur + i); + + if (!can_use && cells.seq_count(head_cur + i) == 1) { + const llama_pos pos_cell = cells.pos_get(head_cur + i); + + // causal mask + if (cells.seq_has(head_cur + i, seq_id)) { + can_use = pos_cell >= pos; + } + + if (!can_use) { + const llama_seq_id seq_id_cell = cells.seq_get(head_cur + i); + + // SWA mask + if (pos_cell == seq_pos_min[seq_id_cell] && + is_masked_swa(pos_cell, cells.seq_pos_max(seq_id_cell) + 1)) { + seq_pos_min[seq_id_cell]++; + can_use = true; + } + } + } if (!can_use) { found = false; @@ -613,9 +637,7 @@ void llama_kv_cache_unified::fill_slot(uint32_t head_cur, const llama_ubatch & u for (uint32_t i = 0; i < ubatch.n_tokens; ++i) { if (!cells.is_empty(head + i)) { - cells.pos_chg(head + i, ubatch.pos[i]); - - continue; + cells.rm(head + i); } cells.pos_set(head + i, ubatch.pos[i]); diff --git a/src/llama-kv-cells.h b/src/llama-kv-cells.h index 5c6b8d0dcdc78..d5331957105cd 100644 --- a/src/llama-kv-cells.h +++ b/src/llama-kv-cells.h @@ -68,12 +68,6 @@ class llama_kv_cells_unified { // the index of the last cell that is used + 1 // return 0 if no cells are used uint32_t used_max_p1() const { -#if 0 - if (!seq_pos[0].empty()) printf("kv_cells: min[0] = %5d, max[0] = %5d\n", *seq_pos[0].begin(), *seq_pos[0].rbegin()); - if (!seq_pos[1].empty()) printf("kv_cells: min[1] = %5d, max[1] = %5d\n", *seq_pos[1].begin(), *seq_pos[1].rbegin()); - if (!seq_pos[2].empty()) printf("kv_cells: min[2] = %5d, max[2] = %5d\n", *seq_pos[2].begin(), *seq_pos[2].rbegin()); -#endif - return used.empty() ? 0 : *used.rbegin() + 1; } @@ -144,6 +138,18 @@ class llama_kv_cells_unified { } } + void rm(uint32_t i) { + assert(i < pos.size()); + assert(pos[i] != -1); + + seq_pos_rm(i); + + pos[i] = -1; + seq[i].reset(); + + used.erase(i); + } + // note: call only if the cell has seq_id // return true if the cell becomes empty bool seq_rm(uint32_t i, llama_seq_id seq_id) { @@ -220,6 +226,18 @@ class llama_kv_cells_unified { seq_pos[seq_id].insert(pos[i]); } + llama_seq_id seq_get(uint32_t i) const { + assert(seq[i].count() == 1); + + for (int s = 0; s < LLAMA_MAX_PARALLEL_SEQUENCES; ++s) { + if (seq[i].test(s)) { + return s; + } + } + + return -1; + } + // the minimum position of sequence seq_id currently present in any of the cells // return -1 if the sequence is not present llama_pos seq_pos_min(llama_seq_id seq_id) const { @@ -275,22 +293,13 @@ class llama_kv_cells_unified { void pos_set(uint32_t i, llama_pos p) { assert(i < pos.size()); assert(pos[i] == -1); + assert(seq[i].none()); pos[i] = p; used.insert(i); } - // change the position of a non-empty cell - // does not modify "has_shift" - // note: call only if the cell is not empty - void pos_chg(uint32_t i, llama_pos p) { - assert(i < pos.size()); - assert(pos[i] != -1); - - pos[i] = p; - } - // pos[i] = pos[i] + d // sets "has_shift" to true // note: call only if the cell is not empty From 780bba94d84995b6607830a07f72cf742f16a032 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Wed, 28 May 2025 13:50:14 +0300 Subject: [PATCH 09/14] llama : handle aborts and compute errors ggml-ci --- include/llama.h | 2 + src/llama-context.cpp | 111 +++++++++++++++++++++++++++++------------- src/llama-context.h | 12 +++-- 3 files changed, 89 insertions(+), 36 deletions(-) diff --git a/include/llama.h b/include/llama.h index 29677d74207a3..adc4c69288a3d 100644 --- a/include/llama.h +++ b/include/llama.h @@ -677,12 +677,14 @@ extern "C" { // Returns the smallest position present in the KV cache for the specified sequence // This is typically non-zero only for SWA caches + // Note that all positions in the range [pos_min, pos_max] are guaranteed to be present in the KV cache // Return -1 if the sequence is empty LLAMA_API llama_pos llama_kv_self_seq_pos_min( struct llama_context * ctx, llama_seq_id seq_id); // Returns the largest position present in the KV cache for the specified sequence + // Note that all positions in the range [pos_min, pos_max] are guaranteed to be present in the KV cache // Return -1 if the sequence is empty LLAMA_API llama_pos llama_kv_self_seq_pos_max( struct llama_context * ctx, diff --git a/src/llama-context.cpp b/src/llama-context.cpp index e3409158e3ab0..dd16e3d63e97e 100644 --- a/src/llama-context.cpp +++ b/src/llama-context.cpp @@ -6,9 +6,10 @@ #include "llama-model.h" #include "llama-kv-cache.h" +#include #include +#include #include -#include // // llama_context @@ -632,6 +633,49 @@ bool llama_context::apply_adapter_cvec( return cvec.apply(model, data, len, n_embd, il_start, il_end); } +llm_graph_result_ptr llama_context::process(const llama_ubatch & ubatch, llm_graph_type gtype, ggml_status * ret) { + auto * gf = graph_init(); + if (!gf) { + LLAMA_LOG_ERROR("%s: failed to initialize graph\n", __func__); + if (ret) { + *ret = GGML_STATUS_FAILED; + } + return nullptr; + } + + auto res = graph_build(ctx_compute.get(), gf, ubatch, gtype); + if (!res) { + LLAMA_LOG_ERROR("%s: failed to build graph\n", __func__); + if (ret) { + *ret = GGML_STATUS_FAILED; + } + return nullptr; + } + + // LLAMA_LOG_INFO("graph build time: %.3f ms (%d nodes, %d leafs)\n", (ggml_time_us() - t_start_us)/1000.0, gf->n_nodes, gf->n_leafs); + + if (!ggml_backend_sched_alloc_graph(sched.get(), gf)) { + LLAMA_LOG_ERROR("%s: failed to allocate graph\n", __func__); + if (ret) { + *ret = GGML_STATUS_ALLOC_FAILED; + } + return nullptr; + } + + res->set_inputs(&ubatch); + + const auto status = graph_compute(gf, ubatch.n_tokens > 1); + if (status != GGML_STATUS_SUCCESS) { + LLAMA_LOG_ERROR("%s: failed to compute graph, compute status: %d\n", __func__, status); + if (ret) { + *ret = status; + } + return nullptr; + } + + return res; +} + int llama_context::encode(llama_batch & inp_batch) { if (inp_batch.n_tokens == 0) { LLAMA_LOG_ERROR("%s: n_tokens == 0\n", __func__); @@ -703,26 +747,18 @@ int llama_context::encode(llama_batch & inp_batch) { // ref: https://github.com/ggml-org/llama.cpp/pull/12181#issuecomment-2730451223 cparams.causal_attn = false; - auto * gf = graph_init(); - auto res = graph_build(ctx_compute.get(), gf, ubatch, LLM_GRAPH_TYPE_ENCODER); - - ggml_backend_sched_alloc_graph(sched.get(), gf); - - res->set_inputs(&ubatch); + ggml_status status; + auto res = process(ubatch, LLM_GRAPH_TYPE_ENCODER, &status); cparams.causal_attn = causal_attn_org; - const auto compute_status = graph_compute(gf, n_tokens > 1); - switch (compute_status) { - case GGML_STATUS_SUCCESS: - break; - case GGML_STATUS_ABORTED: - return 2; - case GGML_STATUS_ALLOC_FAILED: - return -2; - case GGML_STATUS_FAILED: - default: - return -3; + if (!res) { + switch (status) { + case GGML_STATUS_ABORTED: return 2; + case GGML_STATUS_ALLOC_FAILED: return -2; + case GGML_STATUS_FAILED: return -3; + case GGML_STATUS_SUCCESS: GGML_ABORT("should not happen"); + } } auto * t_embd = res->get_embd_pooled() ? res->get_embd_pooled() : res->get_embd(); @@ -942,25 +978,34 @@ int llama_context::decode(llama_batch & inp_batch) { ggml_backend_sched_reset(sched.get()); ggml_backend_sched_set_eval_callback(sched.get(), cparams.cb_eval, cparams.cb_eval_user_data); - auto * gf = graph_init(); - auto res = graph_build(ctx_compute.get(), gf, ubatch, LLM_GRAPH_TYPE_DECODER); + ggml_status status; + auto res = process(ubatch, LLM_GRAPH_TYPE_DECODER, &status); + + if (!res) { + // the last ubatch failed or was aborted -> remove all positions of that ubatch from the KV cache + llama_pos pos_min[LLAMA_MAX_PARALLEL_SEQUENCES] = { std::numeric_limits::max() }; + + for (uint32_t i = 0; i < ubatch.n_tokens; ++i) { + const auto & seq_id = ubatch.seq_id[i][0]; - // LLAMA_LOG_INFO("graph build time: %.3f ms (%d nodes, %d leafs)\n", (ggml_time_us() - t_start_us)/1000.0, gf->n_nodes, gf->n_leafs); + pos_min[seq_id] = std::min(pos_min[seq_id], ubatch.pos[i]); + } - ggml_backend_sched_alloc_graph(sched.get(), gf); + for (int s = 0; s < LLAMA_MAX_PARALLEL_SEQUENCES; ++s) { + if (pos_min[s] == std::numeric_limits::max()) { + continue; + } - res->set_inputs(&ubatch); + LLAMA_LOG_WARN("%s: removing KV cache entries for seq_id = %d, pos = [%d, +inf)\n", __func__, s, pos_min[s]); + + llama_kv_self_seq_rm(this, s, pos_min[s], -1); + } - const auto compute_status = graph_compute(gf, ubatch.n_tokens > 1); - if (compute_status != GGML_STATUS_SUCCESS) { - switch (compute_status) { - case GGML_STATUS_ABORTED: - return 2; - case GGML_STATUS_ALLOC_FAILED: - return -2; - case GGML_STATUS_FAILED: - default: - return -3; + switch (status) { + case GGML_STATUS_ABORTED: return 2; + case GGML_STATUS_ALLOC_FAILED: return -2; + case GGML_STATUS_FAILED: return -3; + case GGML_STATUS_SUCCESS: GGML_ABORT("should not happen"); } } diff --git a/src/llama-context.h b/src/llama-context.h index 2de7368293b12..6c11b584679ac 100644 --- a/src/llama-context.h +++ b/src/llama-context.h @@ -89,6 +89,14 @@ struct llama_context { int32_t il_start, int32_t il_end); + // process a single ubatch with a specific graph type + // ret contains the status of the graph computation + // returns nullptr only if ret != GGML_STATUS_SUCCESS + llm_graph_result_ptr process( + const llama_ubatch & ubatch, + llm_graph_type gtype, + ggml_status * ret); + int encode(llama_batch & inp_batch); int decode(llama_batch & inp_batch); @@ -181,9 +189,7 @@ struct llama_context { ggml_cgraph * graph_init(); // returns the result of ggml_backend_sched_graph_compute_async execution - ggml_status graph_compute( - ggml_cgraph * gf, - bool batched); + ggml_status graph_compute(ggml_cgraph * gf, bool batched); // reserve a graph with a dummy ubatch of the specified size ggml_cgraph * graph_reserve(uint32_t n_tokens, uint32_t n_seqs, uint32_t n_outputs); From dbcfa5f1d75170795458d1aa70de286d6f6df94c Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Wed, 28 May 2025 19:20:12 +0300 Subject: [PATCH 10/14] memory : extract state into llama_memory_state ggml-ci --- src/llama-context.cpp | 97 ++++---- src/llama-context.h | 24 +- src/llama-graph.cpp | 92 ++++---- src/llama-graph.h | 49 +++-- src/llama-kv-cache.cpp | 490 ++++++++++++++++++++++++----------------- src/llama-kv-cache.h | 276 ++++++++++++++++++++--- src/llama-memory.h | 29 ++- src/llama-model.cpp | 28 +-- 8 files changed, 712 insertions(+), 373 deletions(-) diff --git a/src/llama-context.cpp b/src/llama-context.cpp index dd16e3d63e97e..808fe5991088c 100644 --- a/src/llama-context.cpp +++ b/src/llama-context.cpp @@ -274,13 +274,16 @@ llama_context::llama_context( // simulate full KV cache llama_kv_cache * kv_self = static_cast(memory.get()); - kv_self->set_full(); + const auto kv_state = kv_self->init_full(); + if (!kv_state) { + throw std::runtime_error("failed to initialize KV cache"); + } cross.v_embd.clear(); // reserve pp graph first so that buffers are only allocated once { - auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens); + auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, kv_state.get()); if (!gf) { throw std::runtime_error("failed to allocate compute pp buffers"); } @@ -291,7 +294,7 @@ llama_context::llama_context( // reserve with tg graph to get the number of splits and nodes { - auto * gf = graph_reserve(1, 1, 1); + auto * gf = graph_reserve(1, 1, 1, kv_state.get()); if (!gf) { throw std::runtime_error("failed to allocate compute tg buffers"); } @@ -302,7 +305,7 @@ llama_context::llama_context( // reserve again with pp graph to avoid ggml-alloc reallocations during inference { - auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens); + auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, kv_state.get()); if (!gf) { throw std::runtime_error("failed to allocate compute pp buffers"); } @@ -430,12 +433,15 @@ void llama_context::kv_self_update() { if (kv_self->update(*this)) { // if the KV cache did any computation, we have to reserve a new worst-case graph - kv_self->set_full(); + const auto kv_state = kv_self->init_full(); + if (!kv_state) { + throw std::runtime_error("failed to initialize KV cache"); + } const uint32_t n_seqs = cparams.n_seq_max; const uint32_t n_tokens = std::min(cparams.n_ctx, cparams.n_ubatch); - auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens); + auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, kv_state.get()); if (!gf) { LLAMA_LOG_ERROR("%s: failed to reserve graph after the KV cache update\n", __func__); } @@ -633,22 +639,24 @@ bool llama_context::apply_adapter_cvec( return cvec.apply(model, data, len, n_embd, il_start, il_end); } -llm_graph_result_ptr llama_context::process(const llama_ubatch & ubatch, llm_graph_type gtype, ggml_status * ret) { +llm_graph_result_ptr llama_context::process_ubatch(const llama_ubatch & ubatch, llm_graph_type gtype, llama_memory_state_i * mstate, ggml_status & ret) { + if (mstate && !mstate->apply()) { + LLAMA_LOG_ERROR("%s: failed to apply memory state\n", __func__); + ret = GGML_STATUS_FAILED; + return nullptr; + } + auto * gf = graph_init(); if (!gf) { LLAMA_LOG_ERROR("%s: failed to initialize graph\n", __func__); - if (ret) { - *ret = GGML_STATUS_FAILED; - } + ret = GGML_STATUS_FAILED; return nullptr; } - auto res = graph_build(ctx_compute.get(), gf, ubatch, gtype); + auto res = graph_build(ctx_compute.get(), gf, ubatch, gtype, mstate); if (!res) { LLAMA_LOG_ERROR("%s: failed to build graph\n", __func__); - if (ret) { - *ret = GGML_STATUS_FAILED; - } + ret = GGML_STATUS_FAILED; return nullptr; } @@ -656,9 +664,7 @@ llm_graph_result_ptr llama_context::process(const llama_ubatch & ubatch, llm_gra if (!ggml_backend_sched_alloc_graph(sched.get(), gf)) { LLAMA_LOG_ERROR("%s: failed to allocate graph\n", __func__); - if (ret) { - *ret = GGML_STATUS_ALLOC_FAILED; - } + ret = GGML_STATUS_ALLOC_FAILED; return nullptr; } @@ -667,12 +673,12 @@ llm_graph_result_ptr llama_context::process(const llama_ubatch & ubatch, llm_gra const auto status = graph_compute(gf, ubatch.n_tokens > 1); if (status != GGML_STATUS_SUCCESS) { LLAMA_LOG_ERROR("%s: failed to compute graph, compute status: %d\n", __func__, status); - if (ret) { - *ret = status; - } + ret = status; return nullptr; } + ret = GGML_STATUS_SUCCESS; + return res; } @@ -748,7 +754,7 @@ int llama_context::encode(llama_batch & inp_batch) { cparams.causal_attn = false; ggml_status status; - auto res = process(ubatch, LLM_GRAPH_TYPE_ENCODER, &status); + const auto res = process_ubatch(ubatch, LLM_GRAPH_TYPE_ENCODER, nullptr, status); cparams.causal_attn = causal_attn_org; @@ -927,12 +933,12 @@ int llama_context::decode(llama_batch & inp_batch) { // handle any pending defrags/shifts kv_self_update(); - auto decode_state = kv_self->init(batch, cparams.n_ubatch, embd_pooled, /* logits_all */ n_outputs_all == n_tokens_all); - if (!decode_state) { + auto kv_state = kv_self->init_batch(batch, cparams.n_ubatch, embd_pooled, /* logits_all */ n_outputs_all == n_tokens_all); + if (!kv_state) { return -2; } - switch (decode_state->get_status()) { + switch (kv_state->get_status()) { case LLAMA_MEMORY_STATUS_SUCCESS: { } break; @@ -955,8 +961,8 @@ int llama_context::decode(llama_batch & inp_batch) { int64_t n_outputs_prev = 0; - while (const auto * ubatch_ptr = decode_state->next()) { - const auto & ubatch = *ubatch_ptr; + do { + const auto & ubatch = kv_state->get_ubatch(); // count the outputs in this u_batch { @@ -979,7 +985,7 @@ int llama_context::decode(llama_batch & inp_batch) { ggml_backend_sched_set_eval_callback(sched.get(), cparams.cb_eval, cparams.cb_eval_user_data); ggml_status status; - auto res = process(ubatch, LLM_GRAPH_TYPE_DECODER, &status); + const auto res = process_ubatch(ubatch, LLM_GRAPH_TYPE_DECODER, kv_state.get(), status); if (!res) { // the last ubatch failed or was aborted -> remove all positions of that ubatch from the KV cache @@ -1092,7 +1098,7 @@ int llama_context::decode(llama_batch & inp_batch) { } n_outputs_prev += n_outputs; - } + } while (kv_state->next()); // set to total number of outputs in the batch, for use in llama_get_logits_ith n_outputs = n_outputs_all; @@ -1101,7 +1107,7 @@ int llama_context::decode(llama_batch & inp_batch) { { bool sorted_output = true; - auto & out_ids = decode_state->out_ids(); + auto & out_ids = kv_state->out_ids(); GGML_ASSERT(out_ids.size() == (size_t) n_outputs_all); @@ -1261,7 +1267,7 @@ ggml_cgraph * llama_context::graph_init() { return ggml_new_graph_custom(ctx_compute.get(), graph_max_nodes(), false); } -ggml_cgraph * llama_context::graph_reserve(uint32_t n_tokens, uint32_t n_seqs, uint32_t n_outputs) { +ggml_cgraph * llama_context::graph_reserve(uint32_t n_tokens, uint32_t n_seqs, uint32_t n_outputs, const llama_memory_state_i * mstate) { LLAMA_LOG_DEBUG("%s: reserving a graph for ubatch with n_tokens = %4u, n_seqs = %2u, n_outputs = %4u\n", __func__, n_tokens, n_seqs, n_outputs); if (n_tokens % n_seqs != 0) { @@ -1281,7 +1287,7 @@ ggml_cgraph * llama_context::graph_reserve(uint32_t n_tokens, uint32_t n_seqs, u llama_ubatch ubatch = { true, n_tokens, n_tokens / n_seqs, n_seqs, &token, nullptr, nullptr, nullptr, nullptr, nullptr}; auto * gf = graph_init(); - auto res = graph_build(ctx_compute.get(), gf, ubatch, LLM_GRAPH_TYPE_DEFAULT); + auto res = graph_build(ctx_compute.get(), gf, ubatch, LLM_GRAPH_TYPE_DEFAULT, mstate); this->n_outputs = save_n_outputs; @@ -1302,10 +1308,11 @@ ggml_cgraph * llama_context::graph_reserve(uint32_t n_tokens, uint32_t n_seqs, u } llm_graph_result_ptr llama_context::graph_build( - ggml_context * ctx, - ggml_cgraph * gf, - const llama_ubatch & ubatch, - llm_graph_type gtype) { + ggml_context * ctx, + ggml_cgraph * gf, + const llama_ubatch & ubatch, + llm_graph_type gtype, + const llama_memory_state_i * mstate) { return model.build_graph( { /*.ctx =*/ ctx, @@ -1317,7 +1324,7 @@ llm_graph_result_ptr llama_context::graph_build( /*.backend_cpu =*/ backend_cpu, /*.cvec =*/ &cvec, /*.loras =*/ &loras, - /*.memory =*/ memory.get(), + /*.mstate =*/ mstate, /*.cross =*/ &cross, /*.n_outputs =*/ n_outputs, /*.cb =*/ graph_get_cb(), @@ -2020,8 +2027,8 @@ void llama_context::opt_epoch_iter( int64_t n_outputs_all = n_tokens_all; - auto decode_state = kv_self->init(batch, cparams.n_ubatch, embd_pooled, /* logits_all */ true); - if (!decode_state || decode_state->get_status() != LLAMA_MEMORY_STATUS_SUCCESS) { + auto kv_state = kv_self->init_batch(batch, cparams.n_ubatch, embd_pooled, /* logits_all */ true); + if (!kv_state || kv_state->get_status() != LLAMA_MEMORY_STATUS_SUCCESS) { LLAMA_LOG_ERROR("%s: could not initialize batch\n", __func__); break; } @@ -2033,13 +2040,18 @@ void llama_context::opt_epoch_iter( }; uint32_t pos_batch = 0; - while (const auto * ubatch_ptr = decode_state->next()) { - const auto & ubatch = *ubatch_ptr; + do { + const auto & ubatch = kv_state->get_ubatch(); n_outputs = ubatch.n_tokens; + if (!kv_state->apply()) { + LLAMA_LOG_ERROR("%s: failed to update the memory state\n", __func__); + break; + } + auto * gf = graph_init(); - auto res = graph_build(ctx_compute.get(), gf, ubatch, LLM_GRAPH_TYPE_DEFAULT); + auto res = graph_build(ctx_compute.get(), gf, ubatch, LLM_GRAPH_TYPE_DEFAULT, kv_state.get()); struct ggml_context * ctx_compute_opt; { @@ -2054,6 +2066,7 @@ void llama_context::opt_epoch_iter( } ggml_opt_prepare_alloc(opt_ctx, ctx_compute_opt, gf, res->get_tokens(), res->get_logits()); ggml_opt_alloc(opt_ctx, train); + res->set_inputs(&ubatch); { struct ggml_tensor * labels = ggml_opt_labels(opt_ctx); @@ -2073,7 +2086,7 @@ void llama_context::opt_epoch_iter( ggml_free(ctx_compute_opt); pos_batch += ubatch.n_tokens; - } + } while (kv_state->next()); } } diff --git a/src/llama-context.h b/src/llama-context.h index 6c11b584679ac..5b79bafa75db7 100644 --- a/src/llama-context.h +++ b/src/llama-context.h @@ -18,6 +18,9 @@ struct llama_kv_cache; class llama_io_read_i; class llama_io_write_i; +class llama_memory_i; +class llama_memory_state_i; + struct llama_context { // init scheduler and compute buffers, reserve worst-case graphs llama_context( @@ -90,12 +93,14 @@ struct llama_context { int32_t il_end); // process a single ubatch with a specific graph type + // if memory_state is provided, it will be applied first to the context's memory // ret contains the status of the graph computation // returns nullptr only if ret != GGML_STATUS_SUCCESS - llm_graph_result_ptr process( - const llama_ubatch & ubatch, - llm_graph_type gtype, - ggml_status * ret); + llm_graph_result_ptr process_ubatch( + const llama_ubatch & ubatch, + llm_graph_type gtype, + llama_memory_state_i * mstate, + ggml_status & ret); int encode(llama_batch & inp_batch); int decode(llama_batch & inp_batch); @@ -192,14 +197,15 @@ struct llama_context { ggml_status graph_compute(ggml_cgraph * gf, bool batched); // reserve a graph with a dummy ubatch of the specified size - ggml_cgraph * graph_reserve(uint32_t n_tokens, uint32_t n_seqs, uint32_t n_outputs); + ggml_cgraph * graph_reserve(uint32_t n_tokens, uint32_t n_seqs, uint32_t n_outputs, const llama_memory_state_i * mstate); private: llm_graph_result_ptr graph_build( - ggml_context * ctx, - ggml_cgraph * gf, - const llama_ubatch & ubatch, - llm_graph_type gtype); + ggml_context * ctx, + ggml_cgraph * gf, + const llama_ubatch & ubatch, + llm_graph_type gtype, + const llama_memory_state_i * mstate); llm_graph_cb graph_get_cb() const; diff --git a/src/llama-graph.cpp b/src/llama-graph.cpp index 7c383e2eb3f27..b30f6fb4f4145 100644 --- a/src/llama-graph.cpp +++ b/src/llama-graph.cpp @@ -83,7 +83,7 @@ void llm_graph_input_pos_bucket::set_input(const llama_ubatch * ubatch) { void llm_graph_input_pos_bucket_kv::set_input(const llama_ubatch * ubatch) { if (pos_bucket) { - kv_self->set_input_pos_bucket(pos_bucket, ubatch); + kv_state->set_input_pos_bucket(pos_bucket, ubatch); } } @@ -234,7 +234,7 @@ void llm_graph_input_cls::set_input(const llama_ubatch * ubatch) { void llm_graph_input_s_copy::set_input(const llama_ubatch * ubatch) { GGML_UNUSED(ubatch); - const int64_t n_kv = kv_self->n; + const int64_t n_kv = kv_state->get_n_kv(); if (s_copy) { GGML_ASSERT(ggml_backend_buffer_is_host(s_copy->buffer)); @@ -242,7 +242,7 @@ void llm_graph_input_s_copy::set_input(const llama_ubatch * ubatch) { // assuming copy destinations ALWAYS happen ONLY on the cells between head and head+n for (uint32_t i = 0; i < n_kv; ++i) { - data[i] = kv_self->s_copy(i); + data[i] = kv_state->s_copy(i); } } } @@ -250,7 +250,7 @@ void llm_graph_input_s_copy::set_input(const llama_ubatch * ubatch) { void llm_graph_input_s_mask::set_input(const llama_ubatch * ubatch) { GGML_UNUSED(ubatch); - const int64_t n_kv = kv_self->n; + const int64_t n_kv = kv_state->get_n_kv(); if (s_mask) { GGML_ASSERT(ggml_backend_buffer_is_host(s_mask->buffer)); @@ -258,7 +258,7 @@ void llm_graph_input_s_mask::set_input(const llama_ubatch * ubatch) { // clear unused states for (int i = 0; i < n_kv; ++i) { - data[i] = kv_self->s_mask(i); + data[i] = kv_state->s_mask(i); } } } @@ -362,17 +362,17 @@ void llm_graph_input_attn_no_cache::set_input(const llama_ubatch * ubatch) { void llm_graph_input_attn_kv_unified::set_input(const llama_ubatch * ubatch) { if (self_kq_mask) { - kv_self->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn); + kv_state->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn); } } void llm_graph_input_attn_kv_unified_iswa::set_input(const llama_ubatch * ubatch) { if (self_kq_mask) { - kv_self->get_kv_base()->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn); + kv_state->get_base()->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn); } if (self_kq_mask_swa) { - kv_self->get_kv_swa()->set_input_kq_mask(self_kq_mask_swa, ubatch, cparams.causal_attn); + kv_state->get_swa()->set_input_kq_mask(self_kq_mask_swa, ubatch, cparams.causal_attn); } } @@ -448,7 +448,7 @@ llm_graph_context::llm_graph_context(const llm_graph_params & params) : backend_cpu (params.backend_cpu), cvec (params.cvec), loras (params.loras), - memory (params.memory), + mstate (params.mstate), cross (params.cross), cb_func (params.cb), res (std::make_unique()) { @@ -954,11 +954,11 @@ ggml_tensor * llm_graph_context::build_inp_cls() const { } ggml_tensor * llm_graph_context::build_inp_s_copy() const { - const llama_kv_cache_recurrent * kv_self = static_cast(memory); + const auto * kv_state = static_cast(mstate); - auto inp = std::make_unique(kv_self); + auto inp = std::make_unique(kv_state); - const auto n_kv = kv_self->n; + const auto n_kv = kv_state->get_n_kv(); auto & cur = inp->s_copy; @@ -971,11 +971,11 @@ ggml_tensor * llm_graph_context::build_inp_s_copy() const { } ggml_tensor * llm_graph_context::build_inp_s_mask() const { - const llama_kv_cache_recurrent * kv_self = static_cast(memory); + const auto * kv_state = static_cast(mstate); - auto inp = std::make_unique(kv_self); + auto inp = std::make_unique(kv_state); - const auto n_kv = kv_self->n; + const auto n_kv = kv_state->get_n_kv(); auto & cur = inp->s_mask; @@ -1025,11 +1025,11 @@ ggml_tensor * llm_graph_context::build_inp_pos_bucket_enc() const { } ggml_tensor * llm_graph_context::build_inp_pos_bucket_dec() const { - const llama_kv_cache_unified * kv_self = static_cast(memory); + const auto * kv_state = static_cast(mstate); - auto inp = std::make_unique(hparams, kv_self); + auto inp = std::make_unique(hparams, kv_state); - const auto n_kv = kv_self->get_n(); + const auto n_kv = kv_state->get_n_kv(); auto & cur = inp->pos_bucket; @@ -1231,14 +1231,14 @@ ggml_tensor * llm_graph_context::build_attn( } llm_graph_input_attn_kv_unified * llm_graph_context::build_attn_inp_kv_unified() const { - const llama_kv_cache_unified * kv_self = static_cast(memory); + const auto * kv_state = static_cast(mstate); - auto inp = std::make_unique(hparams, cparams, kv_self); + auto inp = std::make_unique(hparams, cparams, kv_state); { GGML_ASSERT(hparams.swa_type == LLAMA_SWA_TYPE_NONE && "Use llama_kv_cache_unified_iswa for SWA"); - const auto n_kv = kv_self->get_n(); + const auto n_kv = kv_state->get_n_kv(); inp->self_kq_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD)); //cb(inp->self_kq_mask, "KQ_mask", -1); @@ -1268,19 +1268,19 @@ ggml_tensor * llm_graph_context::build_attn( ggml_build_forward_expand(gf, k_cur); ggml_build_forward_expand(gf, v_cur); - const llama_kv_cache_unified * kv_self = static_cast(memory); + const auto * kv_state = static_cast(mstate); // store to KV cache { - ggml_build_forward_expand(gf, kv_self->cpy_k(ctx0, k_cur, il)); - ggml_build_forward_expand(gf, kv_self->cpy_v(ctx0, v_cur, il)); + ggml_build_forward_expand(gf, kv_state->cpy_k(ctx0, k_cur, il)); + ggml_build_forward_expand(gf, kv_state->cpy_v(ctx0, v_cur, il)); } const auto & kq_mask = inp->get_kq_mask(); ggml_tensor * q = q_cur; - ggml_tensor * k = kv_self->get_k(ctx0, il); - ggml_tensor * v = kv_self->get_v(ctx0, il); + ggml_tensor * k = kv_state->get_k(ctx0, il); + ggml_tensor * v = kv_state->get_v(ctx0, il); ggml_tensor * cur = build_attn_mha(gf, q, k, v, kq_b, kq_mask, v_mla, kq_scale); cb(cur, "kqv_out", il); @@ -1301,12 +1301,12 @@ ggml_tensor * llm_graph_context::build_attn( } llm_graph_input_attn_kv_unified_iswa * llm_graph_context::build_attn_inp_kv_unified_iswa() const { - const llama_kv_cache_unified_iswa * kv_self = static_cast(memory); + const auto * kv_state = static_cast(mstate); - auto inp = std::make_unique(hparams, cparams, kv_self); + auto inp = std::make_unique(hparams, cparams, kv_state); { - const auto n_kv = kv_self->get_kv_base()->get_n(); + const auto n_kv = kv_state->get_base()->get_n_kv(); inp->self_kq_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD)); //cb(inp->self_kq_mask, "KQ_mask", -1); @@ -1318,7 +1318,7 @@ llm_graph_input_attn_kv_unified_iswa * llm_graph_context::build_attn_inp_kv_unif { GGML_ASSERT(hparams.swa_type != LLAMA_SWA_TYPE_NONE && "Use llama_kv_cache_unified for non-SWA"); - const auto n_kv = kv_self->get_kv_swa()->get_n(); + const auto n_kv = kv_state->get_swa()->get_n_kv(); inp->self_kq_mask_swa = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD)); //cb(inp->self_kq_mask_swa, "KQ_mask_swa", -1); @@ -1348,23 +1348,23 @@ ggml_tensor * llm_graph_context::build_attn( ggml_build_forward_expand(gf, k_cur); ggml_build_forward_expand(gf, v_cur); - const bool is_swa = hparams.is_swa(il); + const auto * kv_state_iswa = static_cast(mstate); - const llama_kv_cache_unified_iswa * kv_self = static_cast(memory); + const bool is_swa = hparams.is_swa(il); - const auto * kv = is_swa ? kv_self->get_kv_swa() : kv_self->get_kv_base(); + const auto * kv_state = is_swa ? kv_state_iswa->get_swa() : kv_state_iswa->get_base(); // store to KV cache { - ggml_build_forward_expand(gf, kv->cpy_k(ctx0, k_cur, il)); - ggml_build_forward_expand(gf, kv->cpy_v(ctx0, v_cur, il)); + ggml_build_forward_expand(gf, kv_state->cpy_k(ctx0, k_cur, il)); + ggml_build_forward_expand(gf, kv_state->cpy_v(ctx0, v_cur, il)); } const auto & kq_mask = is_swa ? inp->get_kq_mask_swa() : inp->get_kq_mask(); ggml_tensor * q = q_cur; - ggml_tensor * k = kv->get_k(ctx0, il); - ggml_tensor * v = kv->get_v(ctx0, il); + ggml_tensor * k = kv_state->get_k(ctx0, il); + ggml_tensor * v = kv_state->get_v(ctx0, il); ggml_tensor * cur = build_attn_mha(gf, q, k, v, kq_b, kq_mask, v_mla, kq_scale); cb(cur, "kqv_out", il); @@ -1446,12 +1446,12 @@ ggml_tensor * llm_graph_context::build_copy_mask_state( ggml_tensor * state_mask, int32_t n_state, int32_t n_seqs) const { - const llama_kv_cache_recurrent * kv_self = static_cast(memory); + const auto * kv_state = static_cast(mstate); - const auto n_kv = kv_self->n; - const auto kv_head = kv_self->head; + const auto n_kv = kv_state->get_n_kv(); + const auto kv_head = kv_state->get_head(); - ggml_tensor * states = ggml_reshape_2d(ctx0, s, n_state, kv_self->size); + ggml_tensor * states = ggml_reshape_2d(ctx0, s, n_state, kv_state->get_size()); // copy states // NOTE: assuming the copy destinations are ALL contained between kv_head and kv_head + n_kv @@ -1478,13 +1478,13 @@ ggml_tensor * llm_graph_context::build_rwkv_token_shift_load( ggml_tensor * state_mask, const llama_ubatch & ubatch, int il) const { - const llama_kv_cache_recurrent * kv_self = static_cast(memory); + const auto * kv_state = static_cast(mstate); const auto token_shift_count = hparams.token_shift_count; const int64_t n_seqs = ubatch.n_seqs; - ggml_tensor * token_shift_all = kv_self->k_l[il]; + ggml_tensor * token_shift_all = kv_state->get_k_l(il); ggml_tensor * token_shift = build_copy_mask_state( gf, token_shift_all, state_copy, state_mask, @@ -1499,19 +1499,19 @@ ggml_tensor * llm_graph_context::build_rwkv_token_shift_store( ggml_tensor * token_shift, const llama_ubatch & ubatch, int il) const { - const llama_kv_cache_recurrent * kv_self = static_cast(memory); + const auto * kv_state = static_cast(mstate); const auto token_shift_count = hparams.token_shift_count; const auto n_embd = hparams.n_embd; const int64_t n_seqs = ubatch.n_seqs; - const auto kv_head = kv_self->head; + const auto kv_head = kv_state->get_head(); return ggml_cpy( ctx0, ggml_view_1d(ctx0, token_shift, n_embd * n_seqs * token_shift_count, 0), - ggml_view_1d(ctx0, kv_self->k_l[il], hparams.n_embd_k_s() * n_seqs, hparams.n_embd_k_s() * kv_head * ggml_element_size(kv_self->k_l[il])) + ggml_view_1d(ctx0, kv_state->get_k_l(il), hparams.n_embd_k_s()*n_seqs, hparams.n_embd_k_s()*kv_head*ggml_element_size(kv_state->get_k_l(il))) ); } diff --git a/src/llama-graph.h b/src/llama-graph.h index 2b85bb25befba..d1c5dd1bf036f 100644 --- a/src/llama-graph.h +++ b/src/llama-graph.h @@ -17,10 +17,11 @@ struct ggml_tensor; struct llama_ubatch; struct llama_cparams; -class llama_memory_i; -class llama_kv_cache_unified; -class llama_kv_cache_unified_iswa; -class llama_kv_cache_recurrent; +class llama_memory_state_i; + +class llama_kv_cache_unified_state; +class llama_kv_cache_unified_iswa_state; +class llama_kv_cache_recurrent_state; // certain models (typically multi-modal) can produce different types of graphs enum llm_graph_type { @@ -133,7 +134,7 @@ class llm_graph_input_pos_bucket_kv : public llm_graph_input_i { public: llm_graph_input_pos_bucket_kv( const llama_hparams & hparams, - const llama_kv_cache_unified * kv_self) : hparams(hparams), kv_self(kv_self) {} + const llama_kv_cache_unified_state * kv_state) : hparams(hparams), kv_state(kv_state) {} virtual ~llm_graph_input_pos_bucket_kv() = default; void set_input(const llama_ubatch * ubatch) override; @@ -141,7 +142,7 @@ class llm_graph_input_pos_bucket_kv : public llm_graph_input_i { ggml_tensor * pos_bucket = nullptr; // I32 [n_kv, n_batch] const llama_hparams & hparams; - const llama_kv_cache_unified * kv_self; + const llama_kv_cache_unified_state * kv_state; }; class llm_graph_input_out_ids : public llm_graph_input_i { @@ -188,26 +189,26 @@ class llm_graph_input_cls : public llm_graph_input_i { class llm_graph_input_s_copy : public llm_graph_input_i { public: - llm_graph_input_s_copy(const llama_kv_cache_recurrent * kv_self) : kv_self(kv_self) {} + llm_graph_input_s_copy(const llama_kv_cache_recurrent_state * kv_state) : kv_state(kv_state) {} virtual ~llm_graph_input_s_copy() = default; void set_input(const llama_ubatch * ubatch) override; ggml_tensor * s_copy; // I32 [kv_size] - const llama_kv_cache_recurrent * kv_self; + const llama_kv_cache_recurrent_state * kv_state; }; class llm_graph_input_s_mask : public llm_graph_input_i { public: - llm_graph_input_s_mask(const llama_kv_cache_recurrent * kv_self) : kv_self(kv_self) {} + llm_graph_input_s_mask(const llama_kv_cache_recurrent_state * kv_state) : kv_state(kv_state) {} virtual ~llm_graph_input_s_mask() = default; void set_input(const llama_ubatch * ubatch) override; ggml_tensor * s_mask; // F32 [1, n_kv] - const llama_kv_cache_recurrent * kv_self; + const llama_kv_cache_recurrent_state * kv_state; }; class llm_graph_input_cross_embd : public llm_graph_input_i { @@ -247,10 +248,10 @@ class llm_graph_input_attn_kv_unified : public llm_graph_input_i { llm_graph_input_attn_kv_unified( const llama_hparams & hparams, const llama_cparams & cparams, - const llama_kv_cache_unified * kv_self) : + const llama_kv_cache_unified_state * kv_state) : hparams(hparams), cparams(cparams), - kv_self(kv_self) { + kv_state(kv_state) { } ~llm_graph_input_attn_kv_unified() = default; @@ -264,7 +265,7 @@ class llm_graph_input_attn_kv_unified : public llm_graph_input_i { const llama_hparams & hparams; const llama_cparams & cparams; - const llama_kv_cache_unified * kv_self; + const llama_kv_cache_unified_state * kv_state; }; class llm_graph_input_attn_kv_unified_iswa : public llm_graph_input_i { @@ -272,10 +273,10 @@ class llm_graph_input_attn_kv_unified_iswa : public llm_graph_input_i { llm_graph_input_attn_kv_unified_iswa( const llama_hparams & hparams, const llama_cparams & cparams, - const llama_kv_cache_unified_iswa * kv_self) : + const llama_kv_cache_unified_iswa_state * kv_state) : hparams(hparams), cparams(cparams), - kv_self(kv_self) { + kv_state(kv_state) { } ~llm_graph_input_attn_kv_unified_iswa() = default; @@ -292,7 +293,7 @@ class llm_graph_input_attn_kv_unified_iswa : public llm_graph_input_i { const llama_hparams & hparams; const llama_cparams & cparams; - const llama_kv_cache_unified_iswa * kv_self; + const llama_kv_cache_unified_iswa_state * kv_state; }; class llm_graph_input_attn_cross : public llm_graph_input_i { @@ -383,10 +384,10 @@ struct llm_graph_params { ggml_backend_sched_t sched; ggml_backend_t backend_cpu; - const llama_adapter_cvec * cvec; - const llama_adapter_loras * loras; - const llama_memory_i * memory; - const llama_cross * cross; + const llama_adapter_cvec * cvec; + const llama_adapter_loras * loras; + const llama_memory_state_i * mstate; + const llama_cross * cross; int32_t n_outputs; @@ -435,10 +436,10 @@ struct llm_graph_context { ggml_backend_t backend_cpu; // TODO: needed by build_attn_mha, figure out a way to remove? - const llama_adapter_cvec * cvec; - const llama_adapter_loras * loras; - const llama_memory_i * memory; - const llama_cross * cross; + const llama_adapter_cvec * cvec; + const llama_adapter_loras * loras; + const llama_memory_state_i * mstate; + const llama_cross * cross; const llm_graph_cb & cb_func; diff --git a/src/llama-kv-cache.cpp b/src/llama-kv-cache.cpp index f316070e92495..0906287c26226 100644 --- a/src/llama-kv-cache.cpp +++ b/src/llama-kv-cache.cpp @@ -17,66 +17,6 @@ // llama_kv_cache_unified // -class llama_kv_cache_unified_decode_state_t : public llama_memory_decode_state_i { -public: - llama_kv_cache_unified_decode_state_t(llama_memory_status status) : status(status) {} - - llama_kv_cache_unified_decode_state_t( - llama_memory_status status, - llama_kv_cache_unified * kv, - llama_sbatch sbatch, - std::vector heads, - std::vector ubatches) - : status(status), - kv(kv), - sbatch(std::move(sbatch)), - heads(std::move(heads)), - ubatches(std::move(ubatches)) { - } - - ~llama_kv_cache_unified_decode_state_t() = default; - - llama_ubatch * next() override { - assert(status == LLAMA_MEMORY_STATUS_SUCCESS); - - if (i_next >= ubatches.size()) { - return nullptr; - } - - kv->fill_slot(heads[i_next], ubatches[i_next]); - - return &ubatches[i_next++]; - } - - std::vector & out_ids() override { - assert(status == LLAMA_MEMORY_STATUS_SUCCESS); - - return sbatch.out_ids; - } - - llama_memory_status get_status() const override { - return status; - } - -private: - const llama_memory_status status; - - llama_kv_cache_unified * kv; - - llama_sbatch sbatch; - - // the index of the next ubatch to process - size_t i_next = 0; - - std::vector heads; - std::vector ubatches; -}; - -uint32_t llama_kv_cache_unified::get_padding(const llama_cparams & cparams) { - // the FA kernels require padding to avoid extra runtime boundary checks - return cparams.flash_attn ? 256u : 32u; -} - llama_kv_cache_unified::llama_kv_cache_unified( const llama_model & model, layer_filter_cb && filter, @@ -348,7 +288,7 @@ llama_pos llama_kv_cache_unified::seq_pos_max(llama_seq_id seq_id) const { return cells.seq_pos_max(seq_id); } -llama_memory_decode_state_ptr llama_kv_cache_unified::init( +llama_memory_state_ptr llama_kv_cache_unified::init_batch( const llama_batch & batch, uint32_t n_ubatch, bool embd_pooled, @@ -364,13 +304,17 @@ llama_memory_decode_state_ptr llama_kv_cache_unified::init( auto heads = prepare(ubatches); if (heads.empty()) { - return std::make_unique(LLAMA_MEMORY_STATUS_FAILED_PREPARE); + return std::make_unique(LLAMA_MEMORY_STATUS_FAILED_PREPARE); } - return std::make_unique(LLAMA_MEMORY_STATUS_SUCCESS, + return std::make_unique(LLAMA_MEMORY_STATUS_SUCCESS, this, std::move(sbatch), std::move(heads), std::move(ubatches)); } +llama_memory_state_ptr llama_kv_cache_unified::init_full() { + return std::make_unique(LLAMA_MEMORY_STATUS_SUCCESS, this); +} + std::vector llama_kv_cache_unified::prepare(const std::vector & ubatches) { std::vector res; @@ -401,7 +345,7 @@ std::vector llama_kv_cache_unified::prepare(const std::vector= 2048 ? std::max(0.0f, 1.0f - (float(cells.get_used() + n_pad)/n)) : 0.0f; + const float fragmentation = n_kv >= 2048 ? std::max(0.0f, 1.0f - (float(cells.get_used() + n_pad)/n_kv)) : 0.0f; // queue defragmentation for next llama_kv_cache_update if (fragmentation > thold) { @@ -507,17 +453,6 @@ void llama_kv_cache_unified::defrag_sched(float thold) { } } -void llama_kv_cache_unified::set_full() { - n = cells.size(); - - // when simulating a full KV cache, the specific value of the "head" pointer is not important because it does not - // affect the shapes of the tensors in the compute graph - it only affects the offsets of the K/V views. - // we should only guarantee that the head position won't cause out-of-bounds view of the K, V tensors, so - // setting it to 0 is the simplest way to achieve that - // ref: https://github.com/ggml-org/llama.cpp/issues/13359 - head = 0; -} - int32_t llama_kv_cache_unified::find_slot(const llama_ubatch & ubatch) const { const uint32_t n_tokens = ubatch.n_tokens; @@ -538,7 +473,7 @@ int32_t llama_kv_cache_unified::find_slot(const llama_ubatch & ubatch) const { //#define FIND_SLOT_DEBUG 1 #if FIND_SLOT_DEBUG - LLAMA_LOG_WARN("begin: n = %5d, used = %5d, head = %5d, n_swa = %5d\n", n, cells.get_used(), head, n_swa); + LLAMA_LOG_WARN("begin: n = %5d, used = %5d, head = %5d, n_swa = %5d\n", cells.used_max_p1(), cells.get_used(), head, n_swa); // for debugging { @@ -632,51 +567,48 @@ int32_t llama_kv_cache_unified::find_slot(const llama_ubatch & ubatch) const { return head_cur; } -void llama_kv_cache_unified::fill_slot(uint32_t head_cur, const llama_ubatch & ubatch) { - head = head_cur; - +void llama_kv_cache_unified::apply_ubatch(uint32_t head_cur, const llama_ubatch & ubatch) { for (uint32_t i = 0; i < ubatch.n_tokens; ++i) { - if (!cells.is_empty(head + i)) { - cells.rm(head + i); + if (!cells.is_empty(head_cur + i)) { + cells.rm(head_cur + i); } - cells.pos_set(head + i, ubatch.pos[i]); + cells.pos_set(head_cur + i, ubatch.pos[i]); for (int32_t j = 0; j < ubatch.n_seq_id[i]; j++) { - cells.seq_add(head + i, ubatch.seq_id[i][j]); + cells.seq_add(head_cur + i, ubatch.seq_id[i][j]); } } - // a heuristic, to avoid attending the full cache if it is not yet utilized - // as the cache gets filled, the benefit from this heuristic disappears - n = std::min(cells.size(), std::max(n_pad, GGML_PAD(cells.used_max_p1(), n_pad))); + // move the head at the end of the slot + head = head_cur + ubatch.n_tokens; } bool llama_kv_cache_unified::get_can_shift() const { return true; } -uint32_t llama_kv_cache_unified::get_n() const { - return n; -} - uint32_t llama_kv_cache_unified::get_size() const { return cells.size(); } -ggml_tensor * llama_kv_cache_unified::get_k(ggml_context * ctx, int32_t il) const { +uint32_t llama_kv_cache_unified::get_n_kv() const { + return std::min(cells.size(), std::max(n_pad, GGML_PAD(cells.used_max_p1(), n_pad))); +} + +ggml_tensor * llama_kv_cache_unified::get_k(ggml_context * ctx, int32_t il, uint32_t n_kv) const { const int32_t ikv = map_layer_ids.at(il); auto * k = layers[ikv].k; return ggml_view_3d(ctx, k, - hparams.n_embd_head_k, hparams.n_head_kv(il), n, + hparams.n_embd_head_k, hparams.n_head_kv(il), n_kv, ggml_row_size(k->type, hparams.n_embd_head_k), ggml_row_size(k->type, hparams.n_embd_k_gqa(il)), 0); } -ggml_tensor * llama_kv_cache_unified::get_v(ggml_context * ctx, int32_t il) const { +ggml_tensor * llama_kv_cache_unified::get_v(ggml_context * ctx, int32_t il, uint32_t n_kv) const { const int32_t ikv = map_layer_ids.at(il); auto * v = layers[ikv].v; @@ -684,7 +616,7 @@ ggml_tensor * llama_kv_cache_unified::get_v(ggml_context * ctx, int32_t il) cons if (!v_trans) { // note: v->nb[1] <= v->nb[2] return ggml_view_3d(ctx, v, - hparams.n_embd_head_v, hparams.n_head_kv(il), n, + hparams.n_embd_head_v, hparams.n_head_kv(il), n_kv, ggml_row_size(v->type, hparams.n_embd_head_v), // v->nb[1] ggml_row_size(v->type, hparams.n_embd_v_gqa(il)), // v->nb[2] 0); @@ -692,13 +624,13 @@ ggml_tensor * llama_kv_cache_unified::get_v(ggml_context * ctx, int32_t il) cons // note: v->nb[1] > v->nb[2] return ggml_view_3d(ctx, v, - n, hparams.n_head_kv(il), hparams.n_embd_head_v, + n_kv, hparams.n_head_kv(il), hparams.n_embd_head_v, ggml_row_size(v->type, v->ne[1]*hparams.n_embd_head_v), // v->nb[1] ggml_row_size(v->type, v->ne[1]), // v->nb[2] 0); } -ggml_tensor * llama_kv_cache_unified::cpy_k(ggml_context * ctx, ggml_tensor * k_cur, int32_t il) const { +ggml_tensor * llama_kv_cache_unified::cpy_k(ggml_context * ctx, ggml_tensor * k_cur, int32_t il, uint32_t head_cur) const { const int32_t ikv = map_layer_ids.at(il); auto * k = layers[ikv].k; @@ -707,12 +639,12 @@ ggml_tensor * llama_kv_cache_unified::cpy_k(ggml_context * ctx, ggml_tensor * k_ ggml_tensor * k_view = ggml_view_1d(ctx, k, n_tokens*hparams.n_embd_k_gqa(il), - ggml_row_size(k->type, hparams.n_embd_k_gqa(il))*head); + ggml_row_size(k->type, hparams.n_embd_k_gqa(il))*head_cur); return ggml_cpy(ctx, k_cur, k_view); } -ggml_tensor * llama_kv_cache_unified::cpy_v(ggml_context * ctx, ggml_tensor * v_cur, int32_t il) const { +ggml_tensor * llama_kv_cache_unified::cpy_v(ggml_context * ctx, ggml_tensor * v_cur, int32_t il, uint32_t head_cur) const { const int32_t ikv = map_layer_ids.at(il); auto * v = layers[ikv].v; @@ -726,12 +658,12 @@ ggml_tensor * llama_kv_cache_unified::cpy_v(ggml_context * ctx, ggml_tensor * v_ if (!v_trans) { v_view = ggml_view_1d(ctx, v, n_tokens*hparams.n_embd_v_gqa(il), - ggml_row_size(v->type, hparams.n_embd_v_gqa(il))*head); + ggml_row_size(v->type, hparams.n_embd_v_gqa(il))*head_cur); } else { // note: the V cache is transposed when not using flash attention v_view = ggml_view_2d(ctx, v, n_tokens, hparams.n_embd_v_gqa(il), (v->ne[1])*ggml_element_size(v), - ( head)*ggml_element_size(v)); + (head_cur)*ggml_element_size(v)); v_cur = ggml_transpose(ctx, v_cur); } @@ -747,7 +679,7 @@ void llama_kv_cache_unified::set_input_kq_mask(ggml_tensor * dst, const llama_ub GGML_ASSERT(ggml_backend_buffer_is_host(dst->buffer)); float * data = (float *) dst->data; - const int64_t n_kv = n; + const auto n_kv = dst->ne[0]; // Use only the previous KV cells of the correct sequence for each token of the ubatch. // It's assumed that if a token in the batch has multiple sequences, they are equivalent. @@ -768,7 +700,7 @@ void llama_kv_cache_unified::set_input_kq_mask(ggml_tensor * dst, const llama_ub for (int j = 0; j < n_seq_tokens; ++j) { const llama_pos p1 = ubatch->pos[s*n_seq_tokens + j]; - for (int i = 0; i < n_kv; ++i) { + for (uint32_t i = 0; i < n_kv; ++i) { float f = 0.0f; bool masked = false; @@ -804,7 +736,7 @@ void llama_kv_cache_unified::set_input_kq_mask(ggml_tensor * dst, const llama_ub // mask padded tokens if (data) { for (int i = n_tokens; i < GGML_PAD(n_tokens, GGML_KQ_MASK_PAD); ++i) { - for (int j = 0; j < n_kv; ++j) { + for (uint32_t j = 0; j < n_kv; ++j) { data[h*(n_kv*n_tokens) + i*n_kv + j] = -INFINITY; } } @@ -830,7 +762,7 @@ void llama_kv_cache_unified::set_input_pos_bucket(ggml_tensor * dst, const llama int32_t * data = (int32_t *) dst->data; - const int64_t n_kv = n; + const int32_t n_kv = dst->ne[0]; for (int h = 0; h < 1; ++h) { for (int j = 0; j < n_tokens; ++j) { @@ -1500,15 +1432,18 @@ bool llama_kv_cache_unified::state_read_meta(llama_io_read_i & io, uint32_t cell return false; } - fill_slot(head_cur, batch); + apply_ubatch(head_cur, batch); - // DEBUG CHECK: kv.head should be our first cell, kv.head + cell_count - 1 should be our last cell (verify seq_id and pos values) + // keep the head at the old position because we will read the KV data into it in state_read_data() + head = head_cur; + + // DEBUG CHECK: head_cur should be our first cell, head_cur + cell_count - 1 should be our last cell (verify seq_id and pos values) // Assume that this is one contiguous block of cells - GGML_ASSERT(head + cell_count <= cells.size()); - GGML_ASSERT(cells.pos_get(head) == batch.pos[0]); - GGML_ASSERT(cells.pos_get(head + cell_count - 1) == batch.pos[cell_count - 1]); - GGML_ASSERT(cells.seq_has(head, dest_seq_id)); - GGML_ASSERT(cells.seq_has(head + cell_count - 1, dest_seq_id)); + GGML_ASSERT(head_cur + cell_count <= cells.size()); + GGML_ASSERT(cells.pos_get(head_cur) == batch.pos[0]); + GGML_ASSERT(cells.pos_get(head_cur + cell_count - 1) == batch.pos[cell_count - 1]); + GGML_ASSERT(cells.seq_has(head_cur, dest_seq_id)); + GGML_ASSERT(cells.seq_has(head_cur + cell_count - 1, dest_seq_id)); } else { // whole KV cache restore @@ -1675,67 +1610,110 @@ bool llama_kv_cache_unified::state_read_data(llama_io_read_i & io, uint32_t cell } // -// llama_kv_cache_unified_iswa +// llama_kv_cache_unified_state // -class llama_kv_cache_unified_iswa_decode_state_t : public llama_memory_decode_state_i { -public: - llama_kv_cache_unified_iswa_decode_state_t(llama_memory_status status) : status(status) {} +llama_kv_cache_unified_state::llama_kv_cache_unified_state(llama_memory_status status) : status(status) {} + +llama_kv_cache_unified_state::llama_kv_cache_unified_state( + llama_memory_status status, + llama_kv_cache_unified * kv) : status(status), kv(kv) { + n_kv = kv->get_size(); + head = 0; + } - llama_kv_cache_unified_iswa_decode_state_t( +llama_kv_cache_unified_state::llama_kv_cache_unified_state( llama_memory_status status, - llama_kv_cache_unified_iswa * kv, + llama_kv_cache_unified * kv, llama_sbatch sbatch, - std::vector heads_base, - std::vector heads_swa, + std::vector heads, std::vector ubatches) : status(status), kv(kv), sbatch(std::move(sbatch)), - heads_base(std::move(heads_base)), - heads_swa(std::move(heads_swa)), + heads(std::move(heads)), ubatches(std::move(ubatches)) { } - ~llama_kv_cache_unified_iswa_decode_state_t() = default; +llama_kv_cache_unified_state::~llama_kv_cache_unified_state() = default; - llama_ubatch * next() override { - assert(status == LLAMA_MEMORY_STATUS_SUCCESS); +bool llama_kv_cache_unified_state::next() { + assert(status == LLAMA_MEMORY_STATUS_SUCCESS); - if (i_next >= ubatches.size()) { - return nullptr; - } + if (++i_next >= ubatches.size()) { + return false; + } - kv->get_kv_base()->fill_slot(heads_base[i_next], ubatches[i_next]); - kv->get_kv_swa ()->fill_slot(heads_swa [i_next], ubatches[i_next]); + return true; +} - return &ubatches[i_next++]; - } +bool llama_kv_cache_unified_state::apply() { + assert(status == LLAMA_MEMORY_STATUS_SUCCESS); - std::vector & out_ids() override { - assert(status == LLAMA_MEMORY_STATUS_SUCCESS); + kv->apply_ubatch(heads[i_next], ubatches[i_next]); - return sbatch.out_ids; - } + n_kv = kv->get_n_kv(); + head = heads[i_next]; - llama_memory_status get_status() const override { - return status; - } + return true; +} -private: - const llama_memory_status status; +std::vector & llama_kv_cache_unified_state::out_ids() { + assert(status == LLAMA_MEMORY_STATUS_SUCCESS); - llama_kv_cache_unified_iswa * kv; + return sbatch.out_ids; +} - llama_sbatch sbatch; +llama_memory_status llama_kv_cache_unified_state::get_status() const { + return status; +} - // the index of the next ubatch to process - size_t i_next = 0; +const llama_ubatch & llama_kv_cache_unified_state::get_ubatch() const { + assert(status == LLAMA_MEMORY_STATUS_SUCCESS); - std::vector heads_base; - std::vector heads_swa; - std::vector ubatches; -}; + return ubatches[i_next]; +} + +uint32_t llama_kv_cache_unified_state::get_n_kv() const { + return n_kv; +} + +ggml_tensor * llama_kv_cache_unified_state::get_k(ggml_context * ctx, int32_t il) const { + return kv->get_k(ctx, il, n_kv); +} + +ggml_tensor * llama_kv_cache_unified_state::get_v(ggml_context * ctx, int32_t il) const { + return kv->get_v(ctx, il, n_kv); +} + +ggml_tensor * llama_kv_cache_unified_state::cpy_k(ggml_context * ctx, ggml_tensor * k_cur, int32_t il) const { + return kv->cpy_k(ctx, k_cur, il, head); +} + +ggml_tensor * llama_kv_cache_unified_state::cpy_v(ggml_context * ctx, ggml_tensor * v_cur, int32_t il) const { + return kv->cpy_v(ctx, v_cur, il, head); +} + +void llama_kv_cache_unified_state::set_input_k_shift(ggml_tensor * dst) const { + kv->set_input_k_shift(dst); +} + +void llama_kv_cache_unified_state::set_input_kq_mask(ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn) const { + kv->set_input_kq_mask(dst, ubatch, causal_attn); +} + +void llama_kv_cache_unified_state::set_input_pos_bucket(ggml_tensor * dst, const llama_ubatch * ubatch) const { + kv->set_input_pos_bucket(dst, ubatch); +} + +uint32_t llama_kv_cache_unified::get_padding(const llama_cparams & cparams) { + // the FA kernels require padding to avoid extra runtime boundary checks + return cparams.flash_attn ? 256u : 32u; +} + +// +// llama_kv_cache_unified_iswa +// llama_kv_cache_unified_iswa::llama_kv_cache_unified_iswa( const llama_model & model, @@ -1821,7 +1799,7 @@ llama_pos llama_kv_cache_unified_iswa::seq_pos_max(llama_seq_id seq_id) const { return kv_swa->seq_pos_max(seq_id); } -llama_memory_decode_state_ptr llama_kv_cache_unified_iswa::init(const llama_batch & batch, uint32_t n_ubatch, bool embd_pooled, bool logits_all) { +llama_memory_state_ptr llama_kv_cache_unified_iswa::init_batch(const llama_batch & batch, uint32_t n_ubatch, bool embd_pooled, bool logits_all) { GGML_UNUSED(embd_pooled); auto sbatch = llama_sbatch(batch, hparams.n_embd, true, logits_all); @@ -1836,20 +1814,24 @@ llama_memory_decode_state_ptr llama_kv_cache_unified_iswa::init(const llama_batc auto heads_base = kv_base->prepare(ubatches); if (heads_base.empty()) { - return std::make_unique(LLAMA_MEMORY_STATUS_FAILED_PREPARE); + return std::make_unique(LLAMA_MEMORY_STATUS_FAILED_PREPARE); } auto heads_swa = kv_swa->prepare(ubatches); if (heads_swa.empty()) { - return std::make_unique(LLAMA_MEMORY_STATUS_FAILED_PREPARE); + return std::make_unique(LLAMA_MEMORY_STATUS_FAILED_PREPARE); } assert(heads_base.size() == heads_swa.size()); - return std::make_unique(LLAMA_MEMORY_STATUS_SUCCESS, + return std::make_unique(LLAMA_MEMORY_STATUS_SUCCESS, this, std::move(sbatch), std::move(heads_base), std::move(heads_swa), std::move(ubatches)); } +llama_memory_state_ptr llama_kv_cache_unified_iswa::init_full() { + return std::make_unique(LLAMA_MEMORY_STATUS_SUCCESS, this); +} + bool llama_kv_cache_unified_iswa::update(llama_context & lctx) { bool res = false; @@ -1864,11 +1846,6 @@ void llama_kv_cache_unified_iswa::defrag_sched(float thold) { kv_swa ->defrag_sched(thold); } -void llama_kv_cache_unified_iswa::set_full() { - kv_base->set_full(); - kv_swa ->set_full(); -} - bool llama_kv_cache_unified_iswa::get_can_shift() const { return kv_base->get_size() == kv_swa->get_size(); } @@ -1883,63 +1860,98 @@ void llama_kv_cache_unified_iswa::state_read(llama_io_read_i & io, llama_seq_id kv_swa ->state_read(io, seq_id); } -llama_kv_cache_unified * llama_kv_cache_unified_iswa::get_kv_base() const { +llama_kv_cache_unified * llama_kv_cache_unified_iswa::get_base() const { return kv_base.get(); } -llama_kv_cache_unified * llama_kv_cache_unified_iswa::get_kv_swa() const { +llama_kv_cache_unified * llama_kv_cache_unified_iswa::get_swa() const { return kv_swa.get(); } // -// llama_kv_cache_recurrent +// llama_kv_cache_unified_iswa_state // -class llama_kv_cache_recurrent_decode_state_t : public llama_memory_decode_state_i { -public: - llama_kv_cache_recurrent_decode_state_t(llama_memory_status status) : status(status) {} +llama_kv_cache_unified_iswa_state::llama_kv_cache_unified_iswa_state(llama_memory_status status) : status(status) {} - llama_kv_cache_recurrent_decode_state_t( - llama_memory_status status, - llama_kv_cache_recurrent * kv, - llama_sbatch sbatch, - std::vector ubatches) : status(status), kv(kv), sbatch(std::move(sbatch)), ubatches(std::move(ubatches)) {} +llama_kv_cache_unified_iswa_state::llama_kv_cache_unified_iswa_state( + llama_memory_status status, + llama_kv_cache_unified_iswa * kv) : status(status) { + state_base.reset(new llama_kv_cache_unified_state(status, kv->get_base())); + state_swa .reset(new llama_kv_cache_unified_state(status, kv->get_swa ())); +} - ~llama_kv_cache_recurrent_decode_state_t() override = default; +llama_kv_cache_unified_iswa_state::llama_kv_cache_unified_iswa_state( + llama_memory_status status, + llama_kv_cache_unified_iswa * kv, + llama_sbatch sbatch, + std::vector heads_base, + std::vector heads_swa, + std::vector ubatches) + : status(status), + sbatch(std::move(sbatch)), + ubatches(std::move(ubatches)) { + // note: here we copy the ubatches. not sure if this is ideal + state_base.reset(new llama_kv_cache_unified_state(status, kv->get_base(), {}, std::move(heads_base), this->ubatches)); + state_swa .reset(new llama_kv_cache_unified_state(status, kv->get_swa (), {}, std::move(heads_swa), this->ubatches)); + } - llama_ubatch * next() override { - assert(status == LLAMA_MEMORY_STATUS_SUCCESS); +llama_kv_cache_unified_iswa_state:: ~llama_kv_cache_unified_iswa_state() = default; - if (i_next >= ubatches.size()) { - return nullptr; - } +bool llama_kv_cache_unified_iswa_state::next() { + assert(status == LLAMA_MEMORY_STATUS_SUCCESS); - kv->find_slot(ubatches[i_next]); + state_base->next(); + state_swa ->next(); - return &ubatches[i_next++]; + if (++i_next >= ubatches.size()) { + return false; } - std::vector & out_ids() override { - assert(status == LLAMA_MEMORY_STATUS_SUCCESS); + return true; +} + +bool llama_kv_cache_unified_iswa_state::apply() { + assert(status == LLAMA_MEMORY_STATUS_SUCCESS); - return sbatch.out_ids; - } + bool res = true; - llama_memory_status get_status() const override { - return status; - } + res = res & state_base->apply(); + res = res & state_swa ->apply(); + + return res; +} -private: - const llama_memory_status status; +std::vector & llama_kv_cache_unified_iswa_state::out_ids() { + assert(status == LLAMA_MEMORY_STATUS_SUCCESS); - llama_kv_cache_recurrent * kv; + return sbatch.out_ids; +} - llama_sbatch sbatch; +llama_memory_status llama_kv_cache_unified_iswa_state::get_status() const { + return status; +} - size_t i_next = 0; +const llama_ubatch & llama_kv_cache_unified_iswa_state::get_ubatch() const { + assert(status == LLAMA_MEMORY_STATUS_SUCCESS); + return ubatches[i_next]; +} - std::vector ubatches; -}; +const llama_kv_cache_unified_state * llama_kv_cache_unified_iswa_state::get_base() const { + assert(status == LLAMA_MEMORY_STATUS_SUCCESS); + + return state_base.get(); +} + +const llama_kv_cache_unified_state * llama_kv_cache_unified_iswa_state::get_swa() const { + assert(status == LLAMA_MEMORY_STATUS_SUCCESS); + + return state_swa.get(); +} + +// +// llama_kv_cache_recurrent +// llama_kv_cache_recurrent::llama_kv_cache_recurrent( const llama_model & model, @@ -2282,7 +2294,7 @@ llama_pos llama_kv_cache_recurrent::seq_pos_max(llama_seq_id seq_id) const { return result; } -llama_memory_decode_state_ptr llama_kv_cache_recurrent::init(const llama_batch & batch, uint32_t n_ubatch, bool embd_pooled, bool logits_all) { +llama_memory_state_ptr llama_kv_cache_recurrent::init_batch(const llama_batch & batch, uint32_t n_ubatch, bool embd_pooled, bool logits_all) { GGML_UNUSED(embd_pooled); auto sbatch = llama_sbatch(batch, hparams.n_embd, false, logits_all); @@ -2303,10 +2315,14 @@ llama_memory_decode_state_ptr llama_kv_cache_recurrent::init(const llama_batch & } if (!prepare(ubatches)) { - return std::make_unique(LLAMA_MEMORY_STATUS_FAILED_PREPARE); + return std::make_unique(LLAMA_MEMORY_STATUS_FAILED_PREPARE); } - return std::make_unique(LLAMA_MEMORY_STATUS_SUCCESS, this, std::move(sbatch), std::move(ubatches)); + return std::make_unique(LLAMA_MEMORY_STATUS_SUCCESS, this, std::move(sbatch), std::move(ubatches)); +} + +llama_memory_state_ptr llama_kv_cache_recurrent::init_full() { + return std::make_unique(LLAMA_MEMORY_STATUS_SUCCESS, this); } bool llama_kv_cache_recurrent::prepare(const std::vector & ubatches) { @@ -2353,11 +2369,6 @@ void llama_kv_cache_recurrent::defrag_sched(float thold) { // noop } -void llama_kv_cache_recurrent::set_full() { - n = size; - head = 0; -} - bool llama_kv_cache_recurrent::find_slot(const llama_ubatch & ubatch) { const uint32_t n_tokens = ubatch.n_tokens; const uint32_t n_seqs = ubatch.n_seqs; @@ -2977,3 +2988,84 @@ bool llama_kv_cache_recurrent::state_read_data(llama_io_read_i & io, uint32_t ce return true; } + +// +// llama_kv_cache_recurrent_state +// + +llama_kv_cache_recurrent_state::llama_kv_cache_recurrent_state(llama_memory_status status) : status(status) {} + +llama_kv_cache_recurrent_state::llama_kv_cache_recurrent_state( + llama_memory_status status, + llama_kv_cache_recurrent * kv) : status(status), kv(kv), is_full(true) { +} + +llama_kv_cache_recurrent_state::llama_kv_cache_recurrent_state( + llama_memory_status status, + llama_kv_cache_recurrent * kv, + llama_sbatch sbatch, + std::vector ubatches) : status(status), kv(kv), sbatch(std::move(sbatch)), ubatches(std::move(ubatches)) {} + +llama_kv_cache_recurrent_state::~llama_kv_cache_recurrent_state() = default; + +bool llama_kv_cache_recurrent_state::next() { + assert(status == LLAMA_MEMORY_STATUS_SUCCESS); + + if (++i_next >= ubatches.size()) { + return false; + } + + return true; +} + +bool llama_kv_cache_recurrent_state::apply() { + assert(status == LLAMA_MEMORY_STATUS_SUCCESS); + + kv->find_slot(ubatches[i_next]); + + return true; +} + +std::vector & llama_kv_cache_recurrent_state::out_ids() { + assert(status == LLAMA_MEMORY_STATUS_SUCCESS); + + return sbatch.out_ids; +} + +llama_memory_status llama_kv_cache_recurrent_state::get_status() const { + return status; +} + +const llama_ubatch & llama_kv_cache_recurrent_state::get_ubatch() const { + assert(status == LLAMA_MEMORY_STATUS_SUCCESS); + + return ubatches[i_next]; +} + +uint32_t llama_kv_cache_recurrent_state::get_n_kv() const { + return is_full ? kv->size : kv->n; +} + +uint32_t llama_kv_cache_recurrent_state::get_head() const { + return is_full ? 0 : kv->head; +} + +uint32_t llama_kv_cache_recurrent_state::get_size() const { + return kv->size; +} + +ggml_tensor * llama_kv_cache_recurrent_state::get_k_l(int32_t il) const { + return kv->k_l[il]; +} + +ggml_tensor * llama_kv_cache_recurrent_state::get_v_l(int32_t il) const { + return kv->v_l[il]; +} + +int32_t llama_kv_cache_recurrent_state::s_copy(int i) const { + return kv->s_copy(i); +} + +float llama_kv_cache_recurrent_state::s_mask(int i) const { + return kv->s_mask(i); +} diff --git a/src/llama-kv-cache.h b/src/llama-kv-cache.h index f1ba7cba390e2..d2439e13603a0 100644 --- a/src/llama-kv-cache.h +++ b/src/llama-kv-cache.h @@ -2,6 +2,7 @@ #include "llama.h" #include "llama-io.h" +#include "llama-batch.h" #include "llama-graph.h" #include "llama-memory.h" #include "llama-kv-cells.h" @@ -14,8 +15,6 @@ struct llama_cparams; struct llama_hparams; -struct llama_ubatch; -struct llama_sbatch; struct llama_model; struct llama_context; @@ -23,25 +22,28 @@ struct llama_kv_cache : public llama_memory_i { virtual ~llama_kv_cache() = default; // split the input batch into a set of ubatches and verify that they can fit into the cache - // check the llama_memory_decode_state_i::get_status() for the result - virtual llama_memory_decode_state_ptr init( + // return a state object containing the ubatches and KV cache state required to process them + // check the llama_memory_state_i::get_status() for the result + virtual llama_memory_state_ptr init_batch( const llama_batch & batch, uint32_t n_ubatch, bool embd_pooled, bool logits_all) = 0; + // simulate full cache, used for allocating worst-case compute buffers + virtual llama_memory_state_ptr init_full() = 0; + // process any pending defrag/shift/etc. operations // optionally call once before processing a new batch // return true if any operations were performed virtual bool update(llama_context & lctx) = 0; // schedule a defrag if the fragmentation threshold is exceeded. otherwise, do nothing + // TODO: change to + // llama_memory_state_ptr init_defrag(float thold) = 0; + // virtual void defrag_sched(float thold) = 0; - // simulate full cache, used for allocating worst-case compute buffers - // TODO: remove - virtual void set_full() = 0; - // getters virtual bool get_can_shift() const = 0; @@ -100,18 +102,18 @@ class llama_kv_cache_unified : public llama_kv_cache { // llama_kv_cache // - llama_memory_decode_state_ptr init( + llama_memory_state_ptr init_batch( const llama_batch & batch, uint32_t n_ubatch, bool embd_pooled, bool logits_all) override; + llama_memory_state_ptr init_full() override; + bool update(llama_context & lctx) override; void defrag_sched(float thold) override; - void set_full() override; - bool get_can_shift() const override; // state write/load @@ -123,16 +125,25 @@ class llama_kv_cache_unified : public llama_kv_cache { // llama_kv_cache_unified specific API // - uint32_t get_n() const; uint32_t get_size() const; + // + // graph_build API + // + + uint32_t get_n_kv() const; + // get views of the current state of the cache - ggml_tensor * get_k(ggml_context * ctx, int32_t il) const; - ggml_tensor * get_v(ggml_context * ctx, int32_t il) const; + ggml_tensor * get_k(ggml_context * ctx, int32_t il, uint32_t n_kv) const; + ggml_tensor * get_v(ggml_context * ctx, int32_t il, uint32_t n_kv) const; - // store k_cur and v_cur in the cache based on the current head location - ggml_tensor * cpy_k(ggml_context * ctx, ggml_tensor * k_cur, int32_t il) const; - ggml_tensor * cpy_v(ggml_context * ctx, ggml_tensor * v_cur, int32_t il) const; + // store k_cur and v_cur in the cache based on the provided head location + ggml_tensor * cpy_k(ggml_context * ctx, ggml_tensor * k_cur, int32_t il, uint32_t head_cur) const; + ggml_tensor * cpy_v(ggml_context * ctx, ggml_tensor * v_cur, int32_t il, uint32_t head_cur) const; + + // + // preparation API + // // find places for the provided ubatches in the cache, returns the head locations // return empty vector on failure @@ -142,9 +153,12 @@ class llama_kv_cache_unified : public llama_kv_cache { // return -1 on failure to find a contiguous slot of kv cells int32_t find_slot(const llama_ubatch & ubatch) const; - // emplace the ubatch context into cells [head_cur, head_cur + ubatch.n_tokens) - // updates head = head_cur - void fill_slot(uint32_t head_cur, const llama_ubatch & ubatch); + // emplace the ubatch context into slot: [head_cur, head_cur + ubatch.n_tokens) + void apply_ubatch(uint32_t head_cur, const llama_ubatch & ubatch); + + // + // set_input API + // void set_input_kq_mask (ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn) const; void set_input_k_shift (ggml_tensor * dst) const; @@ -166,11 +180,9 @@ class llama_kv_cache_unified : public llama_kv_cache { bool do_defrag = false; bool v_trans = true; // the value tensor is transposed - uint32_t head = 0; // the location where the batch will be placed in the cache (see find_slot()) - - // computed before each graph build - // TODO: cells should start to maintain this value dynamically based on the edits - uint32_t n = 0; + // the current index from where we start searching for a free slot in the ring buffer of KV cells (see find_slot()) + // note: this is not part of the KV state and it's only used to speed-up the find_slot() method + uint32_t head = 0; const uint32_t n_seq_max = 1; @@ -233,6 +245,82 @@ class llama_kv_cache_unified : public llama_kv_cache { bool state_read_data(llama_io_read_i & io, uint32_t cell_count); }; +class llama_kv_cache_unified_state : public llama_memory_state_i { +public: + // used for errors + llama_kv_cache_unified_state(llama_memory_status status); + + // used to create a full-cache state + llama_kv_cache_unified_state( + llama_memory_status status, + llama_kv_cache_unified * kv); + + // used to create a state from a batch + llama_kv_cache_unified_state( + llama_memory_status status, + llama_kv_cache_unified * kv, + llama_sbatch sbatch, + std::vector heads, + std::vector ubatches); + + virtual ~llama_kv_cache_unified_state(); + + // + // llama_memory_state_i + // + + bool next() override; + bool apply() override; + + std::vector & out_ids() override; + + llama_memory_status get_status() const override; + const llama_ubatch & get_ubatch() const override; + + // + // llama_kv_cache_unified_state specific API + // + + uint32_t get_n_kv() const; + + // get views of the current state of the cache + ggml_tensor * get_k(ggml_context * ctx, int32_t il) const; + ggml_tensor * get_v(ggml_context * ctx, int32_t il) const; + + // store k_cur and v_cur in the cache based on the provided head location + ggml_tensor * cpy_k(ggml_context * ctx, ggml_tensor * k_cur, int32_t il) const; + ggml_tensor * cpy_v(ggml_context * ctx, ggml_tensor * v_cur, int32_t il) const; + + void set_input_k_shift(ggml_tensor * dst) const; + + void set_input_kq_mask (ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn) const; + void set_input_pos_bucket(ggml_tensor * dst, const llama_ubatch * ubatch) const; + +private: + const llama_memory_status status; + + llama_kv_cache_unified * kv; + + llama_sbatch sbatch; + + // the index of the next ubatch to process + size_t i_next = 0; + + std::vector heads; + std::vector ubatches; + + // + // data needed for building the compute graph for the current ubatch: + // + + // a heuristic, to avoid attending the full cache if it is not yet utilized + // as the cache gets filled, the benefit from this heuristic disappears + int32_t n_kv; + + // the beginning of the current slot in which the ubatch will be inserted + int32_t head; +}; + // // llama_kv_cache_unified_iswa // @@ -275,18 +363,18 @@ class llama_kv_cache_unified_iswa : public llama_kv_cache { // llama_kv_cache // - llama_memory_decode_state_ptr init( + llama_memory_state_ptr init_batch( const llama_batch & batch, uint32_t n_ubatch, bool embd_pooled, bool logits_all) override; + llama_memory_state_ptr init_full() override; + bool update(llama_context & lctx) override; void defrag_sched(float thold) override; - void set_full() override; - bool get_can_shift() const override; // state write/load @@ -298,8 +386,8 @@ class llama_kv_cache_unified_iswa : public llama_kv_cache { // llama_kv_cache_unified_iswa specific API // - llama_kv_cache_unified * get_kv_base() const; - llama_kv_cache_unified * get_kv_swa () const; + llama_kv_cache_unified * get_base() const; + llama_kv_cache_unified * get_swa () const; private: const llama_hparams & hparams; @@ -308,10 +396,68 @@ class llama_kv_cache_unified_iswa : public llama_kv_cache { std::unique_ptr kv_swa; }; +class llama_kv_cache_unified_iswa_state : public llama_memory_state_i { +public: + // used for errors + llama_kv_cache_unified_iswa_state(llama_memory_status status); + + // used to create a full-cache state + llama_kv_cache_unified_iswa_state( + llama_memory_status status, + llama_kv_cache_unified_iswa * kv); + + // used to create a state from a batch + llama_kv_cache_unified_iswa_state( + llama_memory_status status, + llama_kv_cache_unified_iswa * kv, + llama_sbatch sbatch, + std::vector heads_base, + std::vector heads_swa, + std::vector ubatches); + + virtual ~llama_kv_cache_unified_iswa_state(); + + // + // llama_memory_state_i + // + + bool next() override; + bool apply() override; + + std::vector & out_ids() override; + + llama_memory_status get_status() const override; + const llama_ubatch & get_ubatch() const override; + + // + // llama_kv_cache_unified_iswa_state specific API + // + + const llama_kv_cache_unified_state * get_base() const; + const llama_kv_cache_unified_state * get_swa() const; + +private: + const llama_memory_status status; + + //llama_kv_cache_unified_iswa * kv; + + llama_sbatch sbatch; + + // the index of the next ubatch to process + size_t i_next = 0; + + std::vector ubatches; + + std::unique_ptr state_base; + std::unique_ptr state_swa; +}; + // // llama_kv_cache_recurrent // +// TODO: extract the KV cache state used for graph computation into llama_kv_cache_recurrent_state_i +// see the implementation of llama_kv_cache_unified_state_i for an example how to do it class llama_kv_cache_recurrent : public llama_kv_cache { public: llama_kv_cache_recurrent( @@ -343,18 +489,18 @@ class llama_kv_cache_recurrent : public llama_kv_cache { // llama_kv_cache // - llama_memory_decode_state_ptr init( + llama_memory_state_ptr init_batch( const llama_batch & batch, uint32_t n_ubatch, bool embd_pooled, bool logits_all) override; + llama_memory_state_ptr init_full() override; + bool update(llama_context & lctx) override; void defrag_sched(float thold) override; - void set_full() override; - bool prepare(const std::vector & ubatches); // find a contiguous slot of kv cells and emplace the ubatch there @@ -424,3 +570,67 @@ class llama_kv_cache_recurrent : public llama_kv_cache { bool state_read_meta(llama_io_read_i & io, uint32_t cell_count, llama_seq_id dest_seq_id = -1); bool state_read_data(llama_io_read_i & io, uint32_t cell_count); }; + +class llama_kv_cache_recurrent_state : public llama_memory_state_i { +public: + // used for errors + llama_kv_cache_recurrent_state(llama_memory_status status); + + // used to create a full-cache state + llama_kv_cache_recurrent_state( + llama_memory_status status, + llama_kv_cache_recurrent * kv); + + // used to create a state from a batch + llama_kv_cache_recurrent_state( + llama_memory_status status, + llama_kv_cache_recurrent * kv, + llama_sbatch sbatch, + std::vector ubatches); + + virtual ~llama_kv_cache_recurrent_state(); + + // + // llama_memory_state_i + // + + bool next() override; + bool apply() override; + + std::vector & out_ids() override; + + llama_memory_status get_status() const override; + const llama_ubatch & get_ubatch() const override; + + // + // llama_kv_cache_recurrent_state specific API + // + + uint32_t get_n_kv() const; + uint32_t get_head() const; + uint32_t get_size() const; + + ggml_tensor * get_k_l(int32_t il) const; + ggml_tensor * get_v_l(int32_t il) const; + + int32_t s_copy(int i) const; + float s_mask(int i) const; + +private: + const llama_memory_status status; + + llama_kv_cache_recurrent * kv; + + llama_sbatch sbatch; + + size_t i_next = 0; + + std::vector ubatches; + + // + // data needed for building the compute graph for the current ubatch: + // TODO: extract all the state like `head` and `n` here + // + + const bool is_full = false; +}; diff --git a/src/llama-memory.h b/src/llama-memory.h index 44a45da9f5891..b3799d66e8c17 100644 --- a/src/llama-memory.h +++ b/src/llama-memory.h @@ -42,18 +42,35 @@ enum llama_memory_status { LLAMA_MEMORY_STATUS_FAILED_COMPUTE, }; -class llama_memory_decode_state_i { +// the interface for managing the memory state during batch processing +// this interface is implemented per memory type. see: +// - llama_kv_cache_unified_state +// - llama_kv_cache_unified_iswa_state +// ... +// +// the only method that can mutate the memory and the memory state is llama_memory_i::apply() +// +// TODO: rename to llama_memory_context_i ? +class llama_memory_state_i { public: - virtual ~llama_memory_decode_state_i() = default; + virtual ~llama_memory_state_i() = default; - // consume the next ubatch from the decode state - // return nullptr if we are done - virtual llama_ubatch * next() = 0; + // consume the current ubatch from the state and proceed to the next one + // return false if we are done + virtual bool next() = 0; + + // apply the memory state for the current ubatch to the memory object + // return false on failure + virtual bool apply() = 0; // TODO: this might get reworked in the future when refactoring llama_batch virtual std::vector & out_ids() = 0; + // get the current ubatch + virtual const llama_ubatch & get_ubatch() const = 0; + + // get the status of the memory state virtual llama_memory_status get_status() const = 0; }; -using llama_memory_decode_state_ptr = std::unique_ptr; +using llama_memory_state_ptr = std::unique_ptr; diff --git a/src/llama-model.cpp b/src/llama-model.cpp index 3f1f6c9bf3b06..e85becbb8f695 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -8892,9 +8892,9 @@ struct llm_build_mamba : public llm_graph_context { ggml_tensor * state_mask, const llama_ubatch & ubatch, int il) const { - const llama_kv_cache_recurrent * kv_self = static_cast(memory); + const auto * kv_state = static_cast(mstate); - const auto kv_head = kv_self->head; + const auto kv_head = kv_state->get_head(); const int64_t d_conv = hparams.ssm_d_conv; const int64_t d_inner = hparams.ssm_d_inner; @@ -8912,8 +8912,8 @@ struct llm_build_mamba : public llm_graph_context { GGML_ASSERT(ubatch.equal_seqs); GGML_ASSERT(ubatch.n_tokens == n_seq_tokens * n_seqs); - ggml_tensor * conv_states_all = kv_self->k_l[il]; - ggml_tensor * ssm_states_all = kv_self->v_l[il]; + ggml_tensor * conv_states_all = kv_state->get_k_l(il); + ggml_tensor * ssm_states_all = kv_state->get_v_l(il); // (ab)using the KV cache to store the states ggml_tensor * conv = build_copy_mask_state( @@ -11640,7 +11640,7 @@ struct llm_build_rwkv6_base : public llm_graph_context { ggml_tensor * state_mask, const llama_ubatch & ubatch, int il) const { - const llama_kv_cache_recurrent * kv_self = static_cast(memory); + const auto * kv_state = static_cast(mstate); const auto n_tokens = ubatch.n_tokens; const auto n_seqs = ubatch.n_seqs; @@ -11650,7 +11650,7 @@ struct llm_build_rwkv6_base : public llm_graph_context { const auto n_head = n_embd / head_size; const auto n_head_kv = hparams.n_head_kv(il); - const auto kv_head = kv_self->head; + const auto kv_head = kv_state->get_head(); const auto & layer = model.layers[il]; @@ -11762,7 +11762,7 @@ struct llm_build_rwkv6_base : public llm_graph_context { } ggml_tensor * wkv_state = build_copy_mask_state( - gf, kv_self->v_l[il], state_copy, state_mask, + gf, kv_state->get_v_l(il), state_copy, state_mask, hparams.n_embd_v_s(), n_seqs); ggml_tensor * wkv_output; @@ -11781,9 +11781,9 @@ struct llm_build_rwkv6_base : public llm_graph_context { wkv_state, ggml_view_1d( ctx0, - kv_self->v_l[il], + kv_state->get_v_l(il), hparams.n_embd_v_s() * n_seqs, - hparams.n_embd_v_s() * kv_head * ggml_element_size(kv_self->v_l[il]) + hparams.n_embd_v_s() * kv_head * ggml_element_size(kv_state->get_v_l(il)) ) ) ); @@ -12036,7 +12036,7 @@ struct llm_build_rwkv7_base : public llm_graph_context { ggml_tensor *& first_layer_value, const llama_ubatch & ubatch, int il) const { - const llama_kv_cache_recurrent * kv_self = static_cast(memory); + const auto * kv_state = static_cast(mstate); const auto n_tokens = ubatch.n_tokens; const auto n_seqs = ubatch.n_seqs; @@ -12045,7 +12045,7 @@ struct llm_build_rwkv7_base : public llm_graph_context { const auto head_count = n_embd / head_size; const auto n_seq_tokens = ubatch.n_seq_tokens; - const auto kv_head = kv_self->head; + const auto kv_head = kv_state->get_head(); const auto & layer = model.layers[il]; @@ -12116,7 +12116,7 @@ struct llm_build_rwkv7_base : public llm_graph_context { a = ggml_reshape_3d(ctx0, a, head_size, head_count, n_tokens); ggml_tensor * wkv_state = build_copy_mask_state( - gf, kv_self->v_l[il], state_copy, state_mask, + gf, kv_state->get_v_l(il), state_copy, state_mask, hparams.n_embd_v_s(), n_seqs); ggml_tensor * wkv_output = ggml_rwkv_wkv7(ctx0, r, w, k, v, ggml_neg(ctx0, kk), ggml_mul(ctx0, kk, a), wkv_state); @@ -12130,9 +12130,9 @@ struct llm_build_rwkv7_base : public llm_graph_context { wkv_state, ggml_view_1d( ctx0, - kv_self->v_l[il], + kv_state->get_v_l(il), hparams.n_embd_v_s() * n_seqs, - hparams.n_embd_v_s() * kv_head * ggml_element_size(kv_self->v_l[il]) + hparams.n_embd_v_s() * kv_head * ggml_element_size(kv_state->get_v_l(il)) ) ) ); From f2ded9d44b18afac15673d94ed741dfca0b55472 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Fri, 30 May 2025 11:39:52 +0300 Subject: [PATCH 11/14] kv-cache : add comments ggml-ci --- src/llama-kv-cache.cpp | 13 ++++++++++--- src/llama-kv-cells.h | 5 +++++ 2 files changed, 15 insertions(+), 3 deletions(-) diff --git a/src/llama-kv-cache.cpp b/src/llama-kv-cache.cpp index 0906287c26226..67dc999eb636a 100644 --- a/src/llama-kv-cache.cpp +++ b/src/llama-kv-cache.cpp @@ -493,9 +493,13 @@ int32_t llama_kv_cache_unified::find_slot(const llama_ubatch & ubatch) const { LLAMA_LOG_WARN("\n%s\n", ss.c_str()); } - LLAMA_LOG_WARN("kv_cells: n_swa = %4d, min[0] = %5d, max[0] = %5d\n", n_swa, cells.seq_pos_min(0), cells.seq_pos_max(0)); - LLAMA_LOG_WARN("kv_cells: n_swa = %4d, min[1] = %5d, max[1] = %5d\n", n_swa, cells.seq_pos_min(1), cells.seq_pos_max(1)); - LLAMA_LOG_WARN("kv_cells: n_swa = %4d, min[2] = %5d, max[2] = %5d\n", n_swa, cells.seq_pos_min(2), cells.seq_pos_max(2)); + for (int s = 0; s < LLAMA_MAX_PARALLEL_SEQUENCES; ++s) { + if (cells.seq_pos_min(s) < 0) { + continue; + } + + LLAMA_LOG_WARN("kv_cells: n_swa = %4d, min[%d] = %5d, max[%d] = %5d\n", n_swa, s, cells.seq_pos_min(s), s, cells.seq_pos_max(s)); + } #endif uint32_t n_tested = 0; @@ -538,6 +542,9 @@ int32_t llama_kv_cache_unified::find_slot(const llama_ubatch & ubatch) const { const llama_seq_id seq_id_cell = cells.seq_get(head_cur + i); // SWA mask + // note: we insert only in the cell with minimum pos in order to preserve the invariant that + // all positions between [pos_min, pos_max] for each sequence will be present in the cache + // ref: https://github.com/ggml-org/llama.cpp/pull/13746#issuecomment-2916057092 if (pos_cell == seq_pos_min[seq_id_cell] && is_masked_swa(pos_cell, cells.seq_pos_max(seq_id_cell) + 1)) { seq_pos_min[seq_id_cell]++; diff --git a/src/llama-kv-cells.h b/src/llama-kv-cells.h index d5331957105cd..9e2c4d927699d 100644 --- a/src/llama-kv-cells.h +++ b/src/llama-kv-cells.h @@ -138,6 +138,7 @@ class llama_kv_cells_unified { } } + // clear a non-empty cell void rm(uint32_t i) { assert(i < pos.size()); assert(pos[i] != -1); @@ -202,6 +203,7 @@ class llama_kv_cells_unified { return false; } + // number of different sequences in the cell int seq_count(uint32_t i) const { assert(i < pos.size()); assert(pos[i] != -1); @@ -209,6 +211,7 @@ class llama_kv_cells_unified { return seq[i].count(); } + // check if the cell contains seq_id bool seq_has(uint32_t i, llama_seq_id seq_id) const { assert(i < pos.size()); assert(seq_id >= 0); @@ -226,6 +229,8 @@ class llama_kv_cells_unified { seq_pos[seq_id].insert(pos[i]); } + // return the sequence id of this cell + // note: call only for cells with exactly one sequence llama_seq_id seq_get(uint32_t i) const { assert(seq[i].count() == 1); From e230e5144739acff637d284a00dfdc092e7d6152 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sat, 31 May 2025 10:04:18 +0300 Subject: [PATCH 12/14] server : update batching logic to reset n_batch on successful decode --- examples/parallel/parallel.cpp | 13 ++++++++++--- tools/server/server.cpp | 12 +++++++++--- 2 files changed, 19 insertions(+), 6 deletions(-) diff --git a/examples/parallel/parallel.cpp b/examples/parallel/parallel.cpp index 931ea0035cffb..22118faf8c20d 100644 --- a/examples/parallel/parallel.cpp +++ b/examples/parallel/parallel.cpp @@ -362,7 +362,9 @@ int main(int argc, char ** argv) { // process in chunks of params.n_batch int32_t n_batch = params.n_batch; - for (int32_t i = 0; i < (int32_t) batch.n_tokens; i += n_batch) { + int32_t i_next = 0; + + for (int32_t i = 0; i < batch.n_tokens; i = i_next) { // experiment: process in powers of 2 //if (i + n_batch > (int32_t) batch.n_tokens && n_batch > 32) { // n_batch /= 2; @@ -370,7 +372,7 @@ int main(int argc, char ** argv) { // continue; //} - const int32_t n_tokens = std::min(n_batch, (int32_t) (batch.n_tokens - i)); + const int32_t n_tokens = std::min(n_batch, batch.n_tokens - i); llama_batch batch_view = { n_tokens, @@ -396,13 +398,18 @@ int main(int argc, char ** argv) { // retry with half the batch size to try to find a free slot in the KV cache n_batch /= 2; - i -= n_batch; continue; } LOG_DBG("%s : decoded batch of %d tokens\n", __func__, n_tokens); + // move the head of the batch forward with the number of tokens we just processed + i_next = i + n_tokens; + + // on successful decode, restore the original batch size + n_batch = params.n_batch; + for (auto & client : clients) { if (client.i_batch < (int) i || client.i_batch >= (int) (i + n_tokens)) { continue; diff --git a/tools/server/server.cpp b/tools/server/server.cpp index 91b73afa7c794..03106760b827b 100644 --- a/tools/server/server.cpp +++ b/tools/server/server.cpp @@ -3381,8 +3381,10 @@ struct server_context { } } + int32_t i_next = 0; + // process the created batch of tokens - for (int32_t i = 0; i < batch.n_tokens; i += n_batch) { + for (int32_t i = 0; i < batch.n_tokens; i = i_next) { const int32_t n_tokens = std::min(n_batch, batch.n_tokens - i); llama_batch batch_view = { @@ -3430,11 +3432,15 @@ struct server_context { SRV_WRN("failed to find free space in the KV cache, retrying with smaller batch size - try increasing it via the context size or enable defragmentation, i = %d, n_batch = %d, ret = %d\n", i, n_batch, ret); - i -= n_batch; - continue; // continue loop of n_batch } + // move the head of the batch forward with the number of tokens we just processed + i_next = i + n_tokens; + + // on successful decode, restore the original batch size + n_batch = llama_n_batch(ctx); + for (auto & slot : slots) { if (slot.i_batch < (int) i || slot.i_batch >= (int) (i + n_tokens)) { continue; // continue loop of slots From 3cf5186356073001ad5dd0785dd4fc80b8cdcc61 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sat, 31 May 2025 10:04:50 +0300 Subject: [PATCH 13/14] server : upon full re-processing, remove the sequence from the cache --- tools/server/server.cpp | 1 + 1 file changed, 1 insertion(+) diff --git a/tools/server/server.cpp b/tools/server/server.cpp index 03106760b827b..90981ff9a5ef7 100644 --- a/tools/server/server.cpp +++ b/tools/server/server.cpp @@ -3219,6 +3219,7 @@ struct server_context { SLT_WRN(slot, "n_past = %d, cache_tokens.size() = %d, seq_id = %d, pos_min = %d\n", slot.n_past, (int) slot.cache_tokens.size(), slot.id, pos_min); SLT_WRN(slot, "forcing full prompt re-processing due to lack of cache data (likely due to SWA, see %s)\n", "https://github.com/ggml-org/llama.cpp/pull/13194#issuecomment-2868343055"); + llama_kv_self_seq_rm(ctx, slot.id, 0, -1); slot.n_past = 0; } } From 71619f2d4f1bc75b52099ffe15185a9244b2a79b Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sat, 31 May 2025 10:05:13 +0300 Subject: [PATCH 14/14] kv-cache : add TODO for doing split_equal when split_simple fails ggml-ci --- src/llama-kv-cache.cpp | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/llama-kv-cache.cpp b/src/llama-kv-cache.cpp index 67dc999eb636a..86c4f2816f809 100644 --- a/src/llama-kv-cache.cpp +++ b/src/llama-kv-cache.cpp @@ -1811,6 +1811,8 @@ llama_memory_state_ptr llama_kv_cache_unified_iswa::init_batch(const llama_batch auto sbatch = llama_sbatch(batch, hparams.n_embd, true, logits_all); + // TODO: if we fail with split_simple, we should attempt split_equal + std::vector ubatches; while (sbatch.n_tokens > 0) {