Skip to content

Commit 12d0188

Browse files
authored
kv-cache : refactor + add llama_memory_state_i (#13746)
* kv-cache : simplify the "struct llama_kv_cache" interface ggml-ci * kv-cache : revert the (n_swa + n_ubatch) change (for next PR) ggml-ci * kv-cache : some comments ggml-ci * context : fix graph reserve for multiple sequences ggml-ci * kv-cache : fix typo [no ci] * kv-cache : fix find_slot() logic for free slots ggml-ci * llama : add TODO for deprecating the defrag API in the future * kv-cache : improve find_slot() using min/max seq pos info ggml-ci * llama : handle aborts and compute errors ggml-ci * memory : extract state into llama_memory_state ggml-ci * kv-cache : add comments ggml-ci * server : update batching logic to reset n_batch on successful decode * server : upon full re-processing, remove the sequence from the cache * kv-cache : add TODO for doing split_equal when split_simple fails ggml-ci
1 parent eb39499 commit 12d0188

File tree

14 files changed

+1296
-647
lines changed

14 files changed

+1296
-647
lines changed

examples/parallel/parallel.cpp

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -362,15 +362,17 @@ int main(int argc, char ** argv) {
362362
// process in chunks of params.n_batch
363363
int32_t n_batch = params.n_batch;
364364

365-
for (int32_t i = 0; i < (int32_t) batch.n_tokens; i += n_batch) {
365+
int32_t i_next = 0;
366+
367+
for (int32_t i = 0; i < batch.n_tokens; i = i_next) {
366368
// experiment: process in powers of 2
367369
//if (i + n_batch > (int32_t) batch.n_tokens && n_batch > 32) {
368370
// n_batch /= 2;
369371
// i -= n_batch;
370372
// continue;
371373
//}
372374

373-
const int32_t n_tokens = std::min(n_batch, (int32_t) (batch.n_tokens - i));
375+
const int32_t n_tokens = std::min(n_batch, batch.n_tokens - i);
374376

375377
llama_batch batch_view = {
376378
n_tokens,
@@ -396,13 +398,18 @@ int main(int argc, char ** argv) {
396398

397399
// retry with half the batch size to try to find a free slot in the KV cache
398400
n_batch /= 2;
399-
i -= n_batch;
400401

401402
continue;
402403
}
403404

404405
LOG_DBG("%s : decoded batch of %d tokens\n", __func__, n_tokens);
405406

407+
// move the head of the batch forward with the number of tokens we just processed
408+
i_next = i + n_tokens;
409+
410+
// on successful decode, restore the original batch size
411+
n_batch = params.n_batch;
412+
406413
for (auto & client : clients) {
407414
if (client.i_batch < (int) i || client.i_batch >= (int) (i + n_tokens)) {
408415
continue;

include/llama.h

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -259,9 +259,9 @@ extern "C" {
259259
llama_token * token;
260260
float * embd;
261261
llama_pos * pos;
262-
int32_t * n_seq_id;
263-
llama_seq_id ** seq_id;
264-
int8_t * logits; // TODO: rename this to "output"
262+
int32_t * n_seq_id; // TODO: remove, should belong to only 1 sequence
263+
llama_seq_id ** seq_id; // TODO: become llama_seq_id * seq_id;
264+
int8_t * logits; // TODO: rename this to "output"
265265
} llama_batch;
266266

267267
enum llama_model_kv_override_type {
@@ -677,12 +677,14 @@ extern "C" {
677677

678678
// Returns the smallest position present in the KV cache for the specified sequence
679679
// This is typically non-zero only for SWA caches
680+
// Note that all positions in the range [pos_min, pos_max] are guaranteed to be present in the KV cache
680681
// Return -1 if the sequence is empty
681682
LLAMA_API llama_pos llama_kv_self_seq_pos_min(
682683
struct llama_context * ctx,
683684
llama_seq_id seq_id);
684685

685686
// Returns the largest position present in the KV cache for the specified sequence
687+
// Note that all positions in the range [pos_min, pos_max] are guaranteed to be present in the KV cache
686688
// Return -1 if the sequence is empty
687689
LLAMA_API llama_pos llama_kv_self_seq_pos_max(
688690
struct llama_context * ctx,
@@ -692,12 +694,14 @@ extern "C" {
692694
// This will be applied:
693695
// - lazily on next llama_decode()
694696
// - explicitly with llama_kv_self_update()
697+
// TODO: deprecate and always update the cache lazily [TAG: API_KV_NO_DEFRAG]
695698
LLAMA_API void llama_kv_self_defrag(struct llama_context * ctx);
696699

697700
// Check if the context supports KV cache shifting
698701
LLAMA_API bool llama_kv_self_can_shift(const struct llama_context * ctx);
699702

700703
// Apply the KV cache updates (such as K-shifts, defragmentation, etc.)
704+
// TODO: deprecate and always update the cache lazily [TAG: API_KV_NO_DEFRAG]
701705
LLAMA_API void llama_kv_self_update(struct llama_context * ctx);
702706

703707
//

src/llama-batch.cpp

Lines changed: 19 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -15,24 +15,31 @@ llama_ubatch llama_sbatch::reserve_ubatch(size_t n_ubatch, bool has_embd) {
1515
break;
1616
}
1717
}
18-
ubatch_token.resize(!has_embd ? n_ubatch : 0);
19-
ubatch_embd.resize(has_embd ? n_embd * n_ubatch : 0);
20-
ubatch_pos.resize(n_ubatch);
21-
ubatch_n_seq_id.resize(n_ubatch);
22-
ubatch_seq_id.resize(n_ubatch);
23-
ubatch_output.resize(n_ubatch);
18+
19+
udatas.push_back({});
20+
21+
auto & udata = udatas.back();
22+
23+
udata.token.resize(!has_embd ? n_ubatch : 0);
24+
udata.embd.resize(has_embd ? n_embd * n_ubatch : 0);
25+
udata.pos.resize(n_ubatch);
26+
udata.n_seq_id.resize(n_ubatch);
27+
udata.seq_id.resize(n_ubatch);
28+
udata.output.resize(n_ubatch);
29+
2430
llama_ubatch ubatch = {
2531
/*equal_seqs =*/ true,
2632
/*n_tokens =*/ 0,
2733
/*n_seq_tokens =*/ 0,
2834
/*n_seqs =*/ 0,
29-
/*token =*/ !has_embd ? ubatch_token.data() : nullptr,
30-
/*embd =*/ has_embd ? ubatch_embd.data() : nullptr,
31-
/*pos =*/ ubatch_pos.data(),
32-
/*n_seq_id =*/ ubatch_n_seq_id.data(),
33-
/*seq_id =*/ ubatch_seq_id.data(),
34-
/*output =*/ ubatch_output.data(),
35+
/*token =*/ !has_embd ? udata.token.data() : nullptr,
36+
/*embd =*/ has_embd ? udata.embd.data() : nullptr,
37+
/*pos =*/ udata.pos.data(),
38+
/*n_seq_id =*/ udata.n_seq_id.data(),
39+
/*seq_id =*/ udata.seq_id.data(),
40+
/*output =*/ udata.output.data(),
3541
};
42+
3643
return ubatch;
3744
}
3845

src/llama-batch.h

Lines changed: 15 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -11,15 +11,15 @@ struct llama_ubatch {
1111
bool equal_seqs;
1212
// TODO: whole_seqs for embeddings?
1313

14-
uint32_t n_tokens; // total tokens (n_seq_tokens * n_seqs)
14+
uint32_t n_tokens; // total tokens (n_seq_tokens * n_seqs)
1515
uint32_t n_seq_tokens; // tokens per sequence
1616
uint32_t n_seqs;
1717

1818
llama_token * token; // [n_tokens]
1919
float * embd; // [n_embd, n_tokens]
2020
llama_pos * pos; // [n_tokens]
21-
int32_t * n_seq_id; // [n_seqs]
22-
llama_seq_id ** seq_id; // [n_seqs]
21+
int32_t * n_seq_id; // [n_seqs] // TODO: remove, should belong to only 1 sequence
22+
llama_seq_id ** seq_id; // [n_seqs] // TODO: become llama_seq_id * seq_id;
2323
int8_t * output; // [n_tokens]
2424
};
2525

@@ -49,13 +49,18 @@ struct llama_sbatch {
4949

5050
const llama_batch * batch = nullptr;
5151

52-
// buffers for the ubatch
53-
std::vector<llama_token> ubatch_token;
54-
std::vector<float> ubatch_embd;
55-
std::vector<llama_pos> ubatch_pos;
56-
std::vector<int32_t> ubatch_n_seq_id;
57-
std::vector<llama_seq_id *> ubatch_seq_id;
58-
std::vector<int8_t> ubatch_output;
52+
// buffers for the ubatches
53+
// TODO: very hacky, this needs a complete rework
54+
struct ubatch_data {
55+
std::vector<llama_token> token;
56+
std::vector<float> embd;
57+
std::vector<llama_pos> pos;
58+
std::vector<int32_t> n_seq_id;
59+
std::vector<llama_seq_id *> seq_id;
60+
std::vector<int8_t> output;
61+
};
62+
63+
std::vector<ubatch_data> udatas;
5964

6065
llama_ubatch reserve_ubatch(size_t n_ubatch, bool has_embd = false);
6166

0 commit comments

Comments
 (0)