diff --git a/examples/parallel/parallel.cpp b/examples/parallel/parallel.cpp index 931ea0035cffb..22118faf8c20d 100644 --- a/examples/parallel/parallel.cpp +++ b/examples/parallel/parallel.cpp @@ -362,7 +362,9 @@ int main(int argc, char ** argv) { // process in chunks of params.n_batch int32_t n_batch = params.n_batch; - for (int32_t i = 0; i < (int32_t) batch.n_tokens; i += n_batch) { + int32_t i_next = 0; + + for (int32_t i = 0; i < batch.n_tokens; i = i_next) { // experiment: process in powers of 2 //if (i + n_batch > (int32_t) batch.n_tokens && n_batch > 32) { // n_batch /= 2; @@ -370,7 +372,7 @@ int main(int argc, char ** argv) { // continue; //} - const int32_t n_tokens = std::min(n_batch, (int32_t) (batch.n_tokens - i)); + const int32_t n_tokens = std::min(n_batch, batch.n_tokens - i); llama_batch batch_view = { n_tokens, @@ -396,13 +398,18 @@ int main(int argc, char ** argv) { // retry with half the batch size to try to find a free slot in the KV cache n_batch /= 2; - i -= n_batch; continue; } LOG_DBG("%s : decoded batch of %d tokens\n", __func__, n_tokens); + // move the head of the batch forward with the number of tokens we just processed + i_next = i + n_tokens; + + // on successful decode, restore the original batch size + n_batch = params.n_batch; + for (auto & client : clients) { if (client.i_batch < (int) i || client.i_batch >= (int) (i + n_tokens)) { continue; diff --git a/include/llama.h b/include/llama.h index 01762bea2bf96..adc4c69288a3d 100644 --- a/include/llama.h +++ b/include/llama.h @@ -259,9 +259,9 @@ extern "C" { llama_token * token; float * embd; llama_pos * pos; - int32_t * n_seq_id; - llama_seq_id ** seq_id; - int8_t * logits; // TODO: rename this to "output" + int32_t * n_seq_id; // TODO: remove, should belong to only 1 sequence + llama_seq_id ** seq_id; // TODO: become llama_seq_id * seq_id; + int8_t * logits; // TODO: rename this to "output" } llama_batch; enum llama_model_kv_override_type { @@ -677,12 +677,14 @@ extern "C" { // Returns the smallest position present in the KV cache for the specified sequence // This is typically non-zero only for SWA caches + // Note that all positions in the range [pos_min, pos_max] are guaranteed to be present in the KV cache // Return -1 if the sequence is empty LLAMA_API llama_pos llama_kv_self_seq_pos_min( struct llama_context * ctx, llama_seq_id seq_id); // Returns the largest position present in the KV cache for the specified sequence + // Note that all positions in the range [pos_min, pos_max] are guaranteed to be present in the KV cache // Return -1 if the sequence is empty LLAMA_API llama_pos llama_kv_self_seq_pos_max( struct llama_context * ctx, @@ -692,12 +694,14 @@ extern "C" { // This will be applied: // - lazily on next llama_decode() // - explicitly with llama_kv_self_update() + // TODO: deprecate and always update the cache lazily [TAG: API_KV_NO_DEFRAG] LLAMA_API void llama_kv_self_defrag(struct llama_context * ctx); // Check if the context supports KV cache shifting LLAMA_API bool llama_kv_self_can_shift(const struct llama_context * ctx); // Apply the KV cache updates (such as K-shifts, defragmentation, etc.) + // TODO: deprecate and always update the cache lazily [TAG: API_KV_NO_DEFRAG] LLAMA_API void llama_kv_self_update(struct llama_context * ctx); // diff --git a/src/llama-batch.cpp b/src/llama-batch.cpp index b98e3256c390d..6a19a243118d3 100644 --- a/src/llama-batch.cpp +++ b/src/llama-batch.cpp @@ -15,24 +15,31 @@ llama_ubatch llama_sbatch::reserve_ubatch(size_t n_ubatch, bool has_embd) { break; } } - ubatch_token.resize(!has_embd ? n_ubatch : 0); - ubatch_embd.resize(has_embd ? n_embd * n_ubatch : 0); - ubatch_pos.resize(n_ubatch); - ubatch_n_seq_id.resize(n_ubatch); - ubatch_seq_id.resize(n_ubatch); - ubatch_output.resize(n_ubatch); + + udatas.push_back({}); + + auto & udata = udatas.back(); + + udata.token.resize(!has_embd ? n_ubatch : 0); + udata.embd.resize(has_embd ? n_embd * n_ubatch : 0); + udata.pos.resize(n_ubatch); + udata.n_seq_id.resize(n_ubatch); + udata.seq_id.resize(n_ubatch); + udata.output.resize(n_ubatch); + llama_ubatch ubatch = { /*equal_seqs =*/ true, /*n_tokens =*/ 0, /*n_seq_tokens =*/ 0, /*n_seqs =*/ 0, - /*token =*/ !has_embd ? ubatch_token.data() : nullptr, - /*embd =*/ has_embd ? ubatch_embd.data() : nullptr, - /*pos =*/ ubatch_pos.data(), - /*n_seq_id =*/ ubatch_n_seq_id.data(), - /*seq_id =*/ ubatch_seq_id.data(), - /*output =*/ ubatch_output.data(), + /*token =*/ !has_embd ? udata.token.data() : nullptr, + /*embd =*/ has_embd ? udata.embd.data() : nullptr, + /*pos =*/ udata.pos.data(), + /*n_seq_id =*/ udata.n_seq_id.data(), + /*seq_id =*/ udata.seq_id.data(), + /*output =*/ udata.output.data(), }; + return ubatch; } diff --git a/src/llama-batch.h b/src/llama-batch.h index 6305051b62b79..b8260b94fd2d0 100644 --- a/src/llama-batch.h +++ b/src/llama-batch.h @@ -11,15 +11,15 @@ struct llama_ubatch { bool equal_seqs; // TODO: whole_seqs for embeddings? - uint32_t n_tokens; // total tokens (n_seq_tokens * n_seqs) + uint32_t n_tokens; // total tokens (n_seq_tokens * n_seqs) uint32_t n_seq_tokens; // tokens per sequence uint32_t n_seqs; llama_token * token; // [n_tokens] float * embd; // [n_embd, n_tokens] llama_pos * pos; // [n_tokens] - int32_t * n_seq_id; // [n_seqs] - llama_seq_id ** seq_id; // [n_seqs] + int32_t * n_seq_id; // [n_seqs] // TODO: remove, should belong to only 1 sequence + llama_seq_id ** seq_id; // [n_seqs] // TODO: become llama_seq_id * seq_id; int8_t * output; // [n_tokens] }; @@ -49,13 +49,18 @@ struct llama_sbatch { const llama_batch * batch = nullptr; - // buffers for the ubatch - std::vector ubatch_token; - std::vector ubatch_embd; - std::vector ubatch_pos; - std::vector ubatch_n_seq_id; - std::vector ubatch_seq_id; - std::vector ubatch_output; + // buffers for the ubatches + // TODO: very hacky, this needs a complete rework + struct ubatch_data { + std::vector token; + std::vector embd; + std::vector pos; + std::vector n_seq_id; + std::vector seq_id; + std::vector output; + }; + + std::vector udatas; llama_ubatch reserve_ubatch(size_t n_ubatch, bool has_embd = false); diff --git a/src/llama-context.cpp b/src/llama-context.cpp index e153351af3809..808fe5991088c 100644 --- a/src/llama-context.cpp +++ b/src/llama-context.cpp @@ -6,9 +6,10 @@ #include "llama-model.h" #include "llama-kv-cache.h" +#include #include +#include #include -#include // // llama_context @@ -259,15 +260,9 @@ llama_context::llama_context( // reserve worst-case graph if (!hparams.vocab_only && memory) { - const uint32_t n_seqs = 1; // TODO: worst-case number of sequences + const uint32_t n_seqs = cparams.n_seq_max; const uint32_t n_tokens = std::min(cparams.n_ctx, cparams.n_ubatch); - llama_token token = model.vocab.token_bos(); // not actually used by llama_build_graph, but required to choose between token and embedding inputs graph - - // restore later - // TODO: something cleaner - const auto n_outputs_save = n_outputs; - LLAMA_LOG_DEBUG("%s: worst-case: n_tokens = %d, n_seqs = %d, n_outputs = %d\n", __func__, n_tokens, n_seqs, n_outputs); int n_splits_pp = -1; @@ -279,23 +274,17 @@ llama_context::llama_context( // simulate full KV cache llama_kv_cache * kv_self = static_cast(memory.get()); - kv_self->set_full(); + const auto kv_state = kv_self->init_full(); + if (!kv_state) { + throw std::runtime_error("failed to initialize KV cache"); + } cross.v_embd.clear(); // reserve pp graph first so that buffers are only allocated once { - llama_ubatch ubatch_pp = { true, n_tokens, n_tokens / n_seqs, n_seqs, &token, nullptr, nullptr, nullptr, nullptr, nullptr}; - - // max number of outputs - n_outputs = ubatch_pp.n_tokens; - - LLAMA_LOG_DEBUG("%s: reserving graph for n_tokens = %d, n_seqs = %d\n", __func__, ubatch_pp.n_tokens, ubatch_pp.n_seqs); - - auto * gf = graph_init(); - graph_build(ctx_compute.get(), gf, ubatch_pp, LLM_GRAPH_TYPE_DEFAULT); - - if (!ggml_backend_sched_reserve(sched.get(), gf)) { + auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, kv_state.get()); + if (!gf) { throw std::runtime_error("failed to allocate compute pp buffers"); } @@ -305,16 +294,8 @@ llama_context::llama_context( // reserve with tg graph to get the number of splits and nodes { - llama_ubatch ubatch_tg = { true, 1, 1, n_seqs, &token, nullptr, nullptr, nullptr, nullptr, nullptr}; - - n_outputs = ubatch_tg.n_tokens; - - LLAMA_LOG_DEBUG("%s: reserving graph for n_tokens = %d, n_seqs = %d\n", __func__, ubatch_tg.n_tokens, ubatch_tg.n_seqs); - - auto * gf = graph_init(); - graph_build(ctx_compute.get(), gf, ubatch_tg, LLM_GRAPH_TYPE_DEFAULT); - - if (!ggml_backend_sched_reserve(sched.get(), gf)) { + auto * gf = graph_reserve(1, 1, 1, kv_state.get()); + if (!gf) { throw std::runtime_error("failed to allocate compute tg buffers"); } @@ -324,22 +305,12 @@ llama_context::llama_context( // reserve again with pp graph to avoid ggml-alloc reallocations during inference { - llama_ubatch ubatch_pp = { true, n_tokens, n_tokens / n_seqs, n_seqs, &token, nullptr, nullptr, nullptr, nullptr, nullptr}; - - n_outputs = ubatch_pp.n_tokens; - - LLAMA_LOG_DEBUG("%s: reserving graph for n_tokens = %d, n_seqs = %d\n", __func__, ubatch_pp.n_tokens, ubatch_pp.n_seqs); - - auto * gf = graph_init(); - graph_build(ctx_compute.get(), gf, ubatch_pp, LLM_GRAPH_TYPE_DEFAULT); - - if (!ggml_backend_sched_reserve(sched.get(), gf)) { + auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, kv_state.get()); + if (!gf) { throw std::runtime_error("failed to allocate compute pp buffers"); } } - n_outputs = n_outputs_save; - for (size_t i = 0; i < backend_ptrs.size(); ++i) { ggml_backend_t backend = backend_ptrs[i]; ggml_backend_buffer_type_t buft = backend_buft[i]; @@ -454,33 +425,25 @@ const llama_kv_cache * llama_context::get_kv_self() const { } void llama_context::kv_self_update() { - bool need_reserve = false; + if (!memory) { + return; + } llama_kv_cache * kv_self = static_cast(memory.get()); - need_reserve = kv_self->update(*this); - - // reserve a worst case graph if needed - if (need_reserve) { - LLAMA_LOG_DEBUG("%s: reserving a worst case graph\n", __func__); - - // build worst-case graph - uint32_t n_seqs = 1; // TODO: worst-case number of sequences - uint32_t n_tokens = std::min(cparams.n_ctx, cparams.n_ubatch); - - // simulate full KV cache - kv_self->set_full(); - - llama_token token = model.vocab.token_bos(); // not actually used by llama_build_graph, but required to choose between token and embedding inputs graph - llama_ubatch ubatch = { true, n_tokens, n_tokens / n_seqs, n_seqs, &token, nullptr, nullptr, nullptr, nullptr, nullptr}; + if (kv_self->update(*this)) { + // if the KV cache did any computation, we have to reserve a new worst-case graph + const auto kv_state = kv_self->init_full(); + if (!kv_state) { + throw std::runtime_error("failed to initialize KV cache"); + } - auto * gf = graph_init(); - graph_build(ctx_compute.get(), gf, ubatch, LLM_GRAPH_TYPE_DEFAULT); + const uint32_t n_seqs = cparams.n_seq_max; + const uint32_t n_tokens = std::min(cparams.n_ctx, cparams.n_ubatch); - // initialize scheduler with the worst-case graph - ggml_backend_sched_reset(sched.get()); - if (!ggml_backend_sched_reserve(sched.get(), gf)) { - LLAMA_LOG_ERROR("%s: failed to allocate compute buffers\n", __func__); + auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, kv_state.get()); + if (!gf) { + LLAMA_LOG_ERROR("%s: failed to reserve graph after the KV cache update\n", __func__); } } } @@ -676,6 +639,49 @@ bool llama_context::apply_adapter_cvec( return cvec.apply(model, data, len, n_embd, il_start, il_end); } +llm_graph_result_ptr llama_context::process_ubatch(const llama_ubatch & ubatch, llm_graph_type gtype, llama_memory_state_i * mstate, ggml_status & ret) { + if (mstate && !mstate->apply()) { + LLAMA_LOG_ERROR("%s: failed to apply memory state\n", __func__); + ret = GGML_STATUS_FAILED; + return nullptr; + } + + auto * gf = graph_init(); + if (!gf) { + LLAMA_LOG_ERROR("%s: failed to initialize graph\n", __func__); + ret = GGML_STATUS_FAILED; + return nullptr; + } + + auto res = graph_build(ctx_compute.get(), gf, ubatch, gtype, mstate); + if (!res) { + LLAMA_LOG_ERROR("%s: failed to build graph\n", __func__); + ret = GGML_STATUS_FAILED; + return nullptr; + } + + // LLAMA_LOG_INFO("graph build time: %.3f ms (%d nodes, %d leafs)\n", (ggml_time_us() - t_start_us)/1000.0, gf->n_nodes, gf->n_leafs); + + if (!ggml_backend_sched_alloc_graph(sched.get(), gf)) { + LLAMA_LOG_ERROR("%s: failed to allocate graph\n", __func__); + ret = GGML_STATUS_ALLOC_FAILED; + return nullptr; + } + + res->set_inputs(&ubatch); + + const auto status = graph_compute(gf, ubatch.n_tokens > 1); + if (status != GGML_STATUS_SUCCESS) { + LLAMA_LOG_ERROR("%s: failed to compute graph, compute status: %d\n", __func__, status); + ret = status; + return nullptr; + } + + ret = GGML_STATUS_SUCCESS; + + return res; +} + int llama_context::encode(llama_batch & inp_batch) { if (inp_batch.n_tokens == 0) { LLAMA_LOG_ERROR("%s: n_tokens == 0\n", __func__); @@ -737,8 +743,6 @@ int llama_context::encode(llama_batch & inp_batch) { n_outputs = n_tokens; - //batch_manager->prepare(ubatch); - ggml_backend_sched_reset(sched.get()); ggml_backend_sched_set_eval_callback(sched.get(), cparams.cb_eval, cparams.cb_eval_user_data); @@ -749,26 +753,18 @@ int llama_context::encode(llama_batch & inp_batch) { // ref: https://github.com/ggml-org/llama.cpp/pull/12181#issuecomment-2730451223 cparams.causal_attn = false; - auto * gf = graph_init(); - auto res = graph_build(ctx_compute.get(), gf, ubatch, LLM_GRAPH_TYPE_ENCODER); - - ggml_backend_sched_alloc_graph(sched.get(), gf); - - res->set_inputs(&ubatch); + ggml_status status; + const auto res = process_ubatch(ubatch, LLM_GRAPH_TYPE_ENCODER, nullptr, status); cparams.causal_attn = causal_attn_org; - const auto compute_status = graph_compute(gf, n_tokens > 1); - switch (compute_status) { - case GGML_STATUS_SUCCESS: - break; - case GGML_STATUS_ABORTED: - return 2; - case GGML_STATUS_ALLOC_FAILED: - return -2; - case GGML_STATUS_FAILED: - default: - return -3; + if (!res) { + switch (status) { + case GGML_STATUS_ABORTED: return 2; + case GGML_STATUS_ALLOC_FAILED: return -2; + case GGML_STATUS_FAILED: return -3; + case GGML_STATUS_SUCCESS: GGML_ABORT("should not happen"); + } } auto * t_embd = res->get_embd_pooled() ? res->get_embd_pooled() : res->get_embd(); @@ -889,8 +885,6 @@ int llama_context::decode(llama_batch & inp_batch) { const int64_t n_tokens_all = batch.n_tokens; const int64_t n_embd = hparams.n_embd; - llama_kv_cache_guard kv_guard(kv_self); - GGML_ASSERT((!batch.token && batch.embd) || (batch.token && !batch.embd)); // NOLINT // TODO: move the validation to the llama_batch_allocr @@ -936,7 +930,28 @@ int llama_context::decode(llama_batch & inp_batch) { n_outputs_all = 1; } - llama_sbatch sbatch = kv_self->sbatch_init(batch, /* logits_all */ n_outputs_all == n_tokens_all); + // handle any pending defrags/shifts + kv_self_update(); + + auto kv_state = kv_self->init_batch(batch, cparams.n_ubatch, embd_pooled, /* logits_all */ n_outputs_all == n_tokens_all); + if (!kv_state) { + return -2; + } + + switch (kv_state->get_status()) { + case LLAMA_MEMORY_STATUS_SUCCESS: + { + } break; + case LLAMA_MEMORY_STATUS_FAILED_PREPARE: + { + // not a fatal error, we can re-try with a different batch + return 1; + } + case LLAMA_MEMORY_STATUS_FAILED_COMPUTE: + { + return -2; + } + } // reserve output buffer if (output_reserve(n_outputs_all) < n_outputs_all) { @@ -944,13 +959,10 @@ int llama_context::decode(llama_batch & inp_batch) { return -2; }; - // handle any pending defrags/shifts - kv_self_update(); - int64_t n_outputs_prev = 0; - while (sbatch.n_tokens > 0) { - llama_ubatch ubatch = kv_self->ubatch_next(sbatch, cparams.n_ubatch, embd_pooled); + do { + const auto & ubatch = kv_state->get_ubatch(); // count the outputs in this u_batch { @@ -969,33 +981,37 @@ int llama_context::decode(llama_batch & inp_batch) { n_outputs = n_outputs_new; } - // find KV slot - if (!kv_self->find_slot(ubatch)) { - return 1; - } - ggml_backend_sched_reset(sched.get()); ggml_backend_sched_set_eval_callback(sched.get(), cparams.cb_eval, cparams.cb_eval_user_data); - auto * gf = graph_init(); - auto res = graph_build(ctx_compute.get(), gf, ubatch, LLM_GRAPH_TYPE_DECODER); + ggml_status status; + const auto res = process_ubatch(ubatch, LLM_GRAPH_TYPE_DECODER, kv_state.get(), status); - // LLAMA_LOG_INFO("graph build time: %.3f ms (%d nodes, %d leafs)\n", (ggml_time_us() - t_start_us)/1000.0, gf->n_nodes, gf->n_leafs); + if (!res) { + // the last ubatch failed or was aborted -> remove all positions of that ubatch from the KV cache + llama_pos pos_min[LLAMA_MAX_PARALLEL_SEQUENCES] = { std::numeric_limits::max() }; - ggml_backend_sched_alloc_graph(sched.get(), gf); + for (uint32_t i = 0; i < ubatch.n_tokens; ++i) { + const auto & seq_id = ubatch.seq_id[i][0]; - res->set_inputs(&ubatch); + pos_min[seq_id] = std::min(pos_min[seq_id], ubatch.pos[i]); + } + + for (int s = 0; s < LLAMA_MAX_PARALLEL_SEQUENCES; ++s) { + if (pos_min[s] == std::numeric_limits::max()) { + continue; + } + + LLAMA_LOG_WARN("%s: removing KV cache entries for seq_id = %d, pos = [%d, +inf)\n", __func__, s, pos_min[s]); + + llama_kv_self_seq_rm(this, s, pos_min[s], -1); + } - const auto compute_status = graph_compute(gf, ubatch.n_tokens > 1); - if (compute_status != GGML_STATUS_SUCCESS) { - switch (compute_status) { - case GGML_STATUS_ABORTED: - return 2; - case GGML_STATUS_ALLOC_FAILED: - return -2; - case GGML_STATUS_FAILED: - default: - return -3; + switch (status) { + case GGML_STATUS_ABORTED: return 2; + case GGML_STATUS_ALLOC_FAILED: return -2; + case GGML_STATUS_FAILED: return -3; + case GGML_STATUS_SUCCESS: GGML_ABORT("should not happen"); } } @@ -1082,10 +1098,7 @@ int llama_context::decode(llama_batch & inp_batch) { } n_outputs_prev += n_outputs; - } - - // finalize the batch processing - kv_guard.commit(); + } while (kv_state->next()); // set to total number of outputs in the batch, for use in llama_get_logits_ith n_outputs = n_outputs_all; @@ -1094,7 +1107,7 @@ int llama_context::decode(llama_batch & inp_batch) { { bool sorted_output = true; - auto & out_ids = sbatch.out_ids; + auto & out_ids = kv_state->out_ids(); GGML_ASSERT(out_ids.size() == (size_t) n_outputs_all); @@ -1254,11 +1267,52 @@ ggml_cgraph * llama_context::graph_init() { return ggml_new_graph_custom(ctx_compute.get(), graph_max_nodes(), false); } +ggml_cgraph * llama_context::graph_reserve(uint32_t n_tokens, uint32_t n_seqs, uint32_t n_outputs, const llama_memory_state_i * mstate) { + LLAMA_LOG_DEBUG("%s: reserving a graph for ubatch with n_tokens = %4u, n_seqs = %2u, n_outputs = %4u\n", __func__, n_tokens, n_seqs, n_outputs); + + if (n_tokens % n_seqs != 0) { + n_tokens = (n_tokens / n_seqs) * n_seqs; + n_outputs = std::min(n_outputs, n_tokens); + + LLAMA_LOG_DEBUG("%s: making n_tokens a multiple of n_seqs - n_tokens = %u, n_seqs = %u, n_outputs = %u\n", __func__, n_tokens, n_seqs, n_outputs); + } + + // store the n_outputs as it is, and restore it afterwards + // TODO: not sure if needed, might simplify in the future by removing this + const auto save_n_outputs = this->n_outputs; + + this->n_outputs = n_outputs; + + llama_token token = model.vocab.token_bos(); // not actually used by llama_build_graph, but required to choose between token and embedding inputs graph + llama_ubatch ubatch = { true, n_tokens, n_tokens / n_seqs, n_seqs, &token, nullptr, nullptr, nullptr, nullptr, nullptr}; + + auto * gf = graph_init(); + auto res = graph_build(ctx_compute.get(), gf, ubatch, LLM_GRAPH_TYPE_DEFAULT, mstate); + + this->n_outputs = save_n_outputs; + + if (!res) { + LLAMA_LOG_ERROR("%s: failed to build worst-case graph\n", __func__); + return nullptr; + } + + ggml_backend_sched_reset(sched.get()); + + // initialize scheduler with the specified graph + if (!ggml_backend_sched_reserve(sched.get(), gf)) { + LLAMA_LOG_ERROR("%s: failed to allocate compute buffers\n", __func__); + return nullptr; + } + + return gf; +} + llm_graph_result_ptr llama_context::graph_build( - ggml_context * ctx, - ggml_cgraph * gf, - const llama_ubatch & ubatch, - llm_graph_type gtype) { + ggml_context * ctx, + ggml_cgraph * gf, + const llama_ubatch & ubatch, + llm_graph_type gtype, + const llama_memory_state_i * mstate) { return model.build_graph( { /*.ctx =*/ ctx, @@ -1270,7 +1324,7 @@ llm_graph_result_ptr llama_context::graph_build( /*.backend_cpu =*/ backend_cpu, /*.cvec =*/ &cvec, /*.loras =*/ &loras, - /*.memory =*/ memory.get(), + /*.mstate =*/ mstate, /*.cross =*/ &cross, /*.n_outputs =*/ n_outputs, /*.cb =*/ graph_get_cb(), @@ -1951,7 +2005,6 @@ void llama_context::opt_epoch_iter( llama_kv_cache * kv_self = static_cast(memory.get()); kv_self->clear(); - llama_kv_cache_guard kv_guard(kv_self); for (uint32_t pos_ctx = 0; pos_ctx < n_ctx; pos_ctx += n_batch) { batch.n_tokens = n_batch; @@ -1974,7 +2027,11 @@ void llama_context::opt_epoch_iter( int64_t n_outputs_all = n_tokens_all; - llama_sbatch sbatch = kv_self->sbatch_init(batch, /*logits_all =*/ true); + auto kv_state = kv_self->init_batch(batch, cparams.n_ubatch, embd_pooled, /* logits_all */ true); + if (!kv_state || kv_state->get_status() != LLAMA_MEMORY_STATUS_SUCCESS) { + LLAMA_LOG_ERROR("%s: could not initialize batch\n", __func__); + break; + } // reserve output buffer if (output_reserve(n_outputs_all) < n_outputs_all) { @@ -1982,20 +2039,19 @@ void llama_context::opt_epoch_iter( GGML_ABORT("TODO: handle this error"); }; - for (uint32_t pos_batch = 0; pos_batch < n_batch; pos_batch += n_ubatch) { - llama_ubatch ubatch = kv_self->ubatch_next(sbatch, cparams.n_ubatch, embd_pooled); + uint32_t pos_batch = 0; + do { + const auto & ubatch = kv_state->get_ubatch(); n_outputs = ubatch.n_tokens; - // TODO: not sure if this is needed - if (!kv_self->find_slot(ubatch)) { - LLAMA_LOG_WARN("%s: failed to find KV cache slot for ubatch of size %d\n", __func__, ubatch.n_tokens); - - GGML_ABORT("TODO: handle this error"); + if (!kv_state->apply()) { + LLAMA_LOG_ERROR("%s: failed to update the memory state\n", __func__); + break; } auto * gf = graph_init(); - auto res = graph_build(ctx_compute.get(), gf, ubatch, LLM_GRAPH_TYPE_DEFAULT); + auto res = graph_build(ctx_compute.get(), gf, ubatch, LLM_GRAPH_TYPE_DEFAULT, kv_state.get()); struct ggml_context * ctx_compute_opt; { @@ -2010,6 +2066,7 @@ void llama_context::opt_epoch_iter( } ggml_opt_prepare_alloc(opt_ctx, ctx_compute_opt, gf, res->get_tokens(), res->get_logits()); ggml_opt_alloc(opt_ctx, train); + res->set_inputs(&ubatch); { struct ggml_tensor * labels = ggml_opt_labels(opt_ctx); @@ -2027,10 +2084,10 @@ void llama_context::opt_epoch_iter( callback(train, opt_ctx, dataset, result, idata_in_loop + (pos_ctx + pos_batch)/n_ubatch + 1, ndata_in_loop, t_loop_start); } ggml_free(ctx_compute_opt); - } - } - kv_guard.commit(); + pos_batch += ubatch.n_tokens; + } while (kv_state->next()); + } } void llama_context::opt_epoch( diff --git a/src/llama-context.h b/src/llama-context.h index c0ceacb10ce6f..5b79bafa75db7 100644 --- a/src/llama-context.h +++ b/src/llama-context.h @@ -18,6 +18,9 @@ struct llama_kv_cache; class llama_io_read_i; class llama_io_write_i; +class llama_memory_i; +class llama_memory_state_i; + struct llama_context { // init scheduler and compute buffers, reserve worst-case graphs llama_context( @@ -47,6 +50,7 @@ struct llama_context { llama_kv_cache * get_kv_self(); const llama_kv_cache * get_kv_self() const; + // TODO: remove void kv_self_update(); enum llama_pooling_type pooling_type() const; @@ -88,6 +92,16 @@ struct llama_context { int32_t il_start, int32_t il_end); + // process a single ubatch with a specific graph type + // if memory_state is provided, it will be applied first to the context's memory + // ret contains the status of the graph computation + // returns nullptr only if ret != GGML_STATUS_SUCCESS + llm_graph_result_ptr process_ubatch( + const llama_ubatch & ubatch, + llm_graph_type gtype, + llama_memory_state_i * mstate, + ggml_status & ret); + int encode(llama_batch & inp_batch); int decode(llama_batch & inp_batch); @@ -180,16 +194,18 @@ struct llama_context { ggml_cgraph * graph_init(); // returns the result of ggml_backend_sched_graph_compute_async execution - ggml_status graph_compute( - ggml_cgraph * gf, - bool batched); + ggml_status graph_compute(ggml_cgraph * gf, bool batched); + + // reserve a graph with a dummy ubatch of the specified size + ggml_cgraph * graph_reserve(uint32_t n_tokens, uint32_t n_seqs, uint32_t n_outputs, const llama_memory_state_i * mstate); private: llm_graph_result_ptr graph_build( - ggml_context * ctx, - ggml_cgraph * gf, - const llama_ubatch & ubatch, - llm_graph_type gtype); + ggml_context * ctx, + ggml_cgraph * gf, + const llama_ubatch & ubatch, + llm_graph_type gtype, + const llama_memory_state_i * mstate); llm_graph_cb graph_get_cb() const; diff --git a/src/llama-graph.cpp b/src/llama-graph.cpp index 7c383e2eb3f27..b30f6fb4f4145 100644 --- a/src/llama-graph.cpp +++ b/src/llama-graph.cpp @@ -83,7 +83,7 @@ void llm_graph_input_pos_bucket::set_input(const llama_ubatch * ubatch) { void llm_graph_input_pos_bucket_kv::set_input(const llama_ubatch * ubatch) { if (pos_bucket) { - kv_self->set_input_pos_bucket(pos_bucket, ubatch); + kv_state->set_input_pos_bucket(pos_bucket, ubatch); } } @@ -234,7 +234,7 @@ void llm_graph_input_cls::set_input(const llama_ubatch * ubatch) { void llm_graph_input_s_copy::set_input(const llama_ubatch * ubatch) { GGML_UNUSED(ubatch); - const int64_t n_kv = kv_self->n; + const int64_t n_kv = kv_state->get_n_kv(); if (s_copy) { GGML_ASSERT(ggml_backend_buffer_is_host(s_copy->buffer)); @@ -242,7 +242,7 @@ void llm_graph_input_s_copy::set_input(const llama_ubatch * ubatch) { // assuming copy destinations ALWAYS happen ONLY on the cells between head and head+n for (uint32_t i = 0; i < n_kv; ++i) { - data[i] = kv_self->s_copy(i); + data[i] = kv_state->s_copy(i); } } } @@ -250,7 +250,7 @@ void llm_graph_input_s_copy::set_input(const llama_ubatch * ubatch) { void llm_graph_input_s_mask::set_input(const llama_ubatch * ubatch) { GGML_UNUSED(ubatch); - const int64_t n_kv = kv_self->n; + const int64_t n_kv = kv_state->get_n_kv(); if (s_mask) { GGML_ASSERT(ggml_backend_buffer_is_host(s_mask->buffer)); @@ -258,7 +258,7 @@ void llm_graph_input_s_mask::set_input(const llama_ubatch * ubatch) { // clear unused states for (int i = 0; i < n_kv; ++i) { - data[i] = kv_self->s_mask(i); + data[i] = kv_state->s_mask(i); } } } @@ -362,17 +362,17 @@ void llm_graph_input_attn_no_cache::set_input(const llama_ubatch * ubatch) { void llm_graph_input_attn_kv_unified::set_input(const llama_ubatch * ubatch) { if (self_kq_mask) { - kv_self->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn); + kv_state->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn); } } void llm_graph_input_attn_kv_unified_iswa::set_input(const llama_ubatch * ubatch) { if (self_kq_mask) { - kv_self->get_kv_base()->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn); + kv_state->get_base()->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn); } if (self_kq_mask_swa) { - kv_self->get_kv_swa()->set_input_kq_mask(self_kq_mask_swa, ubatch, cparams.causal_attn); + kv_state->get_swa()->set_input_kq_mask(self_kq_mask_swa, ubatch, cparams.causal_attn); } } @@ -448,7 +448,7 @@ llm_graph_context::llm_graph_context(const llm_graph_params & params) : backend_cpu (params.backend_cpu), cvec (params.cvec), loras (params.loras), - memory (params.memory), + mstate (params.mstate), cross (params.cross), cb_func (params.cb), res (std::make_unique()) { @@ -954,11 +954,11 @@ ggml_tensor * llm_graph_context::build_inp_cls() const { } ggml_tensor * llm_graph_context::build_inp_s_copy() const { - const llama_kv_cache_recurrent * kv_self = static_cast(memory); + const auto * kv_state = static_cast(mstate); - auto inp = std::make_unique(kv_self); + auto inp = std::make_unique(kv_state); - const auto n_kv = kv_self->n; + const auto n_kv = kv_state->get_n_kv(); auto & cur = inp->s_copy; @@ -971,11 +971,11 @@ ggml_tensor * llm_graph_context::build_inp_s_copy() const { } ggml_tensor * llm_graph_context::build_inp_s_mask() const { - const llama_kv_cache_recurrent * kv_self = static_cast(memory); + const auto * kv_state = static_cast(mstate); - auto inp = std::make_unique(kv_self); + auto inp = std::make_unique(kv_state); - const auto n_kv = kv_self->n; + const auto n_kv = kv_state->get_n_kv(); auto & cur = inp->s_mask; @@ -1025,11 +1025,11 @@ ggml_tensor * llm_graph_context::build_inp_pos_bucket_enc() const { } ggml_tensor * llm_graph_context::build_inp_pos_bucket_dec() const { - const llama_kv_cache_unified * kv_self = static_cast(memory); + const auto * kv_state = static_cast(mstate); - auto inp = std::make_unique(hparams, kv_self); + auto inp = std::make_unique(hparams, kv_state); - const auto n_kv = kv_self->get_n(); + const auto n_kv = kv_state->get_n_kv(); auto & cur = inp->pos_bucket; @@ -1231,14 +1231,14 @@ ggml_tensor * llm_graph_context::build_attn( } llm_graph_input_attn_kv_unified * llm_graph_context::build_attn_inp_kv_unified() const { - const llama_kv_cache_unified * kv_self = static_cast(memory); + const auto * kv_state = static_cast(mstate); - auto inp = std::make_unique(hparams, cparams, kv_self); + auto inp = std::make_unique(hparams, cparams, kv_state); { GGML_ASSERT(hparams.swa_type == LLAMA_SWA_TYPE_NONE && "Use llama_kv_cache_unified_iswa for SWA"); - const auto n_kv = kv_self->get_n(); + const auto n_kv = kv_state->get_n_kv(); inp->self_kq_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD)); //cb(inp->self_kq_mask, "KQ_mask", -1); @@ -1268,19 +1268,19 @@ ggml_tensor * llm_graph_context::build_attn( ggml_build_forward_expand(gf, k_cur); ggml_build_forward_expand(gf, v_cur); - const llama_kv_cache_unified * kv_self = static_cast(memory); + const auto * kv_state = static_cast(mstate); // store to KV cache { - ggml_build_forward_expand(gf, kv_self->cpy_k(ctx0, k_cur, il)); - ggml_build_forward_expand(gf, kv_self->cpy_v(ctx0, v_cur, il)); + ggml_build_forward_expand(gf, kv_state->cpy_k(ctx0, k_cur, il)); + ggml_build_forward_expand(gf, kv_state->cpy_v(ctx0, v_cur, il)); } const auto & kq_mask = inp->get_kq_mask(); ggml_tensor * q = q_cur; - ggml_tensor * k = kv_self->get_k(ctx0, il); - ggml_tensor * v = kv_self->get_v(ctx0, il); + ggml_tensor * k = kv_state->get_k(ctx0, il); + ggml_tensor * v = kv_state->get_v(ctx0, il); ggml_tensor * cur = build_attn_mha(gf, q, k, v, kq_b, kq_mask, v_mla, kq_scale); cb(cur, "kqv_out", il); @@ -1301,12 +1301,12 @@ ggml_tensor * llm_graph_context::build_attn( } llm_graph_input_attn_kv_unified_iswa * llm_graph_context::build_attn_inp_kv_unified_iswa() const { - const llama_kv_cache_unified_iswa * kv_self = static_cast(memory); + const auto * kv_state = static_cast(mstate); - auto inp = std::make_unique(hparams, cparams, kv_self); + auto inp = std::make_unique(hparams, cparams, kv_state); { - const auto n_kv = kv_self->get_kv_base()->get_n(); + const auto n_kv = kv_state->get_base()->get_n_kv(); inp->self_kq_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD)); //cb(inp->self_kq_mask, "KQ_mask", -1); @@ -1318,7 +1318,7 @@ llm_graph_input_attn_kv_unified_iswa * llm_graph_context::build_attn_inp_kv_unif { GGML_ASSERT(hparams.swa_type != LLAMA_SWA_TYPE_NONE && "Use llama_kv_cache_unified for non-SWA"); - const auto n_kv = kv_self->get_kv_swa()->get_n(); + const auto n_kv = kv_state->get_swa()->get_n_kv(); inp->self_kq_mask_swa = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD)); //cb(inp->self_kq_mask_swa, "KQ_mask_swa", -1); @@ -1348,23 +1348,23 @@ ggml_tensor * llm_graph_context::build_attn( ggml_build_forward_expand(gf, k_cur); ggml_build_forward_expand(gf, v_cur); - const bool is_swa = hparams.is_swa(il); + const auto * kv_state_iswa = static_cast(mstate); - const llama_kv_cache_unified_iswa * kv_self = static_cast(memory); + const bool is_swa = hparams.is_swa(il); - const auto * kv = is_swa ? kv_self->get_kv_swa() : kv_self->get_kv_base(); + const auto * kv_state = is_swa ? kv_state_iswa->get_swa() : kv_state_iswa->get_base(); // store to KV cache { - ggml_build_forward_expand(gf, kv->cpy_k(ctx0, k_cur, il)); - ggml_build_forward_expand(gf, kv->cpy_v(ctx0, v_cur, il)); + ggml_build_forward_expand(gf, kv_state->cpy_k(ctx0, k_cur, il)); + ggml_build_forward_expand(gf, kv_state->cpy_v(ctx0, v_cur, il)); } const auto & kq_mask = is_swa ? inp->get_kq_mask_swa() : inp->get_kq_mask(); ggml_tensor * q = q_cur; - ggml_tensor * k = kv->get_k(ctx0, il); - ggml_tensor * v = kv->get_v(ctx0, il); + ggml_tensor * k = kv_state->get_k(ctx0, il); + ggml_tensor * v = kv_state->get_v(ctx0, il); ggml_tensor * cur = build_attn_mha(gf, q, k, v, kq_b, kq_mask, v_mla, kq_scale); cb(cur, "kqv_out", il); @@ -1446,12 +1446,12 @@ ggml_tensor * llm_graph_context::build_copy_mask_state( ggml_tensor * state_mask, int32_t n_state, int32_t n_seqs) const { - const llama_kv_cache_recurrent * kv_self = static_cast(memory); + const auto * kv_state = static_cast(mstate); - const auto n_kv = kv_self->n; - const auto kv_head = kv_self->head; + const auto n_kv = kv_state->get_n_kv(); + const auto kv_head = kv_state->get_head(); - ggml_tensor * states = ggml_reshape_2d(ctx0, s, n_state, kv_self->size); + ggml_tensor * states = ggml_reshape_2d(ctx0, s, n_state, kv_state->get_size()); // copy states // NOTE: assuming the copy destinations are ALL contained between kv_head and kv_head + n_kv @@ -1478,13 +1478,13 @@ ggml_tensor * llm_graph_context::build_rwkv_token_shift_load( ggml_tensor * state_mask, const llama_ubatch & ubatch, int il) const { - const llama_kv_cache_recurrent * kv_self = static_cast(memory); + const auto * kv_state = static_cast(mstate); const auto token_shift_count = hparams.token_shift_count; const int64_t n_seqs = ubatch.n_seqs; - ggml_tensor * token_shift_all = kv_self->k_l[il]; + ggml_tensor * token_shift_all = kv_state->get_k_l(il); ggml_tensor * token_shift = build_copy_mask_state( gf, token_shift_all, state_copy, state_mask, @@ -1499,19 +1499,19 @@ ggml_tensor * llm_graph_context::build_rwkv_token_shift_store( ggml_tensor * token_shift, const llama_ubatch & ubatch, int il) const { - const llama_kv_cache_recurrent * kv_self = static_cast(memory); + const auto * kv_state = static_cast(mstate); const auto token_shift_count = hparams.token_shift_count; const auto n_embd = hparams.n_embd; const int64_t n_seqs = ubatch.n_seqs; - const auto kv_head = kv_self->head; + const auto kv_head = kv_state->get_head(); return ggml_cpy( ctx0, ggml_view_1d(ctx0, token_shift, n_embd * n_seqs * token_shift_count, 0), - ggml_view_1d(ctx0, kv_self->k_l[il], hparams.n_embd_k_s() * n_seqs, hparams.n_embd_k_s() * kv_head * ggml_element_size(kv_self->k_l[il])) + ggml_view_1d(ctx0, kv_state->get_k_l(il), hparams.n_embd_k_s()*n_seqs, hparams.n_embd_k_s()*kv_head*ggml_element_size(kv_state->get_k_l(il))) ); } diff --git a/src/llama-graph.h b/src/llama-graph.h index 2b85bb25befba..d1c5dd1bf036f 100644 --- a/src/llama-graph.h +++ b/src/llama-graph.h @@ -17,10 +17,11 @@ struct ggml_tensor; struct llama_ubatch; struct llama_cparams; -class llama_memory_i; -class llama_kv_cache_unified; -class llama_kv_cache_unified_iswa; -class llama_kv_cache_recurrent; +class llama_memory_state_i; + +class llama_kv_cache_unified_state; +class llama_kv_cache_unified_iswa_state; +class llama_kv_cache_recurrent_state; // certain models (typically multi-modal) can produce different types of graphs enum llm_graph_type { @@ -133,7 +134,7 @@ class llm_graph_input_pos_bucket_kv : public llm_graph_input_i { public: llm_graph_input_pos_bucket_kv( const llama_hparams & hparams, - const llama_kv_cache_unified * kv_self) : hparams(hparams), kv_self(kv_self) {} + const llama_kv_cache_unified_state * kv_state) : hparams(hparams), kv_state(kv_state) {} virtual ~llm_graph_input_pos_bucket_kv() = default; void set_input(const llama_ubatch * ubatch) override; @@ -141,7 +142,7 @@ class llm_graph_input_pos_bucket_kv : public llm_graph_input_i { ggml_tensor * pos_bucket = nullptr; // I32 [n_kv, n_batch] const llama_hparams & hparams; - const llama_kv_cache_unified * kv_self; + const llama_kv_cache_unified_state * kv_state; }; class llm_graph_input_out_ids : public llm_graph_input_i { @@ -188,26 +189,26 @@ class llm_graph_input_cls : public llm_graph_input_i { class llm_graph_input_s_copy : public llm_graph_input_i { public: - llm_graph_input_s_copy(const llama_kv_cache_recurrent * kv_self) : kv_self(kv_self) {} + llm_graph_input_s_copy(const llama_kv_cache_recurrent_state * kv_state) : kv_state(kv_state) {} virtual ~llm_graph_input_s_copy() = default; void set_input(const llama_ubatch * ubatch) override; ggml_tensor * s_copy; // I32 [kv_size] - const llama_kv_cache_recurrent * kv_self; + const llama_kv_cache_recurrent_state * kv_state; }; class llm_graph_input_s_mask : public llm_graph_input_i { public: - llm_graph_input_s_mask(const llama_kv_cache_recurrent * kv_self) : kv_self(kv_self) {} + llm_graph_input_s_mask(const llama_kv_cache_recurrent_state * kv_state) : kv_state(kv_state) {} virtual ~llm_graph_input_s_mask() = default; void set_input(const llama_ubatch * ubatch) override; ggml_tensor * s_mask; // F32 [1, n_kv] - const llama_kv_cache_recurrent * kv_self; + const llama_kv_cache_recurrent_state * kv_state; }; class llm_graph_input_cross_embd : public llm_graph_input_i { @@ -247,10 +248,10 @@ class llm_graph_input_attn_kv_unified : public llm_graph_input_i { llm_graph_input_attn_kv_unified( const llama_hparams & hparams, const llama_cparams & cparams, - const llama_kv_cache_unified * kv_self) : + const llama_kv_cache_unified_state * kv_state) : hparams(hparams), cparams(cparams), - kv_self(kv_self) { + kv_state(kv_state) { } ~llm_graph_input_attn_kv_unified() = default; @@ -264,7 +265,7 @@ class llm_graph_input_attn_kv_unified : public llm_graph_input_i { const llama_hparams & hparams; const llama_cparams & cparams; - const llama_kv_cache_unified * kv_self; + const llama_kv_cache_unified_state * kv_state; }; class llm_graph_input_attn_kv_unified_iswa : public llm_graph_input_i { @@ -272,10 +273,10 @@ class llm_graph_input_attn_kv_unified_iswa : public llm_graph_input_i { llm_graph_input_attn_kv_unified_iswa( const llama_hparams & hparams, const llama_cparams & cparams, - const llama_kv_cache_unified_iswa * kv_self) : + const llama_kv_cache_unified_iswa_state * kv_state) : hparams(hparams), cparams(cparams), - kv_self(kv_self) { + kv_state(kv_state) { } ~llm_graph_input_attn_kv_unified_iswa() = default; @@ -292,7 +293,7 @@ class llm_graph_input_attn_kv_unified_iswa : public llm_graph_input_i { const llama_hparams & hparams; const llama_cparams & cparams; - const llama_kv_cache_unified_iswa * kv_self; + const llama_kv_cache_unified_iswa_state * kv_state; }; class llm_graph_input_attn_cross : public llm_graph_input_i { @@ -383,10 +384,10 @@ struct llm_graph_params { ggml_backend_sched_t sched; ggml_backend_t backend_cpu; - const llama_adapter_cvec * cvec; - const llama_adapter_loras * loras; - const llama_memory_i * memory; - const llama_cross * cross; + const llama_adapter_cvec * cvec; + const llama_adapter_loras * loras; + const llama_memory_state_i * mstate; + const llama_cross * cross; int32_t n_outputs; @@ -435,10 +436,10 @@ struct llm_graph_context { ggml_backend_t backend_cpu; // TODO: needed by build_attn_mha, figure out a way to remove? - const llama_adapter_cvec * cvec; - const llama_adapter_loras * loras; - const llama_memory_i * memory; - const llama_cross * cross; + const llama_adapter_cvec * cvec; + const llama_adapter_loras * loras; + const llama_memory_state_i * mstate; + const llama_cross * cross; const llm_graph_cb & cb_func; diff --git a/src/llama-kv-cache.cpp b/src/llama-kv-cache.cpp index 766f8d079afb2..86c4f2816f809 100644 --- a/src/llama-kv-cache.cpp +++ b/src/llama-kv-cache.cpp @@ -17,11 +17,6 @@ // llama_kv_cache_unified // -uint32_t llama_kv_cache_unified::get_padding(const llama_cparams & cparams) { - // the FA kernels require padding to avoid extra runtime boundary checks - return cparams.flash_attn ? 256u : 32u; -} - llama_kv_cache_unified::llama_kv_cache_unified( const llama_model & model, layer_filter_cb && filter, @@ -293,26 +288,81 @@ llama_pos llama_kv_cache_unified::seq_pos_max(llama_seq_id seq_id) const { return cells.seq_pos_max(seq_id); } -void llama_kv_cache_unified::restore() { - for (auto & state : recovery.states) { - cells.set(state.i, state.cells); +llama_memory_state_ptr llama_kv_cache_unified::init_batch( + const llama_batch & batch, + uint32_t n_ubatch, + bool embd_pooled, + bool logits_all) { + GGML_UNUSED(embd_pooled); + + auto sbatch = llama_sbatch(batch, hparams.n_embd, true, logits_all); + + std::vector ubatches; + while (sbatch.n_tokens > 0) { + ubatches.push_back(sbatch.split_simple(n_ubatch)); + } + + auto heads = prepare(ubatches); + if (heads.empty()) { + return std::make_unique(LLAMA_MEMORY_STATUS_FAILED_PREPARE); } - recovery.clear(); + return std::make_unique(LLAMA_MEMORY_STATUS_SUCCESS, + this, std::move(sbatch), std::move(heads), std::move(ubatches)); } -void llama_kv_cache_unified::commit() { - if (recovery.states.empty()) { - LLAMA_LOG_WARN("%s: the recovery information upon a commit was empty - might indicate a bug (ref: %s)\n", - __func__, "https://github.com/ggml-org/llama.cpp/pull/13194"); - return; +llama_memory_state_ptr llama_kv_cache_unified::init_full() { + return std::make_unique(LLAMA_MEMORY_STATUS_SUCCESS, this); +} + +std::vector llama_kv_cache_unified::prepare(const std::vector & ubatches) { + std::vector res; + + struct state { + uint32_t head_old; // old position of the head, before placing the ubatch + uint32_t head_new; // new position of the head, after placing the ubatch + + llama_kv_cells_unified cells; // copy of the old cells, before placing the ubatch + }; + + // remember the old state of the cells so we can restore it in the end + std::vector states; + + bool success = true; + + for (const auto & ubatch : ubatches) { + // only find a suitable slot for the ubatch. don't modify the cells yet + const int32_t head_new = find_slot(ubatch); + if (head_new < 0) { + success = false; + break; + } + + // remeber the position that we found + res.push_back(head_new); + + // store the old state of the cells in the recovery stack + states.push_back({head, (uint32_t) head_new, cells.cp(head_new, ubatch.n_tokens)}); + + // now emplace the ubatch + apply_ubatch(head_new, ubatch); + } + + // iterate backwards and restore the cells to their original state + for (auto it = states.rbegin(); it != states.rend(); ++it) { + cells.set(it->head_new, it->cells); + head = it->head_old; + } + + if (!success) { + return {}; } - recovery.clear(); + return res; } bool llama_kv_cache_unified::update(llama_context & lctx) { - bool need_reserve = false; + bool updated = false; auto * sched = lctx.get_sched(); @@ -330,14 +380,24 @@ bool llama_kv_cache_unified::update(llama_context & lctx) { auto * gf = lctx.graph_init(); auto res = build_graph_shift(lctx.get_cparams(), lctx.get_ctx_compute(), gf); + if (!res) { + LLAMA_LOG_ERROR("%s: failed to build graph for K-shift\n", __func__); + return updated; + } - ggml_backend_sched_alloc_graph(sched, gf); + if (!ggml_backend_sched_alloc_graph(sched, gf)) { + LLAMA_LOG_ERROR("%s: failed to allocate compute graph for K-shift\n", __func__); + return updated; + } res->set_inputs(nullptr); - lctx.graph_compute(gf, false); + if (lctx.graph_compute(gf, false) != GGML_STATUS_SUCCESS) { + LLAMA_LOG_ERROR("%s: failed to compute K-shift\n", __func__); + return updated; + } - need_reserve = true; + updated = true; } cells.reset_shift(); @@ -352,26 +412,38 @@ bool llama_kv_cache_unified::update(llama_context & lctx) { auto * gf = lctx.graph_init(); auto res = build_graph_defrag(lctx.get_cparams(), lctx.get_ctx_compute(), gf); + if (!res) { + LLAMA_LOG_ERROR("%s: failed to build graph for defrag\n", __func__); + return updated; + } - ggml_backend_sched_alloc_graph(sched, gf); + if (!ggml_backend_sched_alloc_graph(sched, gf)) { + LLAMA_LOG_ERROR("%s: failed to allocate compute graph for defrag\n", __func__); + return updated; + } res->set_inputs(nullptr); - lctx.graph_compute(gf, false); + if (lctx.graph_compute(gf, false) != GGML_STATUS_SUCCESS) { + LLAMA_LOG_ERROR("%s: failed to compute defrag\n", __func__); + return updated; + } - need_reserve = true; + updated = true; } do_defrag = false; } - return need_reserve; + return updated; } void llama_kv_cache_unified::defrag_sched(float thold) { + const auto n_kv = cells.used_max_p1(); + // - do not defrag small contexts (i.e. < 2048 tokens) // - count the padding towards the number of used tokens - const float fragmentation = n >= 2048 ? std::max(0.0f, 1.0f - (float(cells.get_used() + n_pad)/n)) : 0.0f; + const float fragmentation = n_kv >= 2048 ? std::max(0.0f, 1.0f - (float(cells.get_used() + n_pad)/n_kv)) : 0.0f; // queue defragmentation for next llama_kv_cache_update if (fragmentation > thold) { @@ -381,55 +453,37 @@ void llama_kv_cache_unified::defrag_sched(float thold) { } } -void llama_kv_cache_unified::set_full() { - n = cells.size(); - - // when simulating a full KV cache, the specific value of the "head" pointer is not important because it does not - // affect the shapes of the tensors in the compute graph - it only affects the offsets of the K/V views. - // we should only guarantee that the head position won't cause out-of-bounds view of the K, V tensors, so - // setting it to 0 is the simplest way to achieve that - // ref: https://github.com/ggml-org/llama.cpp/issues/13359 - head = 0; -} - -llama_sbatch llama_kv_cache_unified::sbatch_init(const llama_batch & batch, bool logits_all) { - return llama_sbatch(batch, hparams.n_embd, true, logits_all); -} - -llama_ubatch llama_kv_cache_unified::ubatch_next(llama_sbatch & sbatch, uint32_t n_ubatch, bool embd_pooled) const { - GGML_UNUSED(embd_pooled); - return sbatch.split_simple(n_ubatch); -} - -bool llama_kv_cache_unified::find_slot(const llama_ubatch & ubatch) { +int32_t llama_kv_cache_unified::find_slot(const llama_ubatch & ubatch) const { const uint32_t n_tokens = ubatch.n_tokens; + uint32_t head_cur = this->head; + // if we have enough unused cells before the current head -> // better to start searching from the beginning of the cache, hoping to fill it - if (head > cells.get_used() + 2*ubatch.n_tokens) { - head = 0; + if (head_cur > cells.get_used() + 2*ubatch.n_tokens) { + head_cur = 0; } // otherwise, one cell per token. if (n_tokens > cells.size()) { LLAMA_LOG_ERROR("%s: n_tokens = %d > size = %u\n", __func__, n_tokens, cells.size()); - return false; + return -1; } //#define FIND_SLOT_DEBUG 1 #if FIND_SLOT_DEBUG - LLAMA_LOG_WARN("begin: n = %5d, used = %5d, head = %5d, n_swa = %5d\n", n, used, head, n_swa); + LLAMA_LOG_WARN("begin: n = %5d, used = %5d, head = %5d, n_swa = %5d\n", cells.used_max_p1(), cells.get_used(), head, n_swa); // for debugging { std::string ss; if (n_swa > 0) { - for (uint32_t i = 0; i < size; ++i) { + for (uint32_t i = 0; i < cells.size(); ++i) { if (cells.is_empty(i)) { ss += '.'; } else { - ss += 'x'; + ss += std::to_string(cells.seq_get(i)); } if (i%256 == 255) { ss += '\n'; @@ -438,23 +492,70 @@ bool llama_kv_cache_unified::find_slot(const llama_ubatch & ubatch) { } LLAMA_LOG_WARN("\n%s\n", ss.c_str()); } + + for (int s = 0; s < LLAMA_MAX_PARALLEL_SEQUENCES; ++s) { + if (cells.seq_pos_min(s) < 0) { + continue; + } + + LLAMA_LOG_WARN("kv_cells: n_swa = %4d, min[%d] = %5d, max[%d] = %5d\n", n_swa, s, cells.seq_pos_min(s), s, cells.seq_pos_max(s)); + } #endif uint32_t n_tested = 0; while (true) { - if (head + n_tokens > cells.size()) { - n_tested += cells.size() - head; - head = 0; + if (head_cur + n_tokens > cells.size()) { + n_tested += cells.size() - head_cur; + head_cur = 0; continue; } + // keep track of what the minimum sequence positions would be if we accept the ubatch + llama_seq_id seq_pos_min[LLAMA_MAX_PARALLEL_SEQUENCES]; + for (int s = 0; s < LLAMA_MAX_PARALLEL_SEQUENCES; ++s) { + seq_pos_min[s] = cells.seq_pos_min(s); + } + bool found = true; for (uint32_t i = 0; i < n_tokens; i++) { - // TODO: improve to accept cells that are masked by the SWA - if (!cells.is_empty(head + i)) { + const llama_pos pos = ubatch.pos[i]; + const llama_seq_id seq_id = ubatch.seq_id[i][0]; + + // can we use this cell? either: + // - the cell is empty + // - the cell is occupied only by one sequence: + // - mask causally, if the sequence is the same as the one we are inserting + // - mask SWA, using current max pos for that sequence in the cache + // always insert in the cell with minimum pos + bool can_use = cells.is_empty(head_cur + i); + + if (!can_use && cells.seq_count(head_cur + i) == 1) { + const llama_pos pos_cell = cells.pos_get(head_cur + i); + + // causal mask + if (cells.seq_has(head_cur + i, seq_id)) { + can_use = pos_cell >= pos; + } + + if (!can_use) { + const llama_seq_id seq_id_cell = cells.seq_get(head_cur + i); + + // SWA mask + // note: we insert only in the cell with minimum pos in order to preserve the invariant that + // all positions between [pos_min, pos_max] for each sequence will be present in the cache + // ref: https://github.com/ggml-org/llama.cpp/pull/13746#issuecomment-2916057092 + if (pos_cell == seq_pos_min[seq_id_cell] && + is_masked_swa(pos_cell, cells.seq_pos_max(seq_id_cell) + 1)) { + seq_pos_min[seq_id_cell]++; + can_use = true; + } + } + } + + if (!can_use) { found = false; - head += i + 1; + head_cur += i + 1; n_tested += i + 1; break; } @@ -466,58 +567,55 @@ bool llama_kv_cache_unified::find_slot(const llama_ubatch & ubatch) { if (n_tested >= cells.size()) { //LLAMA_LOG_ERROR("%s: failed to find a slot for %d tokens\n", __func__, n_tokens); - return false; + return -1; } } - // store the old state of the cells in the recovery stack - recovery.states.push_back({head, cells.cp(head, n_tokens)}); + return head_cur; +} + +void llama_kv_cache_unified::apply_ubatch(uint32_t head_cur, const llama_ubatch & ubatch) { + for (uint32_t i = 0; i < ubatch.n_tokens; ++i) { + if (!cells.is_empty(head_cur + i)) { + cells.rm(head_cur + i); + } - for (uint32_t i = 0; i < n_tokens; ++i) { - cells.pos_set(head + i, ubatch.pos[i]); + cells.pos_set(head_cur + i, ubatch.pos[i]); for (int32_t j = 0; j < ubatch.n_seq_id[i]; j++) { - cells.seq_add(head + i, ubatch.seq_id[i][j]); + cells.seq_add(head_cur + i, ubatch.seq_id[i][j]); } } - // a heuristic, to avoid attending the full cache if it is not yet utilized - // after enough generations, the benefit from this heuristic disappears - // if we start defragmenting the cache, the benefit from this will be more important - n = std::min(cells.size(), std::max(n_pad, GGML_PAD(cells.used_max_p1(), n_pad))); - -#ifdef FIND_SLOT_DEBUG - LLAMA_LOG_WARN("end: n = %5d, used = %5d, head = %5d, n_swa = %5d\n", n, used, head, n_swa); -#endif - - return true; + // move the head at the end of the slot + head = head_cur + ubatch.n_tokens; } bool llama_kv_cache_unified::get_can_shift() const { return true; } -uint32_t llama_kv_cache_unified::get_n() const { - return n; -} - uint32_t llama_kv_cache_unified::get_size() const { return cells.size(); } -ggml_tensor * llama_kv_cache_unified::get_k(ggml_context * ctx, int32_t il) const { +uint32_t llama_kv_cache_unified::get_n_kv() const { + return std::min(cells.size(), std::max(n_pad, GGML_PAD(cells.used_max_p1(), n_pad))); +} + +ggml_tensor * llama_kv_cache_unified::get_k(ggml_context * ctx, int32_t il, uint32_t n_kv) const { const int32_t ikv = map_layer_ids.at(il); auto * k = layers[ikv].k; return ggml_view_3d(ctx, k, - hparams.n_embd_head_k, hparams.n_head_kv(il), n, + hparams.n_embd_head_k, hparams.n_head_kv(il), n_kv, ggml_row_size(k->type, hparams.n_embd_head_k), ggml_row_size(k->type, hparams.n_embd_k_gqa(il)), 0); } -ggml_tensor * llama_kv_cache_unified::get_v(ggml_context * ctx, int32_t il) const { +ggml_tensor * llama_kv_cache_unified::get_v(ggml_context * ctx, int32_t il, uint32_t n_kv) const { const int32_t ikv = map_layer_ids.at(il); auto * v = layers[ikv].v; @@ -525,7 +623,7 @@ ggml_tensor * llama_kv_cache_unified::get_v(ggml_context * ctx, int32_t il) cons if (!v_trans) { // note: v->nb[1] <= v->nb[2] return ggml_view_3d(ctx, v, - hparams.n_embd_head_v, hparams.n_head_kv(il), n, + hparams.n_embd_head_v, hparams.n_head_kv(il), n_kv, ggml_row_size(v->type, hparams.n_embd_head_v), // v->nb[1] ggml_row_size(v->type, hparams.n_embd_v_gqa(il)), // v->nb[2] 0); @@ -533,13 +631,13 @@ ggml_tensor * llama_kv_cache_unified::get_v(ggml_context * ctx, int32_t il) cons // note: v->nb[1] > v->nb[2] return ggml_view_3d(ctx, v, - n, hparams.n_head_kv(il), hparams.n_embd_head_v, + n_kv, hparams.n_head_kv(il), hparams.n_embd_head_v, ggml_row_size(v->type, v->ne[1]*hparams.n_embd_head_v), // v->nb[1] ggml_row_size(v->type, v->ne[1]), // v->nb[2] 0); } -ggml_tensor * llama_kv_cache_unified::cpy_k(ggml_context * ctx, ggml_tensor * k_cur, int32_t il) const { +ggml_tensor * llama_kv_cache_unified::cpy_k(ggml_context * ctx, ggml_tensor * k_cur, int32_t il, uint32_t head_cur) const { const int32_t ikv = map_layer_ids.at(il); auto * k = layers[ikv].k; @@ -548,12 +646,12 @@ ggml_tensor * llama_kv_cache_unified::cpy_k(ggml_context * ctx, ggml_tensor * k_ ggml_tensor * k_view = ggml_view_1d(ctx, k, n_tokens*hparams.n_embd_k_gqa(il), - ggml_row_size(k->type, hparams.n_embd_k_gqa(il))*head); + ggml_row_size(k->type, hparams.n_embd_k_gqa(il))*head_cur); return ggml_cpy(ctx, k_cur, k_view); } -ggml_tensor * llama_kv_cache_unified::cpy_v(ggml_context * ctx, ggml_tensor * v_cur, int32_t il) const { +ggml_tensor * llama_kv_cache_unified::cpy_v(ggml_context * ctx, ggml_tensor * v_cur, int32_t il, uint32_t head_cur) const { const int32_t ikv = map_layer_ids.at(il); auto * v = layers[ikv].v; @@ -567,12 +665,12 @@ ggml_tensor * llama_kv_cache_unified::cpy_v(ggml_context * ctx, ggml_tensor * v_ if (!v_trans) { v_view = ggml_view_1d(ctx, v, n_tokens*hparams.n_embd_v_gqa(il), - ggml_row_size(v->type, hparams.n_embd_v_gqa(il))*head); + ggml_row_size(v->type, hparams.n_embd_v_gqa(il))*head_cur); } else { // note: the V cache is transposed when not using flash attention v_view = ggml_view_2d(ctx, v, n_tokens, hparams.n_embd_v_gqa(il), (v->ne[1])*ggml_element_size(v), - ( head)*ggml_element_size(v)); + (head_cur)*ggml_element_size(v)); v_cur = ggml_transpose(ctx, v_cur); } @@ -580,33 +678,6 @@ ggml_tensor * llama_kv_cache_unified::cpy_v(ggml_context * ctx, ggml_tensor * v_ return ggml_cpy(ctx, v_cur, v_view); } -void llama_kv_cache_unified::prune_swa(llama_seq_id seq_id, llama_pos pmin, llama_pos pmax) { - // no pruning is needed when the cache does not use SWA - GGML_ASSERT(swa_type != LLAMA_SWA_TYPE_NONE && "do not prune non-SWA cache"); - - int n_attended = 0; - - for (uint32_t i = 0; i < cells.size(); ++i) { - if (!cells.seq_has(i, seq_id)) { - continue; - } - - const llama_pos p0 = cells.pos_get(i); - - if (p0 <= pmin && !is_masked_swa(p0, pmin)) { - n_attended++; - } - - if (is_masked_swa(p0, pmax)) { - cells.seq_rm(i, seq_id); - } - } - - if (n_attended < std::min(n_swa, pmin)) { - LLAMA_LOG_WARN("%s: partial SWA cache detected - possible loss of information, pmin = %d, n_attended = %d, n_swa = %d\n", __func__, pmin, n_attended, n_swa); - } -} - void llama_kv_cache_unified::set_input_kq_mask(ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn) const { const int64_t n_tokens = ubatch->n_tokens; const int64_t n_seq_tokens = ubatch->n_seq_tokens; @@ -615,7 +686,7 @@ void llama_kv_cache_unified::set_input_kq_mask(ggml_tensor * dst, const llama_ub GGML_ASSERT(ggml_backend_buffer_is_host(dst->buffer)); float * data = (float *) dst->data; - const int64_t n_kv = n; + const auto n_kv = dst->ne[0]; // Use only the previous KV cells of the correct sequence for each token of the ubatch. // It's assumed that if a token in the batch has multiple sequences, they are equivalent. @@ -636,7 +707,7 @@ void llama_kv_cache_unified::set_input_kq_mask(ggml_tensor * dst, const llama_ub for (int j = 0; j < n_seq_tokens; ++j) { const llama_pos p1 = ubatch->pos[s*n_seq_tokens + j]; - for (int i = 0; i < n_kv; ++i) { + for (uint32_t i = 0; i < n_kv; ++i) { float f = 0.0f; bool masked = false; @@ -672,7 +743,7 @@ void llama_kv_cache_unified::set_input_kq_mask(ggml_tensor * dst, const llama_ub // mask padded tokens if (data) { for (int i = n_tokens; i < GGML_PAD(n_tokens, GGML_KQ_MASK_PAD); ++i) { - for (int j = 0; j < n_kv; ++j) { + for (uint32_t j = 0; j < n_kv; ++j) { data[h*(n_kv*n_tokens) + i*n_kv + j] = -INFINITY; } } @@ -698,7 +769,7 @@ void llama_kv_cache_unified::set_input_pos_bucket(ggml_tensor * dst, const llama int32_t * data = (int32_t *) dst->data; - const int64_t n_kv = n; + const int32_t n_kv = dst->ne[0]; for (int h = 0; h < 1; ++h) { for (int j = 0; j < n_tokens; ++j) { @@ -1362,20 +1433,24 @@ bool llama_kv_cache_unified::state_read_meta(llama_io_read_i & io, uint32_t cell batch.seq_id[i] = &dest_seq_id; } - if (!find_slot(batch)) { + const auto head_cur = find_slot(batch); + if (head_cur < 0) { LLAMA_LOG_ERROR("%s: failed to find available cells in kv cache\n", __func__); return false; } - commit(); + apply_ubatch(head_cur, batch); - // DEBUG CHECK: kv.head should be our first cell, kv.head + cell_count - 1 should be our last cell (verify seq_id and pos values) + // keep the head at the old position because we will read the KV data into it in state_read_data() + head = head_cur; + + // DEBUG CHECK: head_cur should be our first cell, head_cur + cell_count - 1 should be our last cell (verify seq_id and pos values) // Assume that this is one contiguous block of cells - GGML_ASSERT(head + cell_count <= cells.size()); - GGML_ASSERT(cells.pos_get(head) == batch.pos[0]); - GGML_ASSERT(cells.pos_get(head + cell_count - 1) == batch.pos[cell_count - 1]); - GGML_ASSERT(cells.seq_has(head, dest_seq_id)); - GGML_ASSERT(cells.seq_has(head + cell_count - 1, dest_seq_id)); + GGML_ASSERT(head_cur + cell_count <= cells.size()); + GGML_ASSERT(cells.pos_get(head_cur) == batch.pos[0]); + GGML_ASSERT(cells.pos_get(head_cur + cell_count - 1) == batch.pos[cell_count - 1]); + GGML_ASSERT(cells.seq_has(head_cur, dest_seq_id)); + GGML_ASSERT(cells.seq_has(head_cur + cell_count - 1, dest_seq_id)); } else { // whole KV cache restore @@ -1425,10 +1500,12 @@ bool llama_kv_cache_unified::state_read_data(llama_io_read_i & io, uint32_t cell LLAMA_LOG_ERROR("%s: mismatched layer count (%u instead of %u)\n", __func__, n_layer, (uint32_t) layers.size()); return false; } + if (cell_count > cells.size()) { LLAMA_LOG_ERROR("%s: not enough cells in kv cache to restore state (%u > %u)\n", __func__, cell_count, cells.size()); return false; } + if (this->v_trans != (bool) v_trans) { LLAMA_LOG_ERROR("%s: incompatible V transposition\n", __func__); return false; @@ -1539,6 +1616,108 @@ bool llama_kv_cache_unified::state_read_data(llama_io_read_i & io, uint32_t cell return true; } +// +// llama_kv_cache_unified_state +// + +llama_kv_cache_unified_state::llama_kv_cache_unified_state(llama_memory_status status) : status(status) {} + +llama_kv_cache_unified_state::llama_kv_cache_unified_state( + llama_memory_status status, + llama_kv_cache_unified * kv) : status(status), kv(kv) { + n_kv = kv->get_size(); + head = 0; + } + +llama_kv_cache_unified_state::llama_kv_cache_unified_state( + llama_memory_status status, + llama_kv_cache_unified * kv, + llama_sbatch sbatch, + std::vector heads, + std::vector ubatches) + : status(status), + kv(kv), + sbatch(std::move(sbatch)), + heads(std::move(heads)), + ubatches(std::move(ubatches)) { + } + +llama_kv_cache_unified_state::~llama_kv_cache_unified_state() = default; + +bool llama_kv_cache_unified_state::next() { + assert(status == LLAMA_MEMORY_STATUS_SUCCESS); + + if (++i_next >= ubatches.size()) { + return false; + } + + return true; +} + +bool llama_kv_cache_unified_state::apply() { + assert(status == LLAMA_MEMORY_STATUS_SUCCESS); + + kv->apply_ubatch(heads[i_next], ubatches[i_next]); + + n_kv = kv->get_n_kv(); + head = heads[i_next]; + + return true; +} + +std::vector & llama_kv_cache_unified_state::out_ids() { + assert(status == LLAMA_MEMORY_STATUS_SUCCESS); + + return sbatch.out_ids; +} + +llama_memory_status llama_kv_cache_unified_state::get_status() const { + return status; +} + +const llama_ubatch & llama_kv_cache_unified_state::get_ubatch() const { + assert(status == LLAMA_MEMORY_STATUS_SUCCESS); + + return ubatches[i_next]; +} + +uint32_t llama_kv_cache_unified_state::get_n_kv() const { + return n_kv; +} + +ggml_tensor * llama_kv_cache_unified_state::get_k(ggml_context * ctx, int32_t il) const { + return kv->get_k(ctx, il, n_kv); +} + +ggml_tensor * llama_kv_cache_unified_state::get_v(ggml_context * ctx, int32_t il) const { + return kv->get_v(ctx, il, n_kv); +} + +ggml_tensor * llama_kv_cache_unified_state::cpy_k(ggml_context * ctx, ggml_tensor * k_cur, int32_t il) const { + return kv->cpy_k(ctx, k_cur, il, head); +} + +ggml_tensor * llama_kv_cache_unified_state::cpy_v(ggml_context * ctx, ggml_tensor * v_cur, int32_t il) const { + return kv->cpy_v(ctx, v_cur, il, head); +} + +void llama_kv_cache_unified_state::set_input_k_shift(ggml_tensor * dst) const { + kv->set_input_k_shift(dst); +} + +void llama_kv_cache_unified_state::set_input_kq_mask(ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn) const { + kv->set_input_kq_mask(dst, ubatch, causal_attn); +} + +void llama_kv_cache_unified_state::set_input_pos_bucket(ggml_tensor * dst, const llama_ubatch * ubatch) const { + kv->set_input_pos_bucket(dst, ubatch); +} + +uint32_t llama_kv_cache_unified::get_padding(const llama_cparams & cparams) { + // the FA kernels require padding to avoid extra runtime boundary checks + return cparams.flash_attn ? 256u : 32u; +} + // // llama_kv_cache_unified_iswa // @@ -1561,13 +1740,12 @@ llama_kv_cache_unified_iswa::llama_kv_cache_unified_iswa( uint32_t size_swa = std::min(size_base, GGML_PAD(hparams.n_swa*n_seq_max + n_batch, n_pad)); - // when using full-size SWA cache, we set the SWA cache size to be equal to the base cache size and disable pruning + // when using full-size SWA cache, we set the SWA cache size to be equal to the base cache size if (swa_full) { LLAMA_LOG_WARN("%s: using full-size SWA cache (ref: %s)\n", __func__, "https://github.com/ggml-org/llama.cpp/pull/13194#issuecomment-2868343055"); size_swa = size_base; - do_prune = false; } LLAMA_LOG_INFO("%s: creating non-SWA KV cache, size = %u cells\n", __func__, size_base); @@ -1628,31 +1806,46 @@ llama_pos llama_kv_cache_unified_iswa::seq_pos_max(llama_seq_id seq_id) const { return kv_swa->seq_pos_max(seq_id); } -void llama_kv_cache_unified_iswa::restore() { - kv_base->restore(); - kv_swa ->restore(); -} +llama_memory_state_ptr llama_kv_cache_unified_iswa::init_batch(const llama_batch & batch, uint32_t n_ubatch, bool embd_pooled, bool logits_all) { + GGML_UNUSED(embd_pooled); -void llama_kv_cache_unified_iswa::commit() { - kv_base->commit(); - kv_swa ->commit(); + auto sbatch = llama_sbatch(batch, hparams.n_embd, true, logits_all); - // slide the attention window, forgetting/pruning old tokens that are outside the window - if (do_prune) { - for (const auto & [seq_id, entry] : pending.pos) { - kv_swa->prune_swa(seq_id, entry.pmin, entry.pmax); - } + // TODO: if we fail with split_simple, we should attempt split_equal + + std::vector ubatches; + + while (sbatch.n_tokens > 0) { + auto ubatch = sbatch.split_simple(n_ubatch); + + ubatches.push_back(ubatch); + } + auto heads_base = kv_base->prepare(ubatches); + if (heads_base.empty()) { + return std::make_unique(LLAMA_MEMORY_STATUS_FAILED_PREPARE); } - pending.clear(); + auto heads_swa = kv_swa->prepare(ubatches); + if (heads_swa.empty()) { + return std::make_unique(LLAMA_MEMORY_STATUS_FAILED_PREPARE); + } + + assert(heads_base.size() == heads_swa.size()); + + return std::make_unique(LLAMA_MEMORY_STATUS_SUCCESS, + this, std::move(sbatch), std::move(heads_base), std::move(heads_swa), std::move(ubatches)); +} + +llama_memory_state_ptr llama_kv_cache_unified_iswa::init_full() { + return std::make_unique(LLAMA_MEMORY_STATUS_SUCCESS, this); } bool llama_kv_cache_unified_iswa::update(llama_context & lctx) { - bool res = true; + bool res = false; - res = res & kv_base->update(lctx); - res = res & kv_swa ->update(lctx); + res = res | kv_base->update(lctx); + res = res | kv_swa ->update(lctx); return res; } @@ -1662,68 +1855,107 @@ void llama_kv_cache_unified_iswa::defrag_sched(float thold) { kv_swa ->defrag_sched(thold); } -void llama_kv_cache_unified_iswa::set_full() { - kv_base->set_full(); - kv_swa ->set_full(); +bool llama_kv_cache_unified_iswa::get_can_shift() const { + return kv_base->get_size() == kv_swa->get_size(); } -llama_sbatch llama_kv_cache_unified_iswa::sbatch_init(const llama_batch & batch, bool logits_all) { - pending.clear(); +void llama_kv_cache_unified_iswa::state_write(llama_io_write_i & io, llama_seq_id seq_id) const { + kv_base->state_write(io, seq_id); + kv_swa ->state_write(io, seq_id); +} - if (do_prune) { - for (int i = 0; i < batch.n_tokens; ++i) { - for (int s = 0; s < batch.n_seq_id[i]; ++s) { - const llama_seq_id seq_id = batch.seq_id[i][s]; - const llama_pos pos = batch.pos[i]; +void llama_kv_cache_unified_iswa::state_read(llama_io_read_i & io, llama_seq_id seq_id) { + kv_base->state_read(io, seq_id); + kv_swa ->state_read(io, seq_id); +} - if (pending.pos.find(seq_id) == pending.pos.end()) { - pending.pos[seq_id].pmin = pos; - pending.pos[seq_id].pmax = pos; - } else { - pending.pos[seq_id].pmin = std::min(pending.pos[seq_id].pmin, pos); - pending.pos[seq_id].pmax = std::max(pending.pos[seq_id].pmax, pos); - } - } - } - } +llama_kv_cache_unified * llama_kv_cache_unified_iswa::get_base() const { + return kv_base.get(); +} - return llama_sbatch(batch, hparams.n_embd, true, logits_all); +llama_kv_cache_unified * llama_kv_cache_unified_iswa::get_swa() const { + return kv_swa.get(); } -llama_ubatch llama_kv_cache_unified_iswa::ubatch_next(llama_sbatch & sbatch, uint32_t n_ubatch, bool embd_pooled) const { - GGML_UNUSED(embd_pooled); - return sbatch.split_simple(n_ubatch); +// +// llama_kv_cache_unified_iswa_state +// + +llama_kv_cache_unified_iswa_state::llama_kv_cache_unified_iswa_state(llama_memory_status status) : status(status) {} + +llama_kv_cache_unified_iswa_state::llama_kv_cache_unified_iswa_state( + llama_memory_status status, + llama_kv_cache_unified_iswa * kv) : status(status) { + state_base.reset(new llama_kv_cache_unified_state(status, kv->get_base())); + state_swa .reset(new llama_kv_cache_unified_state(status, kv->get_swa ())); +} + +llama_kv_cache_unified_iswa_state::llama_kv_cache_unified_iswa_state( + llama_memory_status status, + llama_kv_cache_unified_iswa * kv, + llama_sbatch sbatch, + std::vector heads_base, + std::vector heads_swa, + std::vector ubatches) + : status(status), + sbatch(std::move(sbatch)), + ubatches(std::move(ubatches)) { + // note: here we copy the ubatches. not sure if this is ideal + state_base.reset(new llama_kv_cache_unified_state(status, kv->get_base(), {}, std::move(heads_base), this->ubatches)); + state_swa .reset(new llama_kv_cache_unified_state(status, kv->get_swa (), {}, std::move(heads_swa), this->ubatches)); + } + +llama_kv_cache_unified_iswa_state:: ~llama_kv_cache_unified_iswa_state() = default; + +bool llama_kv_cache_unified_iswa_state::next() { + assert(status == LLAMA_MEMORY_STATUS_SUCCESS); + + state_base->next(); + state_swa ->next(); + + if (++i_next >= ubatches.size()) { + return false; + } + + return true; } -bool llama_kv_cache_unified_iswa::find_slot(const llama_ubatch & batch) { +bool llama_kv_cache_unified_iswa_state::apply() { + assert(status == LLAMA_MEMORY_STATUS_SUCCESS); + bool res = true; - res = res & kv_base->find_slot(batch); - res = res & kv_swa ->find_slot(batch); + res = res & state_base->apply(); + res = res & state_swa ->apply(); return res; } -bool llama_kv_cache_unified_iswa::get_can_shift() const { - return kv_base->get_size() == kv_swa->get_size(); +std::vector & llama_kv_cache_unified_iswa_state::out_ids() { + assert(status == LLAMA_MEMORY_STATUS_SUCCESS); + + return sbatch.out_ids; } -void llama_kv_cache_unified_iswa::state_write(llama_io_write_i & io, llama_seq_id seq_id) const { - kv_base->state_write(io, seq_id); - kv_swa ->state_write(io, seq_id); +llama_memory_status llama_kv_cache_unified_iswa_state::get_status() const { + return status; } -void llama_kv_cache_unified_iswa::state_read(llama_io_read_i & io, llama_seq_id seq_id) { - kv_base->state_read(io, seq_id); - kv_swa ->state_read(io, seq_id); +const llama_ubatch & llama_kv_cache_unified_iswa_state::get_ubatch() const { + assert(status == LLAMA_MEMORY_STATUS_SUCCESS); + return ubatches[i_next]; } -llama_kv_cache_unified * llama_kv_cache_unified_iswa::get_kv_base() const { - return kv_base.get(); +const llama_kv_cache_unified_state * llama_kv_cache_unified_iswa_state::get_base() const { + assert(status == LLAMA_MEMORY_STATUS_SUCCESS); + + return state_base.get(); } -llama_kv_cache_unified * llama_kv_cache_unified_iswa::get_kv_swa() const { - return kv_swa.get(); +const llama_kv_cache_unified_state * llama_kv_cache_unified_iswa_state::get_swa() const { + assert(status == LLAMA_MEMORY_STATUS_SUCCESS); + + return state_swa.get(); } // @@ -2071,50 +2303,82 @@ llama_pos llama_kv_cache_recurrent::seq_pos_max(llama_seq_id seq_id) const { return result; } -void llama_kv_cache_recurrent::restore() { - if (pending.ranges.empty()) { - return; +llama_memory_state_ptr llama_kv_cache_recurrent::init_batch(const llama_batch & batch, uint32_t n_ubatch, bool embd_pooled, bool logits_all) { + GGML_UNUSED(embd_pooled); + + auto sbatch = llama_sbatch(batch, hparams.n_embd, false, logits_all); + + std::vector ubatches; + + while (sbatch.n_tokens > 0) { + llama_ubatch ubatch; + + if (embd_pooled) { + // Pooled embeddings cannot be split across ubatches (yet) + ubatch = sbatch.split_seq(n_ubatch); + } else { + ubatch = sbatch.split_equal(n_ubatch); + } + + ubatches.push_back(ubatch); } - seq_rm(-1, -1, -1); -} + if (!prepare(ubatches)) { + return std::make_unique(LLAMA_MEMORY_STATUS_FAILED_PREPARE); + } -void llama_kv_cache_recurrent::commit() { - pending.ranges.clear(); + return std::make_unique(LLAMA_MEMORY_STATUS_SUCCESS, this, std::move(sbatch), std::move(ubatches)); } -bool llama_kv_cache_recurrent::update(llama_context & ctx) { - GGML_UNUSED(ctx); - return false; +llama_memory_state_ptr llama_kv_cache_recurrent::init_full() { + return std::make_unique(LLAMA_MEMORY_STATUS_SUCCESS, this); } -void llama_kv_cache_recurrent::defrag_sched(float thold) { - GGML_UNUSED(thold); - // noop -} +bool llama_kv_cache_recurrent::prepare(const std::vector & ubatches) { + // simply remember the full state because it is very small for this type of cache + // TODO: optimize + auto org_cells = cells; + auto org_used = used; + auto org_head = head; -void llama_kv_cache_recurrent::set_full() { - n = size; - head = 0; -} + bool success = true; + + // TODO: here we have to verify that all ubatches can fit in the cells + // however, the current implementation is broken because it relies on s_copy() and s_mask() to update the cells + // during the compute of each ubatch. to reproduce, uncomment the following loop and run: + // + // $ llama-parallel -m ./mamba-130m/ggml-model-f16.gguf -np 5 -ns 8 + // + // recovery from failures when the batch does not fit in the KV cache will not work correctly until this is fixed + // + GGML_UNUSED(ubatches); + //for (const auto & ubatch : ubatches) { + // if (!find_slot(ubatch)) { + // success = false; + // break; + // } + //} -llama_sbatch llama_kv_cache_recurrent::sbatch_init( - const llama_batch & batch, - bool logits_all) { - return llama_sbatch(batch, hparams.n_embd, false, logits_all); + // restore the original state + cells = std::move(org_cells); + used = org_used; + head = org_head; + + return success; } -llama_ubatch llama_kv_cache_recurrent::ubatch_next(llama_sbatch & sbatch, uint32_t n_ubatch, bool embd_pooled) const { - if (embd_pooled) { - // Pooled embeddings cannot be split across ubatches (yet) - return sbatch.split_seq(n_ubatch); - } +bool llama_kv_cache_recurrent::update(llama_context & lctx) { + GGML_UNUSED(lctx); + // noop + return false; +} - return sbatch.split_equal(n_ubatch); +void llama_kv_cache_recurrent::defrag_sched(float thold) { + GGML_UNUSED(thold); + // noop } -bool llama_kv_cache_recurrent::find_slot( - const llama_ubatch & ubatch) { +bool llama_kv_cache_recurrent::find_slot(const llama_ubatch & ubatch) { const uint32_t n_tokens = ubatch.n_tokens; const uint32_t n_seqs = ubatch.n_seqs; @@ -2332,18 +2596,6 @@ float llama_kv_cache_recurrent::s_mask(int i) const { return res; } -uint32_t llama_kv_cache_recurrent::cell_max() const { - for (uint32_t i = size; i > 0; --i) { - const kv_cell & cell = cells[i - 1]; - - if (cell.pos >= 0 && !cell.is_empty()) { - return i; - } - } - - return 0; -} - size_t llama_kv_cache_recurrent::total_size() const { size_t size = 0; for (const auto & buf : bufs) { @@ -2558,11 +2810,11 @@ bool llama_kv_cache_recurrent::state_read_meta(llama_io_read_i & io, uint32_t ce } batch.n_seq_id[0] = 1; batch.seq_id[0] = &dest_seq_id; + if (!find_slot(batch)) { LLAMA_LOG_ERROR("%s: failed to find available cells in kv cache\n", __func__); return false; } - commit(); // DEBUG CHECK: kv.head should be our first cell, kv.head + cell_count - 1 should be our last cell (verify seq_id and pos values) // Assume that this is one contiguous block of cells @@ -2745,3 +2997,84 @@ bool llama_kv_cache_recurrent::state_read_data(llama_io_read_i & io, uint32_t ce return true; } + +// +// llama_kv_cache_recurrent_state +// + +llama_kv_cache_recurrent_state::llama_kv_cache_recurrent_state(llama_memory_status status) : status(status) {} + +llama_kv_cache_recurrent_state::llama_kv_cache_recurrent_state( + llama_memory_status status, + llama_kv_cache_recurrent * kv) : status(status), kv(kv), is_full(true) { +} + +llama_kv_cache_recurrent_state::llama_kv_cache_recurrent_state( + llama_memory_status status, + llama_kv_cache_recurrent * kv, + llama_sbatch sbatch, + std::vector ubatches) : status(status), kv(kv), sbatch(std::move(sbatch)), ubatches(std::move(ubatches)) {} + +llama_kv_cache_recurrent_state::~llama_kv_cache_recurrent_state() = default; + +bool llama_kv_cache_recurrent_state::next() { + assert(status == LLAMA_MEMORY_STATUS_SUCCESS); + + if (++i_next >= ubatches.size()) { + return false; + } + + return true; +} + +bool llama_kv_cache_recurrent_state::apply() { + assert(status == LLAMA_MEMORY_STATUS_SUCCESS); + + kv->find_slot(ubatches[i_next]); + + return true; +} + +std::vector & llama_kv_cache_recurrent_state::out_ids() { + assert(status == LLAMA_MEMORY_STATUS_SUCCESS); + + return sbatch.out_ids; +} + +llama_memory_status llama_kv_cache_recurrent_state::get_status() const { + return status; +} + +const llama_ubatch & llama_kv_cache_recurrent_state::get_ubatch() const { + assert(status == LLAMA_MEMORY_STATUS_SUCCESS); + + return ubatches[i_next]; +} + +uint32_t llama_kv_cache_recurrent_state::get_n_kv() const { + return is_full ? kv->size : kv->n; +} + +uint32_t llama_kv_cache_recurrent_state::get_head() const { + return is_full ? 0 : kv->head; +} + +uint32_t llama_kv_cache_recurrent_state::get_size() const { + return kv->size; +} + +ggml_tensor * llama_kv_cache_recurrent_state::get_k_l(int32_t il) const { + return kv->k_l[il]; +} + +ggml_tensor * llama_kv_cache_recurrent_state::get_v_l(int32_t il) const { + return kv->v_l[il]; +} + +int32_t llama_kv_cache_recurrent_state::s_copy(int i) const { + return kv->s_copy(i); +} + +float llama_kv_cache_recurrent_state::s_mask(int i) const { + return kv->s_mask(i); +} diff --git a/src/llama-kv-cache.h b/src/llama-kv-cache.h index ce6261e45a6e1..d2439e13603a0 100644 --- a/src/llama-kv-cache.h +++ b/src/llama-kv-cache.h @@ -2,6 +2,7 @@ #include "llama.h" #include "llama-io.h" +#include "llama-batch.h" #include "llama-graph.h" #include "llama-memory.h" #include "llama-kv-cells.h" @@ -14,47 +15,34 @@ struct llama_cparams; struct llama_hparams; -struct llama_ubatch; -struct llama_sbatch; struct llama_model; struct llama_context; struct llama_kv_cache : public llama_memory_i { virtual ~llama_kv_cache() = default; - // call if batch processing fails - restores the cache state - virtual void restore() = 0; + // split the input batch into a set of ubatches and verify that they can fit into the cache + // return a state object containing the ubatches and KV cache state required to process them + // check the llama_memory_state_i::get_status() for the result + virtual llama_memory_state_ptr init_batch( + const llama_batch & batch, + uint32_t n_ubatch, + bool embd_pooled, + bool logits_all) = 0; - // call after successful batch processing - clears any pending state - virtual void commit() = 0; + // simulate full cache, used for allocating worst-case compute buffers + virtual llama_memory_state_ptr init_full() = 0; // process any pending defrag/shift/etc. operations // optionally call once before processing a new batch + // return true if any operations were performed virtual bool update(llama_context & lctx) = 0; // schedule a defrag if the fragmentation threshold is exceeded. otherwise, do nothing - virtual void defrag_sched(float thold) = 0; - - // simulate full cache, used for allocating worst-case compute buffers - // TODO: remove - virtual void set_full() = 0; - - // - // batch processing + // TODO: change to + // llama_memory_state_ptr init_defrag(float thold) = 0; // - - // ============================================================================================================= - // TODO: refactor and simplify this [TAG: KV_API] - - virtual llama_sbatch sbatch_init(const llama_batch & batch, bool logits_all) = 0; - - // different KV caches require different batch splitting strategies - virtual llama_ubatch ubatch_next(llama_sbatch & sbatch, uint32_t n_ubatch, bool embd_pooled) const = 0; - - // find an empty slot of size "n_tokens" in the cache - virtual bool find_slot(const llama_ubatch & batch) = 0; - - // ============================================================================================================= + virtual void defrag_sched(float thold) = 0; // getters virtual bool get_can_shift() const = 0; @@ -69,25 +57,6 @@ struct llama_kv_cache : public llama_memory_i { virtual void state_read (llama_io_read_i & io, llama_seq_id seq_id = -1) = 0; }; -// -// llama_kv_cache_guard -// - -struct llama_kv_cache_guard { - llama_kv_cache_guard(llama_kv_cache * kv) : kv(kv) {} - - ~llama_kv_cache_guard() { - kv->restore(); - } - - void commit() { - kv->commit(); - } - -private: - llama_kv_cache * kv; -}; - // // llama_kv_cache_unified // @@ -133,22 +102,17 @@ class llama_kv_cache_unified : public llama_kv_cache { // llama_kv_cache // - void restore() override; - void commit() override; - - bool update(llama_context & ctx) override; + llama_memory_state_ptr init_batch( + const llama_batch & batch, + uint32_t n_ubatch, + bool embd_pooled, + bool logits_all) override; - void defrag_sched(float thold) override; - - void set_full() override; + llama_memory_state_ptr init_full() override; - llama_sbatch sbatch_init(const llama_batch & batch, bool logits_all) override; - llama_ubatch ubatch_next(llama_sbatch & sbatch, uint32_t n_ubatch, bool embd_pooled) const override; + bool update(llama_context & lctx) override; - // updates the cache head - // Note: On success, it's important that cache.head points - // to the first cell of the slot. - bool find_slot(const llama_ubatch & batch) override; + void defrag_sched(float thold) override; bool get_can_shift() const override; @@ -161,18 +125,40 @@ class llama_kv_cache_unified : public llama_kv_cache { // llama_kv_cache_unified specific API // - uint32_t get_n() const; uint32_t get_size() const; + // + // graph_build API + // + + uint32_t get_n_kv() const; + // get views of the current state of the cache - ggml_tensor * get_k(ggml_context * ctx, int32_t il) const; - ggml_tensor * get_v(ggml_context * ctx, int32_t il) const; + ggml_tensor * get_k(ggml_context * ctx, int32_t il, uint32_t n_kv) const; + ggml_tensor * get_v(ggml_context * ctx, int32_t il, uint32_t n_kv) const; - // store k_cur and v_cur in the cache based on the current head location - ggml_tensor * cpy_k(ggml_context * ctx, ggml_tensor * k_cur, int32_t il) const; - ggml_tensor * cpy_v(ggml_context * ctx, ggml_tensor * v_cur, int32_t il) const; + // store k_cur and v_cur in the cache based on the provided head location + ggml_tensor * cpy_k(ggml_context * ctx, ggml_tensor * k_cur, int32_t il, uint32_t head_cur) const; + ggml_tensor * cpy_v(ggml_context * ctx, ggml_tensor * v_cur, int32_t il, uint32_t head_cur) const; + + // + // preparation API + // + + // find places for the provided ubatches in the cache, returns the head locations + // return empty vector on failure + std::vector prepare(const std::vector & ubatches); - void prune_swa(llama_seq_id seq_id, llama_pos pmin, llama_pos pmax); + // return the cell position where we can insert the ubatch + // return -1 on failure to find a contiguous slot of kv cells + int32_t find_slot(const llama_ubatch & ubatch) const; + + // emplace the ubatch context into slot: [head_cur, head_cur + ubatch.n_tokens) + void apply_ubatch(uint32_t head_cur, const llama_ubatch & ubatch); + + // + // set_input API + // void set_input_kq_mask (ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn) const; void set_input_k_shift (ggml_tensor * dst) const; @@ -194,11 +180,9 @@ class llama_kv_cache_unified : public llama_kv_cache { bool do_defrag = false; bool v_trans = true; // the value tensor is transposed - uint32_t head = 0; // the location where the batch will be placed in the cache (see find_slot()) - - // computed before each graph build - // TODO: cells should start to maintain this value dynamically based on the edits - uint32_t n = 0; + // the current index from where we start searching for a free slot in the ring buffer of KV cells (see find_slot()) + // note: this is not part of the KV state and it's only used to speed-up the find_slot() method + uint32_t head = 0; const uint32_t n_seq_max = 1; @@ -220,24 +204,6 @@ class llama_kv_cache_unified : public llama_kv_cache { // model layer id -> KV cache layer id std::unordered_map map_layer_ids; - // recovery information used to restore the KV cells to their original state in case of a failure - // TODO: do not store as a state in the llama_kv_cache object, instead return upon batch preparation - // to achieve that, first need to refactor the llama_kv_cache interface [TAG: KV_API] - struct { - void clear() { - states.clear(); - } - - struct state { - uint32_t i; - - llama_kv_cells_unified cells; - }; - - // stack with the partial states before each ubatch - std::vector states; - } recovery; - // defrag struct { std::vector ids; @@ -279,13 +245,88 @@ class llama_kv_cache_unified : public llama_kv_cache { bool state_read_data(llama_io_read_i & io, uint32_t cell_count); }; +class llama_kv_cache_unified_state : public llama_memory_state_i { +public: + // used for errors + llama_kv_cache_unified_state(llama_memory_status status); + + // used to create a full-cache state + llama_kv_cache_unified_state( + llama_memory_status status, + llama_kv_cache_unified * kv); + + // used to create a state from a batch + llama_kv_cache_unified_state( + llama_memory_status status, + llama_kv_cache_unified * kv, + llama_sbatch sbatch, + std::vector heads, + std::vector ubatches); + + virtual ~llama_kv_cache_unified_state(); + + // + // llama_memory_state_i + // + + bool next() override; + bool apply() override; + + std::vector & out_ids() override; + + llama_memory_status get_status() const override; + const llama_ubatch & get_ubatch() const override; + + // + // llama_kv_cache_unified_state specific API + // + + uint32_t get_n_kv() const; + + // get views of the current state of the cache + ggml_tensor * get_k(ggml_context * ctx, int32_t il) const; + ggml_tensor * get_v(ggml_context * ctx, int32_t il) const; + + // store k_cur and v_cur in the cache based on the provided head location + ggml_tensor * cpy_k(ggml_context * ctx, ggml_tensor * k_cur, int32_t il) const; + ggml_tensor * cpy_v(ggml_context * ctx, ggml_tensor * v_cur, int32_t il) const; + + void set_input_k_shift(ggml_tensor * dst) const; + + void set_input_kq_mask (ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn) const; + void set_input_pos_bucket(ggml_tensor * dst, const llama_ubatch * ubatch) const; + +private: + const llama_memory_status status; + + llama_kv_cache_unified * kv; + + llama_sbatch sbatch; + + // the index of the next ubatch to process + size_t i_next = 0; + + std::vector heads; + std::vector ubatches; + + // + // data needed for building the compute graph for the current ubatch: + // + + // a heuristic, to avoid attending the full cache if it is not yet utilized + // as the cache gets filled, the benefit from this heuristic disappears + int32_t n_kv; + + // the beginning of the current slot in which the ubatch will be inserted + int32_t head; +}; + // // llama_kv_cache_unified_iswa // // utilizes two instances of llama_kv_cache_unified // the first instance is for the non-SWA layers of the model and the second instance is for the SWA layers -// upon successful commit, the SWA cache removes old tokens outside the n_swa window class llama_kv_cache_unified_iswa : public llama_kv_cache { public: @@ -322,19 +363,17 @@ class llama_kv_cache_unified_iswa : public llama_kv_cache { // llama_kv_cache // - void restore() override; - void commit() override; - - bool update(llama_context & ctx) override; - - void defrag_sched(float thold) override; + llama_memory_state_ptr init_batch( + const llama_batch & batch, + uint32_t n_ubatch, + bool embd_pooled, + bool logits_all) override; - void set_full() override; + llama_memory_state_ptr init_full() override; - llama_sbatch sbatch_init(const llama_batch & batch, bool logits_all) override; - llama_ubatch ubatch_next(llama_sbatch & sbatch, uint32_t n_ubatch, bool embd_pooled) const override; + bool update(llama_context & lctx) override; - bool find_slot(const llama_ubatch & batch) override; + void defrag_sched(float thold) override; bool get_can_shift() const override; @@ -347,58 +386,80 @@ class llama_kv_cache_unified_iswa : public llama_kv_cache { // llama_kv_cache_unified_iswa specific API // - llama_kv_cache_unified * get_kv_base() const; - llama_kv_cache_unified * get_kv_swa () const; + llama_kv_cache_unified * get_base() const; + llama_kv_cache_unified * get_swa () const; private: const llama_hparams & hparams; - bool do_prune = true; + std::unique_ptr kv_base; + std::unique_ptr kv_swa; +}; - struct { - struct entry { - llama_pos pmin; - llama_pos pmax; - }; +class llama_kv_cache_unified_iswa_state : public llama_memory_state_i { +public: + // used for errors + llama_kv_cache_unified_iswa_state(llama_memory_status status); - void clear() { - pos.clear(); - } + // used to create a full-cache state + llama_kv_cache_unified_iswa_state( + llama_memory_status status, + llama_kv_cache_unified_iswa * kv); - // used to perform SWA pruning of old tokens - std::unordered_map pos; - } pending; + // used to create a state from a batch + llama_kv_cache_unified_iswa_state( + llama_memory_status status, + llama_kv_cache_unified_iswa * kv, + llama_sbatch sbatch, + std::vector heads_base, + std::vector heads_swa, + std::vector ubatches); - std::unique_ptr kv_base; - std::unique_ptr kv_swa; + virtual ~llama_kv_cache_unified_iswa_state(); + + // + // llama_memory_state_i + // + + bool next() override; + bool apply() override; + + std::vector & out_ids() override; + + llama_memory_status get_status() const override; + const llama_ubatch & get_ubatch() const override; + + // + // llama_kv_cache_unified_iswa_state specific API + // + + const llama_kv_cache_unified_state * get_base() const; + const llama_kv_cache_unified_state * get_swa() const; + +private: + const llama_memory_status status; + + //llama_kv_cache_unified_iswa * kv; + + llama_sbatch sbatch; + + // the index of the next ubatch to process + size_t i_next = 0; + + std::vector ubatches; + + std::unique_ptr state_base; + std::unique_ptr state_swa; }; // // llama_kv_cache_recurrent // +// TODO: extract the KV cache state used for graph computation into llama_kv_cache_recurrent_state_i +// see the implementation of llama_kv_cache_unified_state_i for an example how to do it class llama_kv_cache_recurrent : public llama_kv_cache { public: - struct kv_cell { - llama_pos pos = -1; - int32_t src = -1; // used to copy states - int32_t tail = -1; - - std::set seq_id; - - bool has_seq_id(const llama_seq_id & id) const { - return seq_id.find(id) != seq_id.end(); - } - - bool is_empty() const { - return seq_id.empty(); - } - - bool is_same_seq(const kv_cell & other) const { - return seq_id == other.seq_id; - } - }; - llama_kv_cache_recurrent( const llama_model & model, ggml_type type_k, @@ -428,19 +489,22 @@ class llama_kv_cache_recurrent : public llama_kv_cache { // llama_kv_cache // - void restore() override; - void commit() override; + llama_memory_state_ptr init_batch( + const llama_batch & batch, + uint32_t n_ubatch, + bool embd_pooled, + bool logits_all) override; - bool update(llama_context & ctx) override; + llama_memory_state_ptr init_full() override; - void defrag_sched(float thold) override; + bool update(llama_context & lctx) override; - void set_full() override; + void defrag_sched(float thold) override; - llama_sbatch sbatch_init(const llama_batch & batch, bool logits_all) override; - llama_ubatch ubatch_next(llama_sbatch & sbatch, uint32_t n_ubatch, bool embd_pooled) const override; + bool prepare(const std::vector & ubatches); - bool find_slot(const llama_ubatch & batch) override; + // find a contiguous slot of kv cells and emplace the ubatch there + bool find_slot(const llama_ubatch & ubatch); bool get_can_shift() const override; @@ -460,6 +524,27 @@ class llama_kv_cache_recurrent : public llama_kv_cache { // computed before each graph build uint32_t n = 0; + // TODO: optimize for recurrent state needs + struct kv_cell { + llama_pos pos = -1; + int32_t src = -1; // used to copy states + int32_t tail = -1; + + std::set seq_id; + + bool has_seq_id(const llama_seq_id & id) const { + return seq_id.find(id) != seq_id.end(); + } + + bool is_empty() const { + return seq_id.empty(); + } + + bool is_same_seq(const kv_cell & other) const { + return seq_id == other.seq_id; + } + }; + std::vector cells; std::vector k_l; // per layer @@ -469,26 +554,11 @@ class llama_kv_cache_recurrent : public llama_kv_cache { //const llama_model & model; const llama_hparams & hparams; - // commit/restore cache - // TODO: rework for recurrent cache - struct slot_range { - uint32_t c0 = 0; // note: these are cell indices, not sequence positions - uint32_t c1 = 0; - }; - - // pending cell updates that are not yet committed - struct { - std::vector ranges; - } pending; - const uint32_t n_seq_max = 1; std::vector ctxs; std::vector bufs; - // find how many cells are currently in use - uint32_t cell_max() const; - size_t total_size() const; size_t size_k_bytes() const; @@ -500,3 +570,67 @@ class llama_kv_cache_recurrent : public llama_kv_cache { bool state_read_meta(llama_io_read_i & io, uint32_t cell_count, llama_seq_id dest_seq_id = -1); bool state_read_data(llama_io_read_i & io, uint32_t cell_count); }; + +class llama_kv_cache_recurrent_state : public llama_memory_state_i { +public: + // used for errors + llama_kv_cache_recurrent_state(llama_memory_status status); + + // used to create a full-cache state + llama_kv_cache_recurrent_state( + llama_memory_status status, + llama_kv_cache_recurrent * kv); + + // used to create a state from a batch + llama_kv_cache_recurrent_state( + llama_memory_status status, + llama_kv_cache_recurrent * kv, + llama_sbatch sbatch, + std::vector ubatches); + + virtual ~llama_kv_cache_recurrent_state(); + + // + // llama_memory_state_i + // + + bool next() override; + bool apply() override; + + std::vector & out_ids() override; + + llama_memory_status get_status() const override; + const llama_ubatch & get_ubatch() const override; + + // + // llama_kv_cache_recurrent_state specific API + // + + uint32_t get_n_kv() const; + uint32_t get_head() const; + uint32_t get_size() const; + + ggml_tensor * get_k_l(int32_t il) const; + ggml_tensor * get_v_l(int32_t il) const; + + int32_t s_copy(int i) const; + float s_mask(int i) const; + +private: + const llama_memory_status status; + + llama_kv_cache_recurrent * kv; + + llama_sbatch sbatch; + + size_t i_next = 0; + + std::vector ubatches; + + // + // data needed for building the compute graph for the current ubatch: + // TODO: extract all the state like `head` and `n` here + // + + const bool is_full = false; +}; diff --git a/src/llama-kv-cells.h b/src/llama-kv-cells.h index dbbd03fcba281..9e2c4d927699d 100644 --- a/src/llama-kv-cells.h +++ b/src/llama-kv-cells.h @@ -68,12 +68,6 @@ class llama_kv_cells_unified { // the index of the last cell that is used + 1 // return 0 if no cells are used uint32_t used_max_p1() const { -#if 0 - if (!seq_pos[0].empty()) printf("kv_cells: min[0] = %5d, max[0] = %5d\n", *seq_pos[0].begin(), *seq_pos[0].rbegin()); - if (!seq_pos[1].empty()) printf("kv_cells: min[1] = %5d, max[1] = %5d\n", *seq_pos[1].begin(), *seq_pos[1].rbegin()); - if (!seq_pos[2].empty()) printf("kv_cells: min[2] = %5d, max[2] = %5d\n", *seq_pos[2].begin(), *seq_pos[2].rbegin()); -#endif - return used.empty() ? 0 : *used.rbegin() + 1; } @@ -144,6 +138,19 @@ class llama_kv_cells_unified { } } + // clear a non-empty cell + void rm(uint32_t i) { + assert(i < pos.size()); + assert(pos[i] != -1); + + seq_pos_rm(i); + + pos[i] = -1; + seq[i].reset(); + + used.erase(i); + } + // note: call only if the cell has seq_id // return true if the cell becomes empty bool seq_rm(uint32_t i, llama_seq_id seq_id) { @@ -196,6 +203,15 @@ class llama_kv_cells_unified { return false; } + // number of different sequences in the cell + int seq_count(uint32_t i) const { + assert(i < pos.size()); + assert(pos[i] != -1); + + return seq[i].count(); + } + + // check if the cell contains seq_id bool seq_has(uint32_t i, llama_seq_id seq_id) const { assert(i < pos.size()); assert(seq_id >= 0); @@ -213,6 +229,20 @@ class llama_kv_cells_unified { seq_pos[seq_id].insert(pos[i]); } + // return the sequence id of this cell + // note: call only for cells with exactly one sequence + llama_seq_id seq_get(uint32_t i) const { + assert(seq[i].count() == 1); + + for (int s = 0; s < LLAMA_MAX_PARALLEL_SEQUENCES; ++s) { + if (seq[i].test(s)) { + return s; + } + } + + return -1; + } + // the minimum position of sequence seq_id currently present in any of the cells // return -1 if the sequence is not present llama_pos seq_pos_min(llama_seq_id seq_id) const { @@ -268,6 +298,7 @@ class llama_kv_cells_unified { void pos_set(uint32_t i, llama_pos p) { assert(i < pos.size()); assert(pos[i] == -1); + assert(seq[i].none()); pos[i] = p; diff --git a/src/llama-memory.h b/src/llama-memory.h index a2d250434affa..b3799d66e8c17 100644 --- a/src/llama-memory.h +++ b/src/llama-memory.h @@ -2,6 +2,11 @@ #include "llama.h" +#include +#include + +struct llama_ubatch; + struct llama_memory_params { // kv cache ggml_type type_k; @@ -30,3 +35,42 @@ class llama_memory_i { virtual bool get_can_edit() const = 0; }; + +enum llama_memory_status { + LLAMA_MEMORY_STATUS_SUCCESS = 0, + LLAMA_MEMORY_STATUS_FAILED_PREPARE, + LLAMA_MEMORY_STATUS_FAILED_COMPUTE, +}; + +// the interface for managing the memory state during batch processing +// this interface is implemented per memory type. see: +// - llama_kv_cache_unified_state +// - llama_kv_cache_unified_iswa_state +// ... +// +// the only method that can mutate the memory and the memory state is llama_memory_i::apply() +// +// TODO: rename to llama_memory_context_i ? +class llama_memory_state_i { +public: + virtual ~llama_memory_state_i() = default; + + // consume the current ubatch from the state and proceed to the next one + // return false if we are done + virtual bool next() = 0; + + // apply the memory state for the current ubatch to the memory object + // return false on failure + virtual bool apply() = 0; + + // TODO: this might get reworked in the future when refactoring llama_batch + virtual std::vector & out_ids() = 0; + + // get the current ubatch + virtual const llama_ubatch & get_ubatch() const = 0; + + // get the status of the memory state + virtual llama_memory_status get_status() const = 0; +}; + +using llama_memory_state_ptr = std::unique_ptr; diff --git a/src/llama-model.cpp b/src/llama-model.cpp index 3f1f6c9bf3b06..e85becbb8f695 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -8892,9 +8892,9 @@ struct llm_build_mamba : public llm_graph_context { ggml_tensor * state_mask, const llama_ubatch & ubatch, int il) const { - const llama_kv_cache_recurrent * kv_self = static_cast(memory); + const auto * kv_state = static_cast(mstate); - const auto kv_head = kv_self->head; + const auto kv_head = kv_state->get_head(); const int64_t d_conv = hparams.ssm_d_conv; const int64_t d_inner = hparams.ssm_d_inner; @@ -8912,8 +8912,8 @@ struct llm_build_mamba : public llm_graph_context { GGML_ASSERT(ubatch.equal_seqs); GGML_ASSERT(ubatch.n_tokens == n_seq_tokens * n_seqs); - ggml_tensor * conv_states_all = kv_self->k_l[il]; - ggml_tensor * ssm_states_all = kv_self->v_l[il]; + ggml_tensor * conv_states_all = kv_state->get_k_l(il); + ggml_tensor * ssm_states_all = kv_state->get_v_l(il); // (ab)using the KV cache to store the states ggml_tensor * conv = build_copy_mask_state( @@ -11640,7 +11640,7 @@ struct llm_build_rwkv6_base : public llm_graph_context { ggml_tensor * state_mask, const llama_ubatch & ubatch, int il) const { - const llama_kv_cache_recurrent * kv_self = static_cast(memory); + const auto * kv_state = static_cast(mstate); const auto n_tokens = ubatch.n_tokens; const auto n_seqs = ubatch.n_seqs; @@ -11650,7 +11650,7 @@ struct llm_build_rwkv6_base : public llm_graph_context { const auto n_head = n_embd / head_size; const auto n_head_kv = hparams.n_head_kv(il); - const auto kv_head = kv_self->head; + const auto kv_head = kv_state->get_head(); const auto & layer = model.layers[il]; @@ -11762,7 +11762,7 @@ struct llm_build_rwkv6_base : public llm_graph_context { } ggml_tensor * wkv_state = build_copy_mask_state( - gf, kv_self->v_l[il], state_copy, state_mask, + gf, kv_state->get_v_l(il), state_copy, state_mask, hparams.n_embd_v_s(), n_seqs); ggml_tensor * wkv_output; @@ -11781,9 +11781,9 @@ struct llm_build_rwkv6_base : public llm_graph_context { wkv_state, ggml_view_1d( ctx0, - kv_self->v_l[il], + kv_state->get_v_l(il), hparams.n_embd_v_s() * n_seqs, - hparams.n_embd_v_s() * kv_head * ggml_element_size(kv_self->v_l[il]) + hparams.n_embd_v_s() * kv_head * ggml_element_size(kv_state->get_v_l(il)) ) ) ); @@ -12036,7 +12036,7 @@ struct llm_build_rwkv7_base : public llm_graph_context { ggml_tensor *& first_layer_value, const llama_ubatch & ubatch, int il) const { - const llama_kv_cache_recurrent * kv_self = static_cast(memory); + const auto * kv_state = static_cast(mstate); const auto n_tokens = ubatch.n_tokens; const auto n_seqs = ubatch.n_seqs; @@ -12045,7 +12045,7 @@ struct llm_build_rwkv7_base : public llm_graph_context { const auto head_count = n_embd / head_size; const auto n_seq_tokens = ubatch.n_seq_tokens; - const auto kv_head = kv_self->head; + const auto kv_head = kv_state->get_head(); const auto & layer = model.layers[il]; @@ -12116,7 +12116,7 @@ struct llm_build_rwkv7_base : public llm_graph_context { a = ggml_reshape_3d(ctx0, a, head_size, head_count, n_tokens); ggml_tensor * wkv_state = build_copy_mask_state( - gf, kv_self->v_l[il], state_copy, state_mask, + gf, kv_state->get_v_l(il), state_copy, state_mask, hparams.n_embd_v_s(), n_seqs); ggml_tensor * wkv_output = ggml_rwkv_wkv7(ctx0, r, w, k, v, ggml_neg(ctx0, kk), ggml_mul(ctx0, kk, a), wkv_state); @@ -12130,9 +12130,9 @@ struct llm_build_rwkv7_base : public llm_graph_context { wkv_state, ggml_view_1d( ctx0, - kv_self->v_l[il], + kv_state->get_v_l(il), hparams.n_embd_v_s() * n_seqs, - hparams.n_embd_v_s() * kv_head * ggml_element_size(kv_self->v_l[il]) + hparams.n_embd_v_s() * kv_head * ggml_element_size(kv_state->get_v_l(il)) ) ) ); diff --git a/tools/server/server.cpp b/tools/server/server.cpp index 5d03dc3dc790a..90981ff9a5ef7 100644 --- a/tools/server/server.cpp +++ b/tools/server/server.cpp @@ -3214,9 +3214,12 @@ struct server_context { } if (slot.n_past > 0 && slot.n_past < (int) slot.cache_tokens.size()) { - if (llama_kv_self_seq_pos_min(ctx, slot.id) > 0) { + const auto pos_min = llama_kv_self_seq_pos_min(ctx, slot.id); + if (pos_min > 0) { + SLT_WRN(slot, "n_past = %d, cache_tokens.size() = %d, seq_id = %d, pos_min = %d\n", slot.n_past, (int) slot.cache_tokens.size(), slot.id, pos_min); SLT_WRN(slot, "forcing full prompt re-processing due to lack of cache data (likely due to SWA, see %s)\n", "https://github.com/ggml-org/llama.cpp/pull/13194#issuecomment-2868343055"); + llama_kv_self_seq_rm(ctx, slot.id, 0, -1); slot.n_past = 0; } } @@ -3379,8 +3382,10 @@ struct server_context { } } + int32_t i_next = 0; + // process the created batch of tokens - for (int32_t i = 0; i < batch.n_tokens; i += n_batch) { + for (int32_t i = 0; i < batch.n_tokens; i = i_next) { const int32_t n_tokens = std::min(n_batch, batch.n_tokens - i); llama_batch batch_view = { @@ -3425,13 +3430,18 @@ struct server_context { // retry with half the batch size to try to find a free slot in the KV cache n_batch /= 2; - i -= n_batch; SRV_WRN("failed to find free space in the KV cache, retrying with smaller batch size - try increasing it via the context size or enable defragmentation, i = %d, n_batch = %d, ret = %d\n", i, n_batch, ret); continue; // continue loop of n_batch } + // move the head of the batch forward with the number of tokens we just processed + i_next = i + n_tokens; + + // on successful decode, restore the original batch size + n_batch = llama_n_batch(ctx); + for (auto & slot : slots) { if (slot.i_batch < (int) i || slot.i_batch >= (int) (i + n_tokens)) { continue; // continue loop of slots