Skip to content

Commit c4fea7f

Browse files
committed
fix qwzn2vl mrope position input
1 parent d18a79e commit c4fea7f

File tree

5 files changed

+60
-43
lines changed

5 files changed

+60
-43
lines changed

examples/llava/qwen2vl-cli.cpp

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -66,8 +66,17 @@ static bool qwen2vl_eval_image_embed(llama_context * ctx_llama, const struct lla
6666
memcpy(&batch_mrope_pos[n_eval * 2], &mrope_pos[img_tokens * 2 + processed], n_eval * sizeof(llama_pos));
6767
memcpy(&batch_mrope_pos[n_eval * 3], &mrope_pos[img_tokens * 3 + processed], n_eval * sizeof(llama_pos));
6868

69+
// tranpose from layout 0123012301230123 to 0000111122223333
70+
// TODO @ngxson : this is a low-effort solution, generated with the help of LLM; we should improve this in the future
71+
std::vector<llama_pos> batch_mrope_pos_T(n_eval * 4);
72+
for (int r = 0; r < 4; r++) {
73+
for (int c = 0; c < n_eval; c++) {
74+
batch_mrope_pos_T[c*4 + r] = batch_mrope_pos[r*n_eval + c];
75+
}
76+
}
77+
6978
float * batch_embd = image_embed->embed+i*n_embd;
70-
const llama_pos * pos = batch_mrope_pos.data();
79+
const llama_pos * pos = batch_mrope_pos_T.data();
7180
auto batch = llama_batch_ext_ptr::init_from_embd(ctx_llama, batch_embd, n_eval, n_embd, pos, 0);
7281

7382
if (llama_decode_ext(ctx_llama, batch.get())) {
@@ -90,13 +99,6 @@ static bool eval_tokens(struct llama_context * ctx_llama, std::vector<llama_toke
9099
n_eval = n_batch;
91100
}
92101

93-
// TODO: add mrope pos ids somewhere else
94-
pos.resize(n_eval * 4);
95-
std::fill(pos.begin(), pos.end(), 0);
96-
for (int j = 0; j < n_eval * 3; j ++) {
97-
pos[j] = *st_pos_id + (j % n_eval);
98-
}
99-
100102
llama_batch_ext_ptr batch(ctx_llama);
101103
for (int j = 0; j < n_eval; j++) {
102104
llama_token token = tokens[i + j];

include/llama.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -928,6 +928,7 @@ extern "C" {
928928
// Same with llama_batch_init, but initializes the batch with the provided raw embeddings
929929
// Size of embd should be n_tokens * n_embd
930930
// Size of pos should be n_tokens * n_pos_per_token
931+
// If one token has multiple pos, the pos must follow the order: 000011112222...
931932
// n_embd is the number of embeddings per token, can be obtained from llama_model_n_embd()
932933
// The sequence ID will be fixed to seq_id
933934
// The batch has to be freed with llama_batch_ext_free()

src/llama-batch.cpp

Lines changed: 42 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -276,15 +276,16 @@ void llama_sbatch::from_batch(const llama_batch_ext & batch, size_t n_embd, bool
276276

277277
llama_batch_allocr::llama_batch_allocr(struct llama_batch & in_batch, llama_pos p0) {
278278
batch = new llama_batch_ext{
279-
/*n_tokens =*/ in_batch.n_tokens,
280-
/*max_tokens =*/ in_batch.n_tokens,
281-
/*is_view =*/ false,
282-
/*tokens =*/ in_batch.token,
283-
/*embd =*/ in_batch.embd,
284-
/*pos =*/ in_batch.pos,
285-
/*n_seq_id =*/ in_batch.n_seq_id,
286-
/*seq_id =*/ in_batch.seq_id,
287-
/*logits =*/ in_batch.logits,
279+
/*n_tokens =*/ in_batch.n_tokens,
280+
/*max_tokens =*/ in_batch.n_tokens,
281+
/*n_pos_per_token =*/ 1,
282+
/*is_view =*/ false,
283+
/*tokens =*/ in_batch.token,
284+
/*embd =*/ in_batch.embd,
285+
/*pos =*/ in_batch.pos,
286+
/*n_seq_id =*/ in_batch.n_seq_id,
287+
/*seq_id =*/ in_batch.seq_id,
288+
/*logits =*/ in_batch.logits,
288289
};
289290
GGML_ASSERT(batch->n_tokens > 0);
290291
if (!in_batch.pos) {
@@ -338,17 +339,18 @@ struct llama_batch llama_batch_get_one(
338339
};
339340
}
340341

341-
static struct llama_batch_ext * llama_batch_ext_init_impl(int32_t n_tokens_alloc, int32_t n_embd, int32_t n_seq_max) {
342+
static struct llama_batch_ext * llama_batch_ext_init_impl(int32_t n_tokens_alloc, int32_t n_embd, int32_t n_seq_max, int32_t n_pos_per_token) {
342343
llama_batch_ext * batch = new llama_batch_ext{
343-
/*n_tokens =*/ 0,
344-
/*max_tokens =*/ n_tokens_alloc,
345-
/*is_view =*/ false,
346-
/*tokens =*/ nullptr,
347-
/*embd =*/ nullptr,
348-
/*pos =*/ nullptr,
349-
/*n_seq_id =*/ nullptr,
350-
/*seq_id =*/ nullptr,
351-
/*logits =*/ nullptr,
344+
/*n_tokens =*/ 0,
345+
/*max_tokens =*/ n_tokens_alloc,
346+
/*n_pos_per_token =*/ n_pos_per_token,
347+
/*is_view =*/ false,
348+
/*tokens =*/ nullptr,
349+
/*embd =*/ nullptr,
350+
/*pos =*/ nullptr,
351+
/*n_seq_id =*/ nullptr,
352+
/*seq_id =*/ nullptr,
353+
/*logits =*/ nullptr,
352354
};
353355

354356
if (n_embd) {
@@ -371,7 +373,8 @@ static struct llama_batch_ext * llama_batch_ext_init_impl(int32_t n_tokens_alloc
371373
}
372374

373375
struct llama_batch_ext * llama_batch_ext_init(struct llama_context * ctx) {
374-
return llama_batch_ext_init_impl(llama_n_batch(ctx), 0, llama_n_seq_max(ctx));
376+
int32_t n_pos_per_token = llama_n_pos_per_token(llama_get_model(ctx));
377+
return llama_batch_ext_init_impl(llama_n_batch(ctx), 0, llama_n_seq_max(ctx), n_pos_per_token);
375378
}
376379

377380
struct llama_batch_ext * llama_batch_ext_init_from_embd(
@@ -381,10 +384,10 @@ struct llama_batch_ext * llama_batch_ext_init_from_embd(
381384
size_t n_embd,
382385
const llama_pos * pos,
383386
llama_seq_id seq_id) {
384-
auto model = llama_get_model(ctx);
385-
struct llama_batch_ext * batch = llama_batch_ext_init_impl(n_tokens, n_embd, 1);
387+
int32_t n_pos_per_token = llama_n_pos_per_token(llama_get_model(ctx));
388+
struct llama_batch_ext * batch = llama_batch_ext_init_impl(n_tokens, n_embd, 1, n_pos_per_token);
386389
memcpy(batch->embd, embd, n_tokens * n_embd * sizeof(float));
387-
memcpy(batch->pos, pos, n_tokens * llama_n_pos_per_token(model) * sizeof(llama_pos));
390+
memcpy(batch->pos, pos, n_tokens * n_pos_per_token * sizeof(llama_pos));
388391
for (size_t i = 0; i < n_tokens; i++) {
389392
batch->n_seq_id[i] = 1;
390393
batch->seq_id [i][0] = seq_id;
@@ -411,12 +414,16 @@ int32_t llama_batch_ext_add_text(
411414
}
412415
const int32_t output_id = batch->n_tokens;
413416
batch->token [output_id] = token;
414-
batch->pos [output_id] = pos;
417+
batch->n_seq_id[output_id] = n_seq_ids;
418+
batch->logits [output_id] = output;
419+
for (int32_t i = 0; i < batch->n_pos_per_token; i++) {
420+
// TODO: this is only used by qwen2vl for now, and text tokens only have 3 pos, the last is set to 0; we should improve this code in the future
421+
batch->pos[output_id * batch->n_pos_per_token + i] = i < 3 ? pos : 0;
422+
}
415423
batch->n_seq_id[output_id] = n_seq_ids;
416424
for (size_t j = 0; j < n_seq_ids; j++) {
417425
batch->seq_id[batch->n_tokens][j] = seq_ids[j];
418426
}
419-
batch->logits [output_id] = output;
420427
batch->n_tokens++;
421428
return output_id;
422429
}
@@ -461,15 +468,16 @@ struct llama_batch_ext * llama_batch_ext_get_view(
461468
return nullptr; // not yet supported
462469
}
463470
llama_batch_ext * batch_view = new llama_batch_ext{
464-
/*n_tokens =*/ n_tokens,
465-
/*max_tokens =*/ n_tokens,
466-
/*is_view =*/ true,
467-
/*tokens =*/ batch->token + offset,
468-
/*embd =*/ nullptr,
469-
/*pos =*/ batch->pos + offset,
470-
/*n_seq_id =*/ batch->n_seq_id + offset,
471-
/*seq_id =*/ batch->seq_id + offset,
472-
/*logits =*/ batch->logits + offset,
471+
/*n_tokens =*/ n_tokens,
472+
/*max_tokens =*/ n_tokens,
473+
/*n_pos_per_token =*/ batch->n_pos_per_token,
474+
/*is_view =*/ true,
475+
/*tokens =*/ batch->token + offset,
476+
/*embd =*/ nullptr,
477+
/*pos =*/ batch->pos + offset * batch->n_pos_per_token,
478+
/*n_seq_id =*/ batch->n_seq_id + offset,
479+
/*seq_id =*/ batch->seq_id + offset,
480+
/*logits =*/ batch->logits + offset,
473481
};
474482
return batch_view;
475483
}

src/llama-batch.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,11 +21,12 @@
2121
struct llama_batch_ext {
2222
int32_t n_tokens;
2323
int32_t max_tokens;
24+
int32_t n_pos_per_token = 1;
2425
bool is_view;
2526

2627
llama_token * token;
2728
float * embd;
28-
llama_pos * pos;
29+
llama_pos * pos; // if multi pos per token: 000011112222...
2930
int32_t * n_seq_id;
3031
llama_seq_id ** seq_id;
3132
int8_t * logits; // TODO: rename this to "output"

src/llama-model.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6075,6 +6075,11 @@ struct llm_build_qwen2vl : public llm_graph_context {
60756075
// inp_pos - contains the positions
60766076
ggml_tensor * inp_pos = build_inp_pos();
60776077

6078+
// TODO @ngxson : transpose layout 0000111122223333 to 0123012301230123, we should improve this in the future
6079+
inp_pos = ggml_reshape_2d(ctx0, inp_pos, n_tokens, n_pos_per_token);
6080+
inp_pos = ggml_cont(ctx0, ggml_transpose(ctx0, inp_pos));
6081+
inp_pos = ggml_reshape_1d(ctx0, inp_pos, n_pos_per_token * n_tokens);
6082+
60786083
auto * inp_attn = build_attn_inp_kv_unified();
60796084

60806085
int sections[4];

0 commit comments

Comments
 (0)