Skip to content

Commit 25b2f8f

Browse files
committed
resolve the image embedding issue in gemma3
1 parent f33dde3 commit 25b2f8f

File tree

2 files changed

+191
-22
lines changed

2 files changed

+191
-22
lines changed

llama_cpp/llama_chat_format.py

Lines changed: 79 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -2835,24 +2835,7 @@ def __call__(
28352835
)
28362836
llama.eval(tokens)
28372837
else:
2838-
image_bytes = self.load_image(value)
2839-
embed = self._embed_image_bytes(image_bytes, llama.context_params.n_threads_batch)
2840-
if llama.n_tokens + embed.contents.n_image_pos > llama.n_ctx():
2841-
raise ValueError(
2842-
f"Prompt exceeds n_ctx: {llama.n_tokens + embed.contents.n_image_pos} > {llama.n_ctx()}"
2843-
)
2844-
n_past = ctypes.c_int(llama.n_tokens)
2845-
n_past_p = ctypes.pointer(n_past)
2846-
with suppress_stdout_stderr(disable=self.verbose):
2847-
self._llava_cpp.llava_eval_image_embed(
2848-
llama.ctx,
2849-
embed,
2850-
llama.n_batch,
2851-
n_past_p,
2852-
)
2853-
# Required to avoid issues with hf tokenizer
2854-
llama.input_ids[llama.n_tokens : n_past.value] = -1
2855-
llama.n_tokens = n_past.value
2838+
self.eval_image(llama, value)
28562839

28572840
# Get prompt tokens to avoid a cache miss
28582841
prompt = llama.input_ids[: llama.n_tokens].tolist()
@@ -2938,6 +2921,26 @@ def __call__(
29382921
)
29392922
return _convert_completion_to_chat(completion_or_chunks, stream=stream)
29402923

2924+
def eval_image(self, llama: llama.Llama, image_url: str):
2925+
image_bytes = self.load_image(image_url)
2926+
embed = self._embed_image_bytes(image_bytes, llama.context_params.n_threads_batch)
2927+
if llama.n_tokens + embed.contents.n_image_pos > llama.n_ctx():
2928+
raise ValueError(
2929+
f"Prompt exceeds n_ctx: {llama.n_tokens + embed.contents.n_image_pos} > {llama.n_ctx()}"
2930+
)
2931+
n_past = ctypes.c_int(llama.n_tokens)
2932+
n_past_p = ctypes.pointer(n_past)
2933+
with suppress_stdout_stderr(disable=self.verbose):
2934+
self._llava_cpp.llava_eval_image_embed(
2935+
llama.ctx,
2936+
embed,
2937+
llama.n_batch,
2938+
n_past_p,
2939+
)
2940+
# Required to avoid issues with hf tokenizer
2941+
llama.input_ids[llama.n_tokens : n_past.value] = -1
2942+
llama.n_tokens = n_past.value
2943+
29412944
@staticmethod
29422945
def _load_image(image_url: str) -> bytes:
29432946
# TODO: Add Pillow support for other image formats beyond (jpg, png)
@@ -3435,10 +3438,10 @@ def split_text_on_image_urls(text: str, image_urls: List[str]):
34353438
if pos != -1:
34363439
assert len(copied_urls) > 0
34373440
if pos > 0:
3438-
split_text += [("text", remaining[:pos])]
3439-
split_text += [("text", "\n\n<start_of_image>")]
3440-
split_text += [("image_url", copied_urls.pop(0))]
3441-
split_text += [("text", "<end_of_image>\n\n")]
3441+
split_text.append(("text", remaining[:pos]))
3442+
split_text.append(("text", "\n\n<start_of_image>"))
3443+
split_text.append(("image_url", copied_urls.pop(0)))
3444+
split_text.append(("text", "<end_of_image>\n\n"))
34423445
remaining = remaining[pos + len(image_placeholder):]
34433446
else:
34443447
assert len(copied_urls) == 0
@@ -3461,6 +3464,60 @@ def get_image_urls(messages: List[llama_types.ChatCompletionRequestMessage]):
34613464
image_urls.append(content["url"])
34623465
return image_urls
34633466

