Skip to content

Commit 04d8fb4

Browse files
committed
kv-cache : rework interface (wip) [no ci]
1 parent dd394a6 commit 04d8fb4

File tree

9 files changed

+613
-416
lines changed

9 files changed

+613
-416
lines changed

include/llama.h

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -259,9 +259,9 @@ extern "C" {
259259
llama_token * token;
260260
float * embd;
261261
llama_pos * pos;
262-
int32_t * n_seq_id;
263-
llama_seq_id ** seq_id;
264-
int8_t * logits; // TODO: rename this to "output"
262+
int32_t * n_seq_id; // TODO: remove, should belong to only 1 sequence
263+
llama_seq_id ** seq_id; // TODO: become llama_seq_id * seq_id;
264+
int8_t * logits; // TODO: rename this to "output"
265265
} llama_batch;
266266

267267
enum llama_model_kv_override_type {

src/llama-batch.cpp

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,15 @@
44
#include <cstring>
55
#include <algorithm>
66

7+
void llama_ubatch::update() {
8+
for (uint32_t i = 0; i < n_tokens; ++i) {
9+
const llama_seq_id s = seq_id[i][0];
10+
11+
seq_pos_min[s] = seq_pos_min[s] == -1 ? pos[i] : std::min(seq_pos_min[s], pos[i]);
12+
seq_pos_max[s] = seq_pos_max[s] == -1 ? pos[i] : std::max(seq_pos_max[s], pos[i]);
13+
}
14+
}
15+
716
llama_ubatch llama_sbatch::reserve_ubatch(size_t n_ubatch, bool has_embd) {
817
// clear empty sequences
918
// the previous ubatch is assumed to be gone,
@@ -26,6 +35,8 @@ llama_ubatch llama_sbatch::reserve_ubatch(size_t n_ubatch, bool has_embd) {
2635
/*n_tokens =*/ 0,
2736
/*n_seq_tokens =*/ 0,
2837
/*n_seqs =*/ 0,
38+
/*seq_pos_min =*/ {-1},
39+
/*seq_pos_max =*/ {-1},
2940
/*token =*/ !has_embd ? ubatch_token.data() : nullptr,
3041
/*embd =*/ has_embd ? ubatch_embd.data() : nullptr,
3142
/*pos =*/ ubatch_pos.data(),
@@ -148,6 +159,7 @@ llama_ubatch llama_sbatch::split_simple(size_t n_ubatch) {
148159
GGML_ASSERT(seq.size() == 1 && s.n_seq_id == 0); // don't mix with other splits
149160
add_seq_to_ubatch(ubatch, s, length);
150161
}
162+
ubatch.update();
151163
return ubatch;
152164
}
153165

@@ -175,6 +187,7 @@ llama_ubatch llama_sbatch::split_equal(size_t n_ubatch) {
175187
if (length + n_tokens_in_ubatch > n_ubatch) { break; }
176188
}
177189
}
190+
ubatch.update();
178191
return ubatch;
179192
}
180193

@@ -187,6 +200,7 @@ llama_ubatch llama_sbatch::split_seq(size_t n_ubatch) {
187200
GGML_ASSERT(s.n_seq_id > 0); // should not be mixed with simple splits
188201
add_seq_to_ubatch(ubatch, s, length);
189202
}
203+
ubatch.update();
190204
return ubatch;
191205
}
192206

src/llama-batch.h

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,25 +1,31 @@
11
#pragma once
22

33
#include "llama.h"
4+
#include "llama-cparams.h"
45

56
#include <array>
67
#include <vector>
78

89
// very similar to llama_batch,
910
// but has more metadata about sequences
1011
struct llama_ubatch {
12+
void update();
13+
1114
bool equal_seqs;
1215
// TODO: whole_seqs for embeddings?
1316

14-
uint32_t n_tokens; // total tokens (n_seq_tokens * n_seqs)
17+
uint32_t n_tokens; // total tokens (n_seq_tokens * n_seqs)
1518
uint32_t n_seq_tokens; // tokens per sequence
1619
uint32_t n_seqs;
1720

21+
llama_pos seq_pos_min[LLAMA_MAX_PARALLEL_SEQUENCES]; // min position of each sequence
22+
llama_pos seq_pos_max[LLAMA_MAX_PARALLEL_SEQUENCES]; // max position of each sequence
23+
1824
llama_token * token; // [n_tokens]
1925
float * embd; // [n_embd, n_tokens]
2026
llama_pos * pos; // [n_tokens]
21-
int32_t * n_seq_id; // [n_seqs]
22-
llama_seq_id ** seq_id; // [n_seqs]
27+
int32_t * n_seq_id; // [n_seqs] // TODO: remove, should belong to only 1 sequence
28+
llama_seq_id ** seq_id; // [n_seqs] // TODO: become llama_seq_id * seq_id;
2329
int8_t * output; // [n_tokens]
2430
};
2531

