@@ -285,7 +285,7 @@ llama_context::llama_context(
285
285
286
286
// reserve pp graph first so that buffers are only allocated once
287
287
{
288
- llama_ubatch ubatch_pp = { true , n_tokens, n_tokens / n_seqs, n_seqs, &token, nullptr , nullptr , nullptr , nullptr , nullptr };
288
+ llama_ubatch ubatch_pp = { true , n_tokens, n_tokens / n_seqs, n_seqs, {- 1 }, {- 1 }, &token, nullptr , nullptr , nullptr , nullptr , nullptr };
289
289
290
290
// max number of outputs
291
291
n_outputs = ubatch_pp.n_tokens ;
@@ -305,7 +305,7 @@ llama_context::llama_context(
305
305
306
306
// reserve with tg graph to get the number of splits and nodes
307
307
{
308
- llama_ubatch ubatch_tg = { true , 1 , 1 , n_seqs, &token, nullptr , nullptr , nullptr , nullptr , nullptr };
308
+ llama_ubatch ubatch_tg = { true , 1 , 1 , n_seqs, {- 1 }, {- 1 }, &token, nullptr , nullptr , nullptr , nullptr , nullptr };
309
309
310
310
n_outputs = ubatch_tg.n_tokens ;
311
311
@@ -324,7 +324,7 @@ llama_context::llama_context(
324
324
325
325
// reserve again with pp graph to avoid ggml-alloc reallocations during inference
326
326
{
327
- llama_ubatch ubatch_pp = { true , n_tokens, n_tokens / n_seqs, n_seqs, &token, nullptr , nullptr , nullptr , nullptr , nullptr };
327
+ llama_ubatch ubatch_pp = { true , n_tokens, n_tokens / n_seqs, n_seqs, {- 1 }, {- 1 }, &token, nullptr , nullptr , nullptr , nullptr , nullptr };
328
328
329
329
n_outputs = ubatch_pp.n_tokens ;
330
330
@@ -472,7 +472,7 @@ void llama_context::kv_self_update() {
472
472
kv_self->set_full ();
473
473
474
474
llama_token token = model.vocab .token_bos (); // not actually used by llama_build_graph, but required to choose between token and embedding inputs graph
475
- llama_ubatch ubatch = { true , n_tokens, n_tokens / n_seqs, n_seqs, &token, nullptr , nullptr , nullptr , nullptr , nullptr };
475
+ llama_ubatch ubatch = { true , n_tokens, n_tokens / n_seqs, n_seqs, {- 1 }, {- 1 }, &token, nullptr , nullptr , nullptr , nullptr , nullptr };
476
476
477
477
auto * gf = graph_init ();
478
478
graph_build (ctx_compute.get (), gf, ubatch, LLM_GRAPH_TYPE_DEFAULT);
@@ -731,8 +731,6 @@ int llama_context::encode(llama_batch & inp_batch) {
731
731
732
732
n_outputs = n_tokens;
733
733
734
- // batch_manager->prepare(ubatch);
735
-
736
734
ggml_backend_sched_reset (sched.get ());
737
735
ggml_backend_sched_set_eval_callback (sched.get (), cparams.cb_eval , cparams.cb_eval_user_data );
738
736
@@ -883,8 +881,6 @@ int llama_context::decode(llama_batch & inp_batch) {
883
881
const int64_t n_tokens_all = batch.n_tokens ;
884
882
const int64_t n_embd = hparams.n_embd ;
885
883
886
- llama_kv_cache_guard kv_guard (kv_self);
887
-
888
884
GGML_ASSERT ((!batch.token && batch.embd ) || (batch.token && !batch.embd )); // NOLINT
889
885
890
886
if (batch.token ) {
@@ -924,21 +920,24 @@ int llama_context::decode(llama_batch & inp_batch) {
924
920
n_outputs_all = 1 ;
925
921
}
926
922
927
- llama_sbatch sbatch = kv_self->sbatch_init (batch, /* logits_all */ n_outputs_all == n_tokens_all);
923
+ // handle any pending defrags/shifts
924
+ kv_self_update ();
925
+
926
+ auto decode_state = kv_self->init (batch, cparams.n_ubatch , embd_pooled, /* logits_all */ n_outputs_all == n_tokens_all);
927
+ if (!decode_state) {
928
+ return 1 ;
929
+ }
928
930
929
931
// reserve output buffer
930
932
if (output_reserve (n_outputs_all) < n_outputs_all) {
931
933
LLAMA_LOG_ERROR (" %s: could not reserve space for batch with %" PRId64 " outputs\n " , __func__, n_outputs_all);
932
934
return -2 ;
933
935
};
934
936
935
- // handle any pending defrags/shifts
936
- kv_self_update ();
937
-
938
937
int64_t n_outputs_prev = 0 ;
939
938
940
- while (sbatch. n_tokens > 0 ) {
941
- llama_ubatch ubatch = kv_self-> ubatch_next (sbatch, cparams. n_ubatch , embd_pooled) ;
939
+ while (const auto * ubatch_ptr = decode_state-> next () ) {
940
+ const auto & ubatch = *ubatch_ptr ;
942
941
943
942
// count the outputs in this u_batch
944
943
{
@@ -957,11 +956,6 @@ int llama_context::decode(llama_batch & inp_batch) {
957
956
n_outputs = n_outputs_new;
958
957
}
959
958
960
- // find KV slot
961
- if (!kv_self->find_slot (ubatch)) {
962
- return 1 ;
963
- }
964
-
965
959
ggml_backend_sched_reset (sched.get ());
966
960
ggml_backend_sched_set_eval_callback (sched.get (), cparams.cb_eval , cparams.cb_eval_user_data );
967
961
@@ -1072,17 +1066,14 @@ int llama_context::decode(llama_batch & inp_batch) {
1072
1066
n_outputs_prev += n_outputs;
1073
1067
}
1074
1068
1075
- // finalize the batch processing
1076
- kv_guard.commit ();
1077
-
1078
1069
// set to total number of outputs in the batch, for use in llama_get_logits_ith
1079
1070
n_outputs = n_outputs_all;
1080
1071
1081
1072
// set output mappings
1082
1073
{
1083
1074
bool sorted_output = true ;
1084
1075
1085
- auto & out_ids = sbatch. out_ids ;
1076
+ auto & out_ids = decode_state-> out_ids () ;
1086
1077
1087
1078
GGML_ASSERT (out_ids.size () == (size_t ) n_outputs_all);
1088
1079
@@ -1939,7 +1930,6 @@ void llama_context::opt_epoch_iter(
1939
1930
llama_kv_cache * kv_self = static_cast <llama_kv_cache *>(memory.get ());
1940
1931
1941
1932
kv_self->clear ();
1942
- llama_kv_cache_guard kv_guard (kv_self);
1943
1933
1944
1934
for (uint32_t pos_ctx = 0 ; pos_ctx < n_ctx; pos_ctx += n_batch) {
1945
1935
batch.n_tokens = n_batch;
@@ -1962,25 +1952,26 @@ void llama_context::opt_epoch_iter(
1962
1952
1963
1953
int64_t n_outputs_all = n_tokens_all;
1964
1954
1965
- llama_sbatch sbatch = kv_self->sbatch_init (batch, /* logits_all =*/ true );
1955
+ // llama_sbatch sbatch = kv_self->sbatch_init(batch, /*logits_all =*/ true);
1956
+ auto decode_state = kv_self->init (batch, cparams.n_ubatch , embd_pooled, /* logits_all */ n_outputs_all == n_tokens_all);
1957
+ if (!decode_state) {
1958
+ LLAMA_LOG_ERROR (" %s: could not initialize batch\n " , __func__);
1959
+ break ;
1960
+ }
1966
1961
1967
1962
// reserve output buffer
1968
1963
if (output_reserve (n_outputs_all) < n_outputs_all) {
1969
1964
LLAMA_LOG_ERROR (" %s: could not reserve space for batch with %" PRId64 " outputs\n " , __func__, n_outputs_all);
1970
1965
GGML_ABORT (" TODO: handle this error" );
1971
1966
};
1972
1967
1973
- for (uint32_t pos_batch = 0 ; pos_batch < n_batch; pos_batch += n_ubatch) {
1974
- llama_ubatch ubatch = kv_self->ubatch_next (sbatch, cparams.n_ubatch , embd_pooled);
1968
+ uint32_t pos_batch = 0 ;
1969
+ while (const auto * ubatch_ptr = decode_state->next ()) {
1970
+ const auto & ubatch = *ubatch_ptr;
1975
1971
1976
- n_outputs = ubatch.n_tokens ;
1972
+ pos_batch + = ubatch.n_tokens ;
1977
1973
1978
- // TODO: not sure if this is needed
1979
- if (!kv_self->find_slot (ubatch)) {
1980
- LLAMA_LOG_WARN (" %s: failed to find KV cache slot for ubatch of size %d\n " , __func__, ubatch.n_tokens );
1981
-
1982
- GGML_ABORT (" TODO: handle this error" );
1983
- }
1974
+ n_outputs = ubatch.n_tokens ;
1984
1975
1985
1976
auto * gf = graph_init ();
1986
1977
auto res = graph_build (ctx_compute.get (), gf, ubatch, LLM_GRAPH_TYPE_DEFAULT);
@@ -2017,8 +2008,6 @@ void llama_context::opt_epoch_iter(
2017
2008
ggml_free (ctx_compute_opt);
2018
2009
}
2019
2010
}
2020
-
2021
- kv_guard.commit ();
2022
2011
}
2023
2012
2024
2013
void llama_context::opt_epoch (
0 commit comments