3467+
def eval_image(self, llama: llama.Llama, image_url: str):
3468+
import llama_cpp
3469+
3470+
img_bytes = self.load_image(image_url)
3471+
img_u8_p = self._llava_cpp.clip_image_u8_init()
3472+
if not self._llava_cpp.clip_image_load_from_bytes(
3473+
ctypes.create_string_buffer(img_bytes, len(img_bytes)),
3474+
ctypes.c_size_t(len(img_bytes)),
3475+
img_u8_p,
3476+
):
3477+
self._llava_cpp.clip_image_u8_free(img_u8_p)
3478+
raise ValueError("Failed to load image.")
3479+
3480+
img_f32 = self._llava_cpp.clip_image_f32_batch()
3481+
img_f32_p = ctypes.byref(img_f32)
3482+
if not self._llava_cpp.clip_image_preprocess(self.clip_ctx, img_u8_p, img_f32_p):
3483+
self._llava_cpp.clip_image_f32_batch_free(img_f32_p)
3484+
self._llava_cpp.clip_image_u8_free(img_u8_p)
3485+
raise ValueError("Failed to preprocess image.")
3486+
3487+
n_embd = llama_cpp.llama_model_n_embd(llama._model.model)
3488+
n_tokens = 256
3489+
embed = (ctypes.c_float * (n_tokens * n_embd))()
3490+
if not self._llava_cpp.clip_image_batch_encode(self.clip_ctx, llama.n_threads, img_f32_p, embed):
3491+
self._llava_cpp.clip_image_f32_batch_free(img_f32_p)
3492+
self._llava_cpp.clip_image_u8_free(img_u8_p)
3493+
raise ValueError("Failed to encode image.")
3494+
3495+
self._llava_cpp.clip_image_f32_batch_free(img_f32_p)
3496+
self._llava_cpp.clip_image_u8_free(img_u8_p)
3497+
llama_cpp.llama_set_causal_attn(llama.ctx, False)
3498+
3499+
seq_id_0 = (ctypes.c_int32 * 1)()
3500+
seq_ids = (ctypes.POINTER(ctypes.c_int32) * (n_tokens + 1))()
3501+
for i in range(n_tokens):
3502+
seq_ids[i] = seq_id_0
3503+
3504+
batch = llama_cpp.llama_batch()
3505+
batch.n_tokens = n_tokens
3506+
batch.token = None
3507+
batch.embd = embed
3508+
batch.pos = (ctypes.c_int32 * n_tokens)(*[i + llama.n_tokens for i in range(n_tokens)])
3509+
batch.seq_id = seq_ids
3510+
batch.n_seq_id = (ctypes.c_int32 * n_tokens)(*([1] * n_tokens))
3511+
batch.logits = (ctypes.c_int8 * n_tokens)()
3512+
3513+
if llama_cpp.llama_decode(llama.ctx, batch):
3514+
raise ValueError("Failed to decode image.")
3515+
3516+
llama_cpp.llama_set_causal_attn(llama.ctx, True)
3517+
# Required to avoid issues with hf tokenizer
3518+
llama.input_ids[llama.n_tokens : llama.n_tokens + n_tokens] = -1
3519+
llama.n_tokens += n_tokens
3520+
34643521

