@@ -2835,24 +2835,7 @@ def __call__(
2835
2835
)
2836
2836
llama .eval (tokens )
2837
2837
else :
2838
- image_bytes = self .load_image (value )
2839
- embed = self ._embed_image_bytes (image_bytes , llama .context_params .n_threads_batch )
2840
- if llama .n_tokens + embed .contents .n_image_pos > llama .n_ctx ():
2841
- raise ValueError (
2842
- f"Prompt exceeds n_ctx: { llama .n_tokens + embed .contents .n_image_pos } > { llama .n_ctx ()} "
2843
- )
2844
- n_past = ctypes .c_int (llama .n_tokens )
2845
- n_past_p = ctypes .pointer (n_past )
2846
- with suppress_stdout_stderr (disable = self .verbose ):
2847
- self ._llava_cpp .llava_eval_image_embed (
2848
- llama .ctx ,
2849
- embed ,
2850
- llama .n_batch ,
2851
- n_past_p ,
2852
- )
2853
- # Required to avoid issues with hf tokenizer
2854
- llama .input_ids [llama .n_tokens : n_past .value ] = - 1
2855
- llama .n_tokens = n_past .value
2838
+ self .eval_image (llama , value )
2856
2839
2857
2840
# Get prompt tokens to avoid a cache miss
2858
2841
prompt = llama .input_ids [: llama .n_tokens ].tolist ()
@@ -2938,6 +2921,26 @@ def __call__(
2938
2921
)
2939
2922
return _convert_completion_to_chat (completion_or_chunks , stream = stream )
2940
2923
2924
+ def eval_image (self , llama : llama .Llama , image_url : str ):
2925
+ image_bytes = self .load_image (image_url )
2926
+ embed = self ._embed_image_bytes (image_bytes , llama .context_params .n_threads_batch )
2927
+ if llama .n_tokens + embed .contents .n_image_pos > llama .n_ctx ():
2928
+ raise ValueError (
2929
+ f"Prompt exceeds n_ctx: { llama .n_tokens + embed .contents .n_image_pos } > { llama .n_ctx ()} "
2930
+ )
2931
+ n_past = ctypes .c_int (llama .n_tokens )
2932
+ n_past_p = ctypes .pointer (n_past )
2933
+ with suppress_stdout_stderr (disable = self .verbose ):
2934
+ self ._llava_cpp .llava_eval_image_embed (
2935
+ llama .ctx ,
2936
+ embed ,
2937
+ llama .n_batch ,
2938
+ n_past_p ,
2939
+ )
2940
+ # Required to avoid issues with hf tokenizer
2941
+ llama .input_ids [llama .n_tokens : n_past .value ] = - 1
2942
+ llama .n_tokens = n_past .value
2943
+
2941
2944
@staticmethod
2942
2945
def _load_image (image_url : str ) -> bytes :
2943
2946
# TODO: Add Pillow support for other image formats beyond (jpg, png)
@@ -3435,10 +3438,10 @@ def split_text_on_image_urls(text: str, image_urls: List[str]):
3435
3438
if pos != - 1 :
3436
3439
assert len (copied_urls ) > 0
3437
3440
if pos > 0 :
3438
- split_text += [( "text" , remaining [:pos ])]
3439
- split_text += [( "text" , "\n \n <start_of_image>" )]
3440
- split_text += [( "image_url" , copied_urls .pop (0 ))]
3441
- split_text += [( "text" , "<end_of_image>\n \n " )]
3441
+ split_text . append (( "text" , remaining [:pos ]))
3442
+ split_text . append (( "text" , "\n \n <start_of_image>" ))
3443
+ split_text . append (( "image_url" , copied_urls .pop (0 )))
3444
+ split_text . append (( "text" , "<end_of_image>\n \n " ))
3442
3445
remaining = remaining [pos + len (image_placeholder ):]
3443
3446
else :
3444
3447
assert len (copied_urls ) == 0
@@ -3461,6 +3464,60 @@ def get_image_urls(messages: List[llama_types.ChatCompletionRequestMessage]):
3461
3464
image_urls .append (content ["url" ])
3462
3465
return image_urls
3463
3466
3467
+ def eval_image (self , llama : llama .Llama , image_url : str ):
3468
+ import llama_cpp
3469
+
3470
+ img_bytes = self .load_image (image_url )
3471
+ img_u8_p = self ._llava_cpp .clip_image_u8_init ()
3472
+ if not self ._llava_cpp .clip_image_load_from_bytes (
3473
+ ctypes .create_string_buffer (img_bytes , len (img_bytes )),
3474
+ ctypes .c_size_t (len (img_bytes )),
3475
+ img_u8_p ,
3476
+ ):
3477
+ self ._llava_cpp .clip_image_u8_free (img_u8_p )
3478
+ raise ValueError ("Failed to load image." )
3479
+
3480
+ img_f32 = self ._llava_cpp .clip_image_f32_batch ()
3481
+ img_f32_p = ctypes .byref (img_f32 )
3482
+ if not self ._llava_cpp .clip_image_preprocess (self .clip_ctx , img_u8_p , img_f32_p ):
3483
+ self ._llava_cpp .clip_image_f32_batch_free (img_f32_p )
3484
+ self ._llava_cpp .clip_image_u8_free (img_u8_p )
3485
+ raise ValueError ("Failed to preprocess image." )
3486
+
3487
+ n_embd = llama_cpp .llama_model_n_embd (llama ._model .model )
3488
+ n_tokens = 256
3489
+ embed = (ctypes .c_float * (n_tokens * n_embd ))()
3490
+ if not self ._llava_cpp .clip_image_batch_encode (self .clip_ctx , llama .n_threads , img_f32_p , embed ):
3491
+ self ._llava_cpp .clip_image_f32_batch_free (img_f32_p )
3492
+ self ._llava_cpp .clip_image_u8_free (img_u8_p )
3493
+ raise ValueError ("Failed to encode image." )
3494
+
3495
+ self ._llava_cpp .clip_image_f32_batch_free (img_f32_p )
3496
+ self ._llava_cpp .clip_image_u8_free (img_u8_p )
3497
+ llama_cpp .llama_set_causal_attn (llama .ctx , False )
3498
+
3499
+ seq_id_0 = (ctypes .c_int32 * 1 )()
3500
+ seq_ids = (ctypes .POINTER (ctypes .c_int32 ) * (n_tokens + 1 ))()
3501
+ for i in range (n_tokens ):
3502
+ seq_ids [i ] = seq_id_0
3503
+
3504
+ batch = llama_cpp .llama_batch ()
3505
+ batch .n_tokens = n_tokens
3506
+ batch .token = None
3507
+ batch .embd = embed
3508
+ batch .pos = (ctypes .c_int32 * n_tokens )(* [i + llama .n_tokens for i in range (n_tokens )])
3509
+ batch .seq_id = seq_ids
3510
+ batch .n_seq_id = (ctypes .c_int32 * n_tokens )(* ([1 ] * n_tokens ))
3511
+ batch .logits = (ctypes .c_int8 * n_tokens )()
3512
+
3513
+ if llama_cpp .llama_decode (llama .ctx , batch ):
3514
+ raise ValueError ("Failed to decode image." )
3515
+
3516
+ llama_cpp .llama_set_causal_attn (llama .ctx , True )
3517
+ # Required to avoid issues with hf tokenizer
3518
+ llama .input_ids [llama .n_tokens : llama .n_tokens + n_tokens ] = - 1
3519
+ llama .n_tokens += n_tokens
3520
+
3464
3521
3465
3522
@register_chat_completion_handler ("chatml-function-calling" )
3466
3523
def chatml_function_calling (
0 commit comments