Skip to content

Commit a10b36c

Browse files
authored
llama : refactor kv cache guard (#12695)
* llama : refactor kv cache guard ggml-ci * cont : fix comment [no ci] * llama : fix kv_cache restore logic ggml-ci * context : simplify kv cache updates ggml-ci * cont : better name [no ci] * llama : fix llama_decode return code when could not find KV slot ggml-ci * context : change log err -> warn [no ci] * kv-cache : add comment + warning
1 parent 83a88bd commit a10b36c

File tree

4 files changed

+107
-127
lines changed

4 files changed

+107
-127
lines changed

examples/parallel/parallel.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,8 @@ int main(int argc, char ** argv) {
106106

107107
common_params params;
108108

109+
params.n_predict = 128;
110+
109111
if (!common_params_parse(argc, argv, params, LLAMA_EXAMPLE_PARALLEL)) {
110112
return 1;
111113
}

src/llama-context.cpp

Lines changed: 8 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -1201,33 +1201,7 @@ int llama_context::decode(llama_batch & inp_batch) {
12011201
const int64_t n_tokens_all = batch.n_tokens;
12021202
const int64_t n_embd = hparams.n_embd;
12031203

1204-
// TODO: remove this stuff
1205-
class batch_guard {
1206-
public:
1207-
batch_guard(llama_kv_cache_unified & kv_self) : kv_slot_restorer(kv_self) {
1208-
}
1209-
1210-
~batch_guard() {
1211-
if (!is_done) {
1212-
kv_slot_restorer.restore();
1213-
}
1214-
}
1215-
1216-
void done() {
1217-
is_done = true;
1218-
}
1219-
1220-
void save(const llama_kv_cache_slot_info & slot_info) {
1221-
kv_slot_restorer.save(slot_info);
1222-
}
1223-
1224-
private:
1225-
bool is_done = false;
1226-
1227-
llama_kv_slot_restorer kv_slot_restorer;
1228-
};
1229-
1230-
batch_guard bg(*kv_self);
1204+
llama_kv_cache_guard kv_guard(kv_self.get());
12311205

12321206
GGML_ASSERT((!batch.token && batch.embd) || (batch.token && !batch.embd)); // NOLINT
12331207

@@ -1280,6 +1254,9 @@ int llama_context::decode(llama_batch & inp_batch) {
12801254
return -2;
12811255
};
12821256

1257+
// handle any pending defrags/shifts
1258+
kv_self_update();
1259+
12831260
int64_t n_outputs_prev = 0;
12841261

12851262
while (sbatch.n_tokens > 0) {
@@ -1319,22 +1296,12 @@ int llama_context::decode(llama_batch & inp_batch) {
13191296

13201297
// find KV slot
13211298
{
1322-
kv_self_update();
1299+
if (!kv_self->find_slot(ubatch)) {
1300+
LLAMA_LOG_WARN("%s: failed to find KV cache slot for ubatch of size %d\n", __func__, ubatch.n_tokens);
13231301

1324-
// if we have enough unused cells before the current head ->
1325-
// better to start searching from the beginning of the cache, hoping to fill it
1326-
if (kv_self->head > kv_self->used + 2*ubatch.n_tokens) {
1327-
kv_self->head = 0;
1302+
return 1;
13281303
}
13291304

1330-
const auto slot_info = kv_self->find_slot(ubatch);
1331-
if (!slot_info) {
1332-
LLAMA_LOG_ERROR("%s: failed to prepare ubatch\n", __func__);
1333-
return -3;
1334-
}
1335-
1336-
bg.save(slot_info);
1337-
13381305
if (!kv_self->recurrent) {
13391306
// a heuristic, to avoid attending the full cache if it is not yet utilized
13401307
// after enough generations, the benefit from this heuristic disappears
@@ -1371,16 +1338,6 @@ int llama_context::decode(llama_batch & inp_batch) {
13711338
}
13721339
}
13731340

1374-
// update the kv ring buffer
1375-
{
1376-
kv_self->head += ubatch.n_tokens;
1377-
1378-
// Ensure kv cache head points to a valid index.
1379-
if (kv_self->head >= kv_self->size) {
1380-
kv_self->head = 0;
1381-
}
1382-
}
1383-
13841341
// plot the computation graph in dot format (for debugging purposes)
13851342
//if (n_past%100 == 0) {
13861343
// ggml_graph_dump_dot(gf, NULL, "llama.dot");
@@ -1467,7 +1424,7 @@ int llama_context::decode(llama_batch & inp_batch) {
14671424
}
14681425

14691426
// finalize the batch processing
1470-
bg.done();
1427+
kv_guard.commit();
14711428

14721429
// set output mappings
14731430
{

src/llama-kv-cache.cpp

Lines changed: 61 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,6 @@
1111
#include <map>
1212
#include <stdexcept>
1313

14-
static const llama_kv_cache_slot_info llama_kv_cache_slot_info_failed{false};
15-
1614
llama_kv_cache_unified::llama_kv_cache_unified(const llama_hparams & hparams, callbacks cbs) : hparams(hparams), cbs(std::move(cbs)) {
1715
}
1816

@@ -206,6 +204,8 @@ bool llama_kv_cache_unified::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos
206204
return false;
207205
}
208206
}
207+
208+
return true;
209209
}
210210

211211
for (uint32_t i = 0; i < size; ++i) {
@@ -446,16 +446,66 @@ void llama_kv_cache_unified::defrag() {
446446
}
447447
}
448448

449+
void llama_kv_cache_unified::restore() {
450+
if (pending.ranges.empty()) {
451+
return;
452+
}
453+
454+
// TODO: tmp - move to llama_kv_cache_recurrent
455+
if (recurrent) {
456+
seq_rm(-1, -1, -1);
457+
return;
458+
}
459+
460+
uint32_t new_head = size;
461+
462+
for (auto & range : pending.ranges) {
463+
for (uint32_t i = range.c0; i < range.c1; ++i) {
464+
cells[i].seq_id.clear();
465+
466+
// keep count of the number of used cells
467+
if (cells[i].pos >= 0) {
468+
used--;
469+
}
470+
471+
cells[i].pos = -1;
472+
cells[i].src = -1;
473+
}
474+
475+
new_head = std::min(new_head, range.c0);
476+
}
477+
478+
if (new_head != size && new_head < head) {
479+
head = new_head;
480+
}
481+
}
482+
483+
void llama_kv_cache_unified::commit() {
484+
if (pending.ranges.empty()) {
485+
LLAMA_LOG_WARN("%s: no pending KV cache updates to commit - might indicate a bug (ref: %s)\n",
486+
__func__, "https://github.com/ggml-org/llama.cpp/pull/12695");
487+
return;
488+
}
489+
490+
pending.ranges.clear();
491+
}
492+
449493
bool llama_kv_cache_unified::get_can_shift() const {
450494
return can_shift;
451495
}
452496

453-
llama_kv_cache_slot_info llama_kv_cache_unified::find_slot(
497+
bool llama_kv_cache_unified::find_slot(
454498
const llama_ubatch & ubatch) {
455499
const uint32_t n_tokens = ubatch.n_tokens;
456500
const uint32_t n_seqs = ubatch.n_seqs;
457501
const uint32_t n_seq_tokens = ubatch.n_seq_tokens;
458502

503+
// if we have enough unused cells before the current head ->
504+
// better to start searching from the beginning of the cache, hoping to fill it
505+
if (head > used + 2*ubatch.n_tokens) {
506+
head = 0;
507+
}
508+
459509
if (recurrent) {
460510
// For recurrent state architectures (like Mamba or RWKV),
461511
// each cache cell can store the state for a whole sequence.
@@ -477,7 +527,7 @@ llama_kv_cache_slot_info llama_kv_cache_unified::find_slot(
477527
// too big seq_id
478528
// TODO: would it be possible to resize the cache instead?
479529
LLAMA_LOG_ERROR("%s: seq_id=%d >= n_seq_max=%d Try using a bigger --parallel value\n", __func__, seq_id, size);
480-
return llama_kv_cache_slot_info_failed;
530+
return false;
481531
}
482532
if (j > 0) {
483533
llama_kv_cell & seq = cells[seq_id];
@@ -616,14 +666,14 @@ llama_kv_cache_slot_info llama_kv_cache_unified::find_slot(
616666
[](const llama_kv_cell& cell){ return !cell.is_empty(); });
617667

618668
// sanity check
619-
return llama_kv_cache_slot_info(n >= n_seqs);
669+
return n >= n_seqs;
620670
}
621671

622672
// otherwise, one cell per token.
623673

624674
if (n_tokens > size) {
625675
LLAMA_LOG_ERROR("%s: n_tokens = %d > size = %d\n", __func__, n_tokens, size);
626-
return llama_kv_cache_slot_info_failed;
676+
return false;
627677
}
628678

629679
uint32_t n_tested = 0;
@@ -651,7 +701,7 @@ llama_kv_cache_slot_info llama_kv_cache_unified::find_slot(
651701

652702
if (n_tested >= size) {
653703
//LLAMA_LOG_ERROR("%s: failed to find a slot for %d tokens\n", __func__, n_tokens);
654-
return llama_kv_cache_slot_info_failed;
704+
return false;
655705
}
656706
}
657707

@@ -668,7 +718,9 @@ llama_kv_cache_slot_info llama_kv_cache_unified::find_slot(
668718

669719
used += n_tokens;
670720

671-
return llama_kv_cache_slot_info(head, head + n_tokens);
721+
pending.ranges.push_back({head, head + n_tokens});
722+
723+
return true;
672724
}
673725

674726
uint32_t llama_kv_cache_unified::get_padding(const llama_cparams & cparams) const {
@@ -1033,6 +1085,7 @@ bool llama_kv_cache_unified::state_read_meta(llama_io_read_i & io, uint32_t cell
10331085
LLAMA_LOG_ERROR("%s: failed to find available cells in kv cache\n", __func__);
10341086
return false;
10351087
}
1088+
commit();
10361089

10371090
// DEBUG CHECK: kv.head should be our first cell, kv.head + cell_count - 1 should be our last cell (verify seq_id and pos values)
10381091
// Assume that this is one contiguous block of cells

0 commit comments

Comments
 (0)