Skip to content

Commit 438bb69

Browse files
tc-mbiceflame89harvestingmoon
authored andcommitted
llava : support MiniCPM-V-2.6 (ggml-org#8967)
* init * rename * add run android for termux in readme * add android readme * add instructions in readme * change name in readme * Update README.md * fixed line * add result in readme * random pos_embed * add positions index * change for ollama * change for ollama * better pos_embed in clip * support ollama * updata cmakelist * updata cmakelist * rename wrapper * clear code * replace and organize code * add link * sync master * fix warnings * fix warnings * fix bug in bicubic resize when need resize iamge smaller * receive review comments and modify * receive review comments and modify * put all code into llava dir * fix quality problem in pr code * change n_layer * add space in "-1" * imitate reshape bug of python code * fix bug in clip * fix issues for merging * fix llama-minicpmv-cli in cmake file * change pr readme * fix code review * remove in line 33 directory in the /cmakelists.txt (not in example, in the main dir * fix cmakefile * add warn * fix KEY_HAS_MINICPMV_PROJ * remove load_image_size into clip_ctx * remove the extern "C", MINICPMV_API * fix uhd code for review comment * delete minicpmv-wrapper in pr * remove uhd_image_embed * Modify 2 notes * support minicpmv2.6 * modify convert script of minicpmv * modify convert * modify convert * add readme * add resampler of v2.6 * modify clip * modify readme * fix type-check * fix type-check * fix type-check * fix type-check * modify convert script and readme * fix convert script and readme * fix convert * fix num in convert * fix type-check --------- Co-authored-by: Hongji Zhu <fireyoucan@gmail.com> Co-authored-by: harvestingmoon <leewenyeong@gmail.com>
1 parent 5b85cc6 commit 438bb69

8 files changed

+645
-35
lines changed

examples/llava/README-minicpmv2.5.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,8 @@ Convert PyTorch model to gguf files (You can also download the converted [gguf](
1616

1717
```bash
1818
python ./examples/minicpmv/minicpmv-surgery.py -m ../MiniCPM-Llama3-V-2_5
19-
python ./examples/minicpmv/minicpmv-convert-image-encoder-to-gguf.py -m ../MiniCPM-Llama3-V-2_5 --minicpmv-projector ../MiniCPM-Llama3-V-2_5/minicpmv.projector --output-dir ../MiniCPM-Llama3-V-2_5/ --image-mean 0.5 0.5 0.5 --image-std 0.5 0.5 0.5
20-
python ./convert-hf-to-gguf.py ../MiniCPM-Llama3-V-2_5/model
19+
python ./examples/minicpmv/minicpmv-convert-image-encoder-to-gguf.py -m ../MiniCPM-Llama3-V-2_5 --minicpmv-projector ../MiniCPM-Llama3-V-2_5/minicpmv.projector --output-dir ../MiniCPM-Llama3-V-2_5/ --image-mean 0.5 0.5 0.5 --image-std 0.5 0.5 0.5 --minicpmv_version 2
20+
python ./convert_hf_to_gguf.py ../MiniCPM-Llama3-V-2_5/model
2121

2222
# quantize int4 version
2323
./llama-quantize ../MiniCPM-Llama3-V-2_5/model/model-8B-F16.gguf ../MiniCPM-Llama3-V-2_5/model/ggml-model-Q4_K_M.gguf Q4_K_M

examples/llava/README-minicpmv2.6.md

Lines changed: 107 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,107 @@
1+
## MiniCPM-V 2.6
2+
3+
### Prepare models and code
4+
5+
Download [MiniCPM-V-2_6](https://huggingface.co/openbmb/MiniCPM-V-2_6) PyTorch model from huggingface to "MiniCPM-V-2_6" folder.
6+
7+
Clone llama.cpp:
8+
```bash
9+
git clone git@github.com:OpenBMB/llama.cpp.git
10+
cd llama.cpp
11+
git checkout minicpmv-main
12+
```
13+
14+
### Usage of MiniCPM-V 2.6
15+
16+
Convert PyTorch model to gguf files (You can also download the converted [gguf](https://huggingface.co/openbmb/MiniCPM-V-2_6-gguf) by us)
17+
18+
```bash
19+
python ./examples/llava/minicpmv-surgery.py -m ../MiniCPM-V-2_6
20+
python ./examples/llava/minicpmv-convert-image-encoder-to-gguf.py -m ../MiniCPM-V-2_6 --minicpmv-projector ../MiniCPM-V-2_6/minicpmv.projector --output-dir ../MiniCPM-V-2_6/ --image-mean 0.5 0.5 0.5 --image-std 0.5 0.5 0.5 --minicpmv_version 3
21+
python ./convert_hf_to_gguf.py ../MiniCPM-V-2_6/model
22+
23+
# quantize int4 version
24+
./llama-quantize ../MiniCPM-V-2_6/model/ggml-model-f16.gguf ../MiniCPM-V-2_6/model/ggml-model-Q4_K_M.gguf Q4_K_M
25+
```
26+
27+
Build for Linux or Mac
28+
29+
```bash
30+
make
31+
make llama-minicpmv-cli
32+
```
33+
34+
Inference on Linux or Mac
35+
```
36+
# run f16 version
37+
./llama-minicpmv-cli -m ../MiniCPM-V-2_6/model/ggml-model-f16.gguf --mmproj ../MiniCPM-V-2_6/mmproj-model-f16.gguf -c 4096 --temp 0.7 --top-p 0.8 --top-k 100 --repeat-penalty 1.05 --image xx.jpg -p "What is in the image?"
38+
39+
# run quantized int4 version
40+
./llama-minicpmv-cli -m ../MiniCPM-V-2_6/model/ggml-model-Q4_K_M.gguf --mmproj ../MiniCPM-V-2_6/mmproj-model-f16.gguf -c 4096 --temp 0.7 --top-p 0.8 --top-k 100 --repeat-penalty 1.05 --image xx.jpg -p "What is in the image?"
41+
42+
# or run in interactive mode
43+
./llama-minicpmv-cli -m ../MiniCPM-V-2_6/model/ggml-model-Q4_K_M.gguf --mmproj ../MiniCPM-V-2_6/mmproj-model-f16.gguf -c 4096 --temp 0.7 --top-p 0.8 --top-k 100 --repeat-penalty 1.05 --image xx.jpg -i
44+
```
45+
46+
### Video
47+
Install FFmpeg
48+
```
49+
brew install ffmpeg
50+
brew install pkg-config
51+
```
52+
53+
### Android
54+
55+
#### Build on Android device using Termux
56+
We found that build on Android device would bring better runtime performance, so we recommend to build on device.
57+
58+
[Termux](https://github.com/termux/termux-app#installation) is a terminal app on Android device (no root required).
59+
60+
Install tools in Termux:
61+
```
62+
apt update && apt upgrade -y
63+
apt install git make cmake
64+
```
65+
66+
It's recommended to move your model inside the `~/` directory for best performance:
67+
```
68+
cd storage/downloads
69+
mv model.gguf ~/
70+
```
71+
72+
#### Building the Project using Android NDK
73+
Obtain the [Android NDK](https://developer.android.com/ndk) and then build with CMake.
74+
75+
Execute the following commands on your computer to avoid downloading the NDK to your mobile. Alternatively, you can also do this in Termux:
76+
77+
```bash
78+
mkdir build-android
79+
cd build-android
80+
export NDK=/your_ndk_path
81+
cmake -DCMAKE_TOOLCHAIN_FILE=$NDK/build/cmake/android.toolchain.cmake -DANDROID_ABI=arm64-v8a -DANDROID_PLATFORM=android-23 -DCMAKE_C_FLAGS=-march=armv8.4a+dotprod ..
82+
make
83+
```
84+
85+
Install [termux](https://github.com/termux/termux-app#installation) on your device and run `termux-setup-storage` to get access to your SD card (if Android 11+ then run the command twice).
86+
87+
Finally, copy these built `llama` binaries and the model file to your device storage. Because the file permissions in the Android sdcard cannot be changed, you can copy the executable files to the `/data/data/com.termux/files/home/bin` path, and then execute the following commands in Termux to add executable permission:
88+
89+
(Assumed that you have pushed the built executable files to the /sdcard/llama.cpp/bin path using `adb push`)
90+
```
91+
$cp -r /sdcard/llama.cpp/bin /data/data/com.termux/files/home/
92+
$cd /data/data/com.termux/files/home/bin
93+
$chmod +x ./*
94+
```
95+
96+
Download models and push them to `/sdcard/llama.cpp/`, then move it to `/data/data/com.termux/files/home/model/`
97+
98+
```
99+
$mv /sdcard/llama.cpp/ggml-model-Q4_K_M.gguf /data/data/com.termux/files/home/model/
100+
$mv /sdcard/llama.cpp/mmproj-model-f16.gguf /data/data/com.termux/files/home/model/
101+
```
102+
103+
Now, you can start chatting:
104+
```
105+
$cd /data/data/com.termux/files/home/bin
106+
$./llama-minicpmv-cli -m ../model/ggml-model-Q4_K_M.gguf --mmproj ../model/mmproj-model-f16.gguf -c 4096 --temp 0.7 --top-p 0.8 --top-k 100 --repeat-penalty 1.05 --image xx.jpg -p "What is in the image?"
107+
```

examples/llava/clip.cpp

Lines changed: 75 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,7 @@ static std::string format(const char * fmt, ...) {
8181
#define KEY_HAS_VIS_ENC "clip.has_vision_encoder"
8282
#define KEY_HAS_LLAVA_PROJ "clip.has_llava_projector"
8383
#define KEY_HAS_MINICPMV_PROJ "clip.has_minicpmv_projector"
84+
#define KEY_MINICPMV_VERSION "clip.minicpmv_version"
8485
#define KEY_USE_GELU "clip.use_gelu"
8586
#define KEY_N_EMBD "clip.%s.embedding_length"
8687
#define KEY_N_FF "clip.%s.feed_forward_length"
@@ -526,6 +527,7 @@ struct clip_ctx {
526527
bool has_vision_encoder = false;
527528
bool has_llava_projector = false;
528529
bool has_minicpmv_projector = false;
530+
int minicpmv_version = 2;
529531

530532
struct clip_vision_model vision_model;
531533
projector_type proj_type = PROJECTOR_TYPE_MLP;
@@ -641,7 +643,12 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32
641643
if (ctx->has_minicpmv_projector) {
642644
int pos_w = image_size_width/patch_size;
643645
int pos_h = image_size_height/patch_size;
644-
pos_embed = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, 4096, pos_w * pos_h, 1);
646+
if (ctx->minicpmv_version == 2) {
647+
pos_embed = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, 4096, pos_w * pos_h, 1);
648+
}
649+
else if (ctx->minicpmv_version == 3) {
650+
pos_embed = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, 3584, pos_w * pos_h, 1);
651+
}
645652
ggml_set_name(pos_embed, "pos_embed");
646653
ggml_set_input(pos_embed);
647654
}
@@ -768,8 +775,8 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32
768775
embeddings = ggml_gelu(ctx0, embeddings);
769776
embeddings = ggml_mul_mat(ctx0, model.mm_2_w, embeddings);
770777
embeddings = ggml_add(ctx0, embeddings, model.mm_2_b);
771-
772-
} else if (ctx->proj_type == PROJECTOR_TYPE_MLP_NORM) {
778+
}
779+
else if (ctx->proj_type == PROJECTOR_TYPE_MLP_NORM) {
773780
embeddings = ggml_mul_mat(ctx0, model.mm_0_w, embeddings);
774781
embeddings = ggml_add(ctx0, embeddings, model.mm_0_b);
775782
// ggml_tensor_printf(embeddings, "mm_0_w",0,true,false);
@@ -949,10 +956,20 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32
949956
}
950957

951958
{ // attention
952-
const int hidden_size = 4096;
959+
int hidden_size = 4096;
953960
const int d_head = 128;
954-
const int n_head = hidden_size/d_head;
955-
const int num_query = 96;
961+
int n_head = hidden_size/d_head;
962+
int num_query = 96;
963+
if (ctx->minicpmv_version == 2) {
964+
hidden_size = 4096;
965+
n_head = hidden_size/d_head;
966+
num_query = 96;
967+
}
968+
else if (ctx->minicpmv_version == 3) {
969+
hidden_size = 3584;
970+
n_head = hidden_size/d_head;
971+
num_query = 64;
972+
}
956973

957974
struct ggml_tensor * Q = ggml_add(ctx0, ggml_mul_mat(ctx0, model.mm_model_attn_q_w, q), model.mm_model_attn_q_b);
958975
Q = ggml_scale_inplace(ctx0, Q, 1.0f / sqrt((float)d_head));
@@ -1149,6 +1166,11 @@ struct clip_ctx * clip_model_load(const char * fname, const int verbosity = 1) {
11491166
new_clip->has_minicpmv_projector = gguf_get_val_bool(ctx, idx);
11501167
}
11511168

1169+
idx = gguf_find_key(ctx, KEY_MINICPMV_VERSION);
1170+
if (idx != -1) {
1171+
new_clip->minicpmv_version = gguf_get_val_i32(ctx, idx);
1172+
}
1173+
11521174
// GGML_ASSERT(new_clip->has_llava_projector); // see monatis/clip.cpp for image and/or text encoding for semantic search
11531175

11541176
GGML_ASSERT(new_clip->has_vision_encoder);
@@ -1910,10 +1932,12 @@ int clip_uhd_num_image_embeds_col(struct clip_ctx * ctx_clip) {
19101932
// returns the normalized float tensor for llava-1.5, for spatial_unpad with anyres processing for llava-1.6 it returns the normalized image patch tensors as a vector
19111933
// res_imgs memory is being allocated here, previous allocations will be freed if found
19121934
bool clip_image_preprocess(struct clip_ctx * ctx, const clip_image_u8 * img, clip_image_f32_batch * res_imgs) {
1913-
if (clip_is_minicpmv(ctx)) {
1914-
std::vector<std::vector<clip_image_u8 *>> imgs = uhd_slice_image(img);
1935+
1936+
if(clip_is_minicpmv(ctx)){
1937+
int max_slice_nums = 9;
1938+
std::vector<std::vector<clip_image_u8 *>> imgs = uhd_slice_image(img, max_slice_nums);
19151939
res_imgs->size = 0;
1916-
for (size_t i = 0; i < imgs.size(); ++i) {
1940+
for (size_t i = 0; i < imgs.size(); ++i){
19171941
res_imgs->size += imgs[i].size();
19181942
}
19191943
res_imgs->data = new clip_image_f32[res_imgs->size];
@@ -2146,7 +2170,12 @@ int clip_n_patches(const struct clip_ctx * ctx) {
21462170
if (ctx->proj_type == PROJECTOR_TYPE_LDP || ctx->proj_type == PROJECTOR_TYPE_LDPV2) {
21472171
n_patches /= 4;
21482172
} else if (ctx->proj_type == PROJECTOR_TYPE_RESAMPLER) {
2149-
n_patches = 96;
2173+
if (ctx->minicpmv_version == 2) {
2174+
n_patches = 96;
2175+
}
2176+
else if (ctx->minicpmv_version == 3) {
2177+
n_patches = 64;
2178+
}
21502179
}
21512180

21522181
return n_patches;
@@ -2282,6 +2311,11 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima
22822311
const int patch_size = hparams.patch_size;
22832312
const int num_patches = ((image_size_width / patch_size) * (image_size_height / patch_size));
22842313
const int num_positions = num_patches + (ctx->has_class_embedding ? 1 : 0);
2314+
if(ctx->load_image_size==nullptr){
2315+
ctx->load_image_size= clip_image_size_init();
2316+
}
2317+
const int pos_w = ctx->load_image_size->width/patch_size;
2318+
const int pos_h = ctx->load_image_size->height/patch_size;
22852319

22862320
{
22872321
struct ggml_tensor * inp_raw = ggml_graph_get_tensor(gf, "inp_raw");
@@ -2316,8 +2350,18 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima
23162350
// -> https://huggingface.co/HuggingFaceM4/siglip-so400m-14-980-flash-attn2-navit/blob/d66538faeba44480d0bfaa42145eef26f9423199/modeling_siglip.py#L316
23172351
struct ggml_tensor * positions = ggml_graph_get_tensor(gf, "positions");
23182352
int* positions_data = (int*)malloc(ggml_nbytes(positions));
2319-
for (int i = 0; i < num_positions; i++) {
2320-
positions_data[i] = std::floor(70.0*i/num_positions);
2353+
int bucket_coords_h[70];
2354+
int bucket_coords_w[70];
2355+
for (int i = 0; i < pos_h; i++){
2356+
bucket_coords_h[i] = std::floor(70.0*i/pos_h);
2357+
}
2358+
for (int i = 0; i < pos_w; i++){
2359+
bucket_coords_w[i] = std::floor(70.0*i/pos_w);
2360+
}
2361+
for (int i = 0, id = 0; i < pos_h; i++){
2362+
for (int j = 0; j < pos_w; j++){
2363+
positions_data[id++] = bucket_coords_h[i]*70 + bucket_coords_w[j];
2364+
}
23212365
}
23222366
ggml_backend_tensor_set(positions, positions_data, 0, ggml_nbytes(positions));
23232367
free(positions_data);
@@ -2328,12 +2372,13 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima
23282372
// -> https://huggingface.co/Qwen/Qwen-VL/tree/main
23292373
// -> https://huggingface.co/Qwen/Qwen-VL/blob/0547ed36a86561e2e42fecec8fd0c4f6953e33c4/visual.py#L23
23302374
struct ggml_tensor * pos_embed = ggml_graph_get_tensor(gf, "pos_embed");
2331-
if(ctx->load_image_size==nullptr){
2332-
ctx->load_image_size= clip_image_size_init();
2333-
}
2334-
int pos_w = ctx->load_image_size->width/patch_size;
2335-
int pos_h = ctx->load_image_size->height/patch_size;
23362375
int embed_dim = 4096;
2376+
if (ctx->minicpmv_version == 2) {
2377+
embed_dim = 4096;
2378+
}
2379+
else if (ctx->minicpmv_version == 3) {
2380+
embed_dim = 3584;
2381+
}
23372382
auto pos_embed_t = get_2d_sincos_pos_embed(embed_dim, std::make_pair(pos_w, pos_h));
23382383

23392384
float * pos_embed_data = (float *)malloc(ggml_nbytes(pos_embed));
@@ -2346,7 +2391,8 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima
23462391
ggml_backend_tensor_set(pos_embed, pos_embed_data, 0, ggml_nbytes(pos_embed));
23472392
free(pos_embed_data);
23482393
}
2349-
} else {
2394+
}
2395+
else{
23502396
{
23512397
if (ctx->has_class_embedding) {
23522398
struct ggml_tensor * embeddings = ggml_graph_get_tensor(gf, "embeddings");
@@ -2548,13 +2594,21 @@ int clip_n_mmproj_embd(const struct clip_ctx * ctx) {
25482594
return ctx->vision_model.mm_3_b->ne[0];
25492595
}
25502596
if (ctx->proj_type == PROJECTOR_TYPE_RESAMPLER) {
2551-
return 4096;
2597+
if (ctx->minicpmv_version == 2) {
2598+
return 4096;
2599+
}
2600+
else if (ctx->minicpmv_version == 3) {
2601+
return 3584;
2602+
}
25522603
}
25532604

25542605
std::string proj_type = PROJECTOR_TYPE_NAMES[ctx->proj_type];
25552606
throw std::runtime_error(format("%s: don't support projector with: %s currently\n", __func__, proj_type.c_str()));
25562607
}
25572608

2558-
bool clip_is_minicpmv(const struct clip_ctx * ctx) {
2559-
return ctx->has_minicpmv_projector;
2609+
int clip_is_minicpmv(const struct clip_ctx * ctx) {
2610+
if (ctx->has_minicpmv_projector) {
2611+
return ctx->minicpmv_version;
2612+
}
2613+
return 0;
25602614
}

examples/llava/clip.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,7 @@ CLIP_API bool clip_image_batch_encode(struct clip_ctx * ctx, int n_threads, cons
8585

8686
CLIP_API bool clip_model_quantize(const char * fname_inp, const char * fname_out, int itype);
8787

88-
CLIP_API bool clip_is_minicpmv(const struct clip_ctx * ctx);
88+
CLIP_API int clip_is_minicpmv(const struct clip_ctx * ctx);
8989

9090
#ifdef __cplusplus
9191
}

examples/llava/llava.cpp

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -256,7 +256,14 @@ static bool encode_image_with_clip(clip_ctx * ctx_clip, int n_threads, const cli
256256
load_image_size->width = img_res_v.data[i].nx;
257257
load_image_size->height = img_res_v.data[i].ny;
258258
clip_add_load_image_size(ctx_clip, load_image_size);
259-
const bool encoded = clip_image_encode(ctx_clip, n_threads, only_v2_5_reshape_by_patch(&img_res_v.data[i], patch_size), image_embd_v[i]);
259+
bool encoded = false;
260+
int has_minicpmv_projector = clip_is_minicpmv(ctx_clip);
261+
if (has_minicpmv_projector == 2) {
262+
encoded = clip_image_encode(ctx_clip, n_threads, only_v2_5_reshape_by_patch(&img_res_v.data[i], patch_size), image_embd_v[i]);
263+
}
264+
else if (has_minicpmv_projector == 3) {
265+
encoded = clip_image_encode(ctx_clip, n_threads, &img_res_v.data[i], image_embd_v[i]);
266+
}
260267
if (!encoded) {
261268
LOG_TEE("Unable to encode image - spatial_unpad - subimage %d of %d\n", (int) i+1, (int) img_res_v.size);
262269
return false;

examples/llava/minicpmv-cli.cpp

Lines changed: 23 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -134,7 +134,13 @@ static void process_image(struct llava_context * ctx_llava, struct llava_image_e
134134
std::string system_prompt;
135135
int idx = 0;
136136
int num_image_embeds = embeds->n_image_pos / clip_n_patches(ctx_llava->ctx_clip);
137-
system_prompt = "<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\n";
137+
int has_minicpmv_projector = clip_is_minicpmv(ctx_llava->ctx_clip);
138+
if (has_minicpmv_projector == 2) {
139+
system_prompt = "<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\n";
140+
}
141+
else if (has_minicpmv_projector == 3) {
142+
system_prompt = "<|im_start|>user\n";
143+
}
138144
LOG_TEE("%s: image token past: %d\n", __func__, n_past);
139145
eval_string(ctx_llava->ctx_llama, (system_prompt+"<image>").c_str(), params->n_batch, &n_past, false);
140146
process_eval_image_embed(ctx_llava, embeds, params->n_batch, &n_past, idx++);
@@ -210,10 +216,24 @@ static struct llava_context * minicpmv_init(gpt_params * params, const std::stri
210216

211217
static struct llama_sampling_context * llama_init(struct llava_context * ctx_llava, gpt_params * params, std::string prompt, int &n_past, bool is_first = false){
212218
std::string user_prompt = prompt;
213-
if (!is_first) user_prompt = "<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\n" + prompt;
219+
int has_minicpmv_projector = clip_is_minicpmv(ctx_llava->ctx_clip);
220+
if (!is_first) {
221+
if (has_minicpmv_projector == 2) {
222+
user_prompt = "<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\n" + prompt;
223+
}
224+
else if (has_minicpmv_projector == 3) {
225+
user_prompt = "<|im_start|>user\n" + prompt;
226+
}
227+
}
214228

215229
eval_string(ctx_llava->ctx_llama, user_prompt.c_str(), params->n_batch, &n_past, false);
216-
eval_string(ctx_llava->ctx_llama, "<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n", params->n_batch, &n_past, false);
230+
if (has_minicpmv_projector == 2) {
231+
eval_string(ctx_llava->ctx_llama, "<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n", params->n_batch, &n_past, false);
232+
}
233+
else if (has_minicpmv_projector == 3) {
234+
eval_string(ctx_llava->ctx_llama, "<|im_end|><|im_start|>assistant\n", params->n_batch, &n_past, false);
235+
}
236+
217237
// generate the response
218238

219239
LOG_TEE("\n");

0 commit comments

Comments
 (0)