@@ -274,13 +274,16 @@ llama_context::llama_context(
274
274
// simulate full KV cache
275
275
llama_kv_cache * kv_self = static_cast <llama_kv_cache *>(memory.get ());
276
276
277
- kv_self->set_full ();
277
+ const auto kv_state = kv_self->init_full ();
278
+ if (!kv_state) {
279
+ throw std::runtime_error (" failed to initialize KV cache" );
280
+ }
278
281
279
282
cross.v_embd .clear ();
280
283
281
284
// reserve pp graph first so that buffers are only allocated once
282
285
{
283
- auto * gf = graph_reserve (n_tokens, n_seqs, n_tokens);
286
+ auto * gf = graph_reserve (n_tokens, n_seqs, n_tokens, kv_state. get () );
284
287
if (!gf) {
285
288
throw std::runtime_error (" failed to allocate compute pp buffers" );
286
289
}
@@ -291,7 +294,7 @@ llama_context::llama_context(
291
294
292
295
// reserve with tg graph to get the number of splits and nodes
293
296
{
294
- auto * gf = graph_reserve (1 , 1 , 1 );
297
+ auto * gf = graph_reserve (1 , 1 , 1 , kv_state. get () );
295
298
if (!gf) {
296
299
throw std::runtime_error (" failed to allocate compute tg buffers" );
297
300
}
@@ -302,7 +305,7 @@ llama_context::llama_context(
302
305
303
306
// reserve again with pp graph to avoid ggml-alloc reallocations during inference
304
307
{
305
- auto * gf = graph_reserve (n_tokens, n_seqs, n_tokens);
308
+ auto * gf = graph_reserve (n_tokens, n_seqs, n_tokens, kv_state. get () );
306
309
if (!gf) {
307
310
throw std::runtime_error (" failed to allocate compute pp buffers" );
308
311
}
@@ -430,12 +433,15 @@ void llama_context::kv_self_update() {
430
433
431
434
if (kv_self->update (*this )) {
432
435
// if the KV cache did any computation, we have to reserve a new worst-case graph
433
- kv_self->set_full ();
436
+ const auto kv_state = kv_self->init_full ();
437
+ if (!kv_state) {
438
+ throw std::runtime_error (" failed to initialize KV cache" );
439
+ }
434
440
435
441
const uint32_t n_seqs = cparams.n_seq_max ;
436
442
const uint32_t n_tokens = std::min (cparams.n_ctx , cparams.n_ubatch );
437
443
438
- auto * gf = graph_reserve (n_tokens, n_seqs, n_tokens);
444
+ auto * gf = graph_reserve (n_tokens, n_seqs, n_tokens, kv_state. get () );
439
445
if (!gf) {
440
446
LLAMA_LOG_ERROR (" %s: failed to reserve graph after the KV cache update\n " , __func__);
441
447
}
@@ -651,7 +657,7 @@ llm_graph_result_ptr llama_context::process_ubatch(const llama_ubatch & ubatch,
651
657
return nullptr ;
652
658
}
653
659
654
- auto res = graph_build (ctx_compute.get (), gf, ubatch, gtype);
660
+ auto res = graph_build (ctx_compute.get (), gf, ubatch, gtype, mstate );
655
661
if (!res) {
656
662
LLAMA_LOG_ERROR (" %s: failed to build graph\n " , __func__);
657
663
if (ret) {
@@ -1269,7 +1275,7 @@ ggml_cgraph * llama_context::graph_init() {
1269
1275
return ggml_new_graph_custom (ctx_compute.get (), graph_max_nodes (), false );
1270
1276
}
1271
1277
1272
- ggml_cgraph * llama_context::graph_reserve (uint32_t n_tokens, uint32_t n_seqs, uint32_t n_outputs) {
1278
+ ggml_cgraph * llama_context::graph_reserve (uint32_t n_tokens, uint32_t n_seqs, uint32_t n_outputs, const llama_memory_state_i * mstate ) {
1273
1279
LLAMA_LOG_DEBUG (" %s: reserving a graph for ubatch with n_tokens = %4u, n_seqs = %2u, n_outputs = %4u\n " , __func__, n_tokens, n_seqs, n_outputs);
1274
1280
1275
1281
if (n_tokens % n_seqs != 0 ) {
@@ -1289,7 +1295,7 @@ ggml_cgraph * llama_context::graph_reserve(uint32_t n_tokens, uint32_t n_seqs, u
1289
1295
llama_ubatch ubatch = { true , n_tokens, n_tokens / n_seqs, n_seqs, &token, nullptr , nullptr , nullptr , nullptr , nullptr };
1290
1296
1291
1297
auto * gf = graph_init ();
1292
- auto res = graph_build (ctx_compute.get (), gf, ubatch, LLM_GRAPH_TYPE_DEFAULT);
1298
+ auto res = graph_build (ctx_compute.get (), gf, ubatch, LLM_GRAPH_TYPE_DEFAULT, mstate );
1293
1299
1294
1300
this ->n_outputs = save_n_outputs;
1295
1301
@@ -1310,10 +1316,11 @@ ggml_cgraph * llama_context::graph_reserve(uint32_t n_tokens, uint32_t n_seqs, u
1310
1316
}
1311
1317
1312
1318
llm_graph_result_ptr llama_context::graph_build (
1313
- ggml_context * ctx,
1314
- ggml_cgraph * gf,
1315
- const llama_ubatch & ubatch,
1316
- llm_graph_type gtype) {
1319
+ ggml_context * ctx,
1320
+ ggml_cgraph * gf,
1321
+ const llama_ubatch & ubatch,
1322
+ llm_graph_type gtype,
1323
+ const llama_memory_state_i * mstate) {
1317
1324
return model.build_graph (
1318
1325
{
1319
1326
/* .ctx =*/ ctx,
@@ -1325,7 +1332,7 @@ llm_graph_result_ptr llama_context::graph_build(
1325
1332
/* .backend_cpu =*/ backend_cpu,
1326
1333
/* .cvec =*/ &cvec,
1327
1334
/* .loras =*/ &loras,
1328
- /* .memory =*/ memory. get () ,
1335
+ /* .mstate =*/ mstate ,
1329
1336
/* .cross =*/ &cross,
1330
1337
/* .n_outputs =*/ n_outputs,
1331
1338
/* .cb =*/ graph_get_cb (),
@@ -2047,7 +2054,7 @@ void llama_context::opt_epoch_iter(
2047
2054
n_outputs = ubatch.n_tokens ;
2048
2055
2049
2056
auto * gf = graph_init ();
2050
- auto res = graph_build (ctx_compute.get (), gf, ubatch, LLM_GRAPH_TYPE_DEFAULT);
2057
+ auto res = graph_build (ctx_compute.get (), gf, ubatch, LLM_GRAPH_TYPE_DEFAULT, kv_state. get () );
2051
2058
2052
2059
struct ggml_context * ctx_compute_opt;
2053
2060
{
0 commit comments