Skip to content

Commit 2cec1cf

Browse files
committed
move various places to batch.add_text
1 parent 2134cab commit 2cec1cf

File tree

15 files changed

+91
-128
lines changed

15 files changed

+91
-128
lines changed

common/speculative.cpp

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -149,8 +149,6 @@ llama_tokens common_speculative_gen_draft(
149149

150150
const int i_start = std::max<int>(0, (int) prompt_tgt.size() - n_ctx);
151151

152-
const llama_seq_id seq_id = 0;
153-
154152
// reuse as much as possible from the old draft context
155153
// ideally, the draft context should be as big as the target context and we will always reuse the entire prompt
156154
for (int i = 0; i < (int) prompt.size(); ++i) {
@@ -210,7 +208,7 @@ llama_tokens common_speculative_gen_draft(
210208

211209
for (size_t i = i_start + reuse_n; i < prompt_tgt.size(); ++i) {
212210
//LOG_DBG("i = %d, i_start = %d, reuse_n = %d, i - i_start = %d, id = %6d\n", i, i_start, reuse_n, i - i_start, prompt_tgt[i]);
213-
llama_batch_ext_add_text(batch.get(), prompt_tgt[i], i - i_start, &seq_id, 1, false);
211+
batch.add_text(prompt_tgt[i], i - i_start, 0, false);
214212

215213
prompt.push_back(prompt_tgt[i]);
216214
}
@@ -227,7 +225,7 @@ llama_tokens common_speculative_gen_draft(
227225
LOG_DBG("%s: n_past = %d\n", __func__, n_past);
228226

229227
llama_batch_ext_clear(batch.get());
230-
llama_batch_ext_add_text(batch.get(), id_last, n_past, &seq_id, 1, true);
228+
batch.add_text(id_last, n_past, 0, true);
231229

232230
prompt.push_back(id_last);
233231

@@ -266,7 +264,7 @@ llama_tokens common_speculative_gen_draft(
266264
break;
267265
}
268266

269-
llama_batch_ext_add_text(batch.get(), id, n_past + i + 1, &seq_id, 1, true);
267+
batch.add_text( id, n_past + i + 1, 0, true);
270268

271269
// evaluate the drafted tokens on the draft model
272270
llama_decode_ext(ctx, batch.get());

examples/gritlm/gritlm.cpp

Lines changed: 10 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
#include "arg.h"
22
#include "common.h"
33
#include "llama.h"
4+
#include "llama-cpp.h"
45

56
#include <string>
67
#include <vector>
@@ -13,10 +14,10 @@ static std::vector<std::vector<float>> encode(llama_context * ctx, const std::ve
1314
const llama_model * model = llama_get_model(ctx);
1415
const llama_vocab * vocab = llama_model_get_vocab(model);
1516

16-
llama_batch_ext * batch = llama_batch_ext_init(llama_n_batch(ctx), 1);
17+
llama_batch_ext_ptr batch(llama_batch_ext_init(llama_n_batch(ctx), 1));
1718

1819
for (uint64_t i = 0; i < sentences.size(); i++) {
19-
llama_batch_ext_clear(batch);
20+
llama_batch_ext_clear(batch.get());
2021

2122
const std::string input_string = instruction + sentences[i];
2223

@@ -41,8 +42,7 @@ static std::vector<std::vector<float>> encode(llama_context * ctx, const std::ve
4142

4243
// add input to batch (this increments n_tokens)
4344
for (int32_t j = 0; j < n_toks; j++) {
44-
const llama_seq_id seq_id = 0;
45-
llama_batch_ext_add_text(batch, inputs[j], j, &seq_id, 1 , j >= n_inst);
45+
batch.add_text(inputs[j], j, 0, j >= n_inst);
4646
}
4747

4848
// clear previous kv_cache values (irrelevant for embeddings)
@@ -51,7 +51,7 @@ static std::vector<std::vector<float>> encode(llama_context * ctx, const std::ve
5151
llama_set_causal_attn(ctx, false);
5252

5353
// run model
54-
llama_decode_ext(ctx, batch);
54+
llama_decode_ext(ctx, batch.get());
5555

5656
// get embedding dimensions
5757
uint64_t n_embd = llama_model_n_embd(model);
@@ -90,8 +90,6 @@ static std::vector<std::vector<float>> encode(llama_context * ctx, const std::ve
9090
#endif
9191
}
9292

93-
llama_batch_ext_free(batch);
94-
9593
return result;
9694
}
9795

@@ -107,26 +105,25 @@ static std::string generate(llama_context * ctx, llama_sampler * smpl, const std
107105
llama_set_embeddings(ctx, false);
108106
llama_set_causal_attn(ctx, true);
109107

110-
llama_batch_ext * bat = llama_batch_ext_init(llama_n_batch(ctx), 1);
108+
llama_batch_ext_ptr batch(llama_batch_ext_init(llama_n_batch(ctx), 1));
111109

112110
std::vector<llama_token> inputs = common_tokenize(vocab, prompt, false, true);
113111
int32_t i_current_token = 0;
114112

115113
while (true) {
116-
llama_batch_ext_clear(bat);
114+
llama_batch_ext_clear(batch.get());
117115
{
118116
const int32_t n_inputs = inputs.size();
119117

120118
for (int32_t i = 0; i < n_inputs; i++) {
121-
const llama_seq_id seq_id = 0;
122-
llama_batch_ext_add_text(bat, inputs[i], i_current_token++, &seq_id, 1, i == n_inputs - 1);
119+
batch.add_text(inputs[i], i_current_token++, 0, i == n_inputs - 1);
123120
}
124121
}
125122
inputs.clear();
126123

127-
llama_decode_ext(ctx, bat);
124+
llama_decode_ext(ctx, batch.get());
128125

129-
llama_token token = llama_sampler_sample(smpl, ctx, llama_batch_ext_get_n_tokens(bat) - 1);
126+
llama_token token = llama_sampler_sample(smpl, ctx, llama_batch_ext_get_n_tokens(batch.get()) - 1);
130127

131128
if (token == eos_token) {
132129
break;
@@ -147,8 +144,6 @@ static std::string generate(llama_context * ctx, llama_sampler * smpl, const std
147144
std::printf("\n");
148145
}
149146

150-
llama_batch_ext_free(bat);
151-
152147
return result;
153148
}
154149

examples/llava/gemma3-cli.cpp

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -92,8 +92,7 @@ static int eval_text(gemma3_context & ctx, std::string input, bool logits_last =
9292
llama_tokens tokens = common_tokenize(ctx.lctx, input, false, true);
9393
llama_batch_ext_clear(ctx.batch.get());
9494
for (llama_token & t : tokens) {
95-
llama_seq_id seq_id = 0;
96-
llama_batch_ext_add_text(ctx.batch.get(), t, ctx.n_past++, &seq_id, 1, false);
95+
ctx.batch.add_text(t, ctx.n_past++, 0, false);
9796
}
9897
if (logits_last) {
9998
llama_batch_ext_set_output_last(ctx.batch.get());
@@ -180,8 +179,7 @@ static int generate_response(gemma3_context & ctx, common_sampler * smpl, int n_
180179

181180
// eval the token
182181
llama_batch_ext_clear(ctx.batch.get());
183-
llama_seq_id seq_id = 0;
184-
llama_batch_ext_add_text(ctx.batch.get(), token_id, ctx.n_past++, &seq_id, 1, true);
182+
ctx.batch.add_text(token_id, ctx.n_past++, 0, true);
185183
if (llama_decode_ext(ctx.lctx, ctx.batch.get())) {
186184
LOG_ERR("failed to decode token\n");
187185
return 1;

examples/llava/qwen2vl-cli.cpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -101,8 +101,7 @@ static bool eval_tokens(struct llama_context * ctx_llama, std::vector<llama_toke
101101
llama_batch_ext_ptr batch(llama_batch_ext_init(n_eval, 1));
102102
for (int j = 0; j < n_eval; j++) {
103103
llama_token token = tokens[i + j];
104-
llama_seq_id seq_id = 0;
105-
llama_batch_ext_add_text(batch.get(), token, pos[j], &seq_id, 1, false);
104+
batch.add_text(token, pos[j], 0, false);
106105
}
107106
llama_batch_ext_set_output_last(batch.get());
108107

examples/lookup/lookup.cpp

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
#include "sampling.h"
66
#include "log.h"
77
#include "llama.h"
8+
#include "llama-cpp.h"
89

910
#include <cstdint>
1011
#include <cstdio>
@@ -110,7 +111,7 @@ int main(int argc, char ** argv){
110111

111112
std::vector<llama_token> draft;
112113

113-
llama_batch_ext * batch_tgt = llama_batch_ext_init(params.n_ctx, 1);
114+
llama_batch_ext_ptr batch_tgt(llama_batch_ext_init(params.n_ctx, 1));
114115

115116
// debug
116117
struct llama_kv_cache_view kvc_view = llama_kv_cache_view_init(ctx, 1);
@@ -196,9 +197,8 @@ int main(int argc, char ** argv){
196197
// clean the cache of draft tokens that weren't accepted
197198
llama_kv_self_seq_rm(ctx, 0, n_past, -1);
198199

199-
const llama_seq_id seq_id = 0;
200-
llama_batch_ext_clear(batch_tgt);
201-
llama_batch_ext_add_text(batch_tgt, draft[0], n_past, &seq_id, 1, true);
200+
llama_batch_ext_clear(batch_tgt.get());
201+
batch_tgt.add_text(draft[0], n_past, 0, true);
202202

203203
// Draft already contains a single token sampled from the model:
204204
GGML_ASSERT(draft.size() == 1);
@@ -208,13 +208,13 @@ int main(int argc, char ** argv){
208208
common_ngram_cache_draft(inp, draft, n_draft, LLAMA_NGRAM_MIN, LLAMA_NGRAM_MAX, ngram_cache_context, ngram_cache_dynamic, ngram_cache_static);
209209

210210
for (size_t i = 1; i < draft.size(); ++i) {
211-
llama_batch_ext_add_text(batch_tgt, draft[i], n_past + i, &seq_id, 1, true);
211+
batch_tgt.add_text(draft[i], n_past + i, 0, true);
212212
}
213213

214214
t_draft_us += ggml_time_us() - t_start_draft_us;
215215
n_drafted += draft.size() - 1;
216216

217-
llama_decode_ext(ctx, batch_tgt);
217+
llama_decode_ext(ctx, batch_tgt.get());
218218
++n_past;
219219

220220
draft.erase(draft.begin());
@@ -246,8 +246,6 @@ int main(int argc, char ** argv){
246246

247247
common_sampler_free(smpl);
248248

249-
llama_batch_ext_free(batch_tgt);
250-
251249
llama_backend_free();
252250

253251
LOG("\n\n");

examples/parallel/parallel.cpp

Lines changed: 16 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
#include "sampling.h"
77
#include "log.h"
88
#include "llama.h"
9+
#include "llama-cpp.h"
910

1011
#include <cmath>
1112
#include <cstdio>
@@ -174,7 +175,7 @@ int main(int argc, char ** argv) {
174175

175176
// the max batch size is as large as the context to handle cases where we get very long input prompt from multiple
176177
// users. regardless of the size, the main loop will chunk the batch into a maximum of params.n_batch tokens at a time
177-
llama_batch_ext * batch = llama_batch_ext_init(n_ctx, 1);
178+
llama_batch_ext_ptr batch(llama_batch_ext_init(n_ctx, 1));
178179

179180
int32_t n_total_prompt = 0;
180181
int32_t n_total_gen = 0;
@@ -192,11 +193,10 @@ int main(int argc, char ** argv) {
192193
LOG_INF("%s: Evaluating the system prompt ...\n", __func__);
193194

194195
for (int32_t i = 0; i < n_tokens_system; ++i) {
195-
llama_seq_id seq_id = 0;
196-
llama_batch_ext_add_text(batch, tokens_system[i], i, &seq_id, 1, false);
196+
batch.add_text(tokens_system[i], i, 0, false);
197197
}
198198

199-
if (llama_decode_ext(ctx, batch) != 0) {
199+
if (llama_decode_ext(ctx, batch.get()) != 0) {
200200
LOG_ERR("%s: llama_decode() failed\n", __func__);
201201
return 1;
202202
}
@@ -217,23 +217,23 @@ int main(int argc, char ** argv) {
217217
common_kv_cache_dump_view_seqs(kvc_view, 40);
218218
}
219219

220-
llama_batch_ext_clear(batch);
220+
llama_batch_ext_clear(batch.get());
221221

222222
// decode any currently ongoing sequences
223223
for (auto & client : clients) {
224224
if (client.seq_id == -1) {
225225
continue;
226226
}
227227

228-
client.i_batch = llama_batch_ext_get_n_tokens(batch);
228+
client.i_batch = llama_batch_ext_get_n_tokens(batch.get());
229229

230230
llama_seq_id seq_id = client.id + 1;
231-
llama_batch_ext_add_text(batch, client.sampled, n_tokens_system + client.n_prompt + client.n_decoded, &seq_id, 1, true);
231+
batch.add_text(client.sampled, n_tokens_system + client.n_prompt + client.n_decoded, seq_id, true);
232232

233233
client.n_decoded += 1;
234234
}
235235

236-
if (llama_batch_ext_get_n_tokens(batch) == 0) {
236+
if (llama_batch_ext_get_n_tokens(batch.get()) == 0) {
237237
// all sequences have ended - clear the entire KV cache
238238
for (int i = 1; i <= n_clients; ++i) {
239239
llama_kv_self_seq_rm(ctx, i, -1, -1);
@@ -245,7 +245,7 @@ int main(int argc, char ** argv) {
245245
}
246246

247247
// insert new sequences for decoding
248-
if (cont_batching || llama_batch_ext_get_n_tokens(batch) == 0) {
248+
if (cont_batching || llama_batch_ext_get_n_tokens(batch.get()) == 0) {
249249
for (auto & client : clients) {
250250
if (client.seq_id == -1 && g_seq_id < n_seq) {
251251
client.seq_id = g_seq_id;
@@ -265,17 +265,17 @@ int main(int argc, char ** argv) {
265265

266266
for (size_t i = 0; i < tokens_prompt.size(); ++i) {
267267
llama_seq_id seq_id = client.id + 1;
268-
llama_batch_ext_add_text(batch, tokens_prompt[i], i + n_tokens_system, &seq_id, 1, false);
268+
batch.add_text(tokens_prompt[i], i + n_tokens_system, seq_id, false);
269269
}
270270

271271
// extract the logits only for the last token
272-
if (llama_batch_ext_get_n_tokens(batch) > 0) {
273-
llama_batch_ext_set_output_last(batch);
272+
if (llama_batch_ext_get_n_tokens(batch.get()) > 0) {
273+
llama_batch_ext_set_output_last(batch.get());
274274
}
275275

276276
client.n_prompt = tokens_prompt.size();
277277
client.n_decoded = 0;
278-
client.i_batch = llama_batch_ext_get_n_tokens(batch) - 1;
278+
client.i_batch = llama_batch_ext_get_n_tokens(batch.get()) - 1;
279279

280280
LOG_INF("\033[31mClient %3d, seq %4d, started decoding ...\033[0m\n", client.id, client.seq_id);
281281

@@ -289,14 +289,14 @@ int main(int argc, char ** argv) {
289289
}
290290
}
291291

292-
if (llama_batch_ext_get_n_tokens(batch) == 0) {
292+
if (llama_batch_ext_get_n_tokens(batch.get()) == 0) {
293293
break;
294294
}
295295

296296
// process in chunks of params.n_batch
297297
int32_t n_batch = params.n_batch;
298298

299-
int32_t n_tokens_in_batch = llama_batch_ext_get_n_tokens(batch);
299+
int32_t n_tokens_in_batch = llama_batch_ext_get_n_tokens(batch.get());
300300
for (int32_t i = 0; i < (int32_t) n_tokens_in_batch; i += n_batch) {
301301
// experiment: process in powers of 2
302302
//if (i + n_batch > (int32_t) batch.n_tokens && n_batch > 32) {
@@ -307,7 +307,7 @@ int main(int argc, char ** argv) {
307307

308308
const int32_t n_tokens = std::min(n_batch, (int32_t) (n_tokens_in_batch - i));
309309

310-
llama_batch_ext * batch_view = llama_batch_ext_get_view(batch, i, n_tokens);
310+
llama_batch_ext * batch_view = llama_batch_ext_get_view(batch.get(), i, n_tokens);
311311
const int ret = llama_decode_ext(ctx, batch_view);
312312
llama_batch_ext_free(batch_view);
313313
if (ret != 0) {
@@ -413,8 +413,6 @@ int main(int argc, char ** argv) {
413413
// TODO: print sampling/grammar timings for all clients
414414
llama_perf_context_print(ctx);
415415

416-
llama_batch_ext_free(batch);
417-
418416
llama_backend_free();
419417

420418
LOG("\n\n");

examples/passkey/passkey.cpp

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -144,8 +144,7 @@ int main(int argc, char ** argv) {
144144
llama_batch_ext_clear(batch.get());
145145

146146
for (int j = 0; j < n_batch && i + j < n_tokens_all; j++) {
147-
llama_seq_id seq_id = 0;
148-
llama_batch_ext_add_text(batch.get(), tokens_list[i + j], n_past++, &seq_id, 1, false);
147+
batch.add_text(tokens_list[i + j], n_past++, 0, false);
149148
}
150149

151150
if (i + n_batch >= n_tokens_all) {
@@ -179,8 +178,7 @@ int main(int argc, char ** argv) {
179178
llama_batch_ext_clear(batch.get());
180179

181180
for (int j = 0; j < n_batch && i + j < n_tokens_all; j++) {
182-
llama_seq_id seq_id = 0;
183-
llama_batch_ext_add_text(batch.get(), tokens_list[i + j], n_past++, &seq_id, 1, false);
181+
batch.add_text(tokens_list[i + j], n_past++, 0, false);
184182
}
185183

186184
if (i + n_batch >= n_tokens_all) {

examples/save-load-state/save-load-state.cpp

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -80,8 +80,7 @@ int main(int argc, char ** argv) {
8080
result0 += next_token_str;
8181

8282
llama_batch_ext_clear(batch.get());
83-
llama_seq_id seq_id = 0;
84-
llama_batch_ext_add_text(batch.get(), next_token, 0, &seq_id, 1, true);
83+
batch.add_text(next_token, 0, 0, true);
8584

8685
if (llama_decode_ext(ctx, batch.get())) {
8786
fprintf(stderr, "\n%s : failed to evaluate\n", __func__);
@@ -133,8 +132,7 @@ int main(int argc, char ** argv) {
133132
result1 += next_token_str;
134133

135134
llama_batch_ext_clear(batch.get());
136-
llama_seq_id seq_id = 0;
137-
llama_batch_ext_add_text(batch.get(), next_token, 0, &seq_id, 1, true);
135+
batch.add_text(next_token, 0, 0, true);
138136

139137
if (llama_decode_ext(ctx2, batch.get())) {
140138
fprintf(stderr, "\n%s : failed to evaluate\n", __func__);
@@ -215,8 +213,7 @@ int main(int argc, char ** argv) {
215213
result2 += next_token_str;
216214

217215
llama_batch_ext_clear(batch.get());
218-
llama_seq_id seq_id = 1; // seq 1 instead of 0
219-
llama_batch_ext_add_text(batch.get(), next_token, 0, &seq_id, 1, true);
216+
batch.add_text(next_token, 0, 1, true);
220217

221218
if (llama_decode_ext(ctx3, batch.get())) {
222219
fprintf(stderr, "\n%s : failed to evaluate\n", __func__);

0 commit comments

Comments
 (0)