From 5e354e3ca21fa74dcf935661bce5bf8b2327576e Mon Sep 17 00:00:00 2001 From: Michael Podvitskiy Date: Tue, 17 Sep 2024 21:43:01 +0200 Subject: [PATCH 1/7] llama: propagating the results of `graph_compute` to the user interface --- src/llama.cpp | 36 ++++++++++++++++++++++++++++++------ 1 file changed, 30 insertions(+), 6 deletions(-) diff --git a/src/llama.cpp b/src/llama.cpp index 4d89c522257c5..c9e11e688fe64 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -17181,7 +17181,7 @@ static void llama_output_reorder(struct llama_context * ctx) { } } -static void llama_graph_compute( +static enum ggml_status llama_graph_compute( llama_context & lctx, ggml_cgraph * gf, int n_threads, @@ -17196,12 +17196,14 @@ static void llama_graph_compute( set_n_threads_fn.second(set_n_threads_fn.first, n_threads); } - auto err = ggml_backend_sched_graph_compute_async(lctx.sched.get(), gf); - if (err != GGML_STATUS_SUCCESS) { - LLAMA_LOG_ERROR("%s: ggml_backend_sched_graph_compute_async failed with error %d\n", __func__, err); + auto status = ggml_backend_sched_graph_compute_async(lctx.sched.get(), gf); + if (status != GGML_STATUS_SUCCESS) { + LLAMA_LOG_ERROR("%s: ggml_backend_sched_graph_compute_async failed with error %d\n", __func__, status); } // fprintf(stderr, "splits: %d\n", ggml_backend_sched_get_n_splits(lctx.sched)); + + return status; } // decode a batch of tokens by evaluating the transformer @@ -17387,7 +17389,18 @@ static int llama_decode_internal( llama_set_inputs(lctx, ubatch); - llama_graph_compute(lctx, gf, n_threads, threadpool); + const auto compute_status = llama_graph_compute(lctx, gf, n_threads, threadpool); + switch (compute_status) { + case GGML_STATUS_SUCCESS: + break; + case GGML_STATUS_ABORTED: + return 2; + case GGML_STATUS_ALLOC_FAILED: + return -2; + case GGML_STATUS_FAILED: + default: + return -3; + } // update the kv ring buffer { @@ -17624,7 +17637,18 @@ static int llama_encode_internal( llama_set_inputs(lctx, ubatch); - llama_graph_compute(lctx, gf, n_threads, threadpool); + const auto compute_status = llama_graph_compute(lctx, gf, n_threads, threadpool); + switch (compute_status) { + case GGML_STATUS_SUCCESS: + break; + case GGML_STATUS_ABORTED: + return 2; + case GGML_STATUS_ALLOC_FAILED: + return -2; + case GGML_STATUS_FAILED: + default: + return -3; + } // extract embeddings if (embd) { From 47018932336f7f7f3b7ae9a79e6d35cc7d354d1c Mon Sep 17 00:00:00 2001 From: Michael Podvitskiy Date: Tue, 24 Sep 2024 21:12:47 +0200 Subject: [PATCH 2/7] llama: reverting kv_cache in case of failed compute --- src/llama.cpp | 59 ++++++++++++++++++++++++++++++++++++++++++--------- 1 file changed, 49 insertions(+), 10 deletions(-) diff --git a/src/llama.cpp b/src/llama.cpp index c9e11e688fe64..427e1ef060d32 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -2811,6 +2811,42 @@ struct llama_kv_cache { } }; +class llama_kv_cache_state { + struct llama_kv_cache_state_short { + uint32_t head = 0; + uint32_t size = 0; + uint32_t used = 0; + uint32_t n = 0; + + std::vector cells; + } old_state; + + bool saved = false; + +public: + void save_state(const llama_kv_cache& cache) { + old_state.head = cache.head; + old_state.size = cache.size; + old_state.used = cache.used; + old_state.n = cache.n; + old_state.cells = cache.cells; + + saved = true; + } + + void restore(llama_kv_cache& cache) { + if (saved) { + cache.head = old_state.head; + cache.size = old_state.size; + cache.used = old_state.used; + cache.n = old_state.n; + cache.cells = std::move(old_state.cells); + + saved = false; + } + } +}; + struct llama_control_vector { std::vector tensors; // per layer std::vector ctxs; @@ -17256,6 +17292,7 @@ static int llama_decode_internal( lctx.n_queued_tokens += n_tokens_all; auto & kv_self = lctx.kv_self; + llama_kv_cache_state kv_cache_state_holder; const int64_t n_embd = hparams.n_embd; const int64_t n_vocab = hparams.n_vocab; @@ -17333,6 +17370,7 @@ static int llama_decode_internal( // non-causal masks do not use the KV cache if (hparams.causal_attn) { llama_kv_cache_update(&lctx); + kv_cache_state_holder.save_state(kv_self); // if we have enough unused cells before the current head -> // better to start searching from the beginning of the cache, hoping to fill it @@ -17390,16 +17428,17 @@ static int llama_decode_internal( llama_set_inputs(lctx, ubatch); const auto compute_status = llama_graph_compute(lctx, gf, n_threads, threadpool); - switch (compute_status) { - case GGML_STATUS_SUCCESS: - break; - case GGML_STATUS_ABORTED: - return 2; - case GGML_STATUS_ALLOC_FAILED: - return -2; - case GGML_STATUS_FAILED: - default: - return -3; + if (compute_status != GGML_STATUS_SUCCESS) { + kv_cache_state_holder.restore(kv_self); + switch (compute_status) { + case GGML_STATUS_ABORTED: + return 2; + case GGML_STATUS_ALLOC_FAILED: + return -2; + case GGML_STATUS_FAILED: + default: + return -3; + } } // update the kv ring buffer From acb9528362d17c2bc76373f4fba5ca185e214e11 Mon Sep 17 00:00:00 2001 From: Michael Podvitskiy Date: Mon, 21 Oct 2024 09:05:33 +0200 Subject: [PATCH 3/7] llama: `llama_kv_cache_state` was removed, only the result of `llama_graph_compute` is returned --- src/llama.cpp | 59 +++++++++------------------------------------------ 1 file changed, 10 insertions(+), 49 deletions(-) diff --git a/src/llama.cpp b/src/llama.cpp index 427e1ef060d32..c9e11e688fe64 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -2811,42 +2811,6 @@ struct llama_kv_cache { } }; -class llama_kv_cache_state { - struct llama_kv_cache_state_short { - uint32_t head = 0; - uint32_t size = 0; - uint32_t used = 0; - uint32_t n = 0; - - std::vector cells; - } old_state; - - bool saved = false; - -public: - void save_state(const llama_kv_cache& cache) { - old_state.head = cache.head; - old_state.size = cache.size; - old_state.used = cache.used; - old_state.n = cache.n; - old_state.cells = cache.cells; - - saved = true; - } - - void restore(llama_kv_cache& cache) { - if (saved) { - cache.head = old_state.head; - cache.size = old_state.size; - cache.used = old_state.used; - cache.n = old_state.n; - cache.cells = std::move(old_state.cells); - - saved = false; - } - } -}; - struct llama_control_vector { std::vector tensors; // per layer std::vector ctxs; @@ -17292,7 +17256,6 @@ static int llama_decode_internal( lctx.n_queued_tokens += n_tokens_all; auto & kv_self = lctx.kv_self; - llama_kv_cache_state kv_cache_state_holder; const int64_t n_embd = hparams.n_embd; const int64_t n_vocab = hparams.n_vocab; @@ -17370,7 +17333,6 @@ static int llama_decode_internal( // non-causal masks do not use the KV cache if (hparams.causal_attn) { llama_kv_cache_update(&lctx); - kv_cache_state_holder.save_state(kv_self); // if we have enough unused cells before the current head -> // better to start searching from the beginning of the cache, hoping to fill it @@ -17428,17 +17390,16 @@ static int llama_decode_internal( llama_set_inputs(lctx, ubatch); const auto compute_status = llama_graph_compute(lctx, gf, n_threads, threadpool); - if (compute_status != GGML_STATUS_SUCCESS) { - kv_cache_state_holder.restore(kv_self); - switch (compute_status) { - case GGML_STATUS_ABORTED: - return 2; - case GGML_STATUS_ALLOC_FAILED: - return -2; - case GGML_STATUS_FAILED: - default: - return -3; - } + switch (compute_status) { + case GGML_STATUS_SUCCESS: + break; + case GGML_STATUS_ABORTED: + return 2; + case GGML_STATUS_ALLOC_FAILED: + return -2; + case GGML_STATUS_FAILED: + default: + return -3; } // update the kv ring buffer From 0026c810d7c18743baa71198409734178df7769a Mon Sep 17 00:00:00 2001 From: Michael Podvitskiy Date: Mon, 21 Oct 2024 10:47:27 +0200 Subject: [PATCH 4/7] llama: restore a kv_cache in case of failed computation --- src/llama.cpp | 77 +++++++++++++++++++++++++++++++++++++++++++-------- 1 file changed, 65 insertions(+), 12 deletions(-) diff --git a/src/llama.cpp b/src/llama.cpp index c9e11e688fe64..85e613a63fef5 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -2811,6 +2811,22 @@ struct llama_kv_cache { } }; +// saves the kv_cache state for future recovery +// used to preserve the kv_cache state before searching for a slot +struct llama_kv_slot_restorer { + struct llama_kv_cache_state { + uint32_t head = 0; + uint32_t size = 0; + uint32_t used = 0; + uint32_t n = 0; + } old_state; + + std::vector recurrent_cells; // for recurrent models only + std::pair slot_boundaries; // for non-recurrent models only + + bool restore = false; +}; + struct llama_control_vector { std::vector tensors; // per layer std::vector ctxs; @@ -3508,11 +3524,19 @@ static bool llama_kv_cache_init( // to the first cell of the slot. static bool llama_kv_cache_find_slot( struct llama_kv_cache & cache, - const struct llama_ubatch & batch) { + const struct llama_ubatch & batch, + struct llama_kv_slot_restorer * slot_restorer = nullptr) { const uint32_t n_tokens = batch.n_tokens; const uint32_t n_seqs = batch.n_seqs; const uint32_t n_seq_tokens = batch.n_seq_tokens; + if (slot_restorer != nullptr) { + slot_restorer->old_state.head = cache.head; + slot_restorer->old_state.size = cache.size; + slot_restorer->old_state.used = cache.used; + slot_restorer->old_state.n = cache.n; + } + if (cache.recurrent) { // For recurrent state architectures (like Mamba or RWKV), // each cache cell can store the state for a whole sequence. @@ -3521,6 +3545,11 @@ static bool llama_kv_cache_find_slot( // can only process batches with an equal number of new tokens in each sequence GGML_ASSERT(batch.equal_seqs); + if (slot_restorer != nullptr) { + slot_restorer->recurrent_cells = cache.cells; + slot_restorer->restore = true; + } + int32_t min = cache.size - 1; int32_t max = 0; @@ -3709,6 +3738,11 @@ static bool llama_kv_cache_find_slot( } } + if (slot_restorer != nullptr) { + slot_restorer->slot_boundaries = std::make_pair(cache.head, cache.head + n_tokens); + slot_restorer->restore = true; + } + for (uint32_t s = 0; s < n_seqs; s++) { for (uint32_t i = 0; i < n_seq_tokens; ++i) { uint32_t k = s*n_seq_tokens + i; @@ -3998,6 +4032,23 @@ static uint32_t llama_kv_cache_get_padding(const struct llama_cparams & cparams) return cparams.flash_attn ? 256u : 32u; } +static void llama_kv_cache_slot_restore( + const struct llama_kv_slot_restorer & restorer, + struct llama_kv_cache & cache) { + if (restorer.restore) { + cache.head = restorer.old_state.head; + cache.size = restorer.old_state.size; + cache.used = restorer.old_state.used; + cache.n = restorer.old_state.n; + + if (cache.recurrent) { + cache.cells = restorer.recurrent_cells; + } else { + llama_kv_cache_seq_rm(cache, -1, restorer.slot_boundaries.first, restorer.slot_boundaries.second + 1); + } + } +} + // // model loading and saving // @@ -17256,6 +17307,7 @@ static int llama_decode_internal( lctx.n_queued_tokens += n_tokens_all; auto & kv_self = lctx.kv_self; + llama_kv_slot_restorer kv_slot_restorer; const int64_t n_embd = hparams.n_embd; const int64_t n_vocab = hparams.n_vocab; @@ -17340,7 +17392,7 @@ static int llama_decode_internal( kv_self.head = 0; } - if (!llama_kv_cache_find_slot(kv_self, ubatch)) { + if (!llama_kv_cache_find_slot(kv_self, ubatch, &kv_slot_restorer)) { return 1; } @@ -17390,16 +17442,17 @@ static int llama_decode_internal( llama_set_inputs(lctx, ubatch); const auto compute_status = llama_graph_compute(lctx, gf, n_threads, threadpool); - switch (compute_status) { - case GGML_STATUS_SUCCESS: - break; - case GGML_STATUS_ABORTED: - return 2; - case GGML_STATUS_ALLOC_FAILED: - return -2; - case GGML_STATUS_FAILED: - default: - return -3; + if (compute_status != GGML_STATUS_SUCCESS) { + llama_kv_cache_slot_restore(kv_slot_restorer, kv_self); + switch (compute_status) { + case GGML_STATUS_ABORTED: + return 2; + case GGML_STATUS_ALLOC_FAILED: + return -2; + case GGML_STATUS_FAILED: + default: + return -3; + } } // update the kv ring buffer From ee599f901a3f4ebbc6e42f2273e3e08dfc5b2646 Mon Sep 17 00:00:00 2001 From: Michael Podvitskiy Date: Tue, 22 Oct 2024 19:57:15 +0200 Subject: [PATCH 5/7] llama: correct reverting of the entire batch. also updates `llama_kv_cache_find_slot`, will correctly count the number of `used` cells for recurrent models --- src/llama.cpp | 122 ++++++++++++++++++++++++++------------------------ 1 file changed, 64 insertions(+), 58 deletions(-) diff --git a/src/llama.cpp b/src/llama.cpp index 85e613a63fef5..48f1f254b35c5 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -2811,22 +2811,6 @@ struct llama_kv_cache { } }; -// saves the kv_cache state for future recovery -// used to preserve the kv_cache state before searching for a slot -struct llama_kv_slot_restorer { - struct llama_kv_cache_state { - uint32_t head = 0; - uint32_t size = 0; - uint32_t used = 0; - uint32_t n = 0; - } old_state; - - std::vector recurrent_cells; // for recurrent models only - std::pair slot_boundaries; // for non-recurrent models only - - bool restore = false; -}; - struct llama_control_vector { std::vector tensors; // per layer std::vector ctxs; @@ -3522,21 +3506,24 @@ static bool llama_kv_cache_init( // updates the cache head // Note: On success, it's important that cache.head points // to the first cell of the slot. -static bool llama_kv_cache_find_slot( +struct llama_kv_cache_slot_info { + std::pair boundaries; + bool found = false; + + explicit llama_kv_cache_slot_info(bool found_) : found{found_} {} + llama_kv_cache_slot_info(uint32_t begin, uint32_t end) : boundaries{begin, end}, found{true} {} + + operator bool() const { return found; } +}; +static const llama_kv_cache_slot_info llama_kv_cache_slot_info_failed{false}; + +static struct llama_kv_cache_slot_info llama_kv_cache_find_slot( struct llama_kv_cache & cache, - const struct llama_ubatch & batch, - struct llama_kv_slot_restorer * slot_restorer = nullptr) { + const struct llama_ubatch & batch) { const uint32_t n_tokens = batch.n_tokens; const uint32_t n_seqs = batch.n_seqs; const uint32_t n_seq_tokens = batch.n_seq_tokens; - if (slot_restorer != nullptr) { - slot_restorer->old_state.head = cache.head; - slot_restorer->old_state.size = cache.size; - slot_restorer->old_state.used = cache.used; - slot_restorer->old_state.n = cache.n; - } - if (cache.recurrent) { // For recurrent state architectures (like Mamba or RWKV), // each cache cell can store the state for a whole sequence. @@ -3545,11 +3532,6 @@ static bool llama_kv_cache_find_slot( // can only process batches with an equal number of new tokens in each sequence GGML_ASSERT(batch.equal_seqs); - if (slot_restorer != nullptr) { - slot_restorer->recurrent_cells = cache.cells; - slot_restorer->restore = true; - } - int32_t min = cache.size - 1; int32_t max = 0; @@ -3563,7 +3545,7 @@ static bool llama_kv_cache_find_slot( // too big seq_id // TODO: would it be possible to resize the cache instead? LLAMA_LOG_ERROR("%s: seq_id=%d >= n_seq_max=%d Try using a bigger --parallel value\n", __func__, seq_id, cache.size); - return false; + return llama_kv_cache_slot_info_failed; } if (j > 0) { llama_kv_cell & seq = cache.cells[seq_id]; @@ -3698,15 +3680,17 @@ static bool llama_kv_cache_find_slot( // allow getting the range of used cells, from head to head + n cache.head = min; cache.n = max - min + 1; + cache.used = std::count_if(cache.cells.begin(), cache.cells.end(), + [](const llama_kv_cell& cell){ return !cell.is_empty(); }); // sanity check - return cache.n >= n_seqs; + return llama_kv_cache_slot_info(cache.n >= n_seqs); } // otherwise, one cell per token. if (n_tokens > cache.size) { LLAMA_LOG_ERROR("%s: n_tokens=%d > cache.size=%d\n", __func__, n_tokens, cache.size); - return false; + return llama_kv_cache_slot_info_failed; } uint32_t n_tested = 0; @@ -3734,15 +3718,10 @@ static bool llama_kv_cache_find_slot( if (n_tested >= cache.size) { //LLAMA_LOG_ERROR("%s: failed to find a slot for %d tokens\n", __func__, n_tokens); - return false; + return llama_kv_cache_slot_info_failed; } } - if (slot_restorer != nullptr) { - slot_restorer->slot_boundaries = std::make_pair(cache.head, cache.head + n_tokens); - slot_restorer->restore = true; - } - for (uint32_t s = 0; s < n_seqs; s++) { for (uint32_t i = 0; i < n_seq_tokens; ++i) { uint32_t k = s*n_seq_tokens + i; @@ -3756,7 +3735,7 @@ static bool llama_kv_cache_find_slot( cache.used += n_tokens; - return true; + return llama_kv_cache_slot_info(cache.head, cache.head + n_tokens); } // find how many cells are currently in use @@ -4032,22 +4011,47 @@ static uint32_t llama_kv_cache_get_padding(const struct llama_cparams & cparams) return cparams.flash_attn ? 256u : 32u; } -static void llama_kv_cache_slot_restore( - const struct llama_kv_slot_restorer & restorer, - struct llama_kv_cache & cache) { - if (restorer.restore) { - cache.head = restorer.old_state.head; - cache.size = restorer.old_state.size; - cache.used = restorer.old_state.used; - cache.n = restorer.old_state.n; - - if (cache.recurrent) { - cache.cells = restorer.recurrent_cells; - } else { - llama_kv_cache_seq_rm(cache, -1, restorer.slot_boundaries.first, restorer.slot_boundaries.second + 1); +// saves the kv_cache state for future recovery. +// used to rollback llama_kv_cache_find_slot changes. +struct llama_kv_slot_restorer { + struct llama_kv_cache_state { + uint32_t head = 0; + uint32_t n = 0; + } old_state; + + std::vector> slot_boundaries; // for non-recurrent models only + + bool do_restore = false; + + explicit llama_kv_slot_restorer(const struct llama_kv_cache & cache) { + old_state.head = cache.head; + old_state.n = cache.n; + } + + void save(const struct llama_kv_cache_slot_info& slot) { + if (slot) { + do_restore = true; + if (slot.boundaries.first != slot.boundaries.second) { + slot_boundaries.push_back(slot.boundaries); + } } } -} + + void restore(struct llama_kv_cache & cache) { + if (do_restore) { + cache.head = old_state.head; + cache.n = old_state.n; + + if (cache.recurrent) { // recurrent models like Mamba or RWKV can't have a state partially erased + llama_kv_cache_seq_rm(cache, -1, -1, -1); + } else { + for (auto & slot : slot_boundaries) { + llama_kv_cache_seq_rm(cache, -1, slot.first, slot.second); + } + } + } + } +}; // // model loading and saving @@ -17307,7 +17311,7 @@ static int llama_decode_internal( lctx.n_queued_tokens += n_tokens_all; auto & kv_self = lctx.kv_self; - llama_kv_slot_restorer kv_slot_restorer; + llama_kv_slot_restorer kv_slot_restorer(kv_self); const int64_t n_embd = hparams.n_embd; const int64_t n_vocab = hparams.n_vocab; @@ -17392,9 +17396,11 @@ static int llama_decode_internal( kv_self.head = 0; } - if (!llama_kv_cache_find_slot(kv_self, ubatch, &kv_slot_restorer)) { + const auto slot = llama_kv_cache_find_slot(kv_self, ubatch); + if (!slot) { return 1; } + kv_slot_restorer.save(slot); if (!kv_self.recurrent) { // a heuristic, to avoid attending the full cache if it is not yet utilized @@ -17443,7 +17449,7 @@ static int llama_decode_internal( const auto compute_status = llama_graph_compute(lctx, gf, n_threads, threadpool); if (compute_status != GGML_STATUS_SUCCESS) { - llama_kv_cache_slot_restore(kv_slot_restorer, kv_self); + kv_slot_restorer.restore(kv_self); switch (compute_status) { case GGML_STATUS_ABORTED: return 2; From 0638c44821c51f65c00cd41154e928f4e6d67a41 Mon Sep 17 00:00:00 2001 From: Michael Podvitskiy Date: Sun, 10 Nov 2024 23:08:42 +0100 Subject: [PATCH 6/7] llama: updated comments --- src/llama.cpp | 27 +++++++++++++++++++-------- 1 file changed, 19 insertions(+), 8 deletions(-) diff --git a/src/llama.cpp b/src/llama.cpp index 48f1f254b35c5..97eee26a577bc 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -3502,13 +3502,10 @@ static bool llama_kv_cache_init( return true; } -// find an empty slot of size "n_tokens" in the cache -// updates the cache head -// Note: On success, it's important that cache.head points -// to the first cell of the slot. +// a structure holds information about the slot found in llama_kv_cache_find_slot struct llama_kv_cache_slot_info { - std::pair boundaries; - bool found = false; + std::pair boundaries; // slot boundaries [begin, end) + bool found = false; // the slot was found explicit llama_kv_cache_slot_info(bool found_) : found{found_} {} llama_kv_cache_slot_info(uint32_t begin, uint32_t end) : boundaries{begin, end}, found{true} {} @@ -3517,6 +3514,11 @@ struct llama_kv_cache_slot_info { }; static const llama_kv_cache_slot_info llama_kv_cache_slot_info_failed{false}; +// find an empty slot of size "n_tokens" in the cache +// updates the cache head +// returns a structure holding information about the slot found +// Note: On success, it's important that cache.head points +// to the first cell of the slot. static struct llama_kv_cache_slot_info llama_kv_cache_find_slot( struct llama_kv_cache & cache, const struct llama_ubatch & batch) { @@ -4019,7 +4021,9 @@ struct llama_kv_slot_restorer { uint32_t n = 0; } old_state; - std::vector> slot_boundaries; // for non-recurrent models only + // for non-recurrent models only + // list of slots to restore + std::vector> slot_boundaries; bool do_restore = false; @@ -4028,7 +4032,8 @@ struct llama_kv_slot_restorer { old_state.n = cache.n; } - void save(const struct llama_kv_cache_slot_info& slot) { + // saves a slot information for future restoration + void save(const struct llama_kv_cache_slot_info & slot) { if (slot) { do_restore = true; if (slot.boundaries.first != slot.boundaries.second) { @@ -4037,6 +4042,8 @@ struct llama_kv_slot_restorer { } } + // must be explicitly called to restore the kv_cache state + // and rollback changes from all llama_kv_cache_find_slot calls void restore(struct llama_kv_cache & cache) { if (do_restore) { cache.head = old_state.head; @@ -17236,6 +17243,7 @@ static void llama_output_reorder(struct llama_context * ctx) { } } +// returns the result of ggml_backend_sched_graph_compute_async execution static enum ggml_status llama_graph_compute( llama_context & lctx, ggml_cgraph * gf, @@ -17262,6 +17270,9 @@ static enum ggml_status llama_graph_compute( } // decode a batch of tokens by evaluating the transformer +// in case of unsuccessful decoding (error or warning), +// the kv_cache state will be returned to its original state +// (for non-recurrent models) or cleaned (for recurrent models) // // - lctx: llama context // - batch: batch to evaluate From 9ef5d089271550a99c5c578898b2ae612713da7e Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Wed, 13 Nov 2024 19:59:20 +0200 Subject: [PATCH 7/7] llama : add comments about KV cache state after error --- include/llama.h | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/include/llama.h b/include/llama.h index ccb48f73cef5c..5e742642eec8d 100644 --- a/include/llama.h +++ b/include/llama.h @@ -797,7 +797,7 @@ extern "C" { // Processes a batch of tokens with the ecoder part of the encoder-decoder model. // Stores the encoder output internally for later use by the decoder cross-attention layers. // 0 - success - // < 0 - error + // < 0 - error. the KV cache state is restored to the state before this call LLAMA_API int32_t llama_encode( struct llama_context * ctx, struct llama_batch batch); @@ -805,7 +805,7 @@ extern "C" { // Positive return values does not mean a fatal error, but rather a warning. // 0 - success // 1 - could not find a KV slot for the batch (try reducing the size of the batch or increase the context) - // < 0 - error + // < 0 - error. the KV cache state is restored to the state before this call LLAMA_API int32_t llama_decode( struct llama_context * ctx, struct llama_batch batch);