34653522
@register_chat_completion_handler("chatml-function-calling")
34663523
def chatml_function_calling(

llama_cpp/llava_cpp.py

Lines changed: 112 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
c_int,
88
c_uint8,
99
c_float,
10+
c_size_t,
1011
c_void_p,
1112
POINTER,
1213
_Pointer, # type: ignore
@@ -141,6 +142,28 @@ def llava_eval_image_embed(
141142
################################################
142143

143144

145+
# struct clip_image_u8_batch {
146+
# struct clip_image_u8 * data;
147+
# size_t size;
148+
# };
149+
class clip_image_u8_batch(Structure):
150+
_fields_ = [
151+
("data", c_void_p),
152+
("size", c_size_t),
153+
]
154+
155+
156+
# struct clip_image_f32_batch {
157+
# struct clip_image_f32 * data;
158+
# size_t size;
159+
# };
160+
class clip_image_f32_batch(Structure):
161+
_fields_ = [
162+
("data", c_void_p),
163+
("size", c_size_t),
164+
]
165+
166+
144167
# /** load mmproj model */
145168
# CLIP_API struct clip_ctx * clip_model_load (const char * fname, int verbosity);
146169
@ctypes_function("clip_model_load", [c_char_p, c_int], clip_ctx_p_ctypes)
@@ -156,3 +179,92 @@ def clip_model_load(
156179
def clip_free(ctx: clip_ctx_p, /):
157180
...
158181

182+
183+
# CLIP_API struct clip_image_u8 * clip_image_u8_init ();
184+
@ctypes_function("clip_image_u8_init", [], c_void_p)
185+
def clip_image_u8_init() -> Optional[c_void_p]:
186+
...
187+
188+
189+
# CLIP_API void clip_image_u8_free (struct clip_image_u8 * img);
190+
@ctypes_function("clip_image_u8_free", [c_void_p], None)
191+
def clip_image_u8_free(img: c_void_p, /):
192+
...
193+
194+
195+
# CLIP_API void clip_image_f32_free(struct clip_image_f32 * img);
196+
@ctypes_function("clip_image_f32_free", [c_void_p], None)
197+
def clip_image_f32_free(img: c_void_p, /):
198+
...
199+
200+
201+
# CLIP_API void clip_image_u8_batch_free (struct clip_image_u8_batch * batch);
202+
@ctypes_function("clip_image_u8_batch_free", [POINTER(clip_image_u8_batch)], None)
203+
def clip_image_u8_batch_free(batch: "_Pointer[clip_image_u8_batch]", /):
204+
...
205+
206+
207+
# CLIP_API void clip_image_f32_batch_free(struct clip_image_f32_batch * batch);
208+
@ctypes_function("clip_image_f32_batch_free", [POINTER(clip_image_f32_batch)], None)
209+
def clip_image_f32_batch_free(batch: "_Pointer[clip_image_f32_batch]", /):
210+
...
211+
212+
213+
# /** preprocess img and store the result in res_imgs, pad_to_square may be overridden to false depending on model configuration */
214+
# CLIP_API bool clip_image_preprocess(struct clip_ctx * ctx, const struct clip_image_u8 * img, struct clip_image_f32_batch * res_imgs );
215+
@ctypes_function(
216+
"clip_image_preprocess",
217+
[
218+
clip_ctx_p_ctypes,
219+
c_void_p,
220+
POINTER(clip_image_f32_batch),
221+
],
222+
c_bool,
223+
)
224+
def clip_image_preprocess(
225+
ctx: clip_ctx_p,
226+
img: c_void_p,
227+
res_imgs: "_Pointer[clip_image_f32_batch]",
228+
/,
229+
) -> bool:
230+
...
231+
232+
233+
# CLIP_API bool clip_image_batch_encode(struct clip_ctx * ctx, int n_threads, const struct clip_image_f32_batch * imgs, float * vec);
234+
@ctypes_function(
235+
"clip_image_batch_encode",
236+
[
237+
clip_ctx_p_ctypes,
238+
c_int,
239+
POINTER(clip_image_f32_batch),
240+
POINTER(c_float),
241+
],
242+
c_bool,
243+
)
244+
def clip_image_batch_encode(
245+
ctx: clip_ctx_p,
246+
n_threads: c_int,
247+
imgs: "_Pointer[clip_image_f32_batch]",
248+
vec: c_void_p
249+
) -> bool:
250+
...
251+
252+
253+
# /** interpret bytes as an image file with length bytes_length, and use the result to populate img */
254+
# CLIP_API bool clip_image_load_from_bytes(const unsigned char * bytes, size_t bytes_length, struct clip_image_u8 * img);
255+
@ctypes_function(
256+
"clip_image_load_from_bytes",
257+
[
258+
c_void_p,
259+
c_size_t,
260+
c_void_p,
261+
],
262+
c_bool,
263+
)
264+
def clip_image_load_from_bytes(
265+
bytes: c_void_p,
266+
bytes_length: c_size_t,
267+
img: c_void_p,
268+
/,
269+
) -> bool:
270+
...

0 commit comments

Comments
 (0)