@@ -3506,11 +3506,24 @@ static bool llama_kv_cache_init(
3506
3506
return true;
3507
3507
}
3508
3508
3509
+ // a structure holds information about the slot found in llama_kv_cache_find_slot
3510
+ struct llama_kv_cache_slot_info {
3511
+ std::pair<uint32_t, uint32_t> boundaries; // slot boundaries [begin, end)
3512
+ bool found = false; // the slot was found
3513
+
3514
+ explicit llama_kv_cache_slot_info(bool found_) : found{found_} {}
3515
+ llama_kv_cache_slot_info(uint32_t begin, uint32_t end) : boundaries{begin, end}, found{true} {}
3516
+
3517
+ operator bool() const { return found; }
3518
+ };
3519
+ static const llama_kv_cache_slot_info llama_kv_cache_slot_info_failed{false};
3520
+
3509
3521
// find an empty slot of size "n_tokens" in the cache
3510
3522
// updates the cache head
3523
+ // returns a structure holding information about the slot found
3511
3524
// Note: On success, it's important that cache.head points
3512
3525
// to the first cell of the slot.
3513
- static bool llama_kv_cache_find_slot(
3526
+ static struct llama_kv_cache_slot_info llama_kv_cache_find_slot(
3514
3527
struct llama_kv_cache & cache,
3515
3528
const struct llama_ubatch & batch) {
3516
3529
const uint32_t n_tokens = batch.n_tokens;
@@ -3538,7 +3551,7 @@ static bool llama_kv_cache_find_slot(
3538
3551
// too big seq_id
3539
3552
// TODO: would it be possible to resize the cache instead?
3540
3553
LLAMA_LOG_ERROR("%s: seq_id=%d >= n_seq_max=%d Try using a bigger --parallel value\n", __func__, seq_id, cache.size);
3541
- return false ;
3554
+ return llama_kv_cache_slot_info_failed ;
3542
3555
}
3543
3556
if (j > 0) {
3544
3557
llama_kv_cell & seq = cache.cells[seq_id];
@@ -3673,15 +3686,17 @@ static bool llama_kv_cache_find_slot(
3673
3686
// allow getting the range of used cells, from head to head + n
3674
3687
cache.head = min;
3675
3688
cache.n = max - min + 1;
3689
+ cache.used = std::count_if(cache.cells.begin(), cache.cells.end(),
3690
+ [](const llama_kv_cell& cell){ return !cell.is_empty(); });
3676
3691
3677
3692
// sanity check
3678
- return cache.n >= n_seqs;
3693
+ return llama_kv_cache_slot_info( cache.n >= n_seqs) ;
3679
3694
}
3680
3695
// otherwise, one cell per token.
3681
3696
3682
3697
if (n_tokens > cache.size) {
3683
3698
LLAMA_LOG_ERROR("%s: n_tokens=%d > cache.size=%d\n", __func__, n_tokens, cache.size);
3684
- return false ;
3699
+ return llama_kv_cache_slot_info_failed ;
3685
3700
}
3686
3701
3687
3702
uint32_t n_tested = 0;
@@ -3709,7 +3724,7 @@ static bool llama_kv_cache_find_slot(
3709
3724
3710
3725
if (n_tested >= cache.size) {
3711
3726
//LLAMA_LOG_ERROR("%s: failed to find a slot for %d tokens\n", __func__, n_tokens);
3712
- return false ;
3727
+ return llama_kv_cache_slot_info_failed ;
3713
3728
}
3714
3729
}
3715
3730
@@ -3726,7 +3741,7 @@ static bool llama_kv_cache_find_slot(
3726
3741
3727
3742
cache.used += n_tokens;
3728
3743
3729
- return true ;
3744
+ return llama_kv_cache_slot_info(cache.head, cache.head + n_tokens) ;
3730
3745
}
3731
3746
3732
3747
// find how many cells are currently in use
@@ -4002,6 +4017,53 @@ static uint32_t llama_kv_cache_get_padding(const struct llama_cparams & cparams)
4002
4017
return cparams.flash_attn ? 256u : 32u;
4003
4018
}
4004
4019
4020
+ // saves the kv_cache state for future recovery.
4021
+ // used to rollback llama_kv_cache_find_slot changes.
4022
+ struct llama_kv_slot_restorer {
4023
+ struct llama_kv_cache_state {
4024
+ uint32_t head = 0;
4025
+ uint32_t n = 0;
4026
+ } old_state;
4027
+
4028
+ // for non-recurrent models only
4029
+ // list of slots to restore
4030
+ std::vector<std::pair<uint32_t, uint32_t>> slot_boundaries;
4031
+
4032
+ bool do_restore = false;
4033
+
4034
+ explicit llama_kv_slot_restorer(const struct llama_kv_cache & cache) {
4035
+ old_state.head = cache.head;
4036
+ old_state.n = cache.n;
4037
+ }
4038
+
4039
+ // saves a slot information for future restoration
4040
+ void save(const struct llama_kv_cache_slot_info & slot) {
4041
+ if (slot) {
4042
+ do_restore = true;
4043
+ if (slot.boundaries.first != slot.boundaries.second) {
4044
+ slot_boundaries.push_back(slot.boundaries);
4045
+ }
4046
+ }
4047
+ }
4048
+
4049
+ // must be explicitly called to restore the kv_cache state
4050
+ // and rollback changes from all llama_kv_cache_find_slot calls
4051
+ void restore(struct llama_kv_cache & cache) {
4052
+ if (do_restore) {
4053
+ cache.head = old_state.head;
4054
+ cache.n = old_state.n;
4055
+
4056
+ if (cache.recurrent) { // recurrent models like Mamba or RWKV can't have a state partially erased
4057
+ llama_kv_cache_seq_rm(cache, -1, -1, -1);
4058
+ } else {
4059
+ for (auto & slot : slot_boundaries) {
4060
+ llama_kv_cache_seq_rm(cache, -1, slot.first, slot.second);
4061
+ }
4062
+ }
4063
+ }
4064
+ }
4065
+ };
4066
+
4005
4067
//
4006
4068
// model loading and saving
4007
4069
//
@@ -17189,7 +17251,8 @@ static void llama_output_reorder(struct llama_context * ctx) {
17189
17251
}
17190
17252
}
17191
17253
17192
- static void llama_graph_compute(
17254
+ // returns the result of ggml_backend_sched_graph_compute_async execution
17255
+ static enum ggml_status llama_graph_compute(
17193
17256
llama_context & lctx,
17194
17257
ggml_cgraph * gf,
17195
17258
int n_threads,
@@ -17204,15 +17267,20 @@ static void llama_graph_compute(
17204
17267
set_n_threads_fn.second(set_n_threads_fn.first, n_threads);
17205
17268
}
17206
17269
17207
- auto err = ggml_backend_sched_graph_compute_async(lctx.sched.get(), gf);
17208
- if (err != GGML_STATUS_SUCCESS) {
17209
- LLAMA_LOG_ERROR("%s: ggml_backend_sched_graph_compute_async failed with error %d\n", __func__, err );
17270
+ auto status = ggml_backend_sched_graph_compute_async(lctx.sched.get(), gf);
17271
+ if (status != GGML_STATUS_SUCCESS) {
17272
+ LLAMA_LOG_ERROR("%s: ggml_backend_sched_graph_compute_async failed with error %d\n", __func__, status );
17210
17273
}
17211
17274
17212
17275
// fprintf(stderr, "splits: %d\n", ggml_backend_sched_get_n_splits(lctx.sched));
17276
+
17277
+ return status;
17213
17278
}
17214
17279
17215
17280
// decode a batch of tokens by evaluating the transformer
17281
+ // in case of unsuccessful decoding (error or warning),
17282
+ // the kv_cache state will be returned to its original state
17283
+ // (for non-recurrent models) or cleaned (for recurrent models)
17216
17284
//
17217
17285
// - lctx: llama context
17218
17286
// - batch: batch to evaluate
@@ -17262,6 +17330,7 @@ static int llama_decode_internal(
17262
17330
lctx.n_queued_tokens += n_tokens_all;
17263
17331
17264
17332
auto & kv_self = lctx.kv_self;
17333
+ llama_kv_slot_restorer kv_slot_restorer(kv_self);
17265
17334
17266
17335
const int64_t n_embd = hparams.n_embd;
17267
17336
const int64_t n_vocab = hparams.n_vocab;
@@ -17346,9 +17415,11 @@ static int llama_decode_internal(
17346
17415
kv_self.head = 0;
17347
17416
}
17348
17417
17349
- if (!llama_kv_cache_find_slot(kv_self, ubatch)) {
17418
+ const auto slot = llama_kv_cache_find_slot(kv_self, ubatch);
17419
+ if (!slot) {
17350
17420
return 1;
17351
17421
}
17422
+ kv_slot_restorer.save(slot);
17352
17423
17353
17424
if (!kv_self.recurrent) {
17354
17425
// a heuristic, to avoid attending the full cache if it is not yet utilized
@@ -17395,7 +17466,19 @@ static int llama_decode_internal(
17395
17466
17396
17467
llama_set_inputs(lctx, ubatch);
17397
17468
17398
- llama_graph_compute(lctx, gf, n_threads, threadpool);
17469
+ const auto compute_status = llama_graph_compute(lctx, gf, n_threads, threadpool);
17470
+ if (compute_status != GGML_STATUS_SUCCESS) {
17471
+ kv_slot_restorer.restore(kv_self);
17472
+ switch (compute_status) {
17473
+ case GGML_STATUS_ABORTED:
17474
+ return 2;
17475
+ case GGML_STATUS_ALLOC_FAILED:
17476
+ return -2;
17477
+ case GGML_STATUS_FAILED:
17478
+ default:
17479
+ return -3;
17480
+ }
17481
+ }
17399
17482
17400
17483
// update the kv ring buffer
17401
17484
{
@@ -17632,7 +17715,18 @@ static int llama_encode_internal(
17632
17715
17633
17716
llama_set_inputs(lctx, ubatch);
17634
17717
17635
- llama_graph_compute(lctx, gf, n_threads, threadpool);
17718
+ const auto compute_status = llama_graph_compute(lctx, gf, n_threads, threadpool);
17719
+ switch (compute_status) {
17720
+ case GGML_STATUS_SUCCESS:
17721
+ break;
17722
+ case GGML_STATUS_ABORTED:
17723
+ return 2;
17724
+ case GGML_STATUS_ALLOC_FAILED:
17725
+ return -2;
17726
+ case GGML_STATUS_FAILED:
17727
+ default:
17728
+ return -3;
17729
+ }
17636
17730
17637
17731
// extract embeddings
17638
17732
if (embd) {
0 commit comments