Skip to content

Commit 825efad

Browse files
committed
kv-cache : extract the graph-specific state from the KV object
ggml-ci
1 parent a592c13 commit 825efad

7 files changed

+385
-167
lines changed

src/llama-context.cpp

Lines changed: 22 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -274,13 +274,16 @@ llama_context::llama_context(
274274
// simulate full KV cache
275275
llama_kv_cache * kv_self = static_cast<llama_kv_cache *>(memory.get());
276276

277-
kv_self->set_full();
277+
const auto kv_state = kv_self->init_full();
278+
if (!kv_state) {
279+
throw std::runtime_error("failed to initialize KV cache");
280+
}
278281

279282
cross.v_embd.clear();
280283

281284
// reserve pp graph first so that buffers are only allocated once
282285
{
283-
auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens);
286+
auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, kv_state.get());
284287
if (!gf) {
285288
throw std::runtime_error("failed to allocate compute pp buffers");
286289
}
@@ -291,7 +294,7 @@ llama_context::llama_context(
291294

292295
// reserve with tg graph to get the number of splits and nodes
293296
{
294-
auto * gf = graph_reserve(1, 1, 1);
297+
auto * gf = graph_reserve(1, 1, 1, kv_state.get());
295298
if (!gf) {
296299
throw std::runtime_error("failed to allocate compute tg buffers");
297300
}
@@ -302,7 +305,7 @@ llama_context::llama_context(
302305

303306
// reserve again with pp graph to avoid ggml-alloc reallocations during inference
304307
{
305-
auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens);
308+
auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, kv_state.get());
306309
if (!gf) {
307310
throw std::runtime_error("failed to allocate compute pp buffers");
308311
}
@@ -430,12 +433,15 @@ void llama_context::kv_self_update() {
430433

431434
if (kv_self->update(*this)) {
432435
// if the KV cache did any computation, we have to reserve a new worst-case graph
433-
kv_self->set_full();
436+
const auto kv_state = kv_self->init_full();
437+
if (!kv_state) {
438+
throw std::runtime_error("failed to initialize KV cache");
439+
}
434440

435441
const uint32_t n_seqs = cparams.n_seq_max;
436442
const uint32_t n_tokens = std::min(cparams.n_ctx, cparams.n_ubatch);
437443

438-
auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens);
444+
auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, kv_state.get());
439445
if (!gf) {
440446
LLAMA_LOG_ERROR("%s: failed to reserve graph after the KV cache update\n", __func__);
441447
}
@@ -651,7 +657,7 @@ llm_graph_result_ptr llama_context::process_ubatch(const llama_ubatch & ubatch,
651657
return nullptr;
652658
}
653659

654-
auto res = graph_build(ctx_compute.get(), gf, ubatch, gtype);
660+
auto res = graph_build(ctx_compute.get(), gf, ubatch, gtype, mstate);
655661
if (!res) {
656662
LLAMA_LOG_ERROR("%s: failed to build graph\n", __func__);
657663
if (ret) {
@@ -1269,7 +1275,7 @@ ggml_cgraph * llama_context::graph_init() {
12691275
return ggml_new_graph_custom(ctx_compute.get(), graph_max_nodes(), false);
12701276
}
12711277

1272-
ggml_cgraph * llama_context::graph_reserve(uint32_t n_tokens, uint32_t n_seqs, uint32_t n_outputs) {
1278+
ggml_cgraph * llama_context::graph_reserve(uint32_t n_tokens, uint32_t n_seqs, uint32_t n_outputs, const llama_memory_state_i * mstate) {
12731279
LLAMA_LOG_DEBUG("%s: reserving a graph for ubatch with n_tokens = %4u, n_seqs = %2u, n_outputs = %4u\n", __func__, n_tokens, n_seqs, n_outputs);
12741280

12751281
if (n_tokens % n_seqs != 0) {
@@ -1289,7 +1295,7 @@ ggml_cgraph * llama_context::graph_reserve(uint32_t n_tokens, uint32_t n_seqs, u
12891295
llama_ubatch ubatch = { true, n_tokens, n_tokens / n_seqs, n_seqs, &token, nullptr, nullptr, nullptr, nullptr, nullptr};
12901296

12911297
auto * gf = graph_init();
1292-
auto res = graph_build(ctx_compute.get(), gf, ubatch, LLM_GRAPH_TYPE_DEFAULT);
1298+
auto res = graph_build(ctx_compute.get(), gf, ubatch, LLM_GRAPH_TYPE_DEFAULT, mstate);
12931299

12941300
this->n_outputs = save_n_outputs;
12951301

@@ -1310,10 +1316,11 @@ ggml_cgraph * llama_context::graph_reserve(uint32_t n_tokens, uint32_t n_seqs, u
13101316
}
13111317

13121318
llm_graph_result_ptr llama_context::graph_build(
1313-
ggml_context * ctx,
1314-
ggml_cgraph * gf,
1315-
const llama_ubatch & ubatch,
1316-
llm_graph_type gtype) {
1319+
ggml_context * ctx,
1320+
ggml_cgraph * gf,
1321+
const llama_ubatch & ubatch,
1322+
llm_graph_type gtype,
1323+
const llama_memory_state_i * mstate) {
13171324
return model.build_graph(
13181325
{
13191326
/*.ctx =*/ ctx,
@@ -1325,7 +1332,7 @@ llm_graph_result_ptr llama_context::graph_build(
13251332
/*.backend_cpu =*/ backend_cpu,
13261333
/*.cvec =*/ &cvec,
13271334
/*.loras =*/ &loras,
1328-
/*.memory =*/ memory.get(),
1335+
/*.mstate =*/ mstate,
13291336
/*.cross =*/ &cross,
13301337
/*.n_outputs =*/ n_outputs,
13311338
/*.cb =*/ graph_get_cb(),
@@ -2047,7 +2054,7 @@ void llama_context::opt_epoch_iter(
20472054
n_outputs = ubatch.n_tokens;
20482055

20492056
auto * gf = graph_init();
2050-
auto res = graph_build(ctx_compute.get(), gf, ubatch, LLM_GRAPH_TYPE_DEFAULT);
2057+
auto res = graph_build(ctx_compute.get(), gf, ubatch, LLM_GRAPH_TYPE_DEFAULT, kv_state.get());
20512058

20522059
struct ggml_context * ctx_compute_opt;
20532060
{

src/llama-context.h

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ struct llama_kv_cache;
1818
class llama_io_read_i;
1919
class llama_io_write_i;
2020

21+
class llama_memory_i;
2122
class llama_memory_state_i;
2223

2324
struct llama_context {
@@ -196,14 +197,15 @@ struct llama_context {
196197
ggml_status graph_compute(ggml_cgraph * gf, bool batched);
197198

198199
// reserve a graph with a dummy ubatch of the specified size
199-
ggml_cgraph * graph_reserve(uint32_t n_tokens, uint32_t n_seqs, uint32_t n_outputs);
200+
ggml_cgraph * graph_reserve(uint32_t n_tokens, uint32_t n_seqs, uint32_t n_outputs, const llama_memory_state_i * mstate);
200201

201202
private:
202203
llm_graph_result_ptr graph_build(
203-
ggml_context * ctx,
204-
ggml_cgraph * gf,
205-
const llama_ubatch & ubatch,
206-
llm_graph_type gtype);
204+
ggml_context * ctx,
205+
ggml_cgraph * gf,
206+
const llama_ubatch & ubatch,
207+
llm_graph_type gtype,
208+
const llama_memory_state_i * mstate);
207209

208210
llm_graph_cb graph_get_cb() const;
209211

0 commit comments

Comments
 (0)