src/llama-context.cpp

Lines changed: 25 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -285,7 +285,7 @@ llama_context::llama_context(
285285

286286
// reserve pp graph first so that buffers are only allocated once
287287
{
288-
llama_ubatch ubatch_pp = { true, n_tokens, n_tokens / n_seqs, n_seqs, &token, nullptr, nullptr, nullptr, nullptr, nullptr};
288+
llama_ubatch ubatch_pp = { true, n_tokens, n_tokens / n_seqs, n_seqs, {-1}, {-1}, &token, nullptr, nullptr, nullptr, nullptr, nullptr};
289289

290290
// max number of outputs
291291
n_outputs = ubatch_pp.n_tokens;
@@ -305,7 +305,7 @@ llama_context::llama_context(
305305

306306
// reserve with tg graph to get the number of splits and nodes
307307
{
308-
llama_ubatch ubatch_tg = { true, 1, 1, n_seqs, &token, nullptr, nullptr, nullptr, nullptr, nullptr};
308+
llama_ubatch ubatch_tg = { true, 1, 1, n_seqs, {-1}, {-1}, &token, nullptr, nullptr, nullptr, nullptr, nullptr};
309309

310310
n_outputs = ubatch_tg.n_tokens;
311311

@@ -324,7 +324,7 @@ llama_context::llama_context(
324324

325325
// reserve again with pp graph to avoid ggml-alloc reallocations during inference
326326
{
327-
llama_ubatch ubatch_pp = { true, n_tokens, n_tokens / n_seqs, n_seqs, &token, nullptr, nullptr, nullptr, nullptr, nullptr};
327+
llama_ubatch ubatch_pp = { true, n_tokens, n_tokens / n_seqs, n_seqs, {-1}, {-1}, &token, nullptr, nullptr, nullptr, nullptr, nullptr};
328328

329329
n_outputs = ubatch_pp.n_tokens;
330330

@@ -472,7 +472,7 @@ void llama_context::kv_self_update() {
472472
kv_self->set_full();
473473

474474
llama_token token = model.vocab.token_bos(); // not actually used by llama_build_graph, but required to choose between token and embedding inputs graph
475-
llama_ubatch ubatch = { true, n_tokens, n_tokens / n_seqs, n_seqs, &token, nullptr, nullptr, nullptr, nullptr, nullptr};
475+
llama_ubatch ubatch = { true, n_tokens, n_tokens / n_seqs, n_seqs, {-1}, {-1}, &token, nullptr, nullptr, nullptr, nullptr, nullptr};
476476

477477
auto * gf = graph_init();
478478
graph_build(ctx_compute.get(), gf, ubatch, LLM_GRAPH_TYPE_DEFAULT);
@@ -731,8 +731,6 @@ int llama_context::encode(llama_batch & inp_batch) {
731731

732732
n_outputs = n_tokens;
733733

734-
//batch_manager->prepare(ubatch);
735-
736734
ggml_backend_sched_reset(sched.get());
737735
ggml_backend_sched_set_eval_callback(sched.get(), cparams.cb_eval, cparams.cb_eval_user_data);
738736

@@ -883,8 +881,6 @@ int llama_context::decode(llama_batch & inp_batch) {
883881
const int64_t n_tokens_all = batch.n_tokens;
884882
const int64_t n_embd = hparams.n_embd;
885883

886-
llama_kv_cache_guard kv_guard(kv_self);
887-
888884
GGML_ASSERT((!batch.token && batch.embd) || (batch.token && !batch.embd)); // NOLINT
889885

890886
if (batch.token) {
@@ -924,21 +920,24 @@ int llama_context::decode(llama_batch & inp_batch) {
924920
n_outputs_all = 1;
925921
}
926922

927-
llama_sbatch sbatch = kv_self->sbatch_init(batch, /* logits_all */ n_outputs_all == n_tokens_all);
923+
// handle any pending defrags/shifts
924+
kv_self_update();
925+
926+
auto decode_state = kv_self->init(batch, cparams.n_ubatch, embd_pooled, /* logits_all */ n_outputs_all == n_tokens_all);
927+
if (!decode_state) {
928+
return 1;
929+
}
928930

929931
// reserve output buffer
930932
if (output_reserve(n_outputs_all) < n_outputs_all) {
931933
LLAMA_LOG_ERROR("%s: could not reserve space for batch with %" PRId64 " outputs\n", __func__, n_outputs_all);
932934
return -2;
933935
};
934936

935-
// handle any pending defrags/shifts
936-
kv_self_update();
937-
938937
int64_t n_outputs_prev = 0;
939938

940-
while (sbatch.n_tokens > 0) {
941-
llama_ubatch ubatch = kv_self->ubatch_next(sbatch, cparams.n_ubatch, embd_pooled);
939+
while (const auto * ubatch_ptr = decode_state->next()) {
940+
const auto & ubatch = *ubatch_ptr;
942941

943942
// count the outputs in this u_batch
944943
{
@@ -957,11 +956,6 @@ int llama_context::decode(llama_batch & inp_batch) {
957956
n_outputs = n_outputs_new;
958957
}
959958

960-
// find KV slot
961-
if (!kv_self->find_slot(ubatch)) {
962-
return 1;
963-
}
964-
965959
ggml_backend_sched_reset(sched.get());
966960
ggml_backend_sched_set_eval_callback(sched.get(), cparams.cb_eval, cparams.cb_eval_user_data);
967961

@@ -1072,17 +1066,14 @@ int llama_context::decode(llama_batch & inp_batch) {
10721066
n_outputs_prev += n_outputs;
10731067
}
10741068

1075-
// finalize the batch processing
1076-
kv_guard.commit();
1077-
10781069
// set to total number of outputs in the batch, for use in llama_get_logits_ith
10791070
n_outputs = n_outputs_all;
10801071

10811072
// set output mappings
10821073
{
10831074
bool sorted_output = true;
10841075

1085-
auto & out_ids = sbatch.out_ids;
1076+
auto & out_ids = decode_state->out_ids();
10861077

10871078
GGML_ASSERT(out_ids.size() == (size_t) n_outputs_all);
10881079

@@ -1939,7 +1930,6 @@ void llama_context::opt_epoch_iter(
19391930
llama_kv_cache * kv_self = static_cast<llama_kv_cache *>(memory.get());
19401931

19411932
kv_self->clear();
1942-
llama_kv_cache_guard kv_guard(kv_self);
19431933

19441934
for (uint32_t pos_ctx = 0; pos_ctx < n_ctx; pos_ctx += n_batch) {
19451935
batch.n_tokens = n_batch;
@@ -1962,25 +1952,26 @@ void llama_context::opt_epoch_iter(
19621952

19631953
int64_t n_outputs_all = n_tokens_all;
19641954

1965-
llama_sbatch sbatch = kv_self->sbatch_init(batch, /*logits_all =*/ true);
1955+
//llama_sbatch sbatch = kv_self->sbatch_init(batch, /*logits_all =*/ true);
1956+
auto decode_state = kv_self->init(batch, cparams.n_ubatch, embd_pooled, /* logits_all */ n_outputs_all == n_tokens_all);
1957+
if (!decode_state) {
1958+
LLAMA_LOG_ERROR("%s: could not initialize batch\n", __func__);
1959+
break;
1960+
}
19661961

19671962
// reserve output buffer
19681963
if (output_reserve(n_outputs_all) < n_outputs_all) {
19691964
LLAMA_LOG_ERROR("%s: could not reserve space for batch with %" PRId64 " outputs\n", __func__, n_outputs_all);
19701965
GGML_ABORT("TODO: handle this error");
19711966
};
19721967

1973-
for (uint32_t pos_batch = 0; pos_batch < n_batch; pos_batch += n_ubatch) {
1974-
llama_ubatch ubatch = kv_self->ubatch_next(sbatch, cparams.n_ubatch, embd_pooled);
1968+
uint32_t pos_batch = 0;
1969+
while (const auto * ubatch_ptr = decode_state->next()) {
1970+
const auto & ubatch = *ubatch_ptr;
19751971

1976-
n_outputs = ubatch.n_tokens;
1972+
pos_batch += ubatch.n_tokens;
19771973

1978-
// TODO: not sure if this is needed
1979-
if (!kv_self->find_slot(ubatch)) {
1980-
LLAMA_LOG_WARN("%s: failed to find KV cache slot for ubatch of size %d\n", __func__, ubatch.n_tokens);
1981-
1982-
GGML_ABORT("TODO: handle this error");
1983-
}
1974+
n_outputs = ubatch.n_tokens;
19841975

19851976
auto * gf = graph_init();
19861977
auto res = graph_build(ctx_compute.get(), gf, ubatch, LLM_GRAPH_TYPE_DEFAULT);
@@ -2017,8 +2008,6 @@ void llama_context::opt_epoch_iter(
20172008
ggml_free(ctx_compute_opt);
20182009
}
20192010
}
2020-
2021-
kv_guard.commit();
20222011
}
20232012

20242013
void llama_context::opt_epoch(

0 commit comments

Comments
 (0)