From 9d97ad56ba4c595501abd20eedc9f7c62e5726c7 Mon Sep 17 00:00:00 2001 From: Xuan Son Nguyen Date: Sat, 8 Mar 2025 18:35:51 +0100 Subject: [PATCH 1/3] (research) experiment with phi-4-multimodal --- convert_hf_to_gguf.py | 26 +- examples/llava/CMakeLists.txt | 7 + examples/llava/clip.cpp | 22 +- examples/llava/phi4mm-cli.cpp | 224 +++++++++++++ .../llava/phi4mm_convert_encoder_to_gguf.py | 314 ++++++++++++++++++ 5 files changed, 589 insertions(+), 4 deletions(-) create mode 100644 examples/llava/phi4mm-cli.cpp create mode 100644 examples/llava/phi4mm_convert_encoder_to_gguf.py diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index 6358a94e9b55f..de470c409d277 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -2398,9 +2398,23 @@ def set_gguf_parameters(self): self.gguf_writer.add_add_bos_token(False) -@Model.register("Phi3ForCausalLM") +@Model.register("Phi3ForCausalLM", "Phi4MMForCausalLM") class Phi3MiniModel(Model): model_arch = gguf.MODEL_ARCH.PHI3 + has_vision: bool = False + + # we need to merge the text_config into the root level of hparams + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + if "vision_lora" in self.hparams: + logger.info("Detected vision encoder, but it will be ignored") + self.has_vision = True + + def write(self): + super().write() + if self.has_vision: + logger.info("NOTE: this script only convert the language model to GGUF") + logger.info(" for the vision model, please use phi4mm_convert_encoder_to_gguf.py") def set_vocab(self): # Phi-4 model uses GPT2Tokenizer @@ -2409,7 +2423,7 @@ def set_vocab(self): with open(tokenizer_config_file, "r", encoding="utf-8") as f: tokenizer_config_json = json.load(f) tokenizer_class = tokenizer_config_json['tokenizer_class'] - if tokenizer_class == 'GPT2Tokenizer': + if tokenizer_class == 'GPT2Tokenizer' or tokenizer_class == 'GPT2TokenizerFast': return self._set_vocab_gpt2() from sentencepiece import SentencePieceProcessor @@ -2575,6 +2589,14 @@ def generate_extra_tensors(self) -> Iterable[tuple[str, Tensor]]: yield (self.format_tensor_name(gguf.MODEL_TENSOR.ROPE_FACTORS_LONG), torch.tensor(long_factors, dtype=torch.float32)) yield (self.format_tensor_name(gguf.MODEL_TENSOR.ROPE_FACTORS_SHORT), torch.tensor(short_factors, dtype=torch.float32)) + def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: + del bid # unused + if self.has_vision: + if name.startswith("model.embed_tokens_extend") or "lora_" in name: + return [] + name = name.replace(".base_layer", "") + return [(self.map_tensor_name(name), data_torch)] + @Model.register("PhiMoEForCausalLM") class PhiMoeModel(Phi3MiniModel): diff --git a/examples/llava/CMakeLists.txt b/examples/llava/CMakeLists.txt index 319effd199aa4..4fa68b6886ca3 100644 --- a/examples/llava/CMakeLists.txt +++ b/examples/llava/CMakeLists.txt @@ -57,3 +57,10 @@ set_target_properties(${TARGET} PROPERTIES OUTPUT_NAME llama-llava-clip-quantize install(TARGETS ${TARGET} RUNTIME) target_link_libraries(${TARGET} PRIVATE common llava ${CMAKE_THREAD_LIBS_INIT}) target_compile_features(${TARGET} PRIVATE cxx_std_17) + +set(TARGET llama-phi4mm-cli) +add_executable(${TARGET} phi4mm-cli.cpp) +set_target_properties(${TARGET} PROPERTIES OUTPUT_NAME llama-phi4mm-cli) +install(TARGETS ${TARGET} RUNTIME) +target_link_libraries(${TARGET} PRIVATE common llava ${CMAKE_THREAD_LIBS_INIT}) +target_compile_features(${TARGET} PRIVATE cxx_std_17) diff --git a/examples/llava/clip.cpp b/examples/llava/clip.cpp index 76d4a78520575..b81be1910cee9 100644 --- a/examples/llava/clip.cpp +++ b/examples/llava/clip.cpp @@ -878,6 +878,24 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32 } } + // FIXME: phi-4, wrap this into an "if" condition + int n_tokens = embeddings->ne[1]; + int n_tokens_sqrt = sqrtf(n_tokens); + printf("embeddings shape: %d %d %d %d\n", embeddings->ne[0], embeddings->ne[1], embeddings->ne[2], embeddings->ne[3]); + embeddings = ggml_cont(ctx0, ggml_transpose(ctx0, embeddings)); + embeddings = ggml_reshape_4d(ctx0, embeddings, n_tokens_sqrt, n_tokens_sqrt, hidden_size, batch_size); + embeddings = ggml_pool_2d(ctx0, embeddings, GGML_OP_POOL_AVG, 2, 2, 2, 2, 0, 0); + embeddings = ggml_reshape_3d(ctx0, embeddings, hidden_size, n_tokens / 4, batch_size); + printf("embeddings shape: %d %d %d %d\n", embeddings->ne[0], embeddings->ne[1], embeddings->ne[2], embeddings->ne[3]); + // mlp + embeddings = ggml_mul_mat(ctx0, model.mm_0_w, embeddings); + embeddings = ggml_add(ctx0, embeddings, model.mm_0_b); + + embeddings = ggml_gelu(ctx0, embeddings); + embeddings = ggml_mul_mat(ctx0, model.mm_2_w, embeddings); + embeddings = ggml_add(ctx0, embeddings, model.mm_2_b); + printf("embeddings shape: %d %d %d %d\n", embeddings->ne[0], embeddings->ne[1], embeddings->ne[2], embeddings->ne[3]); + // llava projector if (ctx->has_llava_projector) { embeddings = ggml_reshape_2d(ctx0, embeddings, embeddings->ne[0], embeddings->ne[1]); @@ -2758,7 +2776,7 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima ggml_backend_tensor_set(positions, positions_data, 0, ggml_nbytes(positions)); free(positions_data); - if (!ctx->has_glm_projector) { + /*if (!ctx->has_glm_projector) { struct ggml_tensor * patches = ggml_graph_get_tensor(gf, "patches"); // The patches vector is used to get rows to index into the embeds with; // we should skip dim 0 only if we have CLS to avoid going out of bounds @@ -2770,7 +2788,7 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima } ggml_backend_tensor_set(patches, patches_data, 0, ggml_nbytes(patches)); free(patches_data); - } + }*/ } } diff --git a/examples/llava/phi4mm-cli.cpp b/examples/llava/phi4mm-cli.cpp new file mode 100644 index 0000000000000..3be248a4770bf --- /dev/null +++ b/examples/llava/phi4mm-cli.cpp @@ -0,0 +1,224 @@ +#include "arg.h" +#include "log.h" +#include "common.h" +#include "sampling.h" +#include "clip.h" +#include "stb_image.h" +#include "llama.h" +#include "ggml.h" + +#include +#include +#include +#include +#include +#include +#include + +struct phi4mm_context { + struct clip_ctx * ctx_clip = NULL; + common_init_result llama_init; + + llama_model * model; + llama_context * lctx; + llama_adapter_lora * vision_lora; + + phi4mm_context(common_params & params) : llama_init(common_init_from_params(params)) { + model = llama_init.model.get(); + lctx = llama_init.context.get(); + vision_lora = llama_init.lora[0].get(); + llama_clear_adapter_lora(lctx); + init_clip_model(params); + } + + void init_clip_model(common_params & params) { + const char * clip_path = params.mmproj.c_str(); + ctx_clip = clip_model_load(clip_path, params.verbosity > 1); + } + + ~phi4mm_context() { + clip_free(ctx_clip); + } +}; + +struct decode_embd_batch { + std::vector pos; + std::vector n_seq_id; + std::vector seq_id_0; + std::vector seq_ids; + std::vector logits; + llama_batch batch; + decode_embd_batch(float * embd, int32_t n_tokens, llama_pos pos_0, llama_seq_id seq_id) { + pos .resize(n_tokens); + n_seq_id.resize(n_tokens); + seq_ids .resize(n_tokens + 1); + logits .resize(n_tokens); + seq_id_0.resize(1); + seq_id_0[0] = seq_id; + seq_ids [n_tokens] = nullptr; + batch = { + /*n_tokens =*/ n_tokens, + /*tokens =*/ nullptr, + /*embd =*/ embd, + /*pos =*/ pos.data(), + /*n_seq_id =*/ n_seq_id.data(), + /*seq_id =*/ seq_ids.data(), + /*logits =*/ logits.data(), + }; + for (int i = 0; i < n_tokens; i++) { + batch.pos [i] = pos_0 + i; + batch.n_seq_id[i] = 1; + batch.seq_id [i] = seq_id_0.data(); + batch.logits [i] = false; + } + } +}; + +struct inp_bitmap { + int nx; + int ny; + std::vector data; +}; + +static void show_additional_info(int /*argc*/, char ** argv) { + GGML_UNUSED(argv); + LOG("TODO\n"); +} + +static void eval_text(phi4mm_context & ctx, int & n_past, std::string input, bool logits_last = false) { + llama_tokens tokens = common_tokenize(ctx.lctx, input, false, true); + llama_batch batch = llama_batch_init(tokens.size(), 0, 1); + for (llama_token & t : tokens) { + common_batch_add(batch, t, n_past++, {0}, false); + } + if (logits_last) { + batch.logits[batch.n_tokens - 1] = true; + } + LOG("eval_text (n_tokens = %d): %s\n", (int)tokens.size(), input.c_str()); + if (llama_decode(ctx.lctx, batch)) { + GGML_ABORT("Failed to decode\n"); + } +} + +int main(int argc, char ** argv) { + ggml_time_init(); + + common_params params; + + // default values + params.prompt = "<|user|>$what did you see?<|end|><|assistant|>"; + params.n_predict = 64; + params.sampling.temp = 0.0f; + + if (!common_params_parse(argc, argv, params, LLAMA_EXAMPLE_LLAVA, show_additional_info)) { + return 1; + } + + common_init(); + + if (params.mmproj.empty() || (params.image.empty())) { + show_additional_info(argc, argv); + return 1; + } + + if (params.lora_adapters.empty()) { + LOG_ERR("error: no vision lora adapters specified\n"); + return 1; + } + + phi4mm_context ctx(params); + printf("%s: %s\n", __func__, params.model.c_str()); + + int n_threads = params.cpuparams.n_threads; + int n_past = 0; + + std::vector prompt_parts = string_split(params.prompt, '$'); + GGML_ASSERT(prompt_parts.size() == 2); + eval_text(ctx, n_past, prompt_parts[0], false); + + // process images + for (auto & image : params.image) { + //break; + std::vector image_embd_v; + int n_embd = llama_model_n_embd(ctx.model); + int n_tokens = 256; + image_embd_v.resize(n_tokens * n_embd); + + bool ok; + struct clip_image_u8 * img_u8 = clip_image_u8_init(); + ok = clip_image_load_from_file(image.c_str(), img_u8); + if (!ok) { + LOG_ERR("Unable to load image %s\n", image.c_str()); + return 1; + } + + clip_image_f32_batch batch_f32; + ok = clip_image_preprocess(ctx.ctx_clip, img_u8, &batch_f32); + if (!ok) { + LOG_ERR("Unable to preprocess image\n"); + return 1; + } + + LOG("Encoding image %s\n", image.c_str()); + ok = clip_image_batch_encode(ctx.ctx_clip, n_threads, &batch_f32, image_embd_v.data()); + if (!ok) { + LOG_ERR("Unable to encode image\n"); + return 1; + } + + // debug + // for (int i = 0; i < 10; i++) { + // LOG("embd[%d] = %f, %f, %f\n", i, image_embd_v[i*n_embd], image_embd_v[i*n_embd+1], image_embd_v[i*n_embd+2]); + // } + + clip_image_f32_batch_free(&batch_f32); + clip_image_u8_free(img_u8); + + // decode image embeddings + llama_set_adapter_lora(ctx.lctx, ctx.vision_lora, 1.0f); + decode_embd_batch batch_img(image_embd_v.data(), n_tokens, n_past, 0); + if (llama_decode(ctx.lctx, batch_img.batch)) { + LOG_ERR("failed to decode image\n"); + return 1; + } + llama_clear_adapter_lora(ctx.lctx); + n_past += n_tokens; + } + + eval_text(ctx, n_past, prompt_parts[1], true); + + // generate text + struct common_sampler * smpl = common_sampler_init(ctx.model, params.sampling); + const llama_vocab * vocab = llama_model_get_vocab(ctx.model); + int n_prompt = n_past; + llama_batch batch = llama_batch_init(1, 0, 1); + while (true) { + int n_generated = n_past - n_prompt; + if (n_generated > params.n_predict) { + printf("\n"); + break; + } + + llama_token token_id = common_sampler_sample(smpl, ctx.lctx, -1); + common_sampler_accept(smpl, token_id, true); + printf("%s", common_token_to_piece(ctx.lctx, token_id).c_str()); + fflush(stdout); + + if (llama_vocab_is_eog(vocab, token_id)) { + printf("\n"); + break; + } + + // eval the token + common_batch_clear(batch); + common_batch_add(batch, token_id, n_past++, {0}, true); + if (llama_decode(ctx.lctx, batch)) { + LOG_ERR("failed to decode token\n"); + break; + } + } + + llama_batch_free(batch); + + return 0; +} diff --git a/examples/llava/phi4mm_convert_encoder_to_gguf.py b/examples/llava/phi4mm_convert_encoder_to_gguf.py new file mode 100644 index 0000000000000..e6445acd60215 --- /dev/null +++ b/examples/llava/phi4mm_convert_encoder_to_gguf.py @@ -0,0 +1,314 @@ +import gguf +import argparse +import logging +import sys +import torch +import json +import os +import numpy as np +from typing import cast, ContextManager, Any, Iterator +from pathlib import Path +from torch import Tensor + +logger = logging.getLogger("phi4-mmproj") + + +# https://huggingface.co/microsoft/Phi-4-multimodal-instruct/blob/main/modeling_phi4mm.py +# https://huggingface.co/google/siglip-base-patch16-224/blob/main/preprocessor_config.json +# https://github.com/EricLBuehler/mistral.rs/pull/1163/files +SIGLIP_MODEL = { + "model_id": "google/siglip-base-patch16-224", + "image_size": 448, + "patch_size": 14, # I had very had time finding this number + "do_normalize": True, + "do_rescale": True, + "do_resize": True, + "image_mean": [ + 0.5, + 0.5, + 0.5 + ], + "image_processor_type": "SiglipImageProcessor", + "image_std": [ + 0.5, + 0.5, + 0.5 + ], + "processor_class": "SiglipProcessor", + "resample": 3, + "rescale_factor": 0.00392156862745098, + "size": { + "height": 224, + "width": 224 + } +} + + +# (copied from convert_hf_to_gguf.py) +# tree of lazy tensors +class LazyTorchTensor(gguf.LazyBase): + _tensor_type = torch.Tensor + # to keep the type-checker happy + dtype: torch.dtype + shape: torch.Size + + # only used when converting a torch.Tensor to a np.ndarray + _dtype_map: dict[torch.dtype, type] = { + torch.float16: np.float16, + torch.float32: np.float32, + } + + # used for safetensors slices + # ref: https://github.com/huggingface/safetensors/blob/079781fd0dc455ba0fe851e2b4507c33d0c0d407/bindings/python/src/lib.rs#L1046 + # TODO: uncomment U64, U32, and U16, ref: https://github.com/pytorch/pytorch/issues/58734 + _dtype_str_map: dict[str, torch.dtype] = { + "F64": torch.float64, + "F32": torch.float32, + "BF16": torch.bfloat16, + "F16": torch.float16, + # "U64": torch.uint64, + "I64": torch.int64, + # "U32": torch.uint32, + "I32": torch.int32, + # "U16": torch.uint16, + "I16": torch.int16, + "U8": torch.uint8, + "I8": torch.int8, + "BOOL": torch.bool, + "F8_E4M3": torch.float8_e4m3fn, + "F8_E5M2": torch.float8_e5m2, + } + + def numpy(self) -> gguf.LazyNumpyTensor: + dtype = self._dtype_map[self.dtype] + return gguf.LazyNumpyTensor( + meta=gguf.LazyNumpyTensor.meta_with_dtype_and_shape(dtype, self.shape), + args=(self,), + func=(lambda s: s.numpy()) + ) + + @classmethod + def meta_with_dtype_and_shape(cls, dtype: torch.dtype, shape: tuple[int, ...]) -> Tensor: + return torch.empty(size=shape, dtype=dtype, device="meta") + + @classmethod + def from_safetensors_slice(cls, st_slice: Any) -> Tensor: + dtype = cls._dtype_str_map[st_slice.get_dtype()] + shape: tuple[int, ...] = tuple(st_slice.get_shape()) + lazy = cls(meta=cls.meta_with_dtype_and_shape(dtype, shape), args=(st_slice,), func=lambda s: s[:]) + return cast(torch.Tensor, lazy) + + @classmethod + def __torch_function__(cls, func, types, args=(), kwargs=None): + del types # unused + + if kwargs is None: + kwargs = {} + + if func is torch.Tensor.numpy: + return args[0].numpy() + + return cls._wrap_fn(func)(*args, **kwargs) + + +class Phi4MM: + hparams: dict + gguf_writer: gguf.GGUFWriter + fname_out: Path + ftype: gguf.LlamaFileType + + @staticmethod + def load_hparams(dir_model: Path): + with open(dir_model / "config.json", "r", encoding="utf-8") as f: + return json.load(f) + + @staticmethod + def get_model_part_names(dir_model: Path, prefix: str, suffix: str) -> list[str]: + part_names: list[str] = [] + for filename in os.listdir(dir_model): + if filename.startswith(prefix) and filename.endswith(suffix): + part_names.append(filename) + part_names.sort() + return part_names + + def __init__(self, + dir_model: Path, + fname_out: Path, + ftype: gguf.LlamaFileType, + is_big_endian: bool,): + hparams = Phi4MM.load_hparams(dir_model) + self.hparams = hparams + self.fname_out = fname_out + self.ftype = ftype + endianess = gguf.GGUFEndian.BIG if is_big_endian else gguf.GGUFEndian.LITTLE + self.gguf_writer = gguf.GGUFWriter(path=None, arch="clip", endianess=endianess) + self.gguf_writer.add_string ("clip.projector_type", "mlp") + self.gguf_writer.add_bool ("clip.has_text_encoder", False) + self.gguf_writer.add_bool ("clip.has_vision_encoder", True) + self.gguf_writer.add_bool ("clip.has_llava_projector", False) + self.gguf_writer.add_uint32 ("clip.vision.image_size", SIGLIP_MODEL["image_size"]) + self.gguf_writer.add_uint32 ("clip.vision.patch_size", SIGLIP_MODEL["patch_size"]) + self.gguf_writer.add_uint32 ("clip.vision.embedding_length", 1152) + self.gguf_writer.add_uint32 ("clip.vision.feed_forward_length", 4304) + self.gguf_writer.add_uint32 ("clip.vision.projection_dim", hparams["hidden_size"]) + self.gguf_writer.add_uint32 ("clip.vision.block_count", 12) + self.gguf_writer.add_uint32 ("clip.vision.attention.head_count", 12) + self.gguf_writer.add_float32("clip.vision.attention.layer_norm_epsilon", 1e-6) + self.gguf_writer.add_array ("clip.vision.image_mean", SIGLIP_MODEL["image_mean"]) + self.gguf_writer.add_array ("clip.vision.image_std", SIGLIP_MODEL["image_std"]) + self.gguf_writer.add_bool ("clip.use_gelu", False) + + # load tensors + for name, data_torch in self.get_tensors(dir_model): + # convert any unsupported data types to float32 + if data_torch.dtype not in (torch.float16, torch.float32): + data_torch = data_torch.to(torch.float32) + self.add_tensor(name, data_torch) + + def get_tensors(self, dir_model: Path) -> Iterator[tuple[str, Tensor]]: + part_names = Phi4MM.get_model_part_names(dir_model, "model", ".safetensors") + tensor_names_from_parts: set[str] = set() + for part_name in part_names: + logger.info(f"gguf: loading model part '{part_name}'") + from safetensors import safe_open + ctx = cast(ContextManager[Any], safe_open(dir_model / part_name, framework="pt", device="cpu")) + with ctx as model_part: + tensor_names_from_parts.update(model_part.keys()) + + for name in model_part.keys(): + data = model_part.get_slice(name) + data = LazyTorchTensor.from_safetensors_slice(data) + yield name, data + + def add_tensor(self, name: str, data_torch: Tensor): + if not name.startswith("model.embed_tokens_extend.image_embed.") \ + or "img_processor.head." in name \ + or "glb_GN" in name \ + or "sub_GN" in name: + return # skip + + is_1d = len(data_torch.shape) == 1 + is_embd = ".embeddings." in name + old_dtype = data_torch.dtype + can_quantize = not is_1d and not is_embd + data_qtype = gguf.GGMLQuantizationType.F32 + + # prefix + name = name.replace("model.embed_tokens_extend.image_embed.img_processor.", "") + name = name.replace("encoder.", "v.") + name = name.replace("layers.", "blk.") + # projector and input embd + name = name.replace("embeddings.patch_embedding.", "v.patch_embd.") + name = name.replace("embeddings.position_embedding.", "v.position_embd.") + name = name.replace("post_layernorm.", "post_ln.") + # each block + name = name.replace(".self_attn.k_proj.", ".attn_k.") + name = name.replace(".self_attn.v_proj.", ".attn_v.") + name = name.replace(".self_attn.q_proj.", ".attn_q.") + name = name.replace(".self_attn.out_proj.", ".attn_out.") + name = name.replace(".layer_norm1.", ".ln1.") + name = name.replace(".layer_norm2.", ".ln2.") + name = name.replace(".mlp.fc1.", ".ffn_down.") + name = name.replace(".mlp.fc2.", ".ffn_up.") + # projector + name = name.replace("model.embed_tokens_extend.image_embed.img_projection.", "mm.") + + if can_quantize: + if self.ftype == gguf.LlamaFileType.ALL_F32: + data_qtype = gguf.GGMLQuantizationType.F32 + elif self.ftype == gguf.LlamaFileType.MOSTLY_F16: + data_qtype = gguf.GGMLQuantizationType.F16 + elif self.ftype == gguf.LlamaFileType.MOSTLY_BF16: + data_qtype = gguf.GGMLQuantizationType.BF16 + elif self.ftype == gguf.LlamaFileType.MOSTLY_Q8_0: + data_qtype = gguf.GGMLQuantizationType.Q8_0 + else: + raise ValueError(f"Unsupported file type: {self.ftype}") + data = data_torch.numpy() + + try: + data = gguf.quants.quantize(data, data_qtype) + except Exception as e: + logger.error(f"Error quantizing tensor '{name}': {e}, fallback to F16") + data_qtype = gguf.GGMLQuantizationType.F16 + data = gguf.quants.quantize(data, data_qtype) + + # reverse shape to make it similar to the internal ggml dimension order + shape_str = f"{{{', '.join(str(n) for n in reversed(data_torch.shape))}}}" + logger.info(f"{f'%-32s' % f'{name},'} {old_dtype} --> {data_qtype.name}, shape = {shape_str}") + + self.gguf_writer.add_tensor(name, data, raw_dtype=data_qtype) + + def write(self): + self.gguf_writer.write_header_to_file(path=self.fname_out) + self.gguf_writer.write_kv_data_to_file() + self.gguf_writer.write_tensors_to_file(progress=True) + self.gguf_writer.close() + + +def parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser( + description="Convert Phi 4 vision encoder safetensors to GGUF format",) + parser.add_argument( + "--outfile", type=Path, default="mmproj.gguf", + help="path to write mmproj file to", + ) + parser.add_argument( + "--outtype", type=str, choices=["f32", "f16", "bf16", "q8_0"], default="f16", + help="output format", + ) + parser.add_argument( + "--bigendian", action="store_true", + help="model is executed on big endian machine", + ) + parser.add_argument( + "model", type=Path, + help="directory containing model file", + nargs="?", + ) + parser.add_argument( + "--verbose", action="store_true", + help="increase output verbosity", + ) + + args = parser.parse_args() + if args.model is None: + parser.error("the following arguments are required: model") + return args + + +def main() -> None: + args = parse_args() + + if args.verbose: + logging.basicConfig(level=logging.DEBUG) + else: + logging.basicConfig(level=logging.INFO) + + dir_model = args.model + + if not dir_model.is_dir(): + logger.error(f'Error: {args.model} is not a directory') + sys.exit(1) + + ftype_map: dict[str, gguf.LlamaFileType] = { + "f32": gguf.LlamaFileType.ALL_F32, + "f16": gguf.LlamaFileType.MOSTLY_F16, + "bf16": gguf.LlamaFileType.MOSTLY_BF16, + "q8_0": gguf.LlamaFileType.MOSTLY_Q8_0, + } + + logger.info(f"Loading model: {dir_model.name}") + + with torch.inference_mode(): + phi4_mm = Phi4MM( + dir_model=dir_model, + fname_out=args.outfile, + ftype=ftype_map[args.outtype], + is_big_endian=args.bigendian, + ) + phi4_mm.write() + + +if __name__ == '__main__': + main() From 565d9f32a639accb5785cafac6d05bfe7e494032 Mon Sep 17 00:00:00 2001 From: Xuan Son Nguyen Date: Sun, 9 Mar 2025 11:18:48 +0100 Subject: [PATCH 2/3] correct some configs --- examples/llava/phi4mm-test.sh | 14 ++++++++++++++ examples/llava/phi4mm_convert_encoder_to_gguf.py | 7 +++++-- 2 files changed, 19 insertions(+), 2 deletions(-) create mode 100755 examples/llava/phi4mm-test.sh diff --git a/examples/llava/phi4mm-test.sh b/examples/llava/phi4mm-test.sh new file mode 100755 index 0000000000000..2fa9b534adfe4 --- /dev/null +++ b/examples/llava/phi4mm-test.sh @@ -0,0 +1,14 @@ +#!/bin/sh + +# for convenience, we have this script to ease the development process + +# make sure we are in the right directory +SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd ) +PROJECT_ROOT="$SCRIPT_DIR/../.." +cd $PROJECT_ROOT + +./build/bin/llama-phi4mm-cli \ + -m ../models/Phi-4-multimodal-instruct/model.gguf \ + --mmproj ../models/Phi-4-multimodal-instruct/mmproj.gguf \ + --lora ../models/Phi-4-multimodal-instruct/vision_lora.gguf \ + --image ../models/bliss.png diff --git a/examples/llava/phi4mm_convert_encoder_to_gguf.py b/examples/llava/phi4mm_convert_encoder_to_gguf.py index e6445acd60215..b26ccdba5520e 100644 --- a/examples/llava/phi4mm_convert_encoder_to_gguf.py +++ b/examples/llava/phi4mm_convert_encoder_to_gguf.py @@ -42,6 +42,8 @@ "width": 224 } } +N_LAYERS = 27 +HEAD_COUNT = 16 # (copied from convert_hf_to_gguf.py) @@ -151,12 +153,13 @@ def __init__(self, self.gguf_writer.add_uint32 ("clip.vision.embedding_length", 1152) self.gguf_writer.add_uint32 ("clip.vision.feed_forward_length", 4304) self.gguf_writer.add_uint32 ("clip.vision.projection_dim", hparams["hidden_size"]) - self.gguf_writer.add_uint32 ("clip.vision.block_count", 12) - self.gguf_writer.add_uint32 ("clip.vision.attention.head_count", 12) + self.gguf_writer.add_uint32 ("clip.vision.block_count", N_LAYERS) + self.gguf_writer.add_uint32 ("clip.vision.attention.head_count", HEAD_COUNT) self.gguf_writer.add_float32("clip.vision.attention.layer_norm_epsilon", 1e-6) self.gguf_writer.add_array ("clip.vision.image_mean", SIGLIP_MODEL["image_mean"]) self.gguf_writer.add_array ("clip.vision.image_std", SIGLIP_MODEL["image_std"]) self.gguf_writer.add_bool ("clip.use_gelu", False) + self.gguf_writer.add_array ("clip.vision.feature_layer", [N_LAYERS]) # load tensors for name, data_torch in self.get_tensors(dir_model): From 605a10f5dd1dc555d040aed5265148c20fffc9e9 Mon Sep 17 00:00:00 2001 From: Xuan Son Nguyen Date: Sun, 9 Mar 2025 12:01:38 +0100 Subject: [PATCH 3/3] it works! --- examples/llava/clip.cpp | 10 +++++++--- examples/llava/phi4mm_convert_encoder_to_gguf.py | 3 ++- 2 files changed, 9 insertions(+), 4 deletions(-) diff --git a/examples/llava/clip.cpp b/examples/llava/clip.cpp index b81be1910cee9..2051c1064d8da 100644 --- a/examples/llava/clip.cpp +++ b/examples/llava/clip.cpp @@ -881,11 +881,15 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32 // FIXME: phi-4, wrap this into an "if" condition int n_tokens = embeddings->ne[1]; int n_tokens_sqrt = sqrtf(n_tokens); + int downscale_factor = 2; printf("embeddings shape: %d %d %d %d\n", embeddings->ne[0], embeddings->ne[1], embeddings->ne[2], embeddings->ne[3]); embeddings = ggml_cont(ctx0, ggml_transpose(ctx0, embeddings)); - embeddings = ggml_reshape_4d(ctx0, embeddings, n_tokens_sqrt, n_tokens_sqrt, hidden_size, batch_size); - embeddings = ggml_pool_2d(ctx0, embeddings, GGML_OP_POOL_AVG, 2, 2, 2, 2, 0, 0); - embeddings = ggml_reshape_3d(ctx0, embeddings, hidden_size, n_tokens / 4, batch_size); + embeddings = ggml_reshape_3d(ctx0, embeddings, n_tokens_sqrt, n_tokens_sqrt, hidden_size); + // downscale n_tokens_sqrt*n_tokens_sqrt to n_tokens_sqrt/2*n_tokens_sqrt/2 + embeddings = ggml_pool_2d(ctx0, embeddings, GGML_OP_POOL_AVG, downscale_factor, downscale_factor, downscale_factor, downscale_factor, 0, 0); + // flatten first two dimensions + embeddings = ggml_reshape_2d(ctx0, embeddings, n_tokens_sqrt/2*n_tokens_sqrt/2, hidden_size); + embeddings = ggml_cont(ctx0, ggml_transpose(ctx0, embeddings)); printf("embeddings shape: %d %d %d %d\n", embeddings->ne[0], embeddings->ne[1], embeddings->ne[2], embeddings->ne[3]); // mlp embeddings = ggml_mul_mat(ctx0, model.mm_0_w, embeddings); diff --git a/examples/llava/phi4mm_convert_encoder_to_gguf.py b/examples/llava/phi4mm_convert_encoder_to_gguf.py index b26ccdba5520e..ea8550b1f2d45 100644 --- a/examples/llava/phi4mm_convert_encoder_to_gguf.py +++ b/examples/llava/phi4mm_convert_encoder_to_gguf.py @@ -43,6 +43,7 @@ } } N_LAYERS = 27 +FEATURE_LAYER = -2 HEAD_COUNT = 16 @@ -159,7 +160,7 @@ def __init__(self, self.gguf_writer.add_array ("clip.vision.image_mean", SIGLIP_MODEL["image_mean"]) self.gguf_writer.add_array ("clip.vision.image_std", SIGLIP_MODEL["image_std"]) self.gguf_writer.add_bool ("clip.use_gelu", False) - self.gguf_writer.add_array ("clip.vision.feature_layer", [N_LAYERS]) + self.gguf_writer.add_array ("clip.vision.feature_layer", [N_LAYERS + FEATURE_LAYER]) # load tensors for name, data_torch in self.get_tensors(dir_model):