@@ -424,28 +424,33 @@ const llama_kv_cache * llama_context::get_kv_self() const {
424
424
return kv_self;
425
425
}
426
426
427
- void llama_context::kv_self_update () {
427
+ bool llama_context::kv_self_update () {
428
428
if (!memory) {
429
- return ;
429
+ return false ;
430
430
}
431
431
432
432
llama_kv_cache * kv_self = static_cast <llama_kv_cache *>(memory.get ());
433
433
434
- if (kv_self->update (*this )) {
435
- // if the KV cache did any computation, we have to reserve a new worst-case graph
436
- const auto kv_state = kv_self->init_full ();
437
- if (!kv_state) {
438
- throw std::runtime_error (" failed to initialize KV cache" );
439
- }
434
+ if (!kv_self->update (*this )) {
435
+ // no updates have been performed
436
+ return false ;
437
+ }
440
438
441
- const uint32_t n_seqs = cparams.n_seq_max ;
442
- const uint32_t n_tokens = std::min (cparams.n_ctx , cparams.n_ubatch );
439
+ // if the KV cache did any computation, we have to reserve a new worst-case graph
440
+ const auto kv_state = kv_self->init_full ();
441
+ if (!kv_state) {
442
+ throw std::runtime_error (" failed to initialize KV cache" );
443
+ }
443
444
444
- auto * gf = graph_reserve (n_tokens, n_seqs, n_tokens, kv_state.get ());
445
- if (!gf) {
446
- LLAMA_LOG_ERROR (" %s: failed to reserve graph after the KV cache update\n " , __func__);
447
- }
445
+ const uint32_t n_seqs = cparams.n_seq_max ;
446
+ const uint32_t n_tokens = std::min (cparams.n_ctx , cparams.n_ubatch );
447
+
448
+ auto * gf = graph_reserve (n_tokens, n_seqs, n_tokens, kv_state.get ());
449
+ if (!gf) {
450
+ LLAMA_LOG_ERROR (" %s: failed to reserve graph after the KV cache update\n " , __func__);
448
451
}
452
+
453
+ return true ;
449
454
}
450
455
451
456
enum llama_pooling_type llama_context::pooling_type () const {
@@ -933,24 +938,44 @@ int llama_context::decode(llama_batch & inp_batch) {
933
938
// handle any pending defrags/shifts
934
939
kv_self_update ();
935
940
936
- auto kv_state = kv_self->init_batch (batch, cparams.n_ubatch , embd_pooled, /* logits_all */ n_outputs_all == n_tokens_all);
937
- if (!kv_state) {
938
- return -2 ;
939
- }
941
+ llama_memory_state_ptr kv_state;
940
942
941
- switch (kv_state->get_status ()) {
942
- case LLAMA_MEMORY_STATUS_SUCCESS:
943
- {
944
- } break ;
945
- case LLAMA_MEMORY_STATUS_FAILED_PREPARE:
946
- {
947
- // not a fatal error, we can re-try with a different batch
948
- return 1 ;
949
- }
950
- case LLAMA_MEMORY_STATUS_FAILED_COMPUTE:
951
- {
952
- return -2 ;
953
- }
943
+ bool did_defrag = false ;
944
+
945
+ while (true ) {
946
+ kv_state = kv_self->init_batch (batch, cparams.n_ubatch , embd_pooled, /* logits_all */ n_outputs_all == n_tokens_all);
947
+ if (!kv_state) {
948
+ return -2 ;
949
+ }
950
+
951
+ switch (kv_state->get_status ()) {
952
+ case LLAMA_MEMORY_STATUS_SUCCESS:
953
+ {
954
+ } break ;
955
+ case LLAMA_MEMORY_STATUS_FAILED_PREPARE:
956
+ {
957
+ if (!did_defrag) {
958
+ did_defrag = true ;
959
+
960
+ kv_self->defrag_sched (-1 .0f );
961
+ if (kv_self_update ()) {
962
+ LLAMA_LOG_DEBUG (" %s: failed to init batch of size %d, retrying after defrag\n " , __func__, batch.n_tokens );
963
+
964
+ continue ;
965
+ }
966
+ }
967
+
968
+ LLAMA_LOG_WARN (" %s: failed to find KV cache slot for batch of size %d\n " , __func__, batch.n_tokens );
969
+
970
+ return 1 ;
971
+ }
972
+ case LLAMA_MEMORY_STATUS_FAILED_COMPUTE:
973
+ {
974
+ return -2 ;
975
+ }
976
+ }
977
+
978
+ break ;
954
979
}
955
980
956
981
// reserve output buffer
@@ -2646,22 +2671,8 @@ int32_t llama_encode(
2646
2671
int32_t llama_decode (
2647
2672
llama_context * ctx,
2648
2673
llama_batch batch) {
2649
- int ret = ctx->decode (batch);
2650
-
2651
- // defrag and try again
2652
- // TODO: distinguish return code when we are sure that even after defrag there is no space available
2653
- if (ret == 1 ) {
2654
- llama_kv_self_defrag (ctx);
2655
- ret = ctx->decode (batch);
2656
-
2657
- if (ret == 1 ) {
2658
- LLAMA_LOG_WARN (" %s: failed to find KV cache slot for batch of size %d\n " , __func__, batch.n_tokens );
2659
-
2660
- return ret;
2661
- }
2662
- }
2663
-
2664
- if (ret != 0 ) {
2674
+ const int ret = ctx->decode (batch);
2675
+ if (ret != 0 && ret != 1 ) {
2665
2676
LLAMA_LOG_ERROR (" %s: failed to decode, ret = %d\n " , __func__, ret);
2666
2677
}
2667
2678
0 commit comments