2
2
#include " llava.h"
3
3
4
4
#include " llama.h"
5
- #include " common.h"
6
5
7
6
#include < algorithm>
8
7
#include < cerrno>
@@ -402,6 +401,38 @@ bool llava_image_embed_make_with_clip_img(clip_ctx * ctx_clip, int n_threads, co
402
401
return true ;
403
402
}
404
403
404
+ struct llava_embd_batch {
405
+ std::vector<llama_pos> pos;
406
+ std::vector<int32_t > n_seq_id;
407
+ std::array <llama_seq_id, 1 > seq_id_0;
408
+ std::vector<llama_seq_id *> seq_ids;
409
+ std::vector<int8_t > logits;
410
+ llama_batch batch;
411
+ llava_embd_batch (float * embd, int32_t n_tokens, llama_pos pos_0, llama_seq_id seq_id) {
412
+ pos .resize (n_tokens);
413
+ n_seq_id.resize (n_tokens);
414
+ seq_ids .resize (n_tokens + 1 );
415
+ logits .resize (n_tokens);
416
+ seq_id_0[0 ] = seq_id;
417
+ seq_ids [n_tokens] = nullptr ;
418
+ batch = {
419
+ /* n_tokens =*/ n_tokens,
420
+ /* tokens =*/ nullptr ,
421
+ /* embd =*/ embd,
422
+ /* pos =*/ pos.data (),
423
+ /* n_seq_id =*/ n_seq_id.data (),
424
+ /* seq_id =*/ seq_ids.data (),
425
+ /* logits =*/ logits.data (),
426
+ };
427
+ for (int i = 0 ; i < n_tokens; i++) {
428
+ batch.pos [i] = pos_0 + i;
429
+ batch.n_seq_id [i] = 1 ;
430
+ batch.seq_id [i] = seq_id_0.data ();
431
+ batch.logits [i] = false ;
432
+ }
433
+ }
434
+ };
435
+
405
436
bool llava_eval_image_embed (llama_context * ctx_llama, const struct llava_image_embed * image_embed, int n_batch, int * n_past) {
406
437
int n_embd = llama_n_embd (llama_get_model (ctx_llama));
407
438
@@ -411,8 +442,8 @@ bool llava_eval_image_embed(llama_context * ctx_llama, const struct llava_image_
411
442
n_eval = n_batch;
412
443
}
413
444
float * embd = image_embed->embed +i*n_embd;
414
- llama_batch batch = llama_batch_get_one (embd, n_eval, *n_past, 0 );
415
- if (llama_decode (ctx_llama, batch)) {
445
+ llava_embd_batch llava_batch = llava_embd_batch (embd, n_eval, *n_past, 0 );
446
+ if (llama_decode (ctx_llama, llava_batch. batch )) {
416
447
LOG_ERR (" %s : failed to eval\n " , __func__);
417
448
return false ;
418
449
}
0 commit comments