Skip to content

Commit 30f1db9

Browse files
committed
remove C API llama_batch_ext_init_from_text
1 parent bd51d63 commit 30f1db9

File tree

11 files changed

+73
-90
lines changed

11 files changed

+73
-90
lines changed

common/common.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1016,7 +1016,7 @@ struct common_init_result common_init_from_params(common_params & params) {
10161016
}
10171017

10181018
if (llama_model_has_encoder(model)) {
1019-
llama_batch_ext_ptr batch(llama_batch_ext_init_from_text(tmp.data(), tmp.size(), 0, 0, true));
1019+
auto batch = llama_batch_ext_ptr::init_from_text(tmp.data(), tmp.size(), 0, 0, true);
10201020
llama_encode_ext(lctx, batch.get());
10211021
llama_token decoder_start_token_id = llama_model_decoder_start_token(model);
10221022
if (decoder_start_token_id == LLAMA_TOKEN_NULL) {
@@ -1026,7 +1026,7 @@ struct common_init_result common_init_from_params(common_params & params) {
10261026
tmp.push_back(decoder_start_token_id);
10271027
}
10281028
if (llama_model_has_decoder(model)) {
1029-
llama_batch_ext_ptr batch(llama_batch_ext_init_from_text(tmp.data(), std::min(tmp.size(), (size_t) params.n_batch), 0, 0, true));
1029+
auto batch = llama_batch_ext_ptr::init_from_text(tmp.data(), std::min(tmp.size(), (size_t) params.n_batch), 0, 0, true);
10301030
llama_decode_ext(lctx, batch.get());
10311031
}
10321032
llama_kv_self_clear(lctx);

examples/lookahead/lookahead.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -92,8 +92,8 @@ int main(int argc, char ** argv) {
9292
const auto t_enc_start = ggml_time_us();
9393

9494
// eval the prompt
95-
llama_batch_ext_ptr batch0(llama_batch_ext_init_from_text( inp.data(), n_input - 1, 0, 0, true));
96-
llama_batch_ext_ptr batch1(llama_batch_ext_init_from_text(&inp.back(), 1, n_input - 1, 0, true));
95+
auto batch0 = llama_batch_ext_ptr::init_from_text( inp.data(), n_input - 1, 0, 0, true);
96+
auto batch1 = llama_batch_ext_ptr::init_from_text(&inp.back(), 1, n_input - 1, 0, true);
9797
llama_decode_ext(ctx, batch0.get());
9898
llama_decode_ext(ctx, batch1.get());
9999

examples/lookup/lookup.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -91,8 +91,8 @@ int main(int argc, char ** argv){
9191

9292
const auto t_enc_start = ggml_time_us();
9393

94-
llama_batch_ext_ptr batch0(llama_batch_ext_init_from_text( inp.data(), n_input - 1, 0, 0, true));
95-
llama_batch_ext_ptr batch1(llama_batch_ext_init_from_text(&inp.back(), 1, n_input - 1, 0, true));
94+
auto batch0 = llama_batch_ext_ptr::init_from_text( inp.data(), n_input - 1, 0, 0, true);
95+
auto batch1 = llama_batch_ext_ptr::init_from_text(&inp.back(), 1, n_input - 1, 0, true);
9696
llama_decode_ext(ctx, batch0.get());
9797
llama_decode_ext(ctx, batch1.get());
9898

examples/run/run.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1017,7 +1017,7 @@ static int generate(LlamaData & llama_data, const std::string & prompt, std::str
10171017
print_word_and_concatenate_to_response(piece, response);
10181018

10191019
// prepare the next batch with the sampled token
1020-
batch.reset(llama_batch_ext_init_from_text(&new_token_id, 1, llama_data.n_past, 0, true));
1020+
batch = llama_batch_ext_ptr::init_from_text(&new_token_id, 1, llama_data.n_past, 0, true);
10211021
}
10221022

10231023
printf(LOG_COL_DEFAULT);

examples/save-load-state/save-load-state.cpp

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -48,11 +48,11 @@ int main(int argc, char ** argv) {
4848
auto tokens = common_tokenize(ctx, params.prompt, true);
4949

5050
// prepare the batch
51-
llama_batch_ext * batch = llama_batch_ext_init_from_text(tokens.data(), tokens.size(), 0, 0, true);
51+
auto batch = llama_batch_ext_ptr::init_from_text(tokens.data(), tokens.size(), 0, 0, true);
5252

5353
// evaluate prompt
54-
llama_decode_ext(ctx, batch);
55-
n_past += llama_batch_ext_get_n_tokens(batch);
54+
llama_decode_ext(ctx, batch.get());
55+
n_past += llama_batch_ext_get_n_tokens(batch.get());
5656

5757
// save state (rng, logits, embedding and kv_cache) to file
5858
{
@@ -79,13 +79,13 @@ int main(int argc, char ** argv) {
7979
printf("%s", next_token_str.c_str());
8080
result0 += next_token_str;
8181

82-
llama_batch_ext_clear(batch);
82+
llama_batch_ext_clear(batch.get());
8383
llama_seq_id seq_id = 0;
84-
llama_batch_ext_add_text(batch, next_token, 0, &seq_id, 1, true);
84+
llama_batch_ext_add_text(batch.get(), next_token, 0, &seq_id, 1, true);
8585

86-
if (llama_decode_ext(ctx, batch)) {
86+
if (llama_decode_ext(ctx, batch.get())) {
8787
fprintf(stderr, "\n%s : failed to evaluate\n", __func__);
88-
llama_batch_ext_free(batch);
88+
llama_batch_ext_free(batch.get());
8989
return 1;
9090
}
9191
n_past += 1;
@@ -132,13 +132,13 @@ int main(int argc, char ** argv) {
132132
printf("%s", next_token_str.c_str());
133133
result1 += next_token_str;
134134

135-
llama_batch_ext_clear(batch);
135+
llama_batch_ext_clear(batch.get());
136136
llama_seq_id seq_id = 0;
137-
llama_batch_ext_add_text(batch, next_token, 0, &seq_id, 1, true);
137+
llama_batch_ext_add_text(batch.get(), next_token, 0, &seq_id, 1, true);
138138

139-
if (llama_decode_ext(ctx2, batch)) {
139+
if (llama_decode_ext(ctx2, batch.get())) {
140140
fprintf(stderr, "\n%s : failed to evaluate\n", __func__);
141-
llama_batch_ext_free(batch);
141+
llama_batch_ext_free(batch.get());
142142
return 1;
143143
}
144144
n_past += 1;
@@ -214,13 +214,13 @@ int main(int argc, char ** argv) {
214214
printf("%s", next_token_str.c_str());
215215
result2 += next_token_str;
216216

217-
llama_batch_ext_clear(batch);
217+
llama_batch_ext_clear(batch.get());
218218
llama_seq_id seq_id = 1; // seq 1 instead of 0
219-
llama_batch_ext_add_text(batch, next_token, 0, &seq_id, 1, true);
219+
llama_batch_ext_add_text(batch.get(), next_token, 0, &seq_id, 1, true);
220220

221-
if (llama_decode_ext(ctx3, batch)) {
221+
if (llama_decode_ext(ctx3, batch.get())) {
222222
fprintf(stderr, "\n%s : failed to evaluate\n", __func__);
223-
llama_batch_ext_free(batch);
223+
llama_batch_ext_free(batch.get());
224224
return 1;
225225
}
226226
n_past += 1;
@@ -232,7 +232,7 @@ int main(int argc, char ** argv) {
232232
llama_sampler_free(smpl2);
233233
llama_sampler_free(smpl3);
234234

235-
llama_batch_ext_free(batch);
235+
llama_batch_ext_free(batch.get());
236236

237237
if (result0 != result2) {
238238
fprintf(stderr, "\n%s : error : the seq restore generation is different\n", __func__);

examples/simple-chat/simple-chat.cpp

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
#include "llama.h"
2+
#include "llama-cpp.h"
23
#include <cstdio>
34
#include <cstring>
45
#include <iostream>
@@ -109,21 +110,21 @@ int main(int argc, char ** argv) {
109110

110111
// prepare a batch for the prompt
111112
llama_pos n_past = 0;
112-
llama_batch_ext * batch = llama_batch_ext_init_from_text(prompt_tokens.data(), prompt_tokens.size(), n_past, 0, true);
113-
n_past += llama_batch_ext_get_n_tokens(batch);
113+
auto batch = llama_batch_ext_ptr::init_from_text(prompt_tokens.data(), prompt_tokens.size(), n_past, 0, true);
114+
n_past += llama_batch_ext_get_n_tokens(batch.get());
114115

115116
llama_token new_token_id;
116117
while (true) {
117118
// check if we have enough space in the context to evaluate this batch
118119
int n_ctx = llama_n_ctx(ctx);
119120
int n_ctx_used = llama_kv_self_used_cells(ctx);
120-
if (n_ctx_used + llama_batch_ext_get_n_tokens(batch) > n_ctx) {
121+
if (n_ctx_used + llama_batch_ext_get_n_tokens(batch.get()) > n_ctx) {
121122
printf("\033[0m\n");
122123
fprintf(stderr, "context size exceeded\n");
123124
exit(0);
124125
}
125126

126-
if (llama_decode_ext(ctx, batch)) {
127+
if (llama_decode_ext(ctx, batch.get())) {
127128
GGML_ABORT("failed to decode\n");
128129
}
129130

@@ -147,13 +148,13 @@ int main(int argc, char ** argv) {
147148
response += piece;
148149

149150
// prepare the next batch with the sampled token
150-
llama_batch_ext_clear(batch);
151+
llama_batch_ext_clear(batch.get());
151152
llama_seq_id seq_id = 0;
152-
llama_batch_ext_add_text(batch, new_token_id, n_past, &seq_id, 1, true);
153+
llama_batch_ext_add_text(batch.get(), new_token_id, n_past, &seq_id, 1, true);
153154
n_past++;
154155
}
155156

156-
llama_batch_ext_free(batch);
157+
llama_batch_ext_free(batch.get());
157158

158159
return response;
159160
};

examples/simple/simple.cpp

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
#include "llama.h"
2+
#include "llama-cpp.h"
23
#include <cstdio>
34
#include <cstring>
45
#include <string>
@@ -143,22 +144,22 @@ int main(int argc, char ** argv) {
143144

144145
// prepare a batch for the prompt
145146

146-
llama_batch_ext * batch = llama_batch_ext_init_from_text(prompt_tokens.data(), prompt_tokens.size(), 0, 0, true);
147+
auto batch = llama_batch_ext_ptr::init_from_text(prompt_tokens.data(), prompt_tokens.size(), 0, 0, true);
147148

148149
// main loop
149150

150151
const auto t_main_start = ggml_time_us();
151152
int n_decode = 0;
152153
llama_token new_token_id;
153154

154-
for (int n_pos = 0; n_pos + llama_batch_ext_get_n_tokens(batch) < n_prompt + n_predict; ) {
155+
for (int n_pos = 0; n_pos + llama_batch_ext_get_n_tokens(batch.get()) < n_prompt + n_predict; ) {
155156
// evaluate the current batch with the transformer model
156-
if (llama_decode_ext(ctx, batch)) {
157+
if (llama_decode_ext(ctx, batch.get())) {
157158
fprintf(stderr, "%s : failed to eval, return code %d\n", __func__, 1);
158159
return 1;
159160
}
160161

161-
n_pos += llama_batch_ext_get_n_tokens(batch);
162+
n_pos += llama_batch_ext_get_n_tokens(batch.get());
162163

163164
// sample the next token
164165
{
@@ -180,9 +181,9 @@ int main(int argc, char ** argv) {
180181
fflush(stdout);
181182

182183
// prepare the next batch with the sampled token
183-
llama_batch_ext_clear(batch);
184+
llama_batch_ext_clear(batch.get());
184185
llama_seq_id seq_id = 0;
185-
llama_batch_ext_add_text(batch, new_token_id, n_pos, &seq_id, 1, true);
186+
llama_batch_ext_add_text(batch.get(), new_token_id, n_pos, &seq_id, 1, true);
186187

187188
n_decode += 1;
188189
}
@@ -200,7 +201,6 @@ int main(int argc, char ** argv) {
200201
llama_perf_context_print(ctx);
201202
fprintf(stderr, "\n");
202203

203-
llama_batch_ext_free(batch);
204204
llama_sampler_free(smpl);
205205
llama_free(ctx);
206206
llama_model_free(model);

examples/speculative/speculative.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -165,9 +165,9 @@ int main(int argc, char ** argv) {
165165
const auto t_enc_start = ggml_time_us();
166166

167167
// eval the prompt with both models
168-
llama_batch_ext_ptr batch0(llama_batch_ext_init_from_text( inp.data(), n_input - 1, 0, 0, true));
169-
llama_batch_ext_ptr batch1(llama_batch_ext_init_from_text(&inp.back(), 1, n_input - 1, 0, true));
170-
llama_batch_ext_ptr batch2(llama_batch_ext_init_from_text( inp.data(), n_input , 0, 0, true));
168+
auto batch0 = llama_batch_ext_ptr::init_from_text( inp.data(), n_input - 1, 0, 0, true);
169+
auto batch1 = llama_batch_ext_ptr::init_from_text(&inp.back(), 1, n_input - 1, 0, true);
170+
auto batch2 = llama_batch_ext_ptr::init_from_text( inp.data(), n_input , 0, 0, true);
171171
llama_decode_ext(ctx_tgt, batch0.get());
172172
llama_decode_ext(ctx_tgt, batch1.get());
173173
llama_decode_ext(ctx_dft, batch2.get());

include/llama-cpp.h

Lines changed: 21 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -37,21 +37,31 @@ struct llama_batch_ext_ptr : std::unique_ptr<llama_batch_ext, llama_batch_ext_de
3737
llama_batch_ext_ptr() : std::unique_ptr<llama_batch_ext, llama_batch_ext_deleter>() {}
3838
llama_batch_ext_ptr(llama_batch_ext * batch) : std::unique_ptr<llama_batch_ext, llama_batch_ext_deleter>(batch) {}
3939

40-
// convenience function to create a batch from text tokens, without worrying about manually freeing it
40+
// Convenience C++ wrapper to create a batch from text tokens, without worrying about manually freeing it
41+
// First token will be at position pos0
42+
// The sequence ID will be fixed to seq_id
43+
// If output_last is true, the last token will have output set
4144
static llama_batch_ext_ptr init_from_text(llama_token * tokens,
42-
int32_t n_tokens,
43-
int32_t pos0,
44-
int32_t seq_id,
45-
bool output_last) {
46-
return llama_batch_ext_ptr(llama_batch_ext_init_from_text(tokens, n_tokens, pos0, seq_id, output_last));
45+
int32_t n_tokens,
46+
llama_pos pos0,
47+
llama_seq_id seq_id,
48+
bool output_last) {
49+
llama_batch_ext * batch = llama_batch_ext_init(n_tokens, 1);
50+
for (int32_t i = 0; i < n_tokens; i++) {
51+
llama_batch_ext_add_text(batch, tokens[i], pos0 + i, &seq_id, 1, false);
52+
}
53+
if (output_last) {
54+
llama_batch_ext_set_output_last(batch);
55+
}
56+
return llama_batch_ext_ptr(batch);
4757
}
4858

49-
// convenience function to create a batch from text embeddings, without worrying about manually freeing it
59+
// Convenience C++ wrapper to create a batch from text embeddings, without worrying about manually freeing it
5060
static llama_batch_ext_ptr init_from_embd(float * embd,
51-
size_t n_tokens,
52-
size_t n_embd,
53-
int32_t pos0,
54-
int32_t seq_id) {
61+
size_t n_tokens,
62+
size_t n_embd,
63+
llama_pos pos0,
64+
llama_seq_id seq_id) {
5565
return llama_batch_ext_ptr(llama_batch_ext_init_from_embd(embd, n_tokens, n_embd, pos0, seq_id));
5666
}
5767
};

include/llama.h

Lines changed: 6 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -900,7 +900,7 @@ extern "C" {
900900
//
901901
DEPRECATED(LLAMA_API struct llama_batch llama_batch_get_one(
902902
llama_token * tokens,
903-
int32_t n_tokens), "use llama_batch_ext_init_from_text instead");
903+
int32_t n_tokens), "use llama_batch_ext API instead");
904904

905905
// Allocates a batch of tokens on the heap that can hold a maximum of n_tokens
906906
// Each token can be assigned up to n_seq_max sequence ids
@@ -925,30 +925,18 @@ extern "C" {
925925
int32_t n_tokens,
926926
int32_t n_seq_max);
927927

928-
// Same with llama_batch_init, but initializes the batch with the provided text tokens
929-
// First token will be at position pos0
930-
// The sequence ID will be fixed to seq_id
931-
// If output_last is true, the last token will have output set
932-
// The batch has to be freed with llama_batch_ext_free()
933-
LLAMA_API struct llama_batch_ext * llama_batch_ext_init_from_text(
934-
llama_token * tokens,
935-
int32_t n_tokens,
936-
int32_t pos0,
937-
int32_t seq_id,
938-
bool output_last);
939-
940928
// Same with llama_batch_init, but initializes the batch with the provided raw embeddings
941929
// Size of embd should be n_tokens * n_embd
942930
// n_embd is the number of embeddings per token, can be obtained from llama_model_n_embd()
943931
// First token will be at position pos0
944932
// The sequence ID will be fixed to seq_id
945933
// The batch has to be freed with llama_batch_ext_free()
946934
LLAMA_API struct llama_batch_ext * llama_batch_ext_init_from_embd(
947-
float * embd,
948-
size_t n_tokens,
949-
size_t n_embd,
950-
int32_t pos0,
951-
int32_t seq_id);
935+
const float * embd,
936+
size_t n_tokens,
937+
size_t n_embd,
938+
llama_pos pos0,
939+
llama_seq_id seq_id);
952940

953941
// Set arbitrary token to the embeddings batch
954942
// Note: this is only to be used in conjunction with llama_batch_ext_init_from_embd()

src/llama-batch.cpp

Lines changed: 5 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -337,22 +337,6 @@ struct llama_batch llama_batch_get_one(
337337
};
338338
}
339339

340-
struct llama_batch_ext * llama_batch_ext_init_from_text(
341-
llama_token * tokens,
342-
int32_t n_tokens,
343-
int32_t pos0,
344-
int32_t seq_id,
345-
bool output_last) {
346-
llama_batch_ext * batch = llama_batch_ext_init(n_tokens, 1);
347-
for (int32_t i = 0; i < n_tokens; i++) {
348-
llama_batch_ext_add_text(batch, tokens[i], pos0 + i, &seq_id, 1, false);
349-
}
350-
if (output_last) {
351-
llama_batch_ext_set_output_last(batch);
352-
}
353-
return batch;
354-
}
355-
356340
static struct llama_batch_ext * llama_batch_ext_init_impl(int32_t n_tokens_alloc, int32_t n_embd, int32_t n_seq_max) {
357341
llama_batch_ext * batch = new llama_batch_ext{
358342
/*n_tokens =*/ 0,
@@ -390,11 +374,11 @@ struct llama_batch_ext * llama_batch_ext_init(int32_t n_tokens_alloc, int32_t n_
390374
}
391375

392376
struct llama_batch_ext * llama_batch_ext_init_from_embd(
393-
float * embd,
394-
size_t n_tokens,
395-
size_t n_embd,
396-
int32_t pos0,
397-
int32_t seq_id) {
377+
const float * embd,
378+
size_t n_tokens,
379+
size_t n_embd,
380+
llama_pos pos0,
381+
llama_seq_id seq_id) {
398382
struct llama_batch_ext * batch = llama_batch_ext_init_impl(n_tokens, n_embd, 1);
399383
memcpy(batch->embd, embd, n_tokens * n_embd * sizeof(float));
400384
for (size_t i = 0; i < n_tokens; i++) {

0 commit comments

Comments
 (0)