@@ -274,13 +274,16 @@ llama_context::llama_context(
274
274
// simulate full KV cache
275
275
llama_kv_cache * kv_self = static_cast <llama_kv_cache *>(memory.get ());
276
276
277
- kv_self->set_full ();
277
+ const auto kv_state = kv_self->init_full ();
278
+ if (!kv_state) {
279
+ throw std::runtime_error (" failed to initialize KV cache" );
280
+ }
278
281
279
282
cross.v_embd .clear ();
280
283
281
284
// reserve pp graph first so that buffers are only allocated once
282
285
{
283
- auto * gf = graph_reserve (n_tokens, n_seqs, n_tokens);
286
+ auto * gf = graph_reserve (n_tokens, n_seqs, n_tokens, kv_state. get () );
284
287
if (!gf) {
285
288
throw std::runtime_error (" failed to allocate compute pp buffers" );
286
289
}
@@ -291,7 +294,7 @@ llama_context::llama_context(
291
294
292
295
// reserve with tg graph to get the number of splits and nodes
293
296
{
294
- auto * gf = graph_reserve (1 , 1 , 1 );
297
+ auto * gf = graph_reserve (1 , 1 , 1 , kv_state. get () );
295
298
if (!gf) {
296
299
throw std::runtime_error (" failed to allocate compute tg buffers" );
297
300
}
@@ -302,7 +305,7 @@ llama_context::llama_context(
302
305
303
306
// reserve again with pp graph to avoid ggml-alloc reallocations during inference
304
307
{
305
- auto * gf = graph_reserve (n_tokens, n_seqs, n_tokens);
308
+ auto * gf = graph_reserve (n_tokens, n_seqs, n_tokens, kv_state. get () );
306
309
if (!gf) {
307
310
throw std::runtime_error (" failed to allocate compute pp buffers" );
308
311
}
@@ -430,12 +433,15 @@ void llama_context::kv_self_update() {
430
433
431
434
if (kv_self->update (*this )) {
432
435
// if the KV cache did any computation, we have to reserve a new worst-case graph
433
- kv_self->set_full ();
436
+ const auto kv_state = kv_self->init_full ();
437
+ if (!kv_state) {
438
+ throw std::runtime_error (" failed to initialize KV cache" );
439
+ }
434
440
435
441
const uint32_t n_seqs = cparams.n_seq_max ;
436
442
const uint32_t n_tokens = std::min (cparams.n_ctx , cparams.n_ubatch );
437
443
438
- auto * gf = graph_reserve (n_tokens, n_seqs, n_tokens);
444
+ auto * gf = graph_reserve (n_tokens, n_seqs, n_tokens, kv_state. get () );
439
445
if (!gf) {
440
446
LLAMA_LOG_ERROR (" %s: failed to reserve graph after the KV cache update\n " , __func__);
441
447
}
@@ -633,32 +639,32 @@ bool llama_context::apply_adapter_cvec(
633
639
return cvec.apply (model, data, len, n_embd, il_start, il_end);
634
640
}
635
641
636
- llm_graph_result_ptr llama_context::process (const llama_ubatch & ubatch, llm_graph_type gtype, ggml_status * ret) {
642
+ llm_graph_result_ptr llama_context::process_ubatch (const llama_ubatch & ubatch, llm_graph_type gtype, llama_memory_state_i * mstate, ggml_status & ret) {
643
+ if (mstate && !mstate->apply ()) {
644
+ LLAMA_LOG_ERROR (" %s: failed to apply memory state\n " , __func__);
645
+ ret = GGML_STATUS_FAILED;
646
+ return nullptr ;
647
+ }
648
+
637
649
auto * gf = graph_init ();
638
650
if (!gf) {
639
651
LLAMA_LOG_ERROR (" %s: failed to initialize graph\n " , __func__);
640
- if (ret) {
641
- *ret = GGML_STATUS_FAILED;
642
- }
652
+ ret = GGML_STATUS_FAILED;
643
653
return nullptr ;
644
654
}
645
655
646
- auto res = graph_build (ctx_compute.get (), gf, ubatch, gtype);
656
+ auto res = graph_build (ctx_compute.get (), gf, ubatch, gtype, mstate );
647
657
if (!res) {
648
658
LLAMA_LOG_ERROR (" %s: failed to build graph\n " , __func__);
649
- if (ret) {
650
- *ret = GGML_STATUS_FAILED;
651
- }
659
+ ret = GGML_STATUS_FAILED;
652
660
return nullptr ;
653
661
}
654
662
655
663
// LLAMA_LOG_INFO("graph build time: %.3f ms (%d nodes, %d leafs)\n", (ggml_time_us() - t_start_us)/1000.0, gf->n_nodes, gf->n_leafs);
656
664
657
665
if (!ggml_backend_sched_alloc_graph (sched.get (), gf)) {
658
666
LLAMA_LOG_ERROR (" %s: failed to allocate graph\n " , __func__);
659
- if (ret) {
660
- *ret = GGML_STATUS_ALLOC_FAILED;
661
- }
667
+ ret = GGML_STATUS_ALLOC_FAILED;
662
668
return nullptr ;
663
669
}
664
670
@@ -667,12 +673,12 @@ llm_graph_result_ptr llama_context::process(const llama_ubatch & ubatch, llm_gra
667
673
const auto status = graph_compute (gf, ubatch.n_tokens > 1 );
668
674
if (status != GGML_STATUS_SUCCESS) {
669
675
LLAMA_LOG_ERROR (" %s: failed to compute graph, compute status: %d\n " , __func__, status);
670
- if (ret) {
671
- *ret = status;
672
- }
676
+ ret = status;
673
677
return nullptr ;
674
678
}
675
679
680
+ ret = GGML_STATUS_SUCCESS;
681
+
676
682
return res;
677
683
}
678
684
@@ -748,7 +754,7 @@ int llama_context::encode(llama_batch & inp_batch) {
748
754
cparams.causal_attn = false ;
749
755
750
756
ggml_status status;
751
- auto res = process (ubatch, LLM_GRAPH_TYPE_ENCODER, & status);
757
+ const auto res = process_ubatch (ubatch, LLM_GRAPH_TYPE_ENCODER, nullptr , status);
752
758
753
759
cparams.causal_attn = causal_attn_org;
754
760
@@ -927,12 +933,12 @@ int llama_context::decode(llama_batch & inp_batch) {
927
933
// handle any pending defrags/shifts
928
934
kv_self_update ();
929
935
930
- auto decode_state = kv_self->init (batch, cparams.n_ubatch , embd_pooled, /* logits_all */ n_outputs_all == n_tokens_all);
931
- if (!decode_state ) {
936
+ auto kv_state = kv_self->init_batch (batch, cparams.n_ubatch , embd_pooled, /* logits_all */ n_outputs_all == n_tokens_all);
937
+ if (!kv_state ) {
932
938
return -2 ;
933
939
}
934
940
935
- switch (decode_state ->get_status ()) {
941
+ switch (kv_state ->get_status ()) {
936
942
case LLAMA_MEMORY_STATUS_SUCCESS:
937
943
{
938
944
} break ;
@@ -955,8 +961,8 @@ int llama_context::decode(llama_batch & inp_batch) {
955
961
956
962
int64_t n_outputs_prev = 0 ;
957
963
958
- while ( const auto * ubatch_ptr = decode_state-> next ()) {
959
- const auto & ubatch = *ubatch_ptr ;
964
+ do {
965
+ const auto & ubatch = kv_state-> get_ubatch () ;
960
966
961
967
// count the outputs in this u_batch
962
968
{
@@ -979,7 +985,7 @@ int llama_context::decode(llama_batch & inp_batch) {
979
985
ggml_backend_sched_set_eval_callback (sched.get (), cparams.cb_eval , cparams.cb_eval_user_data );
980
986
981
987
ggml_status status;
982
- auto res = process (ubatch, LLM_GRAPH_TYPE_DECODER, & status);
988
+ const auto res = process_ubatch (ubatch, LLM_GRAPH_TYPE_DECODER, kv_state. get (), status);
983
989
984
990
if (!res) {
985
991
// the last ubatch failed or was aborted -> remove all positions of that ubatch from the KV cache
@@ -1092,7 +1098,7 @@ int llama_context::decode(llama_batch & inp_batch) {
1092
1098
}
1093
1099
1094
1100
n_outputs_prev += n_outputs;
1095
- }
1101
+ } while (kv_state-> next ());
1096
1102
1097
1103
// set to total number of outputs in the batch, for use in llama_get_logits_ith
1098
1104
n_outputs = n_outputs_all;
@@ -1101,7 +1107,7 @@ int llama_context::decode(llama_batch & inp_batch) {
1101
1107
{
1102
1108
bool sorted_output = true ;
1103
1109
1104
- auto & out_ids = decode_state ->out_ids ();
1110
+ auto & out_ids = kv_state ->out_ids ();
1105
1111
1106
1112
GGML_ASSERT (out_ids.size () == (size_t ) n_outputs_all);
1107
1113
@@ -1261,7 +1267,7 @@ ggml_cgraph * llama_context::graph_init() {
1261
1267
return ggml_new_graph_custom (ctx_compute.get (), graph_max_nodes (), false );
1262
1268
}
1263
1269
1264
- ggml_cgraph * llama_context::graph_reserve (uint32_t n_tokens, uint32_t n_seqs, uint32_t n_outputs) {
1270
+ ggml_cgraph * llama_context::graph_reserve (uint32_t n_tokens, uint32_t n_seqs, uint32_t n_outputs, const llama_memory_state_i * mstate ) {
1265
1271
LLAMA_LOG_DEBUG (" %s: reserving a graph for ubatch with n_tokens = %4u, n_seqs = %2u, n_outputs = %4u\n " , __func__, n_tokens, n_seqs, n_outputs);
1266
1272
1267
1273
if (n_tokens % n_seqs != 0 ) {
@@ -1281,7 +1287,7 @@ ggml_cgraph * llama_context::graph_reserve(uint32_t n_tokens, uint32_t n_seqs, u
1281
1287
llama_ubatch ubatch = { true , n_tokens, n_tokens / n_seqs, n_seqs, &token, nullptr , nullptr , nullptr , nullptr , nullptr };
1282
1288
1283
1289
auto * gf = graph_init ();
1284
- auto res = graph_build (ctx_compute.get (), gf, ubatch, LLM_GRAPH_TYPE_DEFAULT);
1290
+ auto res = graph_build (ctx_compute.get (), gf, ubatch, LLM_GRAPH_TYPE_DEFAULT, mstate );
1285
1291
1286
1292
this ->n_outputs = save_n_outputs;
1287
1293
@@ -1302,10 +1308,11 @@ ggml_cgraph * llama_context::graph_reserve(uint32_t n_tokens, uint32_t n_seqs, u
1302
1308
}
1303
1309
1304
1310
llm_graph_result_ptr llama_context::graph_build (
1305
- ggml_context * ctx,
1306
- ggml_cgraph * gf,
1307
- const llama_ubatch & ubatch,
1308
- llm_graph_type gtype) {
1311
+ ggml_context * ctx,
1312
+ ggml_cgraph * gf,
1313
+ const llama_ubatch & ubatch,
1314
+ llm_graph_type gtype,
1315
+ const llama_memory_state_i * mstate) {
1309
1316
return model.build_graph (
1310
1317
{
1311
1318
/* .ctx =*/ ctx,
@@ -1317,7 +1324,7 @@ llm_graph_result_ptr llama_context::graph_build(
1317
1324
/* .backend_cpu =*/ backend_cpu,
1318
1325
/* .cvec =*/ &cvec,
1319
1326
/* .loras =*/ &loras,
1320
- /* .memory =*/ memory. get () ,
1327
+ /* .mstate =*/ mstate ,
1321
1328
/* .cross =*/ &cross,
1322
1329
/* .n_outputs =*/ n_outputs,
1323
1330
/* .cb =*/ graph_get_cb (),
@@ -2020,8 +2027,8 @@ void llama_context::opt_epoch_iter(
2020
2027
2021
2028
int64_t n_outputs_all = n_tokens_all;
2022
2029
2023
- auto decode_state = kv_self->init (batch, cparams.n_ubatch , embd_pooled, /* logits_all */ true );
2024
- if (!decode_state || decode_state ->get_status () != LLAMA_MEMORY_STATUS_SUCCESS) {
2030
+ auto kv_state = kv_self->init_batch (batch, cparams.n_ubatch , embd_pooled, /* logits_all */ true );
2031
+ if (!kv_state || kv_state ->get_status () != LLAMA_MEMORY_STATUS_SUCCESS) {
2025
2032
LLAMA_LOG_ERROR (" %s: could not initialize batch\n " , __func__);
2026
2033
break ;
2027
2034
}
@@ -2033,13 +2040,13 @@ void llama_context::opt_epoch_iter(
2033
2040
};
2034
2041
2035
2042
uint32_t pos_batch = 0 ;
2036
- while ( const auto * ubatch_ptr = decode_state-> next ()) {
2037
- const auto & ubatch = *ubatch_ptr ;
2043
+ do {
2044
+ const auto & ubatch = kv_state-> get_ubatch () ;
2038
2045
2039
2046
n_outputs = ubatch.n_tokens ;
2040
2047
2041
2048
auto * gf = graph_init ();
2042
- auto res = graph_build (ctx_compute.get (), gf, ubatch, LLM_GRAPH_TYPE_DEFAULT);
2049
+ auto res = graph_build (ctx_compute.get (), gf, ubatch, LLM_GRAPH_TYPE_DEFAULT, kv_state. get () );
2043
2050
2044
2051
struct ggml_context * ctx_compute_opt;
2045
2052
{
@@ -2073,7 +2080,7 @@ void llama_context::opt_epoch_iter(
2073
2080
ggml_free (ctx_compute_opt);
2074
2081
2075
2082
pos_batch += ubatch.n_tokens ;
2076
- }
2083
+ } while (kv_state-> next ());
2077
2084
}
2078
2085
}
2079
2086
0 commit comments