11
11
#include < map>
12
12
#include < stdexcept>
13
13
14
- static const llama_kv_cache_slot_info llama_kv_cache_slot_info_failed{false };
15
-
16
14
llama_kv_cache_unified::llama_kv_cache_unified (const llama_hparams & hparams, callbacks cbs) : hparams(hparams), cbs(std::move(cbs)) {
17
15
}
18
16
@@ -206,6 +204,8 @@ bool llama_kv_cache_unified::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos
206
204
return false ;
207
205
}
208
206
}
207
+
208
+ return true ;
209
209
}
210
210
211
211
for (uint32_t i = 0 ; i < size; ++i) {
@@ -446,16 +446,66 @@ void llama_kv_cache_unified::defrag() {
446
446
}
447
447
}
448
448
449
+ void llama_kv_cache_unified::restore () {
450
+ if (pending.ranges .empty ()) {
451
+ return ;
452
+ }
453
+
454
+ // TODO: tmp - move to llama_kv_cache_recurrent
455
+ if (recurrent) {
456
+ seq_rm (-1 , -1 , -1 );
457
+ return ;
458
+ }
459
+
460
+ uint32_t new_head = size;
461
+
462
+ for (auto & range : pending.ranges ) {
463
+ for (uint32_t i = range.c0 ; i < range.c1 ; ++i) {
464
+ cells[i].seq_id .clear ();
465
+
466
+ // keep count of the number of used cells
467
+ if (cells[i].pos >= 0 ) {
468
+ used--;
469
+ }
470
+
471
+ cells[i].pos = -1 ;
472
+ cells[i].src = -1 ;
473
+ }
474
+
475
+ new_head = std::min (new_head, range.c0 );
476
+ }
477
+
478
+ if (new_head != size && new_head < head) {
479
+ head = new_head;
480
+ }
481
+ }
482
+
483
+ void llama_kv_cache_unified::commit () {
484
+ if (pending.ranges .empty ()) {
485
+ LLAMA_LOG_WARN (" %s: no pending KV cache updates to commit - might indicate a bug (ref: %s)\n " ,
486
+ __func__, " https://github.com/ggml-org/llama.cpp/pull/12695" );
487
+ return ;
488
+ }
489
+
490
+ pending.ranges .clear ();
491
+ }
492
+
449
493
bool llama_kv_cache_unified::get_can_shift () const {
450
494
return can_shift;
451
495
}
452
496
453
- llama_kv_cache_slot_info llama_kv_cache_unified::find_slot (
497
+ bool llama_kv_cache_unified::find_slot (
454
498
const llama_ubatch & ubatch) {
455
499
const uint32_t n_tokens = ubatch.n_tokens ;
456
500
const uint32_t n_seqs = ubatch.n_seqs ;
457
501
const uint32_t n_seq_tokens = ubatch.n_seq_tokens ;
458
502
503
+ // if we have enough unused cells before the current head ->
504
+ // better to start searching from the beginning of the cache, hoping to fill it
505
+ if (head > used + 2 *ubatch.n_tokens ) {
506
+ head = 0 ;
507
+ }
508
+
459
509
if (recurrent) {
460
510
// For recurrent state architectures (like Mamba or RWKV),
461
511
// each cache cell can store the state for a whole sequence.
@@ -477,7 +527,7 @@ llama_kv_cache_slot_info llama_kv_cache_unified::find_slot(
477
527
// too big seq_id
478
528
// TODO: would it be possible to resize the cache instead?
479
529
LLAMA_LOG_ERROR (" %s: seq_id=%d >= n_seq_max=%d Try using a bigger --parallel value\n " , __func__, seq_id, size);
480
- return llama_kv_cache_slot_info_failed ;
530
+ return false ;
481
531
}
482
532
if (j > 0 ) {
483
533
llama_kv_cell & seq = cells[seq_id];
@@ -616,14 +666,14 @@ llama_kv_cache_slot_info llama_kv_cache_unified::find_slot(
616
666
[](const llama_kv_cell& cell){ return !cell.is_empty (); });
617
667
618
668
// sanity check
619
- return llama_kv_cache_slot_info ( n >= n_seqs) ;
669
+ return n >= n_seqs;
620
670
}
621
671
622
672
// otherwise, one cell per token.
623
673
624
674
if (n_tokens > size) {
625
675
LLAMA_LOG_ERROR (" %s: n_tokens = %d > size = %d\n " , __func__, n_tokens, size);
626
- return llama_kv_cache_slot_info_failed ;
676
+ return false ;
627
677
}
628
678
629
679
uint32_t n_tested = 0 ;
@@ -651,7 +701,7 @@ llama_kv_cache_slot_info llama_kv_cache_unified::find_slot(
651
701
652
702
if (n_tested >= size) {
653
703
// LLAMA_LOG_ERROR("%s: failed to find a slot for %d tokens\n", __func__, n_tokens);
654
- return llama_kv_cache_slot_info_failed ;
704
+ return false ;
655
705
}
656
706
}
657
707
@@ -668,7 +718,9 @@ llama_kv_cache_slot_info llama_kv_cache_unified::find_slot(
668
718
669
719
used += n_tokens;
670
720
671
- return llama_kv_cache_slot_info (head, head + n_tokens);
721
+ pending.ranges .push_back ({head, head + n_tokens});
722
+
723
+ return true ;
672
724
}
673
725
674
726
uint32_t llama_kv_cache_unified::get_padding (const llama_cparams & cparams) const {
@@ -1033,6 +1085,7 @@ bool llama_kv_cache_unified::state_read_meta(llama_io_read_i & io, uint32_t cell
1033
1085
LLAMA_LOG_ERROR (" %s: failed to find available cells in kv cache\n " , __func__);
1034
1086
return false ;
1035
1087
}
1088
+ commit ();
1036
1089
1037
1090
// DEBUG CHECK: kv.head should be our first cell, kv.head + cell_count - 1 should be our last cell (verify seq_id and pos values)
1038
1091
// Assume that this is one contiguous block of cells
0 commit comments