Skip to content

Commit 0b55716

Browse files
Xarbirusggerganov
authored andcommitted
llama : propagate the results of graph_compute (ggml-org#9525)
* llama: propagating the results of `graph_compute` to the user interface * llama: reverting kv_cache in case of failed compute * llama: `llama_kv_cache_state` was removed, only the result of `llama_graph_compute` is returned * llama: restore a kv_cache in case of failed computation * llama: correct reverting of the entire batch. also updates `llama_kv_cache_find_slot`, will correctly count the number of `used` cells for recurrent models * llama: updated comments * llama : add comments about KV cache state after error --------- Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
1 parent 7eee12b commit 0b55716

File tree

2 files changed

+109
-15
lines changed

2 files changed

+109
-15
lines changed

include/llama.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -797,15 +797,15 @@ extern "C" {
797797
// Processes a batch of tokens with the ecoder part of the encoder-decoder model.
798798
// Stores the encoder output internally for later use by the decoder cross-attention layers.
799799
// 0 - success
800-
// < 0 - error
800+
// < 0 - error. the KV cache state is restored to the state before this call
801801
LLAMA_API int32_t llama_encode(
802802
struct llama_context * ctx,
803803
struct llama_batch batch);
804804

805805
// Positive return values does not mean a fatal error, but rather a warning.
806806
// 0 - success
807807
// 1 - could not find a KV slot for the batch (try reducing the size of the batch or increase the context)
808-
// < 0 - error
808+
// < 0 - error. the KV cache state is restored to the state before this call
809809
LLAMA_API int32_t llama_decode(
810810
struct llama_context * ctx,
811811
struct llama_batch batch);

src/llama.cpp

Lines changed: 107 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -3506,11 +3506,24 @@ static bool llama_kv_cache_init(
35063506
return true;
35073507
}
35083508

3509+
// a structure holds information about the slot found in llama_kv_cache_find_slot
3510+
struct llama_kv_cache_slot_info {
3511+
std::pair<uint32_t, uint32_t> boundaries; // slot boundaries [begin, end)
3512+
bool found = false; // the slot was found
3513+
3514+
explicit llama_kv_cache_slot_info(bool found_) : found{found_} {}
3515+
llama_kv_cache_slot_info(uint32_t begin, uint32_t end) : boundaries{begin, end}, found{true} {}
3516+
3517+
operator bool() const { return found; }
3518+
};
3519+
static const llama_kv_cache_slot_info llama_kv_cache_slot_info_failed{false};
3520+
35093521
// find an empty slot of size "n_tokens" in the cache
35103522
// updates the cache head
3523+
// returns a structure holding information about the slot found
35113524
// Note: On success, it's important that cache.head points
35123525
// to the first cell of the slot.
3513-
static bool llama_kv_cache_find_slot(
3526+
static struct llama_kv_cache_slot_info llama_kv_cache_find_slot(
35143527
struct llama_kv_cache & cache,
35153528
const struct llama_ubatch & batch) {
35163529
const uint32_t n_tokens = batch.n_tokens;
@@ -3538,7 +3551,7 @@ static bool llama_kv_cache_find_slot(
35383551
// too big seq_id
35393552
// TODO: would it be possible to resize the cache instead?
35403553
LLAMA_LOG_ERROR("%s: seq_id=%d >= n_seq_max=%d Try using a bigger --parallel value\n", __func__, seq_id, cache.size);
3541-
return false;
3554+
return llama_kv_cache_slot_info_failed;
35423555
}
35433556
if (j > 0) {
35443557
llama_kv_cell & seq = cache.cells[seq_id];
@@ -3673,15 +3686,17 @@ static bool llama_kv_cache_find_slot(
36733686
// allow getting the range of used cells, from head to head + n
36743687
cache.head = min;
36753688
cache.n = max - min + 1;
3689+
cache.used = std::count_if(cache.cells.begin(), cache.cells.end(),
3690+
[](const llama_kv_cell& cell){ return !cell.is_empty(); });
36763691

36773692
// sanity check
3678-
return cache.n >= n_seqs;
3693+
return llama_kv_cache_slot_info(cache.n >= n_seqs);
36793694
}
36803695
// otherwise, one cell per token.
36813696

36823697
if (n_tokens > cache.size) {
36833698
LLAMA_LOG_ERROR("%s: n_tokens=%d > cache.size=%d\n", __func__, n_tokens, cache.size);
3684-
return false;
3699+
return llama_kv_cache_slot_info_failed;
36853700
}
36863701

36873702
uint32_t n_tested = 0;
@@ -3709,7 +3724,7 @@ static bool llama_kv_cache_find_slot(
37093724

37103725
if (n_tested >= cache.size) {
37113726
//LLAMA_LOG_ERROR("%s: failed to find a slot for %d tokens\n", __func__, n_tokens);
3712-
return false;
3727+
return llama_kv_cache_slot_info_failed;
37133728
}
37143729
}
37153730

@@ -3726,7 +3741,7 @@ static bool llama_kv_cache_find_slot(
37263741

37273742
cache.used += n_tokens;
37283743

3729-
return true;
3744+
return llama_kv_cache_slot_info(cache.head, cache.head + n_tokens);
37303745
}
37313746

37323747
// find how many cells are currently in use
@@ -4002,6 +4017,53 @@ static uint32_t llama_kv_cache_get_padding(const struct llama_cparams & cparams)
40024017
return cparams.flash_attn ? 256u : 32u;
40034018
}
40044019

4020+
// saves the kv_cache state for future recovery.
4021+
// used to rollback llama_kv_cache_find_slot changes.
4022+
struct llama_kv_slot_restorer {
4023+
struct llama_kv_cache_state {
4024+
uint32_t head = 0;
4025+
uint32_t n = 0;
4026+
} old_state;
4027+
4028+
// for non-recurrent models only
4029+
// list of slots to restore
4030+
std::vector<std::pair<uint32_t, uint32_t>> slot_boundaries;
4031+
4032+
bool do_restore = false;
4033+
4034+
explicit llama_kv_slot_restorer(const struct llama_kv_cache & cache) {
4035+
old_state.head = cache.head;
4036+
old_state.n = cache.n;
4037+
}
4038+
4039+
// saves a slot information for future restoration
4040+
void save(const struct llama_kv_cache_slot_info & slot) {
4041+
if (slot) {
4042+
do_restore = true;
4043+
if (slot.boundaries.first != slot.boundaries.second) {
4044+
slot_boundaries.push_back(slot.boundaries);
4045+
}
4046+
}
4047+
}
4048+
4049+
// must be explicitly called to restore the kv_cache state
4050+
// and rollback changes from all llama_kv_cache_find_slot calls
4051+
void restore(struct llama_kv_cache & cache) {
4052+
if (do_restore) {
4053+
cache.head = old_state.head;
4054+
cache.n = old_state.n;
4055+
4056+
if (cache.recurrent) { // recurrent models like Mamba or RWKV can't have a state partially erased
4057+
llama_kv_cache_seq_rm(cache, -1, -1, -1);
4058+
} else {
4059+
for (auto & slot : slot_boundaries) {
4060+
llama_kv_cache_seq_rm(cache, -1, slot.first, slot.second);
4061+
}
4062+
}
4063+
}
4064+
}
4065+
};
4066+
40054067
//
40064068
// model loading and saving
40074069
//
@@ -17189,7 +17251,8 @@ static void llama_output_reorder(struct llama_context * ctx) {
1718917251
}
1719017252
}
1719117253

17192-
static void llama_graph_compute(
17254+
// returns the result of ggml_backend_sched_graph_compute_async execution
17255+
static enum ggml_status llama_graph_compute(
1719317256
llama_context & lctx,
1719417257
ggml_cgraph * gf,
1719517258
int n_threads,
@@ -17204,15 +17267,20 @@ static void llama_graph_compute(
1720417267
set_n_threads_fn.second(set_n_threads_fn.first, n_threads);
1720517268
}
1720617269

17207-
auto err = ggml_backend_sched_graph_compute_async(lctx.sched.get(), gf);
17208-
if (err != GGML_STATUS_SUCCESS) {
17209-
LLAMA_LOG_ERROR("%s: ggml_backend_sched_graph_compute_async failed with error %d\n", __func__, err);
17270+
auto status = ggml_backend_sched_graph_compute_async(lctx.sched.get(), gf);
17271+
if (status != GGML_STATUS_SUCCESS) {
17272+
LLAMA_LOG_ERROR("%s: ggml_backend_sched_graph_compute_async failed with error %d\n", __func__, status);
1721017273
}
1721117274

1721217275
// fprintf(stderr, "splits: %d\n", ggml_backend_sched_get_n_splits(lctx.sched));
17276+
17277+
return status;
1721317278
}
1721417279

1721517280
// decode a batch of tokens by evaluating the transformer
17281+
// in case of unsuccessful decoding (error or warning),
17282+
// the kv_cache state will be returned to its original state
17283+
// (for non-recurrent models) or cleaned (for recurrent models)
1721617284
//
1721717285
// - lctx: llama context
1721817286
// - batch: batch to evaluate
@@ -17262,6 +17330,7 @@ static int llama_decode_internal(
1726217330
lctx.n_queued_tokens += n_tokens_all;
1726317331

1726417332
auto & kv_self = lctx.kv_self;
17333+
llama_kv_slot_restorer kv_slot_restorer(kv_self);
1726517334

1726617335
const int64_t n_embd = hparams.n_embd;
1726717336
const int64_t n_vocab = hparams.n_vocab;
@@ -17346,9 +17415,11 @@ static int llama_decode_internal(
1734617415
kv_self.head = 0;
1734717416
}
1734817417

17349-
if (!llama_kv_cache_find_slot(kv_self, ubatch)) {
17418+
const auto slot = llama_kv_cache_find_slot(kv_self, ubatch);
17419+
if (!slot) {
1735017420
return 1;
1735117421
}
17422+
kv_slot_restorer.save(slot);
1735217423

1735317424
if (!kv_self.recurrent) {
1735417425
// a heuristic, to avoid attending the full cache if it is not yet utilized
@@ -17395,7 +17466,19 @@ static int llama_decode_internal(
1739517466

1739617467
llama_set_inputs(lctx, ubatch);
1739717468

17398-
llama_graph_compute(lctx, gf, n_threads, threadpool);
17469+
const auto compute_status = llama_graph_compute(lctx, gf, n_threads, threadpool);
17470+
if (compute_status != GGML_STATUS_SUCCESS) {
17471+
kv_slot_restorer.restore(kv_self);
17472+
switch (compute_status) {
17473+
case GGML_STATUS_ABORTED:
17474+
return 2;
17475+
case GGML_STATUS_ALLOC_FAILED:
17476+
return -2;
17477+
case GGML_STATUS_FAILED:
17478+
default:
17479+
return -3;
17480+
}
17481+
}
1739917482

1740017483
// update the kv ring buffer
1740117484
{
@@ -17632,7 +17715,18 @@ static int llama_encode_internal(
1763217715

1763317716
llama_set_inputs(lctx, ubatch);
1763417717

17635-
llama_graph_compute(lctx, gf, n_threads, threadpool);
17718+
const auto compute_status = llama_graph_compute(lctx, gf, n_threads, threadpool);
17719+
switch (compute_status) {
17720+
case GGML_STATUS_SUCCESS:
17721+
break;
17722+
case GGML_STATUS_ABORTED:
17723+
return 2;
17724+
case GGML_STATUS_ALLOC_FAILED:
17725+
return -2;
17726+
case GGML_STATUS_FAILED:
17727+
default:
17728+
return -3;
17729+
}
1763617730

1763717731
// extract embeddings
1763817732
if (embd) {

0 commit comments

Comments
 (0)