@@ -633,7 +633,15 @@ bool llama_context::apply_adapter_cvec(
633
633
return cvec.apply (model, data, len, n_embd, il_start, il_end);
634
634
}
635
635
636
- llm_graph_result_ptr llama_context::process (const llama_ubatch & ubatch, llm_graph_type gtype, ggml_status * ret) {
636
+ llm_graph_result_ptr llama_context::process_ubatch (const llama_ubatch & ubatch, llm_graph_type gtype, llama_memory_state_i * mstate, ggml_status * ret) {
637
+ if (mstate && !mstate->apply ()) {
638
+ LLAMA_LOG_ERROR (" %s: failed to apply memory state\n " , __func__);
639
+ if (ret) {
640
+ *ret = GGML_STATUS_FAILED;
641
+ }
642
+ return nullptr ;
643
+ }
644
+
637
645
auto * gf = graph_init ();
638
646
if (!gf) {
639
647
LLAMA_LOG_ERROR (" %s: failed to initialize graph\n " , __func__);
@@ -748,7 +756,7 @@ int llama_context::encode(llama_batch & inp_batch) {
748
756
cparams.causal_attn = false ;
749
757
750
758
ggml_status status;
751
- auto res = process (ubatch, LLM_GRAPH_TYPE_ENCODER, &status);
759
+ const auto res = process_ubatch (ubatch, LLM_GRAPH_TYPE_ENCODER, nullptr , &status);
752
760
753
761
cparams.causal_attn = causal_attn_org;
754
762
@@ -927,12 +935,12 @@ int llama_context::decode(llama_batch & inp_batch) {
927
935
// handle any pending defrags/shifts
928
936
kv_self_update ();
929
937
930
- auto decode_state = kv_self->init (batch, cparams.n_ubatch , embd_pooled, /* logits_all */ n_outputs_all == n_tokens_all);
931
- if (!decode_state ) {
938
+ auto kv_state = kv_self->init_batch (batch, cparams.n_ubatch , embd_pooled, /* logits_all */ n_outputs_all == n_tokens_all);
939
+ if (!kv_state ) {
932
940
return -2 ;
933
941
}
934
942
935
- switch (decode_state ->get_status ()) {
943
+ switch (kv_state ->get_status ()) {
936
944
case LLAMA_MEMORY_STATUS_SUCCESS:
937
945
{
938
946
} break ;
@@ -955,8 +963,8 @@ int llama_context::decode(llama_batch & inp_batch) {
955
963
956
964
int64_t n_outputs_prev = 0 ;
957
965
958
- while ( const auto * ubatch_ptr = decode_state-> next ()) {
959
- const auto & ubatch = *ubatch_ptr ;
966
+ do {
967
+ const auto & ubatch = kv_state-> get_ubatch () ;
960
968
961
969
// count the outputs in this u_batch
962
970
{
@@ -979,7 +987,7 @@ int llama_context::decode(llama_batch & inp_batch) {
979
987
ggml_backend_sched_set_eval_callback (sched.get (), cparams.cb_eval , cparams.cb_eval_user_data );
980
988
981
989
ggml_status status;
982
- auto res = process (ubatch, LLM_GRAPH_TYPE_DECODER, &status);
990
+ const auto res = process_ubatch (ubatch, LLM_GRAPH_TYPE_DECODER, kv_state. get () , &status);
983
991
984
992
if (!res) {
985
993
// the last ubatch failed or was aborted -> remove all positions of that ubatch from the KV cache
@@ -1092,7 +1100,7 @@ int llama_context::decode(llama_batch & inp_batch) {
1092
1100
}
1093
1101
1094
1102
n_outputs_prev += n_outputs;
1095
- }
1103
+ } while (kv_state-> next ());
1096
1104
1097
1105
// set to total number of outputs in the batch, for use in llama_get_logits_ith
1098
1106
n_outputs = n_outputs_all;
@@ -1101,7 +1109,7 @@ int llama_context::decode(llama_batch & inp_batch) {
1101
1109
{
1102
1110
bool sorted_output = true ;
1103
1111
1104
- auto & out_ids = decode_state ->out_ids ();
1112
+ auto & out_ids = kv_state ->out_ids ();
1105
1113
1106
1114
GGML_ASSERT (out_ids.size () == (size_t ) n_outputs_all);
1107
1115
@@ -2020,8 +2028,8 @@ void llama_context::opt_epoch_iter(
2020
2028
2021
2029
int64_t n_outputs_all = n_tokens_all;
2022
2030
2023
- auto decode_state = kv_self->init (batch, cparams.n_ubatch , embd_pooled, /* logits_all */ true );
2024
- if (!decode_state || decode_state ->get_status () != LLAMA_MEMORY_STATUS_SUCCESS) {
2031
+ auto kv_state = kv_self->init_batch (batch, cparams.n_ubatch , embd_pooled, /* logits_all */ true );
2032
+ if (!kv_state || kv_state ->get_status () != LLAMA_MEMORY_STATUS_SUCCESS) {
2025
2033
LLAMA_LOG_ERROR (" %s: could not initialize batch\n " , __func__);
2026
2034
break ;
2027
2035
}
@@ -2033,8 +2041,8 @@ void llama_context::opt_epoch_iter(
2033
2041
};
2034
2042
2035
2043
uint32_t pos_batch = 0 ;
2036
- while ( const auto * ubatch_ptr = decode_state-> next ()) {
2037
- const auto & ubatch = *ubatch_ptr ;
2044
+ do {
2045
+ const auto & ubatch = kv_state-> get_ubatch () ;
2038
2046
2039
2047
n_outputs = ubatch.n_tokens ;
2040
2048
@@ -2073,7 +2081,7 @@ void llama_context::opt_epoch_iter(
2073
2081
ggml_free (ctx_compute_opt);
2074
2082
2075
2083
pos_batch += ubatch.n_tokens ;
2076
- }
2084
+ } while (kv_state-> next ());
2077
2085
}
2078
2086
}
2079
2087
0 commit comments