@@ -276,15 +276,16 @@ void llama_sbatch::from_batch(const llama_batch_ext & batch, size_t n_embd, bool
276
276
277
277
llama_batch_allocr::llama_batch_allocr (struct llama_batch & in_batch, llama_pos p0) {
278
278
batch = new llama_batch_ext{
279
- /* n_tokens =*/ in_batch.n_tokens ,
280
- /* max_tokens =*/ in_batch.n_tokens ,
281
- /* is_view =*/ false ,
282
- /* tokens =*/ in_batch.token ,
283
- /* embd =*/ in_batch.embd ,
284
- /* pos =*/ in_batch.pos ,
285
- /* n_seq_id =*/ in_batch.n_seq_id ,
286
- /* seq_id =*/ in_batch.seq_id ,
287
- /* logits =*/ in_batch.logits ,
279
+ /* n_tokens =*/ in_batch.n_tokens ,
280
+ /* max_tokens =*/ in_batch.n_tokens ,
281
+ /* n_pos_per_token =*/ 1 ,
282
+ /* is_view =*/ false ,
283
+ /* tokens =*/ in_batch.token ,
284
+ /* embd =*/ in_batch.embd ,
285
+ /* pos =*/ in_batch.pos ,
286
+ /* n_seq_id =*/ in_batch.n_seq_id ,
287
+ /* seq_id =*/ in_batch.seq_id ,
288
+ /* logits =*/ in_batch.logits ,
288
289
};
289
290
GGML_ASSERT (batch->n_tokens > 0 );
290
291
if (!in_batch.pos ) {
@@ -338,17 +339,18 @@ struct llama_batch llama_batch_get_one(
338
339
};
339
340
}
340
341
341
- static struct llama_batch_ext * llama_batch_ext_init_impl (int32_t n_tokens_alloc, int32_t n_embd, int32_t n_seq_max) {
342
+ static struct llama_batch_ext * llama_batch_ext_init_impl (int32_t n_tokens_alloc, int32_t n_embd, int32_t n_seq_max, int32_t n_pos_per_token ) {
342
343
llama_batch_ext * batch = new llama_batch_ext{
343
- /* n_tokens =*/ 0 ,
344
- /* max_tokens =*/ n_tokens_alloc,
345
- /* is_view =*/ false ,
346
- /* tokens =*/ nullptr ,
347
- /* embd =*/ nullptr ,
348
- /* pos =*/ nullptr ,
349
- /* n_seq_id =*/ nullptr ,
350
- /* seq_id =*/ nullptr ,
351
- /* logits =*/ nullptr ,
344
+ /* n_tokens =*/ 0 ,
345
+ /* max_tokens =*/ n_tokens_alloc,
346
+ /* n_pos_per_token =*/ n_pos_per_token,
347
+ /* is_view =*/ false ,
348
+ /* tokens =*/ nullptr ,
349
+ /* embd =*/ nullptr ,
350
+ /* pos =*/ nullptr ,
351
+ /* n_seq_id =*/ nullptr ,
352
+ /* seq_id =*/ nullptr ,
353
+ /* logits =*/ nullptr ,
352
354
};
353
355
354
356
if (n_embd) {
@@ -371,7 +373,8 @@ static struct llama_batch_ext * llama_batch_ext_init_impl(int32_t n_tokens_alloc
371
373
}
372
374
373
375
struct llama_batch_ext * llama_batch_ext_init (struct llama_context * ctx) {
374
- return llama_batch_ext_init_impl (llama_n_batch (ctx), 0 , llama_n_seq_max (ctx));
376
+ int32_t n_pos_per_token = llama_n_pos_per_token (llama_get_model (ctx));
377
+ return llama_batch_ext_init_impl (llama_n_batch (ctx), 0 , llama_n_seq_max (ctx), n_pos_per_token);
375
378
}
376
379
377
380
struct llama_batch_ext * llama_batch_ext_init_from_embd (
@@ -381,10 +384,10 @@ struct llama_batch_ext * llama_batch_ext_init_from_embd(
381
384
size_t n_embd,
382
385
const llama_pos * pos,
383
386
llama_seq_id seq_id) {
384
- auto model = llama_get_model (ctx);
385
- struct llama_batch_ext * batch = llama_batch_ext_init_impl (n_tokens, n_embd, 1 );
387
+ int32_t n_pos_per_token = llama_n_pos_per_token ( llama_get_model (ctx) );
388
+ struct llama_batch_ext * batch = llama_batch_ext_init_impl (n_tokens, n_embd, 1 , n_pos_per_token );
386
389
memcpy (batch->embd , embd, n_tokens * n_embd * sizeof (float ));
387
- memcpy (batch->pos , pos, n_tokens * llama_n_pos_per_token (model) * sizeof (llama_pos));
390
+ memcpy (batch->pos , pos, n_tokens * n_pos_per_token * sizeof (llama_pos));
388
391
for (size_t i = 0 ; i < n_tokens; i++) {
389
392
batch->n_seq_id [i] = 1 ;
390
393
batch->seq_id [i][0 ] = seq_id;
@@ -411,12 +414,16 @@ int32_t llama_batch_ext_add_text(
411
414
}
412
415
const int32_t output_id = batch->n_tokens ;
413
416
batch->token [output_id] = token;
414
- batch->pos [output_id] = pos;
417
+ batch->n_seq_id [output_id] = n_seq_ids;
418
+ batch->logits [output_id] = output;
419
+ for (int32_t i = 0 ; i < batch->n_pos_per_token ; i++) {
420
+ // TODO: this is only used by qwen2vl for now, and text tokens only have 3 pos, the last is set to 0; we should improve this code in the future
421
+ batch->pos [output_id * batch->n_pos_per_token + i] = i < 3 ? pos : 0 ;
422
+ }
415
423
batch->n_seq_id [output_id] = n_seq_ids;
416
424
for (size_t j = 0 ; j < n_seq_ids; j++) {
417
425
batch->seq_id [batch->n_tokens ][j] = seq_ids[j];
418
426
}
419
- batch->logits [output_id] = output;
420
427
batch->n_tokens ++;
421
428
return output_id;
422
429
}
@@ -461,15 +468,16 @@ struct llama_batch_ext * llama_batch_ext_get_view(
461
468
return nullptr ; // not yet supported
462
469
}
463
470
llama_batch_ext * batch_view = new llama_batch_ext{
464
- /* n_tokens =*/ n_tokens,
465
- /* max_tokens =*/ n_tokens,
466
- /* is_view =*/ true ,
467
- /* tokens =*/ batch->token + offset,
468
- /* embd =*/ nullptr ,
469
- /* pos =*/ batch->pos + offset,
470
- /* n_seq_id =*/ batch->n_seq_id + offset,
471
- /* seq_id =*/ batch->seq_id + offset,
472
- /* logits =*/ batch->logits + offset,
471
+ /* n_tokens =*/ n_tokens,
472
+ /* max_tokens =*/ n_tokens,
473
+ /* n_pos_per_token =*/ batch->n_pos_per_token ,
474
+ /* is_view =*/ true ,
475
+ /* tokens =*/ batch->token + offset,
476
+ /* embd =*/ nullptr ,
477
+ /* pos =*/ batch->pos + offset * batch->n_pos_per_token ,
478
+ /* n_seq_id =*/ batch->n_seq_id + offset,
479
+ /* seq_id =*/ batch->seq_id + offset,
480
+ /* logits =*/ batch->logits + offset,
473
481
};
474
482
return batch_view;
475
483
}
0 commit comments