Skip to content

Commit 6c01174

Browse files
committed
fix llava and android builds
1 parent 5e8f776 commit 6c01174

File tree

2 files changed

+34
-6
lines changed

2 files changed

+34
-6
lines changed

examples/llama.android/llama/src/main/cpp/llama-android.cpp

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -283,9 +283,6 @@ Java_android_llama_cpp_LLamaAndroid_new_1batch(JNIEnv *, jobject, jint n_tokens,
283283
nullptr,
284284
nullptr,
285285
nullptr,
286-
0,
287-
0,
288-
0,
289286
};
290287

291288
if (embd) {

examples/llava/llava.cpp

Lines changed: 34 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22
#include "llava.h"
33

44
#include "llama.h"
5-
#include "common.h"
65

76
#include <algorithm>
87
#include <cerrno>
@@ -402,6 +401,38 @@ bool llava_image_embed_make_with_clip_img(clip_ctx * ctx_clip, int n_threads, co
402401
return true;
403402
}
404403

404+
struct llava_embd_batch {
405+
std::vector<llama_pos> pos;
406+
std::vector<int32_t> n_seq_id;
407+
std::array <llama_seq_id, 1> seq_id_0;
408+
std::vector<llama_seq_id *> seq_ids;
409+
std::vector<int8_t> logits;
410+
llama_batch batch;
411+
llava_embd_batch(float * embd, int32_t n_tokens, llama_pos pos_0, llama_seq_id seq_id) {
412+
pos .resize(n_tokens);
413+
n_seq_id.resize(n_tokens);
414+
seq_ids .resize(n_tokens + 1);
415+
logits .resize(n_tokens);
416+
seq_id_0[0] = seq_id;
417+
seq_ids [n_tokens] = nullptr;
418+
batch = {
419+
/*n_tokens =*/ n_tokens,
420+
/*tokens =*/ nullptr,
421+
/*embd =*/ embd,
422+
/*pos =*/ pos.data(),
423+
/*n_seq_id =*/ n_seq_id.data(),
424+
/*seq_id =*/ seq_ids.data(),
425+
/*logits =*/ logits.data(),
426+
};
427+
for (int i = 0; i < n_tokens; i++) {
428+
batch.pos [i] = pos_0 + i;
429+
batch.n_seq_id[i] = 1;
430+
batch.seq_id [i] = seq_id_0.data();
431+
batch.logits [i] = false;
432+
}
433+
}
434+
};
435+
405436
bool llava_eval_image_embed(llama_context * ctx_llama, const struct llava_image_embed * image_embed, int n_batch, int * n_past) {
406437
int n_embd = llama_n_embd(llama_get_model(ctx_llama));
407438

@@ -411,8 +442,8 @@ bool llava_eval_image_embed(llama_context * ctx_llama, const struct llava_image_
411442
n_eval = n_batch;
412443
}
413444
float * embd = image_embed->embed+i*n_embd;
414-
llama_batch batch = llama_batch_get_one(embd, n_eval, *n_past, 0);
415-
if (llama_decode(ctx_llama, batch)) {
445+
llava_embd_batch llava_batch = llava_embd_batch(embd, n_eval, *n_past, 0);
446+
if (llama_decode(ctx_llama, llava_batch.batch)) {
416447
LOG_ERR("%s : failed to eval\n", __func__);
417448
return false;
418449
}

0 commit comments

Comments
 (0)