From b226c5b1a72cef6140bc62a3b94ef773ba1c9dc7 Mon Sep 17 00:00:00 2001 From: Xuan Son Nguyen Date: Fri, 11 Oct 2024 11:48:09 +0200 Subject: [PATCH 01/12] refactor llama_batch_get_one --- include/llama.h | 20 +++----- src/llama.cpp | 130 +++++++++++++++++++++++++++--------------------- 2 files changed, 80 insertions(+), 70 deletions(-) diff --git a/include/llama.h b/include/llama.h index 7cae1bbe2e5b8..dabd64dbbffed 100644 --- a/include/llama.h +++ b/include/llama.h @@ -232,8 +232,11 @@ extern "C" { // - token : the token ids of the input (used when embd is NULL) // - embd : token embeddings (i.e. float vector of size n_embd) (used when token is NULL) // - pos : the positions of the respective token in the sequence + // (if set to NULL, the token position will be tracked automatically by llama_decode) // - seq_id : the sequence to which the respective token belongs + // (if set to NULL, the sequence ID will be assumed to be 0) // - logits : if zero, the logits (and/or the embeddings) for the respective token will not be output + // (if set to NULL, only the logits for last token will be returned) // typedef struct llama_batch { int32_t n_tokens; @@ -244,15 +247,6 @@ extern "C" { int32_t * n_seq_id; llama_seq_id ** seq_id; int8_t * logits; // TODO: rename this to "output" - - // NOTE: helpers for smooth API transition - can be deprecated in the future - // for future-proof code, use the above fields instead and ignore everything below - // - // pos[i] = all_pos_0 + i*all_pos_1 - // - llama_pos all_pos_0; // used if pos == NULL - llama_pos all_pos_1; // used if pos == NULL - llama_seq_id all_seq_id; // used if seq_id == NULL } llama_batch; enum llama_model_kv_override_type { @@ -775,15 +769,15 @@ extern "C" { // Decoding // - // Return batch for single sequence of tokens starting at pos_0 + // Return batch for single sequence of tokens + // The sequence ID will be fixed to 0 + // The position of the tokens will be tracked automatically by llama_decode // // NOTE: this is a helper function to facilitate transition to the new batch API - avoid using it // LLAMA_API struct llama_batch llama_batch_get_one( llama_token * tokens, - int32_t n_tokens, - llama_pos pos_0, - llama_seq_id seq_id); + int32_t n_tokens); // Allocates a batch of tokens on the heap that can hold a maximum of n_tokens // Each token can be assigned up to n_seq_max sequence ids diff --git a/src/llama.cpp b/src/llama.cpp index 3443b0689bf5e..825999b3c6ab6 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -2941,9 +2941,6 @@ struct llama_sbatch_seq { llama_seq_id * seq_id; size_t offset; size_t length; - - // helper for smoother batch API transition -- can be deprecated in the future - llama_seq_id all_seq_id; // used if seq_id == NULL }; // sequence-length-aware batch splitting @@ -3038,30 +3035,18 @@ struct llama_sbatch { } else { ubatch.embd = nullptr; } - // from here on, the else branches are deprecated; - // they are helpers for smoother batch API transition - if (batch->pos) { - if (ubatch.equal_seqs) { - for (size_t i = 0; i < length; ++i) { - ubatch.pos[ubatch.n_tokens + i] = batch->pos[ids[seq.offset + i]]; - } - } else { - // simple split - ubatch.pos = batch->pos + seq.offset; - } - } else { + if (ubatch.equal_seqs) { for (size_t i = 0; i < length; ++i) { - llama_pos bi = ids[seq.offset + i]; - ubatch.pos[ubatch.n_tokens + i] = batch->all_pos_0 + (bi * batch->all_pos_1); + ubatch.pos[ubatch.n_tokens + i] = batch->pos[ids[seq.offset + i]]; } + } else { + // simple split + ubatch.pos = batch->pos + seq.offset; } if (ubatch.equal_seqs) { ubatch.n_seq_id[ubatch.n_seqs] = seq.n_seq_id; if (seq.seq_id) { ubatch.seq_id[ubatch.n_seqs] = seq.seq_id; - } else { - GGML_ASSERT(seq.n_seq_id == 1); - ubatch.seq_id[ubatch.n_seqs] = &seq.all_seq_id; } } else { // simple split @@ -3074,10 +3059,6 @@ struct llama_sbatch { } if (batch->seq_id) { ubatch.seq_id = batch->seq_id + seq.offset; - } else { - for (size_t i = 0; i < length; ++i) { - ubatch.seq_id[ubatch.n_seqs + i] = &seq.all_seq_id; - } } } if (logits_all) { @@ -3196,7 +3177,6 @@ struct llama_sbatch { s.seq_id = nullptr; s.offset = 0; s.length = n_tokens; - s.all_seq_id = batch.all_seq_id; return; } std::sort(ids.begin(), ids.end(), @@ -3219,7 +3199,7 @@ struct llama_sbatch { if (batch.pos) { return batch.pos[a] < batch.pos[b]; } - // no pos, sort by id (assuming batch.all_pos_1 is positive) + // no pos, sort by id return a < b; } // shared prompts go first @@ -3229,30 +3209,25 @@ struct llama_sbatch { // init seq llama_sbatch_seq * last_seq = nullptr; - if (batch.n_seq_id != nullptr && batch.seq_id != nullptr) { - for (size_t i = 0; i < n_tokens; ++i) { - const size_t bi = ids[i]; - const int32_t n_seqs = batch.n_seq_id[bi]; - llama_seq_id * seq_ids = batch.seq_id[bi]; - if (last_seq != nullptr) { - bool same = n_seqs == last_seq->n_seq_id; - for (int32_t j = 0; same && j < n_seqs; ++j) { - if (seq_ids[j] != last_seq->seq_id[j]) { - same = false; - } - } - if (same) { - last_seq->length += 1; - continue; + for (size_t i = 0; i < n_tokens; ++i) { + const size_t bi = ids[i]; + const int32_t n_seqs = batch.n_seq_id[bi]; + llama_seq_id * seq_ids = batch.seq_id[bi]; + if (last_seq != nullptr) { + bool same = n_seqs == last_seq->n_seq_id; + for (int32_t j = 0; same && j < n_seqs; ++j) { + if (seq_ids[j] != last_seq->seq_id[j]) { + same = false; } } - llama_sbatch_seq new_seq = {n_seqs, seq_ids, i, 1, batch.all_seq_id}; - seq.push_back(new_seq); - last_seq = &seq.back(); + if (same) { + last_seq->length += 1; + continue; + } } - } else { - llama_sbatch_seq new_seq = {1, nullptr, 0, n_tokens, batch.all_seq_id}; + llama_sbatch_seq new_seq = {n_seqs, seq_ids, i, 1}; seq.push_back(new_seq); + last_seq = &seq.back(); } // keep shared prompts first at the end, then sort by length descending. std::sort(seq.begin(), seq.end(), @@ -21069,9 +21044,7 @@ void llama_set_causal_attn(struct llama_context * ctx, bool causal_attn) { struct llama_batch llama_batch_get_one( llama_token * tokens, - int32_t n_tokens, - llama_pos pos_0, - llama_seq_id seq_id) { + int32_t n_tokens) { return { /*n_tokens =*/ n_tokens, /*tokens =*/ tokens, @@ -21080,9 +21053,6 @@ struct llama_batch llama_batch_get_one( /*n_seq_id =*/ nullptr, /*seq_id =*/ nullptr, /*logits =*/ nullptr, - /*all_pos_0 =*/ pos_0, - /*all_pos_1 =*/ 1, - /*all_seq_id =*/ seq_id, }; } @@ -21095,9 +21065,6 @@ struct llama_batch llama_batch_init(int32_t n_tokens_alloc, int32_t embd, int32_ /*n_seq_id =*/ nullptr, /*seq_id =*/ nullptr, /*logits =*/ nullptr, - /*all_pos_0 =*/ 0, - /*all_pos_1 =*/ 0, - /*all_seq_id =*/ 0, }; if (embd) { @@ -21133,10 +21100,58 @@ void llama_batch_free(struct llama_batch batch) { if (batch.logits) free(batch.logits); } +// temporary allocate memory for the input batch if needed +struct llama_batch_allocr { + static const llama_seq_id default_seq_id = 0; + std::array seq_id_0 = {default_seq_id}; + std::vector pos; + std::vector n_seq_id; + std::vector seq_id; + std::vector logits; + // fulfill the batch returned by llama_batch_get_one + struct llama_batch get_fulfilled_batch(struct llama_context * ctx, struct llama_batch in_batch) { + struct llama_batch batch = in_batch; + if (!batch.pos) { + // determine the last position in KV cache + llama_pos last_pos; + for (const auto & cell : ctx->kv_self.cells) { + if (cell.seq_id.find(default_seq_id) != cell.seq_id.end()) { + last_pos = std::max(last_pos, cell.pos); + } + } + pos.resize(batch.n_tokens); + for (int32_t i = 1; i <= batch.n_tokens; i++) { + pos[i] = i+last_pos; + } + batch.pos = pos.data(); + } + if (!batch.n_seq_id) { + n_seq_id.reserve(batch.n_tokens); + for (int32_t i = 1; i <= batch.n_tokens; i++) { + n_seq_id[i] = seq_id_0.size(); + } + batch.n_seq_id = n_seq_id.data(); + } + if (!batch.seq_id) { + seq_id.reserve(batch.n_tokens); + for (int32_t i = 1; i <= batch.n_tokens; i++) { + seq_id[i] = seq_id_0.data(); + } + batch.seq_id = seq_id.data(); + } + if (!batch.logits) { + logits.reserve(batch.n_tokens); + logits[logits.size() - 1] = true; + batch.logits = logits.data(); + } + } +}; + int32_t llama_encode( struct llama_context * ctx, struct llama_batch batch) { - const int ret = llama_encode_internal(*ctx, batch); + llama_batch_allocr batch_allocr; + const int ret = llama_encode_internal(*ctx, batch_allocr.get_fulfilled_batch(ctx, batch)); if (ret < 0) { LLAMA_LOG_ERROR("%s: failed to encode, ret = %d\n", __func__, ret); } @@ -21147,7 +21162,8 @@ int32_t llama_encode( int32_t llama_decode( struct llama_context * ctx, struct llama_batch batch) { - const int ret = llama_decode_internal(*ctx, batch); + llama_batch_allocr batch_allocr; + const int ret = llama_decode_internal(*ctx, batch_allocr.get_fulfilled_batch(ctx, batch)); if (ret < 0) { LLAMA_LOG_ERROR("%s: failed to decode, ret = %d\n", __func__, ret); } From 1c486169ed829ba10c625a5b162424f68608fd0a Mon Sep 17 00:00:00 2001 From: Xuan Son Nguyen Date: Fri, 11 Oct 2024 12:11:00 +0200 Subject: [PATCH 02/12] adapt all examples --- common/common.cpp | 4 +- examples/batched-bench/batched-bench.cpp | 1 - .../cvector-generator/cvector-generator.cpp | 2 +- examples/eval-callback/eval-callback.cpp | 2 +- examples/imatrix/imatrix.cpp | 13 ++++++- examples/infill/infill.cpp | 2 +- examples/llama-bench/llama-bench.cpp | 4 +- .../llama/src/main/cpp/llama-android.cpp | 3 -- examples/llava/llava-cli.cpp | 2 +- examples/llava/llava.cpp | 38 ++++++++++++++++++- examples/llava/minicpmv-cli.cpp | 2 +- examples/lookahead/lookahead.cpp | 4 +- examples/lookup/lookup.cpp | 4 +- examples/main/main.cpp | 4 +- examples/parallel/parallel.cpp | 1 - examples/perplexity/perplexity.cpp | 27 ++++++++++--- examples/save-load-state/save-load-state.cpp | 8 ++-- examples/server/server.cpp | 1 - examples/speculative/speculative.cpp | 6 +-- src/llama.cpp | 1 + 20 files changed, 92 insertions(+), 37 deletions(-) diff --git a/common/common.cpp b/common/common.cpp index a0611f3d1734b..2538361291abb 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -912,7 +912,7 @@ struct llama_init_result llama_init_from_gpt_params(gpt_params & params) { } if (llama_model_has_encoder(model)) { - llama_encode(lctx, llama_batch_get_one(tmp.data(), tmp.size(), 0, 0)); + llama_encode(lctx, llama_batch_get_one(tmp.data(), tmp.size())); llama_token decoder_start_token_id = llama_model_decoder_start_token(model); if (decoder_start_token_id == -1) { decoder_start_token_id = bos; @@ -921,7 +921,7 @@ struct llama_init_result llama_init_from_gpt_params(gpt_params & params) { tmp.push_back(decoder_start_token_id); } if (llama_model_has_decoder(model)) { - llama_decode(lctx, llama_batch_get_one(tmp.data(), std::min(tmp.size(), (size_t) params.n_batch), 0, 0)); + llama_decode(lctx, llama_batch_get_one(tmp.data(), std::min(tmp.size(), (size_t) params.n_batch))); } llama_kv_cache_clear(lctx); llama_synchronize(lctx); diff --git a/examples/batched-bench/batched-bench.cpp b/examples/batched-bench/batched-bench.cpp index 4a15941f19abe..921e5dc3b3c43 100644 --- a/examples/batched-bench/batched-bench.cpp +++ b/examples/batched-bench/batched-bench.cpp @@ -74,7 +74,6 @@ int main(int argc, char ** argv) { batch.n_seq_id + i, batch.seq_id + i, batch.logits + i, - 0, 0, 0, // unused }; const int ret = llama_decode(ctx, batch_view); diff --git a/examples/cvector-generator/cvector-generator.cpp b/examples/cvector-generator/cvector-generator.cpp index 41bf4eb2a406c..102e3c517f7be 100644 --- a/examples/cvector-generator/cvector-generator.cpp +++ b/examples/cvector-generator/cvector-generator.cpp @@ -339,7 +339,7 @@ static bool cb_eval(struct ggml_tensor * t, bool ask, void * user_data) { static bool get_hidden_layers(llama_context * ctx, std::vector & tokens) { llama_kv_cache_clear(ctx); - if (llama_decode(ctx, llama_batch_get_one(tokens.data(), tokens.size(), 0, 0))) { + if (llama_decode(ctx, llama_batch_get_one(tokens.data(), tokens.size()))) { fprintf(stderr, "%s : failed to eval\n", __func__); return false; } diff --git a/examples/eval-callback/eval-callback.cpp b/examples/eval-callback/eval-callback.cpp index 6d629fe4ef189..0ab2f0d0e7b45 100644 --- a/examples/eval-callback/eval-callback.cpp +++ b/examples/eval-callback/eval-callback.cpp @@ -131,7 +131,7 @@ static bool run(llama_context * ctx, const gpt_params & params) { std::vector tokens = ::llama_tokenize(ctx, params.prompt, add_bos); - if (llama_decode(ctx, llama_batch_get_one(tokens.data(), tokens.size(), 0, 0))) { + if (llama_decode(ctx, llama_batch_get_one(tokens.data(), tokens.size()))) { LOG_ERR("%s : failed to eval\n", __func__); return false; } diff --git a/examples/imatrix/imatrix.cpp b/examples/imatrix/imatrix.cpp index c8e273529e0fe..f743d22691f17 100644 --- a/examples/imatrix/imatrix.cpp +++ b/examples/imatrix/imatrix.cpp @@ -508,12 +508,21 @@ static bool compute_imatrix(llama_context * ctx, const gpt_params & params) { tokens[batch_start] = llama_token_bos(llama_get_model(ctx)); } - // TODO: use batch.logits to save computations instead of relying on logits_all == true - if (llama_decode(ctx, llama_batch_get_one(tokens.data() + batch_start, batch_size, j * n_batch, 0))) { + llama_batch batch = llama_batch_init(batch_size, 0, 1); + for (int i = 0; i < batch_size; i++) { + batch. token[i] = tokens[batch_start + i]; + batch. pos[i] = j*n_batch + i; + batch.logits[i] = true; + batch.seq_id[i][0] = 0; + } + + if (llama_decode(ctx, batch)) { LOG_ERR("%s : failed to eval\n", __func__); return false; } + llama_batch_free(batch); + // restore the original token in case it was set to BOS tokens[batch_start] = token_org; diff --git a/examples/infill/infill.cpp b/examples/infill/infill.cpp index d52425ae61ef3..67331eb704849 100644 --- a/examples/infill/infill.cpp +++ b/examples/infill/infill.cpp @@ -396,7 +396,7 @@ int main(int argc, char ** argv) { LOG_DBG("eval: %s\n", string_from(ctx, embd).c_str()); - if (llama_decode(ctx, llama_batch_get_one(&embd[i], n_eval, n_past, 0))) { + if (llama_decode(ctx, llama_batch_get_one(&embd[i], n_eval))) { LOG_ERR("%s : failed to eval\n", __func__); return 1; } diff --git a/examples/llama-bench/llama-bench.cpp b/examples/llama-bench/llama-bench.cpp index fb1d387b2b11d..d989140a64cc6 100644 --- a/examples/llama-bench/llama-bench.cpp +++ b/examples/llama-bench/llama-bench.cpp @@ -1446,7 +1446,7 @@ static void test_prompt(llama_context * ctx, int n_prompt, int n_past, int n_bat for (int i = 1; i < n_tokens; i++) { tokens[i] = std::rand() % n_vocab; } - llama_decode(ctx, llama_batch_get_one(tokens.data(), n_tokens, n_past + n_processed, 0)); + llama_decode(ctx, llama_batch_get_one(tokens.data(), n_tokens)); n_processed += n_tokens; } @@ -1462,7 +1462,7 @@ static void test_gen(llama_context * ctx, int n_gen, int n_past, int n_threads) llama_token token = llama_add_bos_token(model) ? llama_token_bos(model) : std::rand() % n_vocab; for (int i = 0; i < n_gen; i++) { - llama_decode(ctx, llama_batch_get_one(&token, 1, n_past + i, 0)); + llama_decode(ctx, llama_batch_get_one(&token, 1)); llama_synchronize(ctx); token = std::rand() % n_vocab; } diff --git a/examples/llama.android/llama/src/main/cpp/llama-android.cpp b/examples/llama.android/llama/src/main/cpp/llama-android.cpp index f611809c6deff..dd9f5a8db775e 100644 --- a/examples/llama.android/llama/src/main/cpp/llama-android.cpp +++ b/examples/llama.android/llama/src/main/cpp/llama-android.cpp @@ -283,9 +283,6 @@ Java_android_llama_cpp_LLamaAndroid_new_1batch(JNIEnv *, jobject, jint n_tokens, nullptr, nullptr, nullptr, - 0, - 0, - 0, }; if (embd) { diff --git a/examples/llava/llava-cli.cpp b/examples/llava/llava-cli.cpp index 8f437863f6d77..509927360a8a8 100644 --- a/examples/llava/llava-cli.cpp +++ b/examples/llava/llava-cli.cpp @@ -20,7 +20,7 @@ static bool eval_tokens(struct llama_context * ctx_llama, std::vector n_batch) { n_eval = n_batch; } - if (llama_decode(ctx_llama, llama_batch_get_one(&tokens[i], n_eval, *n_past, 0))) { + if (llama_decode(ctx_llama, llama_batch_get_one(&tokens[i], n_eval))) { LOG_ERR("%s : failed to eval. token %d/%d (batch size %d, n_past %d)\n", __func__, i, N, n_batch, *n_past); return false; } diff --git a/examples/llava/llava.cpp b/examples/llava/llava.cpp index 8558c6bdcae0f..aa94ff22c2986 100644 --- a/examples/llava/llava.cpp +++ b/examples/llava/llava.cpp @@ -401,6 +401,39 @@ bool llava_image_embed_make_with_clip_img(clip_ctx * ctx_clip, int n_threads, co return true; } +struct llava_embd_batch { + std::vector pos; + std::vector n_seq_id; + std::vector seq_id_0; + std::vector seq_ids; + std::vector logits; + llama_batch batch; + llava_embd_batch(float * embd, int32_t n_tokens, llama_pos pos_0, llama_seq_id seq_id) { + pos .resize(n_tokens); + n_seq_id.resize(n_tokens); + seq_ids .resize(n_tokens + 1); + logits .resize(n_tokens); + seq_id_0.resize(1); + seq_id_0[0] = seq_id; + seq_ids [n_tokens] = nullptr; + batch = { + /*n_tokens =*/ n_tokens, + /*tokens =*/ nullptr, + /*embd =*/ embd, + /*pos =*/ pos.data(), + /*n_seq_id =*/ n_seq_id.data(), + /*seq_id =*/ seq_ids.data(), + /*logits =*/ logits.data(), + }; + for (int i = 0; i < n_tokens; i++) { + batch.pos [i] = pos_0 + i; + batch.n_seq_id[i] = 1; + batch.seq_id [i] = seq_id_0.data(); + batch.logits [i] = false; + } + } +}; + bool llava_eval_image_embed(llama_context * ctx_llama, const struct llava_image_embed * image_embed, int n_batch, int * n_past) { int n_embd = llama_n_embd(llama_get_model(ctx_llama)); @@ -409,8 +442,9 @@ bool llava_eval_image_embed(llama_context * ctx_llama, const struct llava_image_ if (n_eval > n_batch) { n_eval = n_batch; } - llama_batch batch = {int32_t(n_eval), nullptr, (image_embed->embed+i*n_embd), nullptr, nullptr, nullptr, nullptr, *n_past, 1, 0, }; - if (llama_decode(ctx_llama, batch)) { + float * embd = image_embed->embed+i*n_embd; + llava_embd_batch llava_batch = llava_embd_batch(embd, n_eval, *n_past, 0); + if (llama_decode(ctx_llama, llava_batch.batch)) { LOG_ERR("%s : failed to eval\n", __func__); return false; } diff --git a/examples/llava/minicpmv-cli.cpp b/examples/llava/minicpmv-cli.cpp index c5156c35b029c..b31832bb0d966 100644 --- a/examples/llava/minicpmv-cli.cpp +++ b/examples/llava/minicpmv-cli.cpp @@ -97,7 +97,7 @@ static bool eval_tokens(struct llama_context * ctx_llama, std::vector n_batch) { n_eval = n_batch; } - if (llama_decode(ctx_llama, llama_batch_get_one(&tokens[i], n_eval, *n_past, 0))) { + if (llama_decode(ctx_llama, llama_batch_get_one(&tokens[i], n_eval))) { LOG_ERR("%s : failed to eval. token %d/%d (batch size %d, n_past %d)\n", __func__, i, N, n_batch, *n_past); return false; } diff --git a/examples/lookahead/lookahead.cpp b/examples/lookahead/lookahead.cpp index 49870b4a4e724..e10a27691ad62 100644 --- a/examples/lookahead/lookahead.cpp +++ b/examples/lookahead/lookahead.cpp @@ -89,8 +89,8 @@ int main(int argc, char ** argv) { const auto t_enc_start = ggml_time_us(); // eval the prompt - llama_decode(ctx, llama_batch_get_one( inp.data(), n_input - 1, 0, 0)); - llama_decode(ctx, llama_batch_get_one(&inp.back(), 1, n_input - 1, 0)); + llama_decode(ctx, llama_batch_get_one( inp.data(), n_input - 1)); + llama_decode(ctx, llama_batch_get_one(&inp.back(), 1)); for (int s = 1; s < W + G + 1; ++s) { llama_kv_cache_seq_cp(ctx, 0, s, -1, -1); diff --git a/examples/lookup/lookup.cpp b/examples/lookup/lookup.cpp index 2ccd0e6c18814..db980e9d02f60 100644 --- a/examples/lookup/lookup.cpp +++ b/examples/lookup/lookup.cpp @@ -89,8 +89,8 @@ int main(int argc, char ** argv){ const auto t_enc_start = ggml_time_us(); - llama_decode(ctx, llama_batch_get_one( inp.data(), n_input - 1, 0, 0)); - llama_decode(ctx, llama_batch_get_one(&inp.back(), 1, n_input - 1, 0)); + llama_decode(ctx, llama_batch_get_one( inp.data(), n_input - 1)); + llama_decode(ctx, llama_batch_get_one(&inp.back(), 1)); const auto t_enc_end = ggml_time_us(); diff --git a/examples/main/main.cpp b/examples/main/main.cpp index 6bbb1e13ed7ac..5801b4fe286a4 100644 --- a/examples/main/main.cpp +++ b/examples/main/main.cpp @@ -528,7 +528,7 @@ int main(int argc, char ** argv) { int enc_input_size = embd_inp.size(); llama_token * enc_input_buf = embd_inp.data(); - if (llama_encode(ctx, llama_batch_get_one(enc_input_buf, enc_input_size, 0, 0))) { + if (llama_encode(ctx, llama_batch_get_one(enc_input_buf, enc_input_size))) { LOG_ERR("%s : failed to eval\n", __func__); return 1; } @@ -648,7 +648,7 @@ int main(int argc, char ** argv) { LOG_DBG("eval: %s\n", string_from(ctx, embd).c_str()); - if (llama_decode(ctx, llama_batch_get_one(&embd[i], n_eval, n_past, 0))) { + if (llama_decode(ctx, llama_batch_get_one(&embd[i], n_eval))) { LOG_ERR("%s : failed to eval\n", __func__); return 1; } diff --git a/examples/parallel/parallel.cpp b/examples/parallel/parallel.cpp index 81e2f7ed7c825..6d3ad3d233e32 100644 --- a/examples/parallel/parallel.cpp +++ b/examples/parallel/parallel.cpp @@ -308,7 +308,6 @@ int main(int argc, char ** argv) { batch.n_seq_id + i, batch.seq_id + i, batch.logits + i, - 0, 0, 0, // unused }; const int ret = llama_decode(ctx, batch_view); diff --git a/examples/perplexity/perplexity.cpp b/examples/perplexity/perplexity.cpp index 87347135e0bb7..10c982b842349 100644 --- a/examples/perplexity/perplexity.cpp +++ b/examples/perplexity/perplexity.cpp @@ -409,13 +409,22 @@ static results_perplexity perplexity_v2(llama_context * ctx, const gpt_params & const int batch_start = start + j * n_batch; const int batch_size = std::min(end - batch_start, n_batch); + llama_batch batch = llama_batch_init(batch_size, 0, 1); + for (int i = 0; i < batch_size; i++) { + batch. token[i] = tokens[batch_start + i]; + batch. pos[i] = j*n_batch + i; + batch.logits[i] = true; + batch.seq_id[i][0] = 0; + } + //LOG_DBG(" Batch %d: starts at %d, size is %d, n_past is %d\n",j,batch_start,batch_size,j * n_batch); - // TODO: use llama_batch.logits instead of relying on logits_all == true - if (llama_decode(ctx, llama_batch_get_one(tokens.data() + batch_start, batch_size, j * n_batch, 0))) { + if (llama_decode(ctx, batch)) { //LOG_ERR("%s : failed to eval\n", __func__); return {tokens, -1, logit_history, prob_history}; } + llama_batch_free(batch); + // save original token and restore it after eval const auto token_org = tokens[batch_start]; @@ -699,7 +708,6 @@ static bool decode_helper(llama_context * ctx, llama_batch & batch, std::vector< batch.n_seq_id + i, batch.seq_id + i, batch.logits + i, - 0, 0, 0, // unused }; const int ret = llama_decode(ctx, batch_view); @@ -1790,12 +1798,21 @@ static void kl_divergence(llama_context * ctx, const gpt_params & params) { tokens[batch_start] = llama_token_bos(llama_get_model(ctx)); } - // TODO: use llama_batch.logits instead of relying on logits_all == true - if (llama_decode(ctx, llama_batch_get_one(tokens.data() + batch_start, batch_size, j * n_batch, 0))) { + llama_batch batch = llama_batch_init(batch_size, 0, 1); + for (int i = 0; i < batch_size; i++) { + batch. token[i] = tokens[batch_start + i]; + batch. pos[i] = j*n_batch + i; + batch.logits[i] = true; + batch.seq_id[i][0] = 0; + } + + if (llama_decode(ctx, batch)) { LOG_ERR("%s : failed to eval\n", __func__); return; } + llama_batch_free(batch); + // restore the original token in case it was set to BOS tokens[batch_start] = token_org; diff --git a/examples/save-load-state/save-load-state.cpp b/examples/save-load-state/save-load-state.cpp index 0117d9357959f..72f94beb8ab8a 100644 --- a/examples/save-load-state/save-load-state.cpp +++ b/examples/save-load-state/save-load-state.cpp @@ -49,7 +49,7 @@ int main(int argc, char ** argv) { auto tokens = llama_tokenize(ctx, params.prompt, true); // evaluate prompt - llama_decode(ctx, llama_batch_get_one(tokens.data(), tokens.size(), n_past, 0)); + llama_decode(ctx, llama_batch_get_one(tokens.data(), tokens.size())); n_past += tokens.size(); // save state (rng, logits, embedding and kv_cache) to file @@ -77,7 +77,7 @@ int main(int argc, char ** argv) { printf("%s", next_token_str.c_str()); result0 += next_token_str; - if (llama_decode(ctx, llama_batch_get_one(&next_token, 1, n_past, 0))) { + if (llama_decode(ctx, llama_batch_get_one(&next_token, 1))) { fprintf(stderr, "\n%s : failed to evaluate\n", __func__); llama_free(ctx); llama_free_model(model); @@ -133,7 +133,7 @@ int main(int argc, char ** argv) { printf("%s", next_token_str.c_str()); result1 += next_token_str; - if (llama_decode(ctx2, llama_batch_get_one(&next_token, 1, n_past, 0))) { + if (llama_decode(ctx2, llama_batch_get_one(&next_token, 1))) { fprintf(stderr, "\n%s : failed to evaluate\n", __func__); llama_free(ctx2); llama_free_model(model); @@ -221,7 +221,7 @@ int main(int argc, char ** argv) { printf("%s", next_token_str.c_str()); result2 += next_token_str; - if (llama_decode(ctx3, llama_batch_get_one(&next_token, 1, n_past, 1))) { + if (llama_decode(ctx3, llama_batch_get_one(&next_token, 1))) { fprintf(stderr, "\n%s : failed to evaluate\n", __func__); llama_free(ctx3); llama_free_model(model); diff --git a/examples/server/server.cpp b/examples/server/server.cpp index f343cc252f89a..163b2ac89ac37 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -2283,7 +2283,6 @@ struct server_context { batch.n_seq_id + i, batch.seq_id + i, batch.logits + i, - 0, 0, 0, // unused }; const int ret = llama_decode(ctx, batch_view); diff --git a/examples/speculative/speculative.cpp b/examples/speculative/speculative.cpp index adf6255e1449f..bc67aae3982d3 100644 --- a/examples/speculative/speculative.cpp +++ b/examples/speculative/speculative.cpp @@ -155,9 +155,9 @@ int main(int argc, char ** argv) { const auto t_enc_start = ggml_time_us(); // eval the prompt with both models - llama_decode(ctx_tgt, llama_batch_get_one( inp.data(), n_input - 1, 0, 0)); - llama_decode(ctx_tgt, llama_batch_get_one(&inp.back(), 1, n_input - 1, 0)); - llama_decode(ctx_dft, llama_batch_get_one( inp.data(), n_input, 0, 0)); + llama_decode(ctx_tgt, llama_batch_get_one( inp.data(), n_input - 1)); + llama_decode(ctx_tgt, llama_batch_get_one(&inp.back(), 1)); + llama_decode(ctx_dft, llama_batch_get_one( inp.data(), n_input)); const auto t_enc_end = ggml_time_us(); diff --git a/src/llama.cpp b/src/llama.cpp index 825999b3c6ab6..4c83b9af723b1 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -21144,6 +21144,7 @@ struct llama_batch_allocr { logits[logits.size() - 1] = true; batch.logits = logits.data(); } + return batch; } }; From 92769503dc565884f08d3fb297880e1368888e68 Mon Sep 17 00:00:00 2001 From: Xuan Son Nguyen Date: Fri, 11 Oct 2024 12:16:03 +0200 Subject: [PATCH 03/12] fix simple.cpp --- examples/simple/simple.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/simple/simple.cpp b/examples/simple/simple.cpp index be91b2891db78..59760fe95db22 100644 --- a/examples/simple/simple.cpp +++ b/examples/simple/simple.cpp @@ -138,7 +138,7 @@ int main(int argc, char ** argv) { // prepare a batch for the prompt - llama_batch batch = llama_batch_get_one(prompt_tokens.data(), prompt_tokens.size(), 0, 0); + llama_batch batch = llama_batch_get_one(prompt_tokens.data(), prompt_tokens.size()); // main loop @@ -175,7 +175,7 @@ int main(int argc, char ** argv) { fflush(stdout); // prepare the next batch with the sampled token - batch = llama_batch_get_one(&new_token_id, 1, n_pos, 0); + batch = llama_batch_get_one(&new_token_id, 1); n_decode += 1; } From 59fd6b61199fa814445925ca9f069038ac7a4b8e Mon Sep 17 00:00:00 2001 From: Xuan Son Nguyen Date: Fri, 11 Oct 2024 12:21:20 +0200 Subject: [PATCH 04/12] fix llama_bench --- examples/llama-bench/llama-bench.cpp | 12 ++++++------ src/llama.cpp | 2 +- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/examples/llama-bench/llama-bench.cpp b/examples/llama-bench/llama-bench.cpp index 5b8bb881006a8..6493cf6ed75d3 100644 --- a/examples/llama-bench/llama-bench.cpp +++ b/examples/llama-bench/llama-bench.cpp @@ -1428,7 +1428,7 @@ struct sql_printer : public printer { } }; -static void test_prompt(llama_context * ctx, int n_prompt, int n_past, int n_batch, int n_threads) { +static void test_prompt(llama_context * ctx, int n_prompt, int n_batch, int n_threads) { llama_set_n_threads(ctx, n_threads, n_threads); const llama_model * model = llama_get_model(ctx); @@ -1451,7 +1451,7 @@ static void test_prompt(llama_context * ctx, int n_prompt, int n_past, int n_bat llama_synchronize(ctx); } -static void test_gen(llama_context * ctx, int n_gen, int n_past, int n_threads) { +static void test_gen(llama_context * ctx, int n_gen, int n_threads) { llama_set_n_threads(ctx, n_threads, n_threads); const llama_model * model = llama_get_model(ctx); @@ -1596,13 +1596,13 @@ int main(int argc, char ** argv) { fprintf(stderr, "llama-bench: benchmark %d/%ld: warmup prompt run\n", params_idx, params_count); } //test_prompt(ctx, std::min(t.n_batch, std::min(t.n_prompt, 32)), 0, t.n_batch, t.n_threads); - test_prompt(ctx, t.n_prompt, 0, t.n_batch, t.n_threads); + test_prompt(ctx, t.n_prompt, t.n_batch, t.n_threads); } if (t.n_gen > 0) { if (params.progress) { fprintf(stderr, "llama-bench: benchmark %d/%ld: warmup generation run\n", params_idx, params_count); } - test_gen(ctx, 1, 0, t.n_threads); + test_gen(ctx, 1, t.n_threads); } for (int i = 0; i < params.reps; i++) { @@ -1614,13 +1614,13 @@ int main(int argc, char ** argv) { if (params.progress) { fprintf(stderr, "llama-bench: benchmark %d/%ld: prompt run %d/%d\n", params_idx, params_count, i + 1, params.reps); } - test_prompt(ctx, t.n_prompt, 0, t.n_batch, t.n_threads); + test_prompt(ctx, t.n_prompt, t.n_batch, t.n_threads); } if (t.n_gen > 0) { if (params.progress) { fprintf(stderr, "llama-bench: benchmark %d/%ld: generation run %d/%d\n", params_idx, params_count, i + 1, params.reps); } - test_gen(ctx, t.n_gen, t.n_prompt, t.n_threads); + test_gen(ctx, t.n_gen, t.n_threads); } uint64_t t_ns = get_time_ns() - t_start; diff --git a/src/llama.cpp b/src/llama.cpp index 24f66c8807b7d..c3646eb4d645b 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -21093,7 +21093,7 @@ struct llama_batch_allocr { struct llama_batch batch = in_batch; if (!batch.pos) { // determine the last position in KV cache - llama_pos last_pos; + llama_pos last_pos = 0; for (const auto & cell : ctx->kv_self.cells) { if (cell.seq_id.find(default_seq_id) != cell.seq_id.end()) { last_pos = std::max(last_pos, cell.pos); From 7740c969d0470b91d57168651e976deb7cda4d9d Mon Sep 17 00:00:00 2001 From: Xuan Son Nguyen Date: Fri, 11 Oct 2024 13:59:26 +0200 Subject: [PATCH 05/12] fix --- src/llama.cpp | 35 ++++++++++++++++++----------------- 1 file changed, 18 insertions(+), 17 deletions(-) diff --git a/src/llama.cpp b/src/llama.cpp index c3646eb4d645b..c25ae1e1e4041 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -21081,58 +21081,59 @@ void llama_batch_free(struct llama_batch batch) { } // temporary allocate memory for the input batch if needed +static const llama_seq_id batch_default_seq_id = 0; struct llama_batch_allocr { - static const llama_seq_id default_seq_id = 0; - std::array seq_id_0 = {default_seq_id}; + std::array seq_id_0 = {batch_default_seq_id}; std::vector pos; std::vector n_seq_id; std::vector seq_id; std::vector logits; - // fulfill the batch returned by llama_batch_get_one - struct llama_batch get_fulfilled_batch(struct llama_context * ctx, struct llama_batch in_batch) { - struct llama_batch batch = in_batch; + struct llama_batch batch; + // optionally fulfill the batch returned by llama_batch_get_one + llama_batch_allocr(struct llama_context * ctx, struct llama_batch in_batch) { + batch = in_batch; if (!batch.pos) { // determine the last position in KV cache llama_pos last_pos = 0; for (const auto & cell : ctx->kv_self.cells) { - if (cell.seq_id.find(default_seq_id) != cell.seq_id.end()) { + if (cell.has_seq_id(batch_default_seq_id)) { last_pos = std::max(last_pos, cell.pos); } } + last_pos++; // next position pos.resize(batch.n_tokens); - for (int32_t i = 1; i <= batch.n_tokens; i++) { + for (int32_t i = 0; i < batch.n_tokens; i++) { pos[i] = i+last_pos; } batch.pos = pos.data(); } if (!batch.n_seq_id) { - n_seq_id.reserve(batch.n_tokens); - for (int32_t i = 1; i <= batch.n_tokens; i++) { + n_seq_id.resize(batch.n_tokens); + for (int32_t i = 0; i < batch.n_tokens; i++) { n_seq_id[i] = seq_id_0.size(); } batch.n_seq_id = n_seq_id.data(); } if (!batch.seq_id) { - seq_id.reserve(batch.n_tokens); - for (int32_t i = 1; i <= batch.n_tokens; i++) { + seq_id.resize(batch.n_tokens); + for (int32_t i = 0; i < batch.n_tokens; i++) { seq_id[i] = seq_id_0.data(); } batch.seq_id = seq_id.data(); } if (!batch.logits) { - logits.reserve(batch.n_tokens); + logits.resize(batch.n_tokens); logits[logits.size() - 1] = true; batch.logits = logits.data(); } - return batch; } }; int32_t llama_encode( struct llama_context * ctx, struct llama_batch batch) { - llama_batch_allocr batch_allocr; - const int ret = llama_encode_internal(*ctx, batch_allocr.get_fulfilled_batch(ctx, batch)); + llama_batch_allocr batch_allocr(ctx, batch); + const int ret = llama_encode_internal(*ctx, batch_allocr.batch); if (ret < 0) { LLAMA_LOG_ERROR("%s: failed to encode, ret = %d\n", __func__, ret); } @@ -21143,8 +21144,8 @@ int32_t llama_encode( int32_t llama_decode( struct llama_context * ctx, struct llama_batch batch) { - llama_batch_allocr batch_allocr; - const int ret = llama_decode_internal(*ctx, batch_allocr.get_fulfilled_batch(ctx, batch)); + llama_batch_allocr batch_allocr(ctx, batch); + const int ret = llama_decode_internal(*ctx, batch_allocr.batch); if (ret < 0) { LLAMA_LOG_ERROR("%s: failed to decode, ret = %d\n", __func__, ret); } From 6a9769a260d64adaf7478b9ca41681190491ada7 Mon Sep 17 00:00:00 2001 From: Xuan Son Nguyen Date: Fri, 11 Oct 2024 14:36:48 +0200 Subject: [PATCH 06/12] fix context shifting --- examples/infill/infill.cpp | 2 +- examples/main/main.cpp | 2 +- src/llama.cpp | 4 ++-- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/examples/infill/infill.cpp b/examples/infill/infill.cpp index 99b13dd2698de..d8a02fee47b41 100644 --- a/examples/infill/infill.cpp +++ b/examples/infill/infill.cpp @@ -376,7 +376,7 @@ int main(int argc, char ** argv) { n_past, n_left, n_ctx, params.n_keep, n_discard); llama_kv_cache_seq_rm (ctx, 0, params.n_keep + 1 , params.n_keep + n_discard + 1); - llama_kv_cache_seq_add(ctx, 0, params.n_keep + 1 + n_discard, n_past, -n_discard); + llama_kv_cache_seq_add(ctx, 0, params.n_keep + 1 + n_discard, n_past + 1, -n_discard); n_past -= n_discard; diff --git a/examples/main/main.cpp b/examples/main/main.cpp index 9aba3487479eb..bc7e839a0b12f 100644 --- a/examples/main/main.cpp +++ b/examples/main/main.cpp @@ -582,7 +582,7 @@ int main(int argc, char ** argv) { n_past, n_left, n_ctx, params.n_keep, n_discard); llama_kv_cache_seq_rm (ctx, 0, params.n_keep , params.n_keep + n_discard); - llama_kv_cache_seq_add(ctx, 0, params.n_keep + n_discard, n_past, -n_discard); + llama_kv_cache_seq_add(ctx, 0, params.n_keep + n_discard, n_past + 1 , -n_discard); n_past -= n_discard; diff --git a/src/llama.cpp b/src/llama.cpp index c25ae1e1e4041..5d41e9b19ee90 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -21134,7 +21134,7 @@ int32_t llama_encode( struct llama_batch batch) { llama_batch_allocr batch_allocr(ctx, batch); const int ret = llama_encode_internal(*ctx, batch_allocr.batch); - if (ret < 0) { + if (ret != 0) { LLAMA_LOG_ERROR("%s: failed to encode, ret = %d\n", __func__, ret); } @@ -21146,7 +21146,7 @@ int32_t llama_decode( struct llama_batch batch) { llama_batch_allocr batch_allocr(ctx, batch); const int ret = llama_decode_internal(*ctx, batch_allocr.batch); - if (ret < 0) { + if (ret != 0) { LLAMA_LOG_ERROR("%s: failed to decode, ret = %d\n", __func__, ret); } From 0639ff16d092f3f63c2c393a17b261190dbf1574 Mon Sep 17 00:00:00 2001 From: Xuan Son Nguyen Date: Sat, 12 Oct 2024 22:47:27 +0200 Subject: [PATCH 07/12] free batch before return --- examples/imatrix/imatrix.cpp | 1 + examples/perplexity/perplexity.cpp | 2 ++ 2 files changed, 3 insertions(+) diff --git a/examples/imatrix/imatrix.cpp b/examples/imatrix/imatrix.cpp index 8a0b425d83893..1e97d2980b303 100644 --- a/examples/imatrix/imatrix.cpp +++ b/examples/imatrix/imatrix.cpp @@ -518,6 +518,7 @@ static bool compute_imatrix(llama_context * ctx, const common_params & params) { if (llama_decode(ctx, batch)) { LOG_ERR("%s : failed to eval\n", __func__); + llama_batch_free(batch); return false; } diff --git a/examples/perplexity/perplexity.cpp b/examples/perplexity/perplexity.cpp index 57d6fd1b29b6e..181a3c86def15 100644 --- a/examples/perplexity/perplexity.cpp +++ b/examples/perplexity/perplexity.cpp @@ -423,6 +423,7 @@ static results_perplexity perplexity_v2(llama_context * ctx, const common_params //LOG_DBG(" Batch %d: starts at %d, size is %d, n_past is %d\n",j,batch_start,batch_size,j * n_batch); if (llama_decode(ctx, batch)) { //LOG_ERR("%s : failed to eval\n", __func__); + llama_batch_free(batch); return {tokens, -1, logit_history, prob_history}; } @@ -1821,6 +1822,7 @@ static void kl_divergence(llama_context * ctx, const common_params & params) { if (llama_decode(ctx, batch)) { LOG_ERR("%s : failed to eval\n", __func__); + llama_batch_free(batch); return; } From 734f9e29de8421afac198b86ad454937d94e672c Mon Sep 17 00:00:00 2001 From: Xuan Son Nguyen Date: Sat, 12 Oct 2024 22:51:30 +0200 Subject: [PATCH 08/12] use common_batch_add, reuse llama_batch in loop --- examples/imatrix/imatrix.cpp | 13 ++++++------- examples/perplexity/perplexity.cpp | 13 ++++++------- 2 files changed, 12 insertions(+), 14 deletions(-) diff --git a/examples/imatrix/imatrix.cpp b/examples/imatrix/imatrix.cpp index 1e97d2980b303..70ff47768c02b 100644 --- a/examples/imatrix/imatrix.cpp +++ b/examples/imatrix/imatrix.cpp @@ -496,6 +496,8 @@ static bool compute_imatrix(llama_context * ctx, const common_params & params) { // clear the KV cache llama_kv_cache_clear(ctx); + llama_batch batch = llama_batch_init(n_batch, 0, 1); + for (int j = 0; j < num_batches; ++j) { const int batch_start = start + j * n_batch; const int batch_size = std::min(end - batch_start, n_batch); @@ -508,12 +510,9 @@ static bool compute_imatrix(llama_context * ctx, const common_params & params) { tokens[batch_start] = llama_token_bos(llama_get_model(ctx)); } - llama_batch batch = llama_batch_init(batch_size, 0, 1); + common_batch_clear(batch); for (int i = 0; i < batch_size; i++) { - batch. token[i] = tokens[batch_start + i]; - batch. pos[i] = j*n_batch + i; - batch.logits[i] = true; - batch.seq_id[i][0] = 0; + common_batch_add(batch, tokens[batch_start + i], j*n_batch + i, {0}, true); } if (llama_decode(ctx, batch)) { @@ -522,8 +521,6 @@ static bool compute_imatrix(llama_context * ctx, const common_params & params) { return false; } - llama_batch_free(batch); - // restore the original token in case it was set to BOS tokens[batch_start] = token_org; @@ -533,6 +530,8 @@ static bool compute_imatrix(llama_context * ctx, const common_params & params) { } } + llama_batch_free(batch); + const auto t_end = std::chrono::high_resolution_clock::now(); if (i == 0) { diff --git a/examples/perplexity/perplexity.cpp b/examples/perplexity/perplexity.cpp index 181a3c86def15..252ef56ba6246 100644 --- a/examples/perplexity/perplexity.cpp +++ b/examples/perplexity/perplexity.cpp @@ -1800,6 +1800,8 @@ static void kl_divergence(llama_context * ctx, const common_params & params) { // clear the KV cache llama_kv_cache_clear(ctx); + llama_batch batch = llama_batch_init(n_batch, 0, 1); + for (int j = 0; j < num_batches; ++j) { const int batch_start = start + j * n_batch; const int batch_size = std::min(end - batch_start, n_batch); @@ -1812,12 +1814,9 @@ static void kl_divergence(llama_context * ctx, const common_params & params) { tokens[batch_start] = llama_token_bos(llama_get_model(ctx)); } - llama_batch batch = llama_batch_init(batch_size, 0, 1); + common_batch_clear(batch); for (int i = 0; i < batch_size; i++) { - batch. token[i] = tokens[batch_start + i]; - batch. pos[i] = j*n_batch + i; - batch.logits[i] = true; - batch.seq_id[i][0] = 0; + common_batch_add(batch, tokens[batch_start + i], j*n_batch + i, {0}, true); } if (llama_decode(ctx, batch)) { @@ -1826,8 +1825,6 @@ static void kl_divergence(llama_context * ctx, const common_params & params) { return; } - llama_batch_free(batch); - // restore the original token in case it was set to BOS tokens[batch_start] = token_org; @@ -1837,6 +1834,8 @@ static void kl_divergence(llama_context * ctx, const common_params & params) { } } + llama_batch_free(batch); + const auto t_end = std::chrono::high_resolution_clock::now(); if (i == 0) { From 7264596a5c0e86173f4c1f8aeae78fa33a216144 Mon Sep 17 00:00:00 2001 From: Xuan Son Nguyen Date: Sat, 12 Oct 2024 22:53:05 +0200 Subject: [PATCH 09/12] null terminated seq_id list --- src/llama.cpp | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/llama.cpp b/src/llama.cpp index df19e1207fe2f..b24d4af80ac3b 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -21199,7 +21199,8 @@ struct llama_batch_allocr { batch.n_seq_id = n_seq_id.data(); } if (!batch.seq_id) { - seq_id.resize(batch.n_tokens); + seq_id.resize(batch.n_tokens + 1); + seq_id[batch.n_tokens] = NULL; for (int32_t i = 0; i < batch.n_tokens; i++) { seq_id[i] = seq_id_0.data(); } From 6395174a5410ba6634b8256e6cbc4c4625e46907 Mon Sep 17 00:00:00 2001 From: Xuan Son Nguyen Date: Sat, 12 Oct 2024 23:09:05 +0200 Subject: [PATCH 10/12] fix save-load-state example --- examples/save-load-state/save-load-state.cpp | 30 ++++++++++++++++---- 1 file changed, 25 insertions(+), 5 deletions(-) diff --git a/examples/save-load-state/save-load-state.cpp b/examples/save-load-state/save-load-state.cpp index f72832d9cc672..5f60a86cbc2d2 100644 --- a/examples/save-load-state/save-load-state.cpp +++ b/examples/save-load-state/save-load-state.cpp @@ -48,9 +48,16 @@ int main(int argc, char ** argv) { // tokenize prompt auto tokens = common_tokenize(ctx, params.prompt, true); + // prepare the batch + llama_batch batch = llama_batch_init(tokens.size(), 0, 1); + for (size_t i = 0; i < tokens.size(); i++) { + common_batch_add(batch, tokens[i], i, {0}, false); + } + batch.logits[batch.n_tokens - 1] = true; // generate next token + // evaluate prompt - llama_decode(ctx, llama_batch_get_one(tokens.data(), tokens.size())); - n_past += tokens.size(); + llama_decode(ctx, batch); + n_past += batch.n_tokens; // save state (rng, logits, embedding and kv_cache) to file { @@ -77,8 +84,12 @@ int main(int argc, char ** argv) { printf("%s", next_token_str.c_str()); result0 += next_token_str; - if (llama_decode(ctx, llama_batch_get_one(&next_token, 1))) { + common_batch_clear(batch); + common_batch_add(batch, next_token, n_past, {0}, true); + + if (llama_decode(ctx, batch)) { fprintf(stderr, "\n%s : failed to evaluate\n", __func__); + llama_batch_free(batch); llama_free(ctx); llama_free_model(model); return 1; @@ -133,8 +144,12 @@ int main(int argc, char ** argv) { printf("%s", next_token_str.c_str()); result1 += next_token_str; - if (llama_decode(ctx2, llama_batch_get_one(&next_token, 1))) { + common_batch_clear(batch); + common_batch_add(batch, next_token, n_past, {0}, true); + + if (llama_decode(ctx2, batch)) { fprintf(stderr, "\n%s : failed to evaluate\n", __func__); + llama_batch_free(batch); llama_free(ctx2); llama_free_model(model); return 1; @@ -221,8 +236,12 @@ int main(int argc, char ** argv) { printf("%s", next_token_str.c_str()); result2 += next_token_str; - if (llama_decode(ctx3, llama_batch_get_one(&next_token, 1))) { + common_batch_clear(batch); + common_batch_add(batch, next_token, n_past, {1}, true); + + if (llama_decode(ctx3, batch)) { fprintf(stderr, "\n%s : failed to evaluate\n", __func__); + llama_batch_free(batch); llama_free(ctx3); llama_free_model(model); return 1; @@ -236,6 +255,7 @@ int main(int argc, char ** argv) { llama_sampler_free(smpl2); llama_sampler_free(smpl3); + llama_batch_free(batch); llama_free(ctx3); llama_free_model(model); From 4be7ecf25e5a13c4a4f614f4b7d3bf095805c9ba Mon Sep 17 00:00:00 2001 From: Xuan Son Nguyen Date: Sat, 12 Oct 2024 23:19:52 +0200 Subject: [PATCH 11/12] fix perplexity --- examples/perplexity/perplexity.cpp | 13 ++++++------- 1 file changed, 6 insertions(+), 7 deletions(-) diff --git a/examples/perplexity/perplexity.cpp b/examples/perplexity/perplexity.cpp index 252ef56ba6246..e803ff143f7d1 100644 --- a/examples/perplexity/perplexity.cpp +++ b/examples/perplexity/perplexity.cpp @@ -408,16 +408,15 @@ static results_perplexity perplexity_v2(llama_context * ctx, const common_params // clear the KV cache llama_kv_cache_clear(ctx); + llama_batch batch = llama_batch_init(n_batch, 0, 1); + for (int j = 0; j < num_batches; ++j) { const int batch_start = start + j * n_batch; const int batch_size = std::min(end - batch_start, n_batch); - llama_batch batch = llama_batch_init(batch_size, 0, 1); + common_batch_clear(batch); for (int i = 0; i < batch_size; i++) { - batch. token[i] = tokens[batch_start + i]; - batch. pos[i] = j*n_batch + i; - batch.logits[i] = true; - batch.seq_id[i][0] = 0; + common_batch_add(batch, tokens[batch_start + i], j*n_batch + i, {0}, true); } //LOG_DBG(" Batch %d: starts at %d, size is %d, n_past is %d\n",j,batch_start,batch_size,j * n_batch); @@ -427,8 +426,6 @@ static results_perplexity perplexity_v2(llama_context * ctx, const common_params return {tokens, -1, logit_history, prob_history}; } - llama_batch_free(batch); - // save original token and restore it after eval const auto token_org = tokens[batch_start]; @@ -445,6 +442,8 @@ static results_perplexity perplexity_v2(llama_context * ctx, const common_params } } + llama_batch_free(batch); + const auto t_end = std::chrono::high_resolution_clock::now(); if (i == 0) { From 5d99ae447b851d6358ee1b5d4a8750a061e69127 Mon Sep 17 00:00:00 2001 From: Xuan Son Nguyen Date: Fri, 18 Oct 2024 15:56:28 +0200 Subject: [PATCH 12/12] correct token pos in llama_batch_allocr --- examples/infill/infill.cpp | 2 +- src/llama.cpp | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/infill/infill.cpp b/examples/infill/infill.cpp index 54fe9fc98101d..f18362c91c7bf 100644 --- a/examples/infill/infill.cpp +++ b/examples/infill/infill.cpp @@ -376,7 +376,7 @@ int main(int argc, char ** argv) { n_past, n_left, n_ctx, params.n_keep, n_discard); llama_kv_cache_seq_rm (ctx, 0, params.n_keep + 1 , params.n_keep + n_discard + 1); - llama_kv_cache_seq_add(ctx, 0, params.n_keep + 1 + n_discard, n_past + 1, -n_discard); + llama_kv_cache_seq_add(ctx, 0, params.n_keep + 1 + n_discard, n_past, -n_discard); n_past -= n_discard; diff --git a/src/llama.cpp b/src/llama.cpp index 1fa7a1a1717c2..1813dd29be2b2 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -21141,7 +21141,7 @@ struct llama_batch_allocr { batch = in_batch; if (!batch.pos) { // determine the last position in KV cache - llama_pos last_pos = 0; + llama_pos last_pos = -1; for (const auto & cell : ctx->kv_self.cells) { if (cell.has_seq_id(batch_default_seq_id)) { last_pos = std::max(last_pos, cell.pos);