diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index 91af508a2fb28..a015ecee08328 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -432,6 +432,9 @@ def load_hparams(dir_model: Path): if "llm_config" in config: # rename for InternVL config["text_config"] = config["llm_config"] + if "thinker_config" in config: + # rename for Qwen2.5-Omni + config["text_config"] = config["thinker_config"]["text_config"] return config @classmethod @@ -1121,18 +1124,21 @@ class MmprojModel(ModelBase): preprocessor_config: dict[str, Any] global_config: dict[str, Any] + n_block_keys = ["n_layers", "num_hidden_layers", "n_layer", "num_layers", "depth"] + has_vision_encoder: bool = True # by default has_audio_encoder: bool = False + # for models having multiple encoders, we need to separate their hparams + hparams_vision: dict[str, Any] | None = None + hparams_audio: dict[str, Any] | None = None + def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) if self.model_arch != gguf.MODEL_ARCH.MMPROJ: raise TypeError("MmprojModel must be subclassed with model_arch = gguf.MODEL_ARCH.MMPROJ") - if self.has_vision_encoder and self.has_audio_encoder: - raise NotImplementedError("both vision + audio not supported yet") - # get n_embd of the text model if "text_config" not in self.hparams: self.hparams["text_config"] = {} @@ -1143,22 +1149,32 @@ def __init__(self, *args, **kwargs): assert self.n_embd_text > 0, "n_embd not found in hparams" # move vision config to the top level, while preserving the original hparams in global_config - self.global_config = self.hparams + import copy + self.global_config = copy.deepcopy(self.hparams) + self.hparams_vision = self.get_vision_config() + self.hparams_audio = self.get_audio_config() - if "vision_config" in self.hparams: - self.hparams = self.hparams["vision_config"] - elif "audio_config" in self.hparams: - self.hparams = self.hparams["audio_config"] - else: + if self.hparams_vision is None and self.hparams_audio is None: raise ValueError("vision_config / audio_config not found in hparams") - self.block_count = self.find_hparam(["n_layers", "num_hidden_layers", "n_layer", "num_layers", "depth"]) + # for compat with vision-only models + self.hparams = self.hparams_vision or self.hparams_audio or self.hparams + + # TODO @ngxson : this is a hack to support both vision and audio encoders + have_multiple_encoders = self.has_audio_encoder and self.has_vision_encoder + self.block_count = 128 if have_multiple_encoders else self.find_hparam(self.n_block_keys, True) self.tensor_map = gguf.get_tensor_name_map(gguf.MODEL_ARCH.MMPROJ, self.block_count) # load preprocessor config with open(self.dir_model / "preprocessor_config.json", "r", encoding="utf-8") as f: self.preprocessor_config = json.load(f) + def get_vision_config(self) -> dict[str, Any] | None: + return self.global_config.get("vision_config") + + def get_audio_config(self) -> dict[str, Any] | None: + return self.global_config.get("audio_config") + def set_type(self): self.gguf_writer.add_type(gguf.GGUFType.MMPROJ) @@ -1170,26 +1186,26 @@ def set_gguf_parameters(self): self.gguf_writer.add_vision_projection_dim(self.n_embd_text) # vision config - self.gguf_writer.add_vision_image_size(self.find_hparam(["image_size"])) - self.gguf_writer.add_vision_patch_size(self.find_hparam(["patch_size"])) - self.gguf_writer.add_vision_embedding_length(self.find_hparam(["hidden_size"])) - self.gguf_writer.add_vision_feed_forward_length(self.find_hparam(["intermediate_size"])) - self.gguf_writer.add_vision_block_count(self.block_count) - self.gguf_writer.add_vision_head_count(self.find_hparam(["num_attention_heads"])) + self.gguf_writer.add_vision_image_size(self.find_vparam(["image_size"])) + self.gguf_writer.add_vision_patch_size(self.find_vparam(["patch_size"])) + self.gguf_writer.add_vision_embedding_length(self.find_vparam(["hidden_size"])) + self.gguf_writer.add_vision_feed_forward_length(self.find_vparam(["intermediate_size"])) + self.gguf_writer.add_vision_block_count(self.find_vparam(self.n_block_keys)) + self.gguf_writer.add_vision_head_count(self.find_vparam(["num_attention_heads"])) # preprocessor config self.gguf_writer.add_vision_image_mean(self.preprocessor_config["image_mean"]) self.gguf_writer.add_vision_image_std(self.preprocessor_config["image_std"]) - elif self.has_audio_encoder: + if self.has_audio_encoder: self.gguf_writer.add_clip_has_audio_encoder(True) self.gguf_writer.add_audio_projection_dim(self.n_embd_text) # audio config - self.gguf_writer.add_audio_embedding_length(self.find_hparam(["hidden_size"])) - self.gguf_writer.add_audio_feed_forward_length(self.find_hparam(["intermediate_size"])) - self.gguf_writer.add_audio_block_count(self.block_count) - self.gguf_writer.add_audio_head_count(self.find_hparam(["num_attention_heads"])) + self.gguf_writer.add_audio_embedding_length(self.find_aparam(["hidden_size"])) + self.gguf_writer.add_audio_feed_forward_length(self.find_aparam(["intermediate_size"])) + self.gguf_writer.add_audio_block_count(self.find_aparam(self.n_block_keys)) + self.gguf_writer.add_audio_head_count(self.find_aparam(["num_attention_heads"])) else: raise ValueError("MmprojModel must have either vision or audio encoder") @@ -1197,6 +1213,22 @@ def set_gguf_parameters(self): def write_vocab(self): raise ValueError("MmprojModel does not support vocab writing") + def find_vparam(self, keys: Iterable[str], optional: bool = False) -> Any: + assert self.hparams_vision is not None + return self._find_param(self.hparams_vision, keys, optional) + + def find_aparam(self, keys: Iterable[str], optional: bool = False) -> Any: + assert self.hparams_audio is not None + return self._find_param(self.hparams_audio, keys, optional) + + def _find_param(self, obj: dict[str, Any], keys: Iterable[str], optional: bool = False) -> Any: + key = next((k for k in keys if k in obj), None) + if key is not None: + return obj[key] + if optional: + return None + raise KeyError(f"could not find any of: {keys}") + @ModelBase.register("GPTNeoXForCausalLM") class GPTNeoXModel(TextModel): @@ -2674,7 +2706,12 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter yield from super().modify_tensors(data_torch, name, bid) -@ModelBase.register("Qwen2VLModel", "Qwen2VLForConditionalGeneration", "Qwen2_5_VLForConditionalGeneration") +@ModelBase.register( + "Qwen2VLModel", + "Qwen2VLForConditionalGeneration", + "Qwen2_5_VLForConditionalGeneration", + "Qwen2_5OmniModel", +) class Qwen2VLModel(TextModel): model_arch = gguf.MODEL_ARCH.QWEN2VL @@ -2692,8 +2729,11 @@ def set_vocab(self): def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: del bid # unused - if name.startswith("visual."): - # skip visual tensors + if name.startswith("thinker."): + name = name.replace("thinker.", "") + if name.startswith("visual") or name.startswith("audio") or \ + name.startswith("talker") or name.startswith("token2wav"): + # skip multimodal tensors return [] return [(self.map_tensor_name(name), data_torch)] @@ -2702,21 +2742,27 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter class Qwen2VLVisionModel(MmprojModel): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) - self.hparams["image_size"] = self.hparams.get("image_size", 560) + assert self.hparams_vision is not None + self.hparams_vision["image_size"] = self.hparams_vision.get("image_size", 560) # rename config.json values - self.hparams["num_attention_heads"] = self.hparams.get("num_heads") - self.hparams["num_hidden_layers"] = self.hparams.get("depth") - if "embed_dim" in self.hparams: # qwen2vl - self.hparams["intermediate_size"] = self.hparams.get("hidden_size") - self.hparams["hidden_size"] = self.hparams.get("embed_dim") + self.hparams_vision["num_attention_heads"] = self.hparams_vision.get("num_heads") + self.hparams_vision["num_hidden_layers"] = self.hparams_vision.get("depth") + if "embed_dim" in self.hparams_vision: # qwen2vl + self.hparams_vision["intermediate_size"] = self.hparams_vision.get("hidden_size") + self.hparams_vision["hidden_size"] = self.hparams_vision.get("embed_dim") def set_gguf_parameters(self): super().set_gguf_parameters() - hparams = self.hparams - if self.global_config['model_type'] == 'qwen2_vl': + assert self.hparams_vision is not None + hparams = self.hparams_vision + model_type = self.global_config['model_type'] + if model_type == 'qwen2_vl': self.gguf_writer.add_clip_projector_type(gguf.VisionProjectorType.QWEN2VL) - elif self.global_config['model_type'] == 'qwen2_5_vl': - self.gguf_writer.add_clip_projector_type(gguf.VisionProjectorType.QWEN25VL) + elif model_type == 'qwen2_5_vl' or model_type == 'qwen2_5_omni': + if model_type == 'qwen2_5_omni': + self.gguf_writer.add_clip_projector_type(gguf.VisionProjectorType.QWEN25O) + else: + self.gguf_writer.add_clip_projector_type(gguf.VisionProjectorType.QWEN25VL) self.gguf_writer.add_vision_use_silu(True) # find n_wa_pattern (window attention pattern) fullatt_block_indexes = hparams.get("fullatt_block_indexes") @@ -2774,6 +2820,66 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter return [] # skip other tensors +@ModelBase.register("Qwen2_5OmniModel") +class Qwen25OmniModel(Qwen2VLVisionModel): + has_vision_encoder = True + has_audio_encoder = True + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + assert self.hparams_audio is not None + self.hparams_audio["hidden_size"] = self.hparams_audio["d_model"] + self.hparams_audio["intermediate_size"] = self.hparams_audio["encoder_ffn_dim"] + self.hparams_audio["num_attention_heads"] = self.hparams_audio["encoder_attention_heads"] + + def set_gguf_parameters(self): + super().set_gguf_parameters() + assert self.hparams_audio is not None + self.gguf_writer.add_audio_num_mel_bins(self.hparams_audio["num_mel_bins"]) + self.gguf_writer.add_audio_attention_layernorm_eps(self.hparams_audio.get("layer_norm_eps", 1e-5)) + + def get_vision_config(self) -> dict[str, Any] | None: + return self.global_config["thinker_config"].get("vision_config") + + def get_audio_config(self) -> dict[str, Any] | None: + return self.global_config["thinker_config"].get("audio_config") + + def generate_extra_tensors(self) -> Iterable[tuple[str, Tensor]]: + # SinusoidsPositionEmbedding + assert self.hparams_audio is not None + max_timescale = 10000 + length = 1500 + channels = self.hparams_audio["hidden_size"] + log_timescale_increment = np.log(max_timescale) / (channels // 2 - 1) + inv_timescales = torch.exp(-log_timescale_increment * torch.arange(channels // 2).float()) + scaled_time = torch.arange(length)[:, np.newaxis] * inv_timescales[np.newaxis, :] + pos_embd = torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], dim=1).to(dtype=torch.float32) + yield ("audio_tower.embed_positions.weight", pos_embd) + + def tensor_force_quant(self, name, new_name, bid, n_dims): + del bid, new_name, n_dims # unused + if ".conv" in name and ".weight" in name: + return gguf.GGMLQuantizationType.F16 + return False + + def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: + if name.startswith("thinker."): + name = name.replace("thinker.", "") + + if name.startswith("audio_tower"): + # process audio tensors + if "conv1.bias" in name or "conv2.bias" in name: + # transpose conv1 and conv2 bias + data_torch = data_torch.unsqueeze(-1) + if "audio_bos_eos_token" in name: + # this tensor is left unused in transformers code + # https://github.com/huggingface/transformers/blob/6e3063422c4b1c014aa60c32b9254fd2902f0f28/src/transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py#L1809 + return [] + return [(self.map_tensor_name(name), data_torch)] + + return super().modify_tensors(data_torch, name, bid) + + @ModelBase.register("InternVisionModel") class InternVisionModel(MmprojModel): def set_gguf_parameters(self): diff --git a/docs/multimodal.md b/docs/multimodal.md index 3a0994a279ae8..e849c2a0b8ba1 100644 --- a/docs/multimodal.md +++ b/docs/multimodal.md @@ -98,3 +98,12 @@ NOTE: some models may require large context window, for example: `-c 8192` # note: no pre-quantized GGUF this model, as they have very poor result # ref: https://github.com/ggml-org/llama.cpp/pull/13760 ``` + +**Mixed modalities**: + +```sh +# Qwen2.5 Omni +# Capabilities: audio input, vision input +(tool_name) -hf ggml-org/Qwen2.5-Omni-3B-GGUF +(tool_name) -hf ggml-org/Qwen2.5-Omni-7B-GGUF +``` diff --git a/gguf-py/gguf/constants.py b/gguf-py/gguf/constants.py index c6255d6867a15..31163effad8f2 100644 --- a/gguf-py/gguf/constants.py +++ b/gguf-py/gguf/constants.py @@ -2260,6 +2260,7 @@ class VisionProjectorType: ULTRAVOX = "ultravox" INTERNVL = "internvl" QWEN2A = "qwen2a" # audio + QWEN25O = "qwen2.5o" # omni # Items here are (block size, type size) diff --git a/gguf-py/gguf/tensor_mapping.py b/gguf-py/gguf/tensor_mapping.py index 4a0615b656812..000ffd00615b5 100644 --- a/gguf-py/gguf/tensor_mapping.py +++ b/gguf-py/gguf/tensor_mapping.py @@ -1125,6 +1125,7 @@ class TensorNameMap: MODEL_TENSOR.A_POST_NORM: ( "audio_tower.layer_norm", # ultravox + "audio_tower.ln_post", # qwen2omni ), MODEL_TENSOR.A_ENC_ATTN_Q: ( @@ -1161,12 +1162,16 @@ class TensorNameMap: "audio_tower.layers.{bid}.fc2", # ultravox ), + # note: some tensors below has "audio." pseudo-prefix, to prevent conflicts with vision tensors + # this prefix is added in the conversion code in modify_tensors() + MODEL_TENSOR.A_MMPROJ: ( "audio.multi_modal_projector.linear_{bid}", # ultravox ), MODEL_TENSOR.A_MMPROJ_FC: ( "audio.multi_modal_projector.linear", # qwen2audio + "audio_tower.proj", # qwen2omni ), MODEL_TENSOR.A_MM_NORM_PRE: ( diff --git a/tools/mtmd/clip-impl.h b/tools/mtmd/clip-impl.h index 27ce8c43f678c..62c936ed00f77 100644 --- a/tools/mtmd/clip-impl.h +++ b/tools/mtmd/clip-impl.h @@ -130,6 +130,7 @@ enum projector_type { PROJECTOR_TYPE_INTERNVL, PROJECTOR_TYPE_LLAMA4, PROJECTOR_TYPE_QWEN2A, + PROJECTOR_TYPE_QWEN25O, // will be replaced by QWEN2A or QWEN25VL depending on clip_ctx PROJECTOR_TYPE_UNKNOWN, }; @@ -148,6 +149,7 @@ static std::map PROJECTOR_TYPE_NAMES = { { PROJECTOR_TYPE_INTERNVL, "internvl"}, { PROJECTOR_TYPE_LLAMA4, "llama4"}, { PROJECTOR_TYPE_QWEN2A, "qwen2a"}, + { PROJECTOR_TYPE_QWEN25O, "qwen2.5o"}, }; static projector_type clip_projector_type_from_string(const std::string & str) { diff --git a/tools/mtmd/clip.cpp b/tools/mtmd/clip.cpp index 6205dad5ae262..6ae2c2ce46fd2 100644 --- a/tools/mtmd/clip.cpp +++ b/tools/mtmd/clip.cpp @@ -166,9 +166,6 @@ enum patch_merge_type { }; struct clip_hparams { - bool has_vision = false; - bool has_audio = false; - int32_t image_size; int32_t patch_size; int32_t n_embd; @@ -178,9 +175,13 @@ struct clip_hparams { int32_t n_layer; int32_t proj_scale_factor = 0; // idefics3 + float image_mean[3]; + float image_std[3]; + // for models using dynamic image size, we need to have a smaller image size to warmup // otherwise, user will get OOM everytime they load the model int32_t warmup_image_size = 0; + int32_t warmup_audio_size = 3000; ffn_op_type ffn_op = FFN_GELU; @@ -199,6 +200,10 @@ struct clip_hparams { // audio int32_t n_mel_bins = 0; // whisper preprocessor int32_t proj_stack_factor = 0; // ultravox + + // legacy + bool has_llava_projector = false; + int minicpmv_version = 0; }; struct clip_layer { @@ -236,8 +241,10 @@ struct clip_layer { ggml_tensor * ls_2_w = nullptr; }; -struct clip_vision_model { - struct clip_hparams hparams; +struct clip_model { + clip_modality modality = CLIP_MODALITY_VISION; + projector_type proj_type = PROJECTOR_TYPE_MLP; + clip_hparams hparams; // embeddings ggml_tensor * class_embedding = nullptr; @@ -353,14 +360,7 @@ struct clip_vision_model { }; struct clip_ctx { - bool has_llava_projector = false; - int minicpmv_version = 0; - - struct clip_vision_model vision_model; - projector_type proj_type = PROJECTOR_TYPE_MLP; - - float image_mean[3]; - float image_std[3]; + clip_model model; gguf_context_ptr ctx_gguf; ggml_context_ptr ctx_data; @@ -414,11 +414,16 @@ struct clip_ctx { ggml_backend_free(backend_cpu); } } + + // this function is added so that we don't change too much of the existing code + projector_type proj_type() const { + return model.proj_type; + } }; struct clip_graph { clip_ctx * ctx; - const clip_vision_model & model; + const clip_model & model; const clip_hparams & hparams; // we only support single image per batch @@ -441,7 +446,7 @@ struct clip_graph { clip_graph(clip_ctx * ctx, const clip_image_f32 & img) : ctx(ctx), - model(ctx->vision_model), + model(ctx->model), hparams(model.hparams), img(img), patch_size(hparams.patch_size), @@ -473,7 +478,7 @@ struct clip_graph { model.position_embeddings, nullptr); - if (ctx->proj_type == PROJECTOR_TYPE_GEMMA3) { + if (ctx->proj_type() == PROJECTOR_TYPE_GEMMA3) { const int batch_size = 1; GGML_ASSERT(n_patches_x == n_patches_y); const int patches_per_image = n_patches_x; @@ -496,7 +501,7 @@ struct clip_graph { ggml_cont(ctx0, ggml_transpose(ctx0, model.mm_input_proj_w)), cur); - } else if (ctx->proj_type == PROJECTOR_TYPE_IDEFICS3) { + } else if (ctx->proj_type() == PROJECTOR_TYPE_IDEFICS3) { // https://github.com/huggingface/transformers/blob/0a950e0bbe1ed58d5401a6b547af19f15f0c195e/src/transformers/models/idefics3/modeling_idefics3.py#L578 const int scale_factor = model.hparams.proj_scale_factor; @@ -630,7 +635,7 @@ struct clip_graph { const int n_pos = n_patches; const int num_position_ids = n_pos * 4; // m-rope requires 4 dim per position - norm_type norm_t = ctx->proj_type == PROJECTOR_TYPE_QWEN25VL + norm_type norm_t = ctx->proj_type() == PROJECTOR_TYPE_QWEN25VL ? NORM_TYPE_RMS // qwen 2.5 vl : NORM_TYPE_NORMAL; // qwen 2 vl @@ -846,11 +851,11 @@ struct clip_graph { const int d_head = 128; int n_head = n_embd/d_head; int num_query = 96; - if (ctx->minicpmv_version == 2) { + if (ctx->model.hparams.minicpmv_version == 2) { num_query = 96; - } else if (ctx->minicpmv_version == 3) { + } else if (ctx->model.hparams.minicpmv_version == 3) { num_query = 64; - } else if (ctx->minicpmv_version == 4) { + } else if (ctx->model.hparams.minicpmv_version == 4) { num_query = 64; } @@ -1067,7 +1072,7 @@ struct clip_graph { int il_last = hparams.n_layer - 1; int deepest_feature_layer = -1; - if (ctx->proj_type == PROJECTOR_TYPE_MINICPMV || ctx->proj_type == PROJECTOR_TYPE_GLM_EDGE) { + if (ctx->proj_type() == PROJECTOR_TYPE_MINICPMV || ctx->proj_type() == PROJECTOR_TYPE_GLM_EDGE) { il_last += 1; } @@ -1201,7 +1206,7 @@ struct clip_graph { } // llava projector (also used by granite) - if (ctx->has_llava_projector) { + if (ctx->model.hparams.has_llava_projector) { embeddings = ggml_reshape_2d(ctx0, embeddings, embeddings->ne[0], embeddings->ne[1]); ggml_tensor * patches = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_patches); @@ -1215,7 +1220,7 @@ struct clip_graph { // print_tensor_info(embeddings, "embeddings"); // llava projector - if (ctx->proj_type == PROJECTOR_TYPE_MLP) { + if (ctx->proj_type() == PROJECTOR_TYPE_MLP) { embeddings = ggml_mul_mat(ctx0, model.mm_0_w, embeddings); embeddings = ggml_add(ctx0, embeddings, model.mm_0_b); @@ -1225,7 +1230,7 @@ struct clip_graph { embeddings = ggml_add(ctx0, embeddings, model.mm_2_b); } } - else if (ctx->proj_type == PROJECTOR_TYPE_MLP_NORM) { + else if (ctx->proj_type() == PROJECTOR_TYPE_MLP_NORM) { embeddings = ggml_mul_mat(ctx0, model.mm_0_w, embeddings); embeddings = ggml_add(ctx0, embeddings, model.mm_0_b); // ggml_tensor_printf(embeddings, "mm_0_w",0,true,false); @@ -1246,7 +1251,7 @@ struct clip_graph { embeddings = ggml_add(ctx0, ggml_mul(ctx0, embeddings, model.mm_4_w), model.mm_4_b); } - else if (ctx->proj_type == PROJECTOR_TYPE_LDP) { + else if (ctx->proj_type() == PROJECTOR_TYPE_LDP) { // MobileVLM projector int n_patch = 24; ggml_tensor * mlp_1 = ggml_mul_mat(ctx0, model.mm_model_mlp_1_w, embeddings); @@ -1356,7 +1361,7 @@ struct clip_graph { } embeddings = block_1; } - else if (ctx->proj_type == PROJECTOR_TYPE_LDPV2) + else if (ctx->proj_type() == PROJECTOR_TYPE_LDPV2) { int n_patch = 24; ggml_tensor * mlp_0 = ggml_mul_mat(ctx0, model.mm_model_mlp_0_w, embeddings); @@ -1386,7 +1391,7 @@ struct clip_graph { } // glm projector - else if (ctx->proj_type == PROJECTOR_TYPE_GLM_EDGE) { + else if (ctx->proj_type() == PROJECTOR_TYPE_GLM_EDGE) { size_t gridsz = (size_t)sqrt(embeddings->ne[1]); embeddings = ggml_cont(ctx0, ggml_permute(ctx0,embeddings,1,0,2,3)); embeddings = ggml_reshape_3d(ctx0, embeddings, gridsz, gridsz, embeddings->ne[1]); @@ -1473,7 +1478,7 @@ struct clip_graph { cb(cur, "after_transformer", -1); - if (ctx->proj_type == PROJECTOR_TYPE_ULTRAVOX) { + if (ctx->proj_type() == PROJECTOR_TYPE_ULTRAVOX) { // StackAudioFrames // https://huggingface.co/fixie-ai/ultravox-v0_5-llama-3_2-1b/blob/main/ultravox_model.py { @@ -1518,7 +1523,7 @@ struct clip_graph { cur = ggml_mul_mat(ctx0, model.mm_2_w, cur); } - } else if (ctx->proj_type == PROJECTOR_TYPE_QWEN2A) { + } else if (ctx->proj_type() == PROJECTOR_TYPE_QWEN2A) { // projector cur = ggml_mul_mat(ctx0, model.mm_fc_w, cur); cur = ggml_add(ctx0, cur, model.mm_fc_b); @@ -1668,7 +1673,7 @@ struct clip_graph { } // TODO @ngxson : find a way to move this outside - if (ctx->proj_type == PROJECTOR_TYPE_QWEN2A) { + if (ctx->proj_type() == PROJECTOR_TYPE_QWEN2A) { ggml_tensor * cur = inpL; cur = ggml_transpose(ctx0, cur); cur = ggml_cont(ctx0, cur); @@ -1947,7 +1952,7 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32 ggml_cgraph * res; - switch (ctx->proj_type) { + switch (ctx->proj_type()) { case PROJECTOR_TYPE_GEMMA3: case PROJECTOR_TYPE_IDEFICS3: { @@ -1991,13 +1996,15 @@ struct clip_model_loader { ggml_context_ptr ctx_meta; gguf_context_ptr ctx_gguf; - clip_ctx & ctx_clip; std::string fname; size_t model_size = 0; // in bytes - // TODO @ngxson : we should not pass clip_ctx here, it should be clip_vision_model - clip_model_loader(const char * fname, clip_ctx & ctx_clip) : ctx_clip(ctx_clip), fname(fname) { + bool has_vision = false; + bool has_audio = false; + + // TODO @ngxson : we should not pass clip_ctx here, it should be clip_model + clip_model_loader(const char * fname) : fname(fname) { struct ggml_context * meta = nullptr; struct gguf_init_params params = { @@ -2029,6 +2036,19 @@ struct clip_model_loader { LOG_INF("\n"); } + // modalities + { + get_bool(KEY_HAS_VISION_ENC, has_vision, false); + get_bool(KEY_HAS_AUDIO_ENC, has_audio, false); + + if (has_vision) { + LOG_INF("%s: has vision encoder\n", __func__); + } + if (has_audio) { + LOG_INF("%s: has audio encoder\n", __func__); + } + } + // tensors { for (int i = 0; i < n_tensors; ++i) { @@ -2044,28 +2064,44 @@ struct clip_model_loader { } } - void load_hparams() { - auto & hparams = ctx_clip.vision_model.hparams; + void load_hparams(clip_model & model, clip_modality modality) { + auto & hparams = model.hparams; std::string log_ffn_op; // for logging + // sanity check + if (modality == CLIP_MODALITY_VISION) { + GGML_ASSERT(has_vision); + } else if (modality == CLIP_MODALITY_AUDIO) { + GGML_ASSERT(has_audio); + } + model.modality = modality; + + // projector type std::string proj_type; { get_string(KEY_PROJ_TYPE, proj_type, false); if (!proj_type.empty()) { - ctx_clip.proj_type = clip_projector_type_from_string(proj_type); + model.proj_type = clip_projector_type_from_string(proj_type); } - if (ctx_clip.proj_type == PROJECTOR_TYPE_UNKNOWN) { + if (model.proj_type == PROJECTOR_TYPE_UNKNOWN) { throw std::runtime_error(string_format("%s: unknown projector type: %s\n", __func__, proj_type.c_str())); } + + // correct arch for multimodal models + if (model.proj_type == PROJECTOR_TYPE_QWEN25O) { + model.proj_type = modality == CLIP_MODALITY_VISION + ? PROJECTOR_TYPE_QWEN25VL + : PROJECTOR_TYPE_QWEN2A; + } } + const bool is_vision = model.modality == CLIP_MODALITY_VISION; + const bool is_audio = model.modality == CLIP_MODALITY_AUDIO; + // other hparams { - get_bool(KEY_HAS_AUDIO_ENC, hparams.has_audio, false); - get_bool(KEY_HAS_VISION_ENC, hparams.has_vision, false); - - const char * prefix = hparams.has_vision ? "vision" : "audio"; + const char * prefix = is_vision ? "vision" : "audio"; get_u32(string_format(KEY_N_EMBD, prefix), hparams.n_embd); get_u32(string_format(KEY_N_HEAD, prefix), hparams.n_head); get_u32(string_format(KEY_N_FF, prefix), hparams.n_ff); @@ -2073,27 +2109,27 @@ struct clip_model_loader { get_u32(string_format(KEY_PROJ_DIM, prefix), hparams.projection_dim); get_f32(string_format(KEY_LAYER_NORM_EPS, prefix), hparams.eps); - if (hparams.has_vision) { + if (is_vision) { get_u32(KEY_IMAGE_SIZE, hparams.image_size); get_u32(KEY_PATCH_SIZE, hparams.patch_size); get_u32(KEY_IMAGE_CROP_RESOLUTION, hparams.image_crop_resolution, false); get_arr_int(KEY_IMAGE_GRID_PINPOINTS, hparams.image_grid_pinpoints, false); - get_i32(KEY_MINICPMV_VERSION, ctx_clip.minicpmv_version, false); // legacy + get_i32(KEY_MINICPMV_VERSION, hparams.minicpmv_version, false); // legacy - } else if (hparams.has_audio) { + } else if (is_audio) { get_u32(KEY_A_NUM_MEL_BINS, hparams.n_mel_bins); } else { - throw std::runtime_error(string_format("%s: neither vision nor audio encoder is present\n", __func__)); + GGML_ASSERT(false && "unknown modality"); } // default warmup value hparams.warmup_image_size = hparams.image_size; - ctx_clip.has_llava_projector = ctx_clip.proj_type == PROJECTOR_TYPE_MLP - || ctx_clip.proj_type == PROJECTOR_TYPE_MLP_NORM - || ctx_clip.proj_type == PROJECTOR_TYPE_LDP - || ctx_clip.proj_type == PROJECTOR_TYPE_LDPV2; + hparams.has_llava_projector = model.proj_type == PROJECTOR_TYPE_MLP + || model.proj_type == PROJECTOR_TYPE_MLP_NORM + || model.proj_type == PROJECTOR_TYPE_LDP + || model.proj_type == PROJECTOR_TYPE_LDPV2; { bool use_gelu = false; @@ -2123,7 +2159,7 @@ struct clip_model_loader { } } - if (hparams.has_vision) { + if (is_vision) { int idx_mean = gguf_find_key(ctx_gguf.get(), KEY_IMAGE_MEAN); int idx_std = gguf_find_key(ctx_gguf.get(), KEY_IMAGE_STD); GGML_ASSERT(idx_mean >= 0 && "image_mean not found"); @@ -2131,8 +2167,8 @@ struct clip_model_loader { const float * mean_data = (const float *) gguf_get_arr_data(ctx_gguf.get(), idx_mean); const float * std_data = (const float *) gguf_get_arr_data(ctx_gguf.get(), idx_std); for (int i = 0; i < 3; ++i) { - ctx_clip.image_mean[i] = mean_data[i]; - ctx_clip.image_std[i] = std_data[i]; + hparams.image_mean[i] = mean_data[i]; + hparams.image_std[i] = std_data[i]; } } @@ -2149,11 +2185,11 @@ struct clip_model_loader { } // model-specific params - switch (ctx_clip.proj_type) { + switch (model.proj_type) { case PROJECTOR_TYPE_MINICPMV: { - if (ctx_clip.minicpmv_version == 0) { - ctx_clip.minicpmv_version = 2; // default to 2 if not set + if (hparams.minicpmv_version == 0) { + hparams.minicpmv_version = 2; // default to 2 if not set } } break; case PROJECTOR_TYPE_IDEFICS3: @@ -2212,7 +2248,7 @@ struct clip_model_loader { case PROJECTOR_TYPE_ULTRAVOX: case PROJECTOR_TYPE_QWEN2A: { - bool require_stack = ctx_clip.proj_type == PROJECTOR_TYPE_ULTRAVOX; + bool require_stack = model.proj_type == PROJECTOR_TYPE_ULTRAVOX; get_u32(KEY_A_PROJ_STACK_FACTOR, hparams.proj_stack_factor, require_stack); if (hparams.n_mel_bins != 128) { throw std::runtime_error(string_format("%s: only 128 mel bins are supported for ultravox\n", __func__)); @@ -2225,23 +2261,22 @@ struct clip_model_loader { } LOG_INF("%s: projector: %s\n", __func__, proj_type.c_str()); - LOG_INF("%s: has_vision_encoder: %d\n", __func__, hparams.has_vision); - LOG_INF("%s: has_audio_encoder: %d\n", __func__, hparams.has_audio); LOG_INF("%s: n_embd: %d\n", __func__, hparams.n_embd); LOG_INF("%s: n_head: %d\n", __func__, hparams.n_head); LOG_INF("%s: n_ff: %d\n", __func__, hparams.n_ff); LOG_INF("%s: n_layer: %d\n", __func__, hparams.n_layer); LOG_INF("%s: ffn_op: %s\n", __func__, log_ffn_op.c_str()); LOG_INF("%s: projection_dim: %d\n", __func__, hparams.projection_dim); - LOG_INF("\n"); - if (hparams.has_vision) { + if (is_vision) { + LOG_INF("\n--- vision hparams ---\n"); LOG_INF("%s: image_size: %d\n", __func__, hparams.image_size); LOG_INF("%s: patch_size: %d\n", __func__, hparams.patch_size); - LOG_INF("%s: has_llava_proj: %d\n", __func__, ctx_clip.has_llava_projector); - LOG_INF("%s: minicpmv_version: %d\n", __func__, ctx_clip.minicpmv_version); + LOG_INF("%s: has_llava_proj: %d\n", __func__, hparams.has_llava_projector); + LOG_INF("%s: minicpmv_version: %d\n", __func__, hparams.minicpmv_version); LOG_INF("%s: proj_scale_factor: %d\n", __func__, hparams.proj_scale_factor); LOG_INF("%s: n_wa_pattern: %d\n", __func__, hparams.n_wa_pattern); - } else if (hparams.has_audio) { + } else if (is_audio) { + LOG_INF("\n--- audio hparams ---\n"); LOG_INF("%s: n_mel_bins: %d\n", __func__, hparams.n_mel_bins); LOG_INF("%s: proj_stack_factor: %d\n", __func__, hparams.proj_stack_factor); } @@ -2251,13 +2286,14 @@ struct clip_model_loader { } } - void load_tensors() { - auto & hparams = ctx_clip.vision_model.hparams; + void load_tensors(clip_ctx & ctx_clip) { + auto & model = ctx_clip.model; + auto & hparams = model.hparams; std::map tensor_offset; std::vector tensors_to_load; // TODO @ngxson : support both audio and video in the future - const char * prefix = hparams.has_audio ? "a" : "v"; + const char * prefix = model.modality == CLIP_MODALITY_AUDIO ? "a" : "v"; // get offsets for (int64_t i = 0; i < gguf_get_n_tensors(ctx_gguf.get()); ++i) { @@ -2292,26 +2328,24 @@ struct clip_model_loader { return cur; }; - auto & vision_model = ctx_clip.vision_model; // TODO: rename this to just "model" - - vision_model.class_embedding = get_tensor(TN_CLASS_EMBD, false); + model.class_embedding = get_tensor(TN_CLASS_EMBD, false); - vision_model.pre_ln_w = get_tensor(string_format(TN_LN_PRE, prefix, "weight"), false); - vision_model.pre_ln_b = get_tensor(string_format(TN_LN_PRE, prefix, "bias"), false); + model.pre_ln_w = get_tensor(string_format(TN_LN_PRE, prefix, "weight"), false); + model.pre_ln_b = get_tensor(string_format(TN_LN_PRE, prefix, "bias"), false); - vision_model.post_ln_w = get_tensor(string_format(TN_LN_POST, prefix, "weight"), false); - vision_model.post_ln_b = get_tensor(string_format(TN_LN_POST, prefix, "bias"), false); + model.post_ln_w = get_tensor(string_format(TN_LN_POST, prefix, "weight"), false); + model.post_ln_b = get_tensor(string_format(TN_LN_POST, prefix, "bias"), false); - vision_model.patch_bias = get_tensor(TN_PATCH_BIAS, false); - vision_model.patch_embeddings_0 = get_tensor(TN_PATCH_EMBD, false); - vision_model.patch_embeddings_1 = get_tensor(TN_PATCH_EMBD_1, false); + model.patch_bias = get_tensor(TN_PATCH_BIAS, false); + model.patch_embeddings_0 = get_tensor(TN_PATCH_EMBD, false); + model.patch_embeddings_1 = get_tensor(TN_PATCH_EMBD_1, false); - vision_model.position_embeddings = get_tensor(string_format(TN_POS_EMBD, prefix), false); + model.position_embeddings = get_tensor(string_format(TN_POS_EMBD, prefix), false); // layers - vision_model.layers.resize(hparams.n_layer); + model.layers.resize(hparams.n_layer); for (int il = 0; il < hparams.n_layer; ++il) { - auto & layer = vision_model.layers[il]; + auto & layer = model.layers[il]; layer.k_w = get_tensor(string_format(TN_ATTN_K, prefix, il, "weight")); layer.q_w = get_tensor(string_format(TN_ATTN_Q, prefix, il, "weight")); layer.v_w = get_tensor(string_format(TN_ATTN_V, prefix, il, "weight")); @@ -2352,166 +2386,166 @@ struct clip_model_loader { } } - switch (ctx_clip.proj_type) { + switch (model.proj_type) { case PROJECTOR_TYPE_MLP: case PROJECTOR_TYPE_MLP_NORM: { // LLaVA projection - vision_model.mm_0_w = get_tensor(string_format(TN_LLAVA_PROJ, 0, "weight"), false); - vision_model.mm_0_b = get_tensor(string_format(TN_LLAVA_PROJ, 0, "bias"), false); + model.mm_0_w = get_tensor(string_format(TN_LLAVA_PROJ, 0, "weight"), false); + model.mm_0_b = get_tensor(string_format(TN_LLAVA_PROJ, 0, "bias"), false); // Yi-type llava - vision_model.mm_1_w = get_tensor(string_format(TN_LLAVA_PROJ, 1, "weight"), false); - vision_model.mm_1_b = get_tensor(string_format(TN_LLAVA_PROJ, 1, "bias"), false); + model.mm_1_w = get_tensor(string_format(TN_LLAVA_PROJ, 1, "weight"), false); + model.mm_1_b = get_tensor(string_format(TN_LLAVA_PROJ, 1, "bias"), false); // missing in Yi-type llava - vision_model.mm_2_w = get_tensor(string_format(TN_LLAVA_PROJ, 2, "weight"), false); - vision_model.mm_2_b = get_tensor(string_format(TN_LLAVA_PROJ, 2, "bias"), false); + model.mm_2_w = get_tensor(string_format(TN_LLAVA_PROJ, 2, "weight"), false); + model.mm_2_b = get_tensor(string_format(TN_LLAVA_PROJ, 2, "bias"), false); // Yi-type llava - vision_model.mm_3_w = get_tensor(string_format(TN_LLAVA_PROJ, 3, "weight"), false); - vision_model.mm_3_b = get_tensor(string_format(TN_LLAVA_PROJ, 3, "bias"), false); - vision_model.mm_4_w = get_tensor(string_format(TN_LLAVA_PROJ, 4, "weight"), false); - vision_model.mm_4_b = get_tensor(string_format(TN_LLAVA_PROJ, 4, "bias"), false); - if (vision_model.mm_3_w) { + model.mm_3_w = get_tensor(string_format(TN_LLAVA_PROJ, 3, "weight"), false); + model.mm_3_b = get_tensor(string_format(TN_LLAVA_PROJ, 3, "bias"), false); + model.mm_4_w = get_tensor(string_format(TN_LLAVA_PROJ, 4, "weight"), false); + model.mm_4_b = get_tensor(string_format(TN_LLAVA_PROJ, 4, "bias"), false); + if (model.mm_3_w) { // TODO: this is a hack to support Yi-type llava - ctx_clip.proj_type = PROJECTOR_TYPE_MLP_NORM; + model.proj_type = PROJECTOR_TYPE_MLP_NORM; } - vision_model.image_newline = get_tensor(TN_IMAGE_NEWLINE, false); + model.image_newline = get_tensor(TN_IMAGE_NEWLINE, false); } break; case PROJECTOR_TYPE_LDP: { // MobileVLM projection - vision_model.mm_model_mlp_1_w = get_tensor(string_format(TN_MVLM_PROJ_MLP, 1, "weight")); - vision_model.mm_model_mlp_1_b = get_tensor(string_format(TN_MVLM_PROJ_MLP, 1, "bias")); - vision_model.mm_model_mlp_3_w = get_tensor(string_format(TN_MVLM_PROJ_MLP, 3, "weight")); - vision_model.mm_model_mlp_3_b = get_tensor(string_format(TN_MVLM_PROJ_MLP, 3, "bias")); - vision_model.mm_model_block_1_block_0_0_w = get_tensor(string_format(TN_MVLM_PROJ_BLOCK, 1, 0, "0.weight")); - vision_model.mm_model_block_1_block_0_1_w = get_tensor(string_format(TN_MVLM_PROJ_BLOCK, 1, 0, "1.weight")); - vision_model.mm_model_block_1_block_0_1_b = get_tensor(string_format(TN_MVLM_PROJ_BLOCK, 1, 0, "1.bias")); - vision_model.mm_model_block_1_block_1_fc1_w = get_tensor(string_format(TN_MVLM_PROJ_BLOCK, 1, 1, "fc1.weight")); - vision_model.mm_model_block_1_block_1_fc1_b = get_tensor(string_format(TN_MVLM_PROJ_BLOCK, 1, 1, "fc1.bias")); - vision_model.mm_model_block_1_block_1_fc2_w = get_tensor(string_format(TN_MVLM_PROJ_BLOCK, 1, 1, "fc2.weight")); - vision_model.mm_model_block_1_block_1_fc2_b = get_tensor(string_format(TN_MVLM_PROJ_BLOCK, 1, 1, "fc2.bias")); - vision_model.mm_model_block_1_block_2_0_w = get_tensor(string_format(TN_MVLM_PROJ_BLOCK, 1, 2, "0.weight")); - vision_model.mm_model_block_1_block_2_1_w = get_tensor(string_format(TN_MVLM_PROJ_BLOCK, 1, 2, "1.weight")); - vision_model.mm_model_block_1_block_2_1_b = get_tensor(string_format(TN_MVLM_PROJ_BLOCK, 1, 2, "1.bias")); - vision_model.mm_model_block_2_block_0_0_w = get_tensor(string_format(TN_MVLM_PROJ_BLOCK, 2, 0, "0.weight")); - vision_model.mm_model_block_2_block_0_1_w = get_tensor(string_format(TN_MVLM_PROJ_BLOCK, 2, 0, "1.weight")); - vision_model.mm_model_block_2_block_0_1_b = get_tensor(string_format(TN_MVLM_PROJ_BLOCK, 2, 0, "1.bias")); - vision_model.mm_model_block_2_block_1_fc1_w = get_tensor(string_format(TN_MVLM_PROJ_BLOCK, 2, 1, "fc1.weight")); - vision_model.mm_model_block_2_block_1_fc1_b = get_tensor(string_format(TN_MVLM_PROJ_BLOCK, 2, 1, "fc1.bias")); - vision_model.mm_model_block_2_block_1_fc2_w = get_tensor(string_format(TN_MVLM_PROJ_BLOCK, 2, 1, "fc2.weight")); - vision_model.mm_model_block_2_block_1_fc2_b = get_tensor(string_format(TN_MVLM_PROJ_BLOCK, 2, 1, "fc2.bias")); - vision_model.mm_model_block_2_block_2_0_w = get_tensor(string_format(TN_MVLM_PROJ_BLOCK, 2, 2, "0.weight")); - vision_model.mm_model_block_2_block_2_1_w = get_tensor(string_format(TN_MVLM_PROJ_BLOCK, 2, 2, "1.weight")); - vision_model.mm_model_block_2_block_2_1_b = get_tensor(string_format(TN_MVLM_PROJ_BLOCK, 2, 2, "1.bias")); + model.mm_model_mlp_1_w = get_tensor(string_format(TN_MVLM_PROJ_MLP, 1, "weight")); + model.mm_model_mlp_1_b = get_tensor(string_format(TN_MVLM_PROJ_MLP, 1, "bias")); + model.mm_model_mlp_3_w = get_tensor(string_format(TN_MVLM_PROJ_MLP, 3, "weight")); + model.mm_model_mlp_3_b = get_tensor(string_format(TN_MVLM_PROJ_MLP, 3, "bias")); + model.mm_model_block_1_block_0_0_w = get_tensor(string_format(TN_MVLM_PROJ_BLOCK, 1, 0, "0.weight")); + model.mm_model_block_1_block_0_1_w = get_tensor(string_format(TN_MVLM_PROJ_BLOCK, 1, 0, "1.weight")); + model.mm_model_block_1_block_0_1_b = get_tensor(string_format(TN_MVLM_PROJ_BLOCK, 1, 0, "1.bias")); + model.mm_model_block_1_block_1_fc1_w = get_tensor(string_format(TN_MVLM_PROJ_BLOCK, 1, 1, "fc1.weight")); + model.mm_model_block_1_block_1_fc1_b = get_tensor(string_format(TN_MVLM_PROJ_BLOCK, 1, 1, "fc1.bias")); + model.mm_model_block_1_block_1_fc2_w = get_tensor(string_format(TN_MVLM_PROJ_BLOCK, 1, 1, "fc2.weight")); + model.mm_model_block_1_block_1_fc2_b = get_tensor(string_format(TN_MVLM_PROJ_BLOCK, 1, 1, "fc2.bias")); + model.mm_model_block_1_block_2_0_w = get_tensor(string_format(TN_MVLM_PROJ_BLOCK, 1, 2, "0.weight")); + model.mm_model_block_1_block_2_1_w = get_tensor(string_format(TN_MVLM_PROJ_BLOCK, 1, 2, "1.weight")); + model.mm_model_block_1_block_2_1_b = get_tensor(string_format(TN_MVLM_PROJ_BLOCK, 1, 2, "1.bias")); + model.mm_model_block_2_block_0_0_w = get_tensor(string_format(TN_MVLM_PROJ_BLOCK, 2, 0, "0.weight")); + model.mm_model_block_2_block_0_1_w = get_tensor(string_format(TN_MVLM_PROJ_BLOCK, 2, 0, "1.weight")); + model.mm_model_block_2_block_0_1_b = get_tensor(string_format(TN_MVLM_PROJ_BLOCK, 2, 0, "1.bias")); + model.mm_model_block_2_block_1_fc1_w = get_tensor(string_format(TN_MVLM_PROJ_BLOCK, 2, 1, "fc1.weight")); + model.mm_model_block_2_block_1_fc1_b = get_tensor(string_format(TN_MVLM_PROJ_BLOCK, 2, 1, "fc1.bias")); + model.mm_model_block_2_block_1_fc2_w = get_tensor(string_format(TN_MVLM_PROJ_BLOCK, 2, 1, "fc2.weight")); + model.mm_model_block_2_block_1_fc2_b = get_tensor(string_format(TN_MVLM_PROJ_BLOCK, 2, 1, "fc2.bias")); + model.mm_model_block_2_block_2_0_w = get_tensor(string_format(TN_MVLM_PROJ_BLOCK, 2, 2, "0.weight")); + model.mm_model_block_2_block_2_1_w = get_tensor(string_format(TN_MVLM_PROJ_BLOCK, 2, 2, "1.weight")); + model.mm_model_block_2_block_2_1_b = get_tensor(string_format(TN_MVLM_PROJ_BLOCK, 2, 2, "1.bias")); } break; case PROJECTOR_TYPE_LDPV2: { // MobilVLM_V2 projection - vision_model.mm_model_mlp_0_w = get_tensor(string_format(TN_MVLM_PROJ_MLP, 0, "weight")); - vision_model.mm_model_mlp_0_b = get_tensor(string_format(TN_MVLM_PROJ_MLP, 0, "bias")); - vision_model.mm_model_mlp_2_w = get_tensor(string_format(TN_MVLM_PROJ_MLP, 2, "weight")); - vision_model.mm_model_mlp_2_b = get_tensor(string_format(TN_MVLM_PROJ_MLP, 2, "bias")); - vision_model.mm_model_peg_0_w = get_tensor(string_format(TN_MVLM_PROJ_PEG, 0, "weight")); - vision_model.mm_model_peg_0_b = get_tensor(string_format(TN_MVLM_PROJ_PEG, 0, "bias")); + model.mm_model_mlp_0_w = get_tensor(string_format(TN_MVLM_PROJ_MLP, 0, "weight")); + model.mm_model_mlp_0_b = get_tensor(string_format(TN_MVLM_PROJ_MLP, 0, "bias")); + model.mm_model_mlp_2_w = get_tensor(string_format(TN_MVLM_PROJ_MLP, 2, "weight")); + model.mm_model_mlp_2_b = get_tensor(string_format(TN_MVLM_PROJ_MLP, 2, "bias")); + model.mm_model_peg_0_w = get_tensor(string_format(TN_MVLM_PROJ_PEG, 0, "weight")); + model.mm_model_peg_0_b = get_tensor(string_format(TN_MVLM_PROJ_PEG, 0, "bias")); } break; case PROJECTOR_TYPE_MINICPMV: { - // vision_model.mm_model_pos_embed = get_tensor(new_clip->ctx_data, TN_MINICPMV_POS_EMBD); - vision_model.mm_model_pos_embed_k = get_tensor(TN_MINICPMV_POS_EMBD_K); - vision_model.mm_model_query = get_tensor(TN_MINICPMV_QUERY); - vision_model.mm_model_proj = get_tensor(TN_MINICPMV_PROJ); - vision_model.mm_model_kv_proj = get_tensor(TN_MINICPMV_KV_PROJ); - vision_model.mm_model_attn_q_w = get_tensor(string_format(TN_MINICPMV_ATTN, "q", "weight")); - vision_model.mm_model_attn_k_w = get_tensor(string_format(TN_MINICPMV_ATTN, "k", "weight")); - vision_model.mm_model_attn_v_w = get_tensor(string_format(TN_MINICPMV_ATTN, "v", "weight")); - vision_model.mm_model_attn_q_b = get_tensor(string_format(TN_MINICPMV_ATTN, "q", "bias")); - vision_model.mm_model_attn_k_b = get_tensor(string_format(TN_MINICPMV_ATTN, "k", "bias")); - vision_model.mm_model_attn_v_b = get_tensor(string_format(TN_MINICPMV_ATTN, "v", "bias")); - vision_model.mm_model_attn_o_w = get_tensor(string_format(TN_MINICPMV_ATTN, "out", "weight")); - vision_model.mm_model_attn_o_b = get_tensor(string_format(TN_MINICPMV_ATTN, "out", "bias")); - vision_model.mm_model_ln_q_w = get_tensor(string_format(TN_MINICPMV_LN, "q", "weight")); - vision_model.mm_model_ln_q_b = get_tensor(string_format(TN_MINICPMV_LN, "q", "bias")); - vision_model.mm_model_ln_kv_w = get_tensor(string_format(TN_MINICPMV_LN, "kv", "weight")); - vision_model.mm_model_ln_kv_b = get_tensor(string_format(TN_MINICPMV_LN, "kv", "bias")); - vision_model.mm_model_ln_post_w = get_tensor(string_format(TN_MINICPMV_LN, "post", "weight")); - vision_model.mm_model_ln_post_b = get_tensor(string_format(TN_MINICPMV_LN, "post", "bias")); + // model.mm_model_pos_embed = get_tensor(new_clip->ctx_data, TN_MINICPMV_POS_EMBD); + model.mm_model_pos_embed_k = get_tensor(TN_MINICPMV_POS_EMBD_K); + model.mm_model_query = get_tensor(TN_MINICPMV_QUERY); + model.mm_model_proj = get_tensor(TN_MINICPMV_PROJ); + model.mm_model_kv_proj = get_tensor(TN_MINICPMV_KV_PROJ); + model.mm_model_attn_q_w = get_tensor(string_format(TN_MINICPMV_ATTN, "q", "weight")); + model.mm_model_attn_k_w = get_tensor(string_format(TN_MINICPMV_ATTN, "k", "weight")); + model.mm_model_attn_v_w = get_tensor(string_format(TN_MINICPMV_ATTN, "v", "weight")); + model.mm_model_attn_q_b = get_tensor(string_format(TN_MINICPMV_ATTN, "q", "bias")); + model.mm_model_attn_k_b = get_tensor(string_format(TN_MINICPMV_ATTN, "k", "bias")); + model.mm_model_attn_v_b = get_tensor(string_format(TN_MINICPMV_ATTN, "v", "bias")); + model.mm_model_attn_o_w = get_tensor(string_format(TN_MINICPMV_ATTN, "out", "weight")); + model.mm_model_attn_o_b = get_tensor(string_format(TN_MINICPMV_ATTN, "out", "bias")); + model.mm_model_ln_q_w = get_tensor(string_format(TN_MINICPMV_LN, "q", "weight")); + model.mm_model_ln_q_b = get_tensor(string_format(TN_MINICPMV_LN, "q", "bias")); + model.mm_model_ln_kv_w = get_tensor(string_format(TN_MINICPMV_LN, "kv", "weight")); + model.mm_model_ln_kv_b = get_tensor(string_format(TN_MINICPMV_LN, "kv", "bias")); + model.mm_model_ln_post_w = get_tensor(string_format(TN_MINICPMV_LN, "post", "weight")); + model.mm_model_ln_post_b = get_tensor(string_format(TN_MINICPMV_LN, "post", "bias")); } break; case PROJECTOR_TYPE_GLM_EDGE: { - vision_model.mm_model_adapter_conv_w = get_tensor(string_format(TN_GLM_ADAPER_CONV, "weight")); - vision_model.mm_model_adapter_conv_b = get_tensor(string_format(TN_GLM_ADAPER_CONV, "bias")); - vision_model.mm_model_mlp_0_w = get_tensor(string_format(TN_GLM_ADAPTER_LINEAR, "weight")); - vision_model.mm_model_ln_q_w = get_tensor(string_format(TN_GLM_ADAPTER_NORM_1, "weight")); - vision_model.mm_model_ln_q_b = get_tensor(string_format(TN_GLM_ADAPTER_NORM_1, "bias")); - vision_model.mm_model_mlp_1_w = get_tensor(string_format(TN_GLM_ADAPTER_D_H_2_4H, "weight")); - vision_model.mm_model_mlp_2_w = get_tensor(string_format(TN_GLM_ADAPTER_GATE, "weight")); - vision_model.mm_model_mlp_3_w = get_tensor(string_format(TN_GLM_ADAPTER_D_4H_2_H, "weight")); - vision_model.mm_glm_tok_boi = get_tensor(string_format(TN_TOK_GLM_BOI, "weight")); - vision_model.mm_glm_tok_eoi = get_tensor(string_format(TN_TOK_GLM_EOI, "weight")); + model.mm_model_adapter_conv_w = get_tensor(string_format(TN_GLM_ADAPER_CONV, "weight")); + model.mm_model_adapter_conv_b = get_tensor(string_format(TN_GLM_ADAPER_CONV, "bias")); + model.mm_model_mlp_0_w = get_tensor(string_format(TN_GLM_ADAPTER_LINEAR, "weight")); + model.mm_model_ln_q_w = get_tensor(string_format(TN_GLM_ADAPTER_NORM_1, "weight")); + model.mm_model_ln_q_b = get_tensor(string_format(TN_GLM_ADAPTER_NORM_1, "bias")); + model.mm_model_mlp_1_w = get_tensor(string_format(TN_GLM_ADAPTER_D_H_2_4H, "weight")); + model.mm_model_mlp_2_w = get_tensor(string_format(TN_GLM_ADAPTER_GATE, "weight")); + model.mm_model_mlp_3_w = get_tensor(string_format(TN_GLM_ADAPTER_D_4H_2_H, "weight")); + model.mm_glm_tok_boi = get_tensor(string_format(TN_TOK_GLM_BOI, "weight")); + model.mm_glm_tok_eoi = get_tensor(string_format(TN_TOK_GLM_EOI, "weight")); } break; case PROJECTOR_TYPE_QWEN2VL: case PROJECTOR_TYPE_QWEN25VL: { - vision_model.mm_0_w = get_tensor(string_format(TN_LLAVA_PROJ, 0, "weight")); - vision_model.mm_0_b = get_tensor(string_format(TN_LLAVA_PROJ, 0, "bias")); - vision_model.mm_1_w = get_tensor(string_format(TN_LLAVA_PROJ, 2, "weight")); - vision_model.mm_1_b = get_tensor(string_format(TN_LLAVA_PROJ, 2, "bias")); + model.mm_0_w = get_tensor(string_format(TN_LLAVA_PROJ, 0, "weight")); + model.mm_0_b = get_tensor(string_format(TN_LLAVA_PROJ, 0, "bias")); + model.mm_1_w = get_tensor(string_format(TN_LLAVA_PROJ, 2, "weight")); + model.mm_1_b = get_tensor(string_format(TN_LLAVA_PROJ, 2, "bias")); } break; case PROJECTOR_TYPE_GEMMA3: { - vision_model.mm_input_proj_w = get_tensor(TN_MM_INP_PROJ); - vision_model.mm_soft_emb_norm_w = get_tensor(TN_MM_SOFT_EMB_N); + model.mm_input_proj_w = get_tensor(TN_MM_INP_PROJ); + model.mm_soft_emb_norm_w = get_tensor(TN_MM_SOFT_EMB_N); } break; case PROJECTOR_TYPE_IDEFICS3: { - vision_model.projection = get_tensor(TN_MM_PROJECTOR); + model.projection = get_tensor(TN_MM_PROJECTOR); } break; case PROJECTOR_TYPE_PIXTRAL: { - vision_model.mm_1_w = get_tensor(string_format(TN_LLAVA_PROJ, 1, "weight")); - vision_model.mm_1_b = get_tensor(string_format(TN_LLAVA_PROJ, 1, "bias"), false); - vision_model.mm_2_w = get_tensor(string_format(TN_LLAVA_PROJ, 2, "weight")); - vision_model.mm_2_b = get_tensor(string_format(TN_LLAVA_PROJ, 2, "bias"), false); + model.mm_1_w = get_tensor(string_format(TN_LLAVA_PROJ, 1, "weight")); + model.mm_1_b = get_tensor(string_format(TN_LLAVA_PROJ, 1, "bias"), false); + model.mm_2_w = get_tensor(string_format(TN_LLAVA_PROJ, 2, "weight")); + model.mm_2_b = get_tensor(string_format(TN_LLAVA_PROJ, 2, "bias"), false); // [IMG_BREAK] token embedding - vision_model.token_embd_img_break = get_tensor(TN_TOK_IMG_BREAK); + model.token_embd_img_break = get_tensor(TN_TOK_IMG_BREAK); // for mistral small 3.1 - vision_model.mm_input_norm_w = get_tensor(TN_MM_INP_NORM, false); - vision_model.mm_patch_merger_w = get_tensor(TN_MM_PATCH_MERGER, false); + model.mm_input_norm_w = get_tensor(TN_MM_INP_NORM, false); + model.mm_patch_merger_w = get_tensor(TN_MM_PATCH_MERGER, false); } break; case PROJECTOR_TYPE_ULTRAVOX: { - vision_model.conv1d_1_w = get_tensor(string_format(TN_CONV1D, 1, "weight")); - vision_model.conv1d_1_b = get_tensor(string_format(TN_CONV1D, 1, "bias")); - vision_model.conv1d_2_w = get_tensor(string_format(TN_CONV1D, 2, "weight")); - vision_model.conv1d_2_b = get_tensor(string_format(TN_CONV1D, 2, "bias")); - vision_model.mm_1_w = get_tensor(string_format(TN_MM_AUDIO_MLP, 1, "weight")); - vision_model.mm_2_w = get_tensor(string_format(TN_MM_AUDIO_MLP, 2, "weight")); - vision_model.mm_norm_pre_w = get_tensor(string_format(TN_MM_NORM_PRE, "weight")); - vision_model.mm_norm_mid_w = get_tensor(string_format(TN_MM_NORM_MID, "weight")); + model.conv1d_1_w = get_tensor(string_format(TN_CONV1D, 1, "weight")); + model.conv1d_1_b = get_tensor(string_format(TN_CONV1D, 1, "bias")); + model.conv1d_2_w = get_tensor(string_format(TN_CONV1D, 2, "weight")); + model.conv1d_2_b = get_tensor(string_format(TN_CONV1D, 2, "bias")); + model.mm_1_w = get_tensor(string_format(TN_MM_AUDIO_MLP, 1, "weight")); + model.mm_2_w = get_tensor(string_format(TN_MM_AUDIO_MLP, 2, "weight")); + model.mm_norm_pre_w = get_tensor(string_format(TN_MM_NORM_PRE, "weight")); + model.mm_norm_mid_w = get_tensor(string_format(TN_MM_NORM_MID, "weight")); } break; case PROJECTOR_TYPE_QWEN2A: { - vision_model.conv1d_1_w = get_tensor(string_format(TN_CONV1D, 1, "weight")); - vision_model.conv1d_1_b = get_tensor(string_format(TN_CONV1D, 1, "bias")); - vision_model.conv1d_2_w = get_tensor(string_format(TN_CONV1D, 2, "weight")); - vision_model.conv1d_2_b = get_tensor(string_format(TN_CONV1D, 2, "bias")); - vision_model.mm_fc_w = get_tensor(string_format(TN_MM_AUDIO_FC, "weight")); - vision_model.mm_fc_b = get_tensor(string_format(TN_MM_AUDIO_FC, "bias")); + model.conv1d_1_w = get_tensor(string_format(TN_CONV1D, 1, "weight")); + model.conv1d_1_b = get_tensor(string_format(TN_CONV1D, 1, "bias")); + model.conv1d_2_w = get_tensor(string_format(TN_CONV1D, 2, "weight")); + model.conv1d_2_b = get_tensor(string_format(TN_CONV1D, 2, "bias")); + model.mm_fc_w = get_tensor(string_format(TN_MM_AUDIO_FC, "weight")); + model.mm_fc_b = get_tensor(string_format(TN_MM_AUDIO_FC, "bias")); } break; case PROJECTOR_TYPE_INTERNVL: { - vision_model.mm_0_w = get_tensor(string_format(TN_MVLM_PROJ_MLP, 0, "weight")); - vision_model.mm_0_b = get_tensor(string_format(TN_MVLM_PROJ_MLP, 0, "bias")); - vision_model.mm_1_w = get_tensor(string_format(TN_MVLM_PROJ_MLP, 1, "weight")); - vision_model.mm_1_b = get_tensor(string_format(TN_MVLM_PROJ_MLP, 1, "bias")); - vision_model.mm_3_w = get_tensor(string_format(TN_MVLM_PROJ_MLP, 3, "weight")); - vision_model.mm_3_b = get_tensor(string_format(TN_MVLM_PROJ_MLP, 3, "bias")); + model.mm_0_w = get_tensor(string_format(TN_MVLM_PROJ_MLP, 0, "weight")); + model.mm_0_b = get_tensor(string_format(TN_MVLM_PROJ_MLP, 0, "bias")); + model.mm_1_w = get_tensor(string_format(TN_MVLM_PROJ_MLP, 1, "weight")); + model.mm_1_b = get_tensor(string_format(TN_MVLM_PROJ_MLP, 1, "bias")); + model.mm_3_w = get_tensor(string_format(TN_MVLM_PROJ_MLP, 3, "weight")); + model.mm_3_b = get_tensor(string_format(TN_MVLM_PROJ_MLP, 3, "bias")); } break; case PROJECTOR_TYPE_LLAMA4: { - vision_model.mm_model_proj = get_tensor(TN_MM_PROJECTOR); - vision_model.mm_model_mlp_1_w = get_tensor(string_format(TN_MVLM_PROJ_MLP, 1, "weight")); - vision_model.mm_model_mlp_2_w = get_tensor(string_format(TN_MVLM_PROJ_MLP, 2, "weight")); + model.mm_model_proj = get_tensor(TN_MM_PROJECTOR); + model.mm_model_mlp_1_w = get_tensor(string_format(TN_MVLM_PROJ_MLP, 1, "weight")); + model.mm_model_mlp_2_w = get_tensor(string_format(TN_MVLM_PROJ_MLP, 2, "weight")); } break; default: GGML_ASSERT(false && "unknown projector type"); @@ -2554,21 +2588,20 @@ struct clip_model_loader { } } - void alloc_compute_meta() { - const auto & hparams = ctx_clip.vision_model.hparams; + void alloc_compute_meta(clip_ctx & ctx_clip) { + const auto & hparams = ctx_clip.model.hparams; ctx_clip.buf_compute_meta.resize(ctx_clip.max_nodes * ggml_tensor_overhead() + ggml_graph_overhead()); // create a fake batch clip_image_f32_batch batch; clip_image_f32_ptr img(clip_image_f32_init()); - if (hparams.has_vision) { + if (ctx_clip.model.modality == CLIP_MODALITY_VISION) { img->nx = hparams.warmup_image_size; img->ny = hparams.warmup_image_size; } else { - img->nx = 1024; // TODO @ngxson : use a better default + img->nx = hparams.warmup_audio_size; img->ny = hparams.n_mel_bins; } - img->buf.resize(img->nx * img->ny * 3); batch.entries.push_back(std::move(img)); ggml_cgraph * gf = clip_image_build_graph(&ctx_clip, batch); @@ -2646,23 +2679,40 @@ struct clip_model_loader { } }; -struct clip_ctx * clip_init(const char * fname, struct clip_context_params ctx_params) { +struct clip_init_result clip_init(const char * fname, struct clip_context_params ctx_params) { g_logger_state.verbosity_thold = ctx_params.verbosity; - clip_ctx * ctx_clip = nullptr; + clip_ctx * ctx_vision = nullptr; + clip_ctx * ctx_audio = nullptr; try { - ctx_clip = new clip_ctx(ctx_params); - clip_model_loader loader(fname, *ctx_clip); - loader.load_hparams(); - loader.load_tensors(); - loader.alloc_compute_meta(); + clip_model_loader loader(fname); + + if (loader.has_vision) { + ctx_vision = new clip_ctx(ctx_params); + loader.load_hparams(ctx_vision->model, CLIP_MODALITY_VISION); + loader.load_tensors(*ctx_vision); + loader.alloc_compute_meta(*ctx_vision); + } + + if (loader.has_audio) { + ctx_audio = new clip_ctx(ctx_params); + loader.load_hparams(ctx_audio->model, CLIP_MODALITY_AUDIO); + loader.load_tensors(*ctx_audio); + loader.alloc_compute_meta(*ctx_audio); + } + } catch (const std::exception & e) { LOG_ERR("%s: failed to load model '%s': %s\n", __func__, fname, e.what()); - delete ctx_clip; - return nullptr; + if (ctx_vision) { + delete ctx_vision; + } + if (ctx_audio) { + delete ctx_audio; + } + return {nullptr, nullptr}; } - return ctx_clip; + return {ctx_vision, ctx_audio}; } struct clip_image_size * clip_image_size_init() { @@ -3023,12 +3073,12 @@ struct llava_uhd { const float ratio = (float)original_width * original_height / (slice_size * slice_size); const int multiple = fmin(ceil(ratio), max_slice_nums); const bool has_slices = (multiple > 1); - const bool has_pinpoints = !ctx->vision_model.hparams.image_grid_pinpoints.empty(); + const bool has_pinpoints = !ctx->model.hparams.image_grid_pinpoints.empty(); if (has_pinpoints) { // has pinpoints, use them to calculate the grid size (e.g. llava-1.6) auto refine_size = llava_uhd::select_best_resolution( - ctx->vision_model.hparams.image_grid_pinpoints, + ctx->model.hparams.image_grid_pinpoints, original_size); res.overview_size = clip_image_size{slice_size, slice_size}; res.refined_size = refine_size; @@ -3250,7 +3300,7 @@ struct llava_uhd { bool clip_image_preprocess(struct clip_ctx * ctx, const clip_image_u8 * img, struct clip_image_f32_batch * res_imgs) { clip_image_size original_size{img->nx, img->ny}; bool pad_to_square = true; - auto & params = ctx->vision_model.hparams; + auto & params = ctx->model.hparams; // The model config actually contains all we need to decide on how to preprocess, here we automatically switch to the new llava-1.6 preprocessing if (params.mm_patch_merge_type == PATCH_MERGE_SPATIAL_UNPAD) { pad_to_square = false; @@ -3263,7 +3313,7 @@ bool clip_image_preprocess(struct clip_ctx * ctx, const clip_image_u8 * img, str for (size_t i = 0; i < imgs.size(); ++i) { // clip_image_save_to_bmp(*imgs[i], "slice_" + std::to_string(i) + ".bmp"); clip_image_f32_ptr res(clip_image_f32_init()); - normalize_image_u8_to_f32(*imgs[i], *res, ctx->image_mean, ctx->image_std); + normalize_image_u8_to_f32(*imgs[i], *res, params.image_mean, params.image_std); res_imgs->entries.push_back(std::move(res)); } @@ -3271,7 +3321,7 @@ bool clip_image_preprocess(struct clip_ctx * ctx, const clip_image_u8 * img, str res_imgs->grid_y = inst.grid_size.height; return true; - } else if (ctx->proj_type == PROJECTOR_TYPE_QWEN2VL || ctx->proj_type == PROJECTOR_TYPE_QWEN25VL) { + } else if (ctx->proj_type() == PROJECTOR_TYPE_QWEN2VL || ctx->proj_type() == PROJECTOR_TYPE_QWEN25VL) { clip_image_u8 resized; auto patch_size = params.patch_size * 2; auto new_size = image_manipulation::calc_size_preserved_ratio(original_size, patch_size, params.image_size); @@ -3279,42 +3329,42 @@ bool clip_image_preprocess(struct clip_ctx * ctx, const clip_image_u8 * img, str clip_image_f32_ptr img_f32(clip_image_f32_init()); // clip_image_f32_ptr res(clip_image_f32_init()); - normalize_image_u8_to_f32(resized, *img_f32, ctx->image_mean, ctx->image_std); + normalize_image_u8_to_f32(resized, *img_f32, params.image_mean, params.image_std); // res_imgs->data[0] = *res; res_imgs->entries.push_back(std::move(img_f32)); return true; } - else if (ctx->proj_type == PROJECTOR_TYPE_GLM_EDGE - || ctx->proj_type == PROJECTOR_TYPE_GEMMA3 - || ctx->proj_type == PROJECTOR_TYPE_IDEFICS3 - || ctx->proj_type == PROJECTOR_TYPE_INTERNVL // TODO @ngxson : support dynamic resolution + else if (ctx->proj_type() == PROJECTOR_TYPE_GLM_EDGE + || ctx->proj_type() == PROJECTOR_TYPE_GEMMA3 + || ctx->proj_type() == PROJECTOR_TYPE_IDEFICS3 + || ctx->proj_type() == PROJECTOR_TYPE_INTERNVL // TODO @ngxson : support dynamic resolution ) { clip_image_u8 resized_image; int sz = params.image_size; image_manipulation::resize_and_pad_image(*img, resized_image, {sz, sz}); clip_image_f32_ptr img_f32(clip_image_f32_init()); //clip_image_save_to_bmp(resized_image, "resized.bmp"); - normalize_image_u8_to_f32(resized_image, *img_f32, ctx->image_mean, ctx->image_std); + normalize_image_u8_to_f32(resized_image, *img_f32, params.image_mean, params.image_std); res_imgs->entries.push_back(std::move(img_f32)); return true; - } else if (ctx->proj_type == PROJECTOR_TYPE_PIXTRAL) { + } else if (ctx->proj_type() == PROJECTOR_TYPE_PIXTRAL) { clip_image_u8 resized_image; auto new_size = image_manipulation::calc_size_preserved_ratio(original_size, params.patch_size, params.image_size); image_manipulation::bilinear_resize(*img, resized_image, new_size.width, new_size.height); clip_image_f32_ptr img_f32(clip_image_f32_init()); - normalize_image_u8_to_f32(resized_image, *img_f32, ctx->image_mean, ctx->image_std); + normalize_image_u8_to_f32(resized_image, *img_f32, params.image_mean, params.image_std); res_imgs->entries.push_back(std::move(img_f32)); return true; - } else if (ctx->proj_type == PROJECTOR_TYPE_LLAMA4) { + } else if (ctx->proj_type() == PROJECTOR_TYPE_LLAMA4) { GGML_ASSERT(!params.image_grid_pinpoints.empty()); auto const inst = llava_uhd::get_slice_instructions(ctx, original_size); std::vector imgs = llava_uhd::slice_image(img, inst); for (size_t i = 0; i < imgs.size(); ++i) { clip_image_f32_ptr res(clip_image_f32_init()); - normalize_image_u8_to_f32(*imgs[i], *res, ctx->image_mean, ctx->image_std); + normalize_image_u8_to_f32(*imgs[i], *res, params.image_mean, params.image_std); res_imgs->entries.push_back(std::move(res)); } @@ -3344,7 +3394,7 @@ bool clip_image_preprocess(struct clip_ctx * ctx, const clip_image_u8 * img, str image_manipulation::resize_and_pad_image(*img, *temp, clip_image_size{params.image_size, params.image_size}, pad_color); clip_image_f32_ptr res(clip_image_f32_init()); - normalize_image_u8_to_f32(*temp, *res, ctx->image_mean, ctx->image_std); + normalize_image_u8_to_f32(*temp, *res, params.image_mean, params.image_std); res_imgs->entries.push_back(std::move(res)); return true; @@ -3356,7 +3406,7 @@ bool clip_image_preprocess(struct clip_ctx * ctx, const clip_image_u8 * img, str for (size_t i = 0; i < imgs.size(); ++i) { // clip_image_save_to_bmp(*imgs[i], "slice_" + std::to_string(i) + ".bmp"); clip_image_f32_ptr res(clip_image_f32_init()); - normalize_image_u8_to_f32(*imgs[i], *res, ctx->image_mean, ctx->image_std); + normalize_image_u8_to_f32(*imgs[i], *res, params.image_mean, params.image_std); res_imgs->entries.push_back(std::move(res)); } @@ -3368,7 +3418,7 @@ bool clip_image_preprocess(struct clip_ctx * ctx, const clip_image_u8 * img, str } ggml_tensor * clip_get_newline_tensor(const struct clip_ctx * ctx) { - return ctx->vision_model.image_newline; + return ctx->model.image_newline; } void clip_free(clip_ctx * ctx) { @@ -3380,8 +3430,8 @@ void clip_free(clip_ctx * ctx) { // deprecated size_t clip_embd_nbytes(const struct clip_ctx * ctx) { - const int32_t nx = ctx->vision_model.hparams.image_size; - const int32_t ny = ctx->vision_model.hparams.image_size; + const int32_t nx = ctx->model.hparams.image_size; + const int32_t ny = ctx->model.hparams.image_size; return clip_embd_nbytes_by_img(ctx, nx, ny); } @@ -3393,105 +3443,135 @@ size_t clip_embd_nbytes_by_img(const struct clip_ctx * ctx, int img_w, int img_h } int32_t clip_get_image_size(const struct clip_ctx * ctx) { - return ctx->vision_model.hparams.image_size; + return ctx->model.hparams.image_size; } int32_t clip_get_patch_size(const struct clip_ctx * ctx) { - return ctx->vision_model.hparams.patch_size; + return ctx->model.hparams.patch_size; } int32_t clip_get_hidden_size(const struct clip_ctx * ctx) { - return ctx->vision_model.hparams.n_embd; + return ctx->model.hparams.n_embd; } const char * clip_patch_merge_type(const struct clip_ctx * ctx) { - return ctx->vision_model.hparams.mm_patch_merge_type == PATCH_MERGE_SPATIAL_UNPAD ? "spatial_unpad" : "flat"; + return ctx->model.hparams.mm_patch_merge_type == PATCH_MERGE_SPATIAL_UNPAD ? "spatial_unpad" : "flat"; } const int32_t * clip_image_grid(const struct clip_ctx * ctx) { - if (ctx->vision_model.hparams.image_grid_pinpoints.size()) { - return &ctx->vision_model.hparams.image_grid_pinpoints.front(); + if (ctx->model.hparams.image_grid_pinpoints.size()) { + return &ctx->model.hparams.image_grid_pinpoints.front(); } return nullptr; } size_t get_clip_image_grid_size(const struct clip_ctx * ctx) { - return ctx->vision_model.hparams.image_grid_pinpoints.size(); + return ctx->model.hparams.image_grid_pinpoints.size(); } int clip_n_output_tokens_x(const struct clip_ctx * ctx, struct clip_image_f32 * img) { - const auto & params = ctx->vision_model.hparams; + const auto & params = ctx->model.hparams; const int n_total = clip_n_output_tokens(ctx, img); - if (ctx->proj_type == PROJECTOR_TYPE_QWEN2VL || ctx->proj_type == PROJECTOR_TYPE_QWEN25VL) { + if (ctx->proj_type() == PROJECTOR_TYPE_QWEN2VL || ctx->proj_type() == PROJECTOR_TYPE_QWEN25VL) { return img->nx / (params.patch_size * 2) + (int)(img->nx % params.patch_size > 0); } return n_total; } int clip_n_output_tokens_y(const struct clip_ctx * ctx, struct clip_image_f32 * img) { - const auto & params = ctx->vision_model.hparams; - if (ctx->proj_type == PROJECTOR_TYPE_QWEN2VL || ctx->proj_type == PROJECTOR_TYPE_QWEN25VL) { + const auto & params = ctx->model.hparams; + if (ctx->proj_type() == PROJECTOR_TYPE_QWEN2VL || ctx->proj_type() == PROJECTOR_TYPE_QWEN25VL) { return img->ny / (params.patch_size * 2) + (int)(img->ny % params.patch_size > 0); } return 1; } int clip_n_output_tokens(const struct clip_ctx * ctx, struct clip_image_f32 * img) { - const auto & params = ctx->vision_model.hparams; + const auto & params = ctx->model.hparams; - int n_patches = (params.image_size / params.patch_size) * (params.image_size / params.patch_size); - int scale_factor = ctx->vision_model.hparams.proj_scale_factor; + // only for models using fixed size square images + int n_patches_sq = (params.image_size / params.patch_size) * (params.image_size / params.patch_size); - if (ctx->proj_type == PROJECTOR_TYPE_LDP - || ctx->proj_type == PROJECTOR_TYPE_LDPV2 - || ctx->proj_type == PROJECTOR_TYPE_GLM_EDGE) { - n_patches /= 4; - if (ctx->vision_model.mm_glm_tok_boi) { - n_patches += 2; // for BOI and EOI token embeddings - } - } else if (ctx->proj_type == PROJECTOR_TYPE_MINICPMV) { - if (ctx->minicpmv_version == 2) { - n_patches = 96; - } - else if (ctx->minicpmv_version == 3) { - n_patches = 64; - } - else if (ctx->minicpmv_version == 4) { - n_patches = 64; - } - else { - GGML_ABORT("Unknown minicpmv version"); - } - } else if (ctx->proj_type == PROJECTOR_TYPE_QWEN2VL || ctx->proj_type == PROJECTOR_TYPE_QWEN25VL) { - int patch_size = params.patch_size * 2; - int x_patch = img->nx / patch_size + (int)(img->nx % patch_size > 0); - int y_patch = img->ny / patch_size + (int)(img->ny % patch_size > 0); - n_patches = x_patch * y_patch; - } else if (ctx->proj_type == PROJECTOR_TYPE_GEMMA3) { - int n_per_side = params.image_size / params.patch_size; - int n_per_side_2d_pool = n_per_side / params.proj_scale_factor; - n_patches = n_per_side_2d_pool * n_per_side_2d_pool; - } else if (ctx->proj_type == PROJECTOR_TYPE_IDEFICS3 || ctx->proj_type == PROJECTOR_TYPE_INTERNVL) { - // both W and H are divided by proj_scale_factor - n_patches /= (params.proj_scale_factor * params.proj_scale_factor); - } else if (ctx->proj_type == PROJECTOR_TYPE_PIXTRAL) { - int n_merge = params.spatial_merge_size; - int n_patches_x = img->nx / params.patch_size / (n_merge > 0 ? n_merge : 1); - int n_patches_y = img->ny / params.patch_size / (n_merge > 0 ? n_merge : 1); - n_patches = n_patches_y*n_patches_x + n_patches_y - 1; // + one [IMG_BREAK] per row, except the last row - } else if (ctx->proj_type == PROJECTOR_TYPE_LLAMA4) { - n_patches /= (scale_factor * scale_factor); - } else if (ctx->proj_type == PROJECTOR_TYPE_ULTRAVOX) { - const int proj_stack_factor = ctx->vision_model.hparams.proj_stack_factor; - const int n_len = CLIP_ALIGN(img->nx, proj_stack_factor); - n_patches = n_len / proj_stack_factor / 2; - } else if (ctx->proj_type == PROJECTOR_TYPE_QWEN2A) { - // divide by 2 because of whisper - // another divide by 2 because of nn.AvgPool1d(2, stride=2) - n_patches = img->nx / 4; - } - - return n_patches; + projector_type proj = ctx->proj_type(); + + switch (proj) { + case PROJECTOR_TYPE_MLP: + case PROJECTOR_TYPE_MLP_NORM: + { + // do nothing + } break; + case PROJECTOR_TYPE_LDP: + case PROJECTOR_TYPE_LDPV2: + case PROJECTOR_TYPE_GLM_EDGE: + { + n_patches_sq /= 4; + if (ctx->model.mm_glm_tok_boi) { + n_patches_sq += 2; // for BOI and EOI token embeddings + } + } break; + case PROJECTOR_TYPE_MINICPMV: + { + if (params.minicpmv_version == 2) { + n_patches_sq = 96; + } else if (params.minicpmv_version == 3) { + n_patches_sq = 64; + } else if (params.minicpmv_version == 4) { + n_patches_sq = 64; + } else { + GGML_ABORT("Unknown minicpmv version"); + } + } break; + case PROJECTOR_TYPE_QWEN2VL: + case PROJECTOR_TYPE_QWEN25VL: + { + // dynamic size + int patch_size = params.patch_size * 2; + int x_patch = img->nx / patch_size + (int)(img->nx % patch_size > 0); + int y_patch = img->ny / patch_size + (int)(img->ny % patch_size > 0); + n_patches_sq = x_patch * y_patch; + } break; + case PROJECTOR_TYPE_GEMMA3: + { + int n_per_side = params.image_size / params.patch_size; + int n_per_side_2d_pool = n_per_side / params.proj_scale_factor; + n_patches_sq = n_per_side_2d_pool * n_per_side_2d_pool; + } break; + case PROJECTOR_TYPE_IDEFICS3: + case PROJECTOR_TYPE_INTERNVL: + { + // both W and H are divided by proj_scale_factor + n_patches_sq /= (params.proj_scale_factor * params.proj_scale_factor); + } break; + case PROJECTOR_TYPE_PIXTRAL: + { + // dynamic size + int n_merge = params.spatial_merge_size; + int n_patches_x = img->nx / params.patch_size / (n_merge > 0 ? n_merge : 1); + int n_patches_y = img->ny / params.patch_size / (n_merge > 0 ? n_merge : 1); + n_patches_sq = n_patches_y * n_patches_x + n_patches_y - 1; // + one [IMG_BREAK] per row, except the last row + } break; + case PROJECTOR_TYPE_LLAMA4: + { + int scale_factor = ctx->model.hparams.proj_scale_factor; + n_patches_sq /= (scale_factor * scale_factor); + } break; + case PROJECTOR_TYPE_ULTRAVOX: + { + const int proj_stack_factor = ctx->model.hparams.proj_stack_factor; + const int n_len = CLIP_ALIGN(img->nx, proj_stack_factor); + n_patches_sq = n_len / proj_stack_factor / 2; + } break; + case PROJECTOR_TYPE_QWEN2A: + { + // divide by 2 because of whisper + // another divide by 2 because of nn.AvgPool1d(2, stride=2) + n_patches_sq = img->nx / 4; + } break; + default: + GGML_ABORT("unsupported projector type"); + } + + return n_patches_sq; } static std::vector>> get_1d_sincos_pos_embed_from_grid_new(int embed_dim, const std::vector> & pos) { @@ -3606,7 +3686,7 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima ggml_backend_sched_alloc_graph(ctx->sched.get(), gf); // set inputs - const auto & model = ctx->vision_model; + const auto & model = ctx->model; const auto & hparams = model.hparams; const int image_size_width = imgs.entries[0]->nx; @@ -3696,7 +3776,7 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima } // set input per projector - switch (ctx->proj_type) { + switch (ctx->model.proj_type) { case PROJECTOR_TYPE_MINICPMV: { // inspired from siglip: @@ -3961,80 +4041,83 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima } int clip_n_mmproj_embd(const struct clip_ctx * ctx) { - switch (ctx->proj_type) { + const auto & hparams = ctx->model.hparams; + switch (ctx->model.proj_type) { case PROJECTOR_TYPE_LDP: - return ctx->vision_model.mm_model_block_1_block_2_1_b->ne[0]; + return ctx->model.mm_model_block_1_block_2_1_b->ne[0]; case PROJECTOR_TYPE_LDPV2: - return ctx->vision_model.mm_model_peg_0_b->ne[0]; + return ctx->model.mm_model_peg_0_b->ne[0]; case PROJECTOR_TYPE_MLP: case PROJECTOR_TYPE_PIXTRAL: - return ctx->vision_model.mm_2_w->ne[1]; + return ctx->model.mm_2_w->ne[1]; case PROJECTOR_TYPE_MLP_NORM: - return ctx->vision_model.mm_3_b->ne[0]; + return ctx->model.mm_3_b->ne[0]; case PROJECTOR_TYPE_MINICPMV: - if (ctx->minicpmv_version == 2) { + if (hparams.minicpmv_version == 2) { return 4096; - } else if (ctx->minicpmv_version == 3) { + } else if (hparams.minicpmv_version == 3) { return 3584; - } else if (ctx->minicpmv_version == 4) { + } else if (hparams.minicpmv_version == 4) { return 3584; } GGML_ABORT("Unknown minicpmv version"); case PROJECTOR_TYPE_GLM_EDGE: - return ctx->vision_model.mm_model_mlp_3_w->ne[1]; + return ctx->model.mm_model_mlp_3_w->ne[1]; case PROJECTOR_TYPE_QWEN2VL: case PROJECTOR_TYPE_QWEN25VL: - return ctx->vision_model.mm_1_b->ne[0]; + return ctx->model.mm_1_b->ne[0]; case PROJECTOR_TYPE_GEMMA3: - return ctx->vision_model.mm_input_proj_w->ne[0]; + return ctx->model.mm_input_proj_w->ne[0]; case PROJECTOR_TYPE_IDEFICS3: - return ctx->vision_model.projection->ne[1]; + return ctx->model.projection->ne[1]; case PROJECTOR_TYPE_ULTRAVOX: - return ctx->vision_model.mm_2_w->ne[1]; + return ctx->model.mm_2_w->ne[1]; case PROJECTOR_TYPE_INTERNVL: - return ctx->vision_model.mm_3_w->ne[1]; + return ctx->model.mm_3_w->ne[1]; case PROJECTOR_TYPE_LLAMA4: - return ctx->vision_model.mm_model_proj->ne[1]; + return ctx->model.mm_model_proj->ne[1]; case PROJECTOR_TYPE_QWEN2A: - return ctx->vision_model.mm_fc_w->ne[1]; + return ctx->model.mm_fc_w->ne[1]; default: GGML_ABORT("Unknown projector type"); } } int clip_is_minicpmv(const struct clip_ctx * ctx) { - if (ctx->proj_type == PROJECTOR_TYPE_MINICPMV) { - return ctx->minicpmv_version; + if (ctx->proj_type() == PROJECTOR_TYPE_MINICPMV) { + return ctx->model.hparams.minicpmv_version; } return 0; } bool clip_is_glm(const struct clip_ctx * ctx) { - return ctx->proj_type == PROJECTOR_TYPE_GLM_EDGE; + return ctx->proj_type() == PROJECTOR_TYPE_GLM_EDGE; } bool clip_is_qwen2vl(const struct clip_ctx * ctx) { - return ctx->proj_type == PROJECTOR_TYPE_QWEN2VL || ctx->proj_type == PROJECTOR_TYPE_QWEN25VL; + return ctx->proj_type() == PROJECTOR_TYPE_QWEN2VL + || ctx->proj_type() == PROJECTOR_TYPE_QWEN25VL; } bool clip_is_llava(const struct clip_ctx * ctx) { - return ctx->has_llava_projector; + return ctx->model.hparams.has_llava_projector; } bool clip_is_gemma3(const struct clip_ctx * ctx) { - return ctx->proj_type == PROJECTOR_TYPE_GEMMA3; + return ctx->proj_type() == PROJECTOR_TYPE_GEMMA3; } bool clip_has_vision_encoder(const struct clip_ctx * ctx) { - return ctx->vision_model.hparams.has_vision; + return ctx->model.modality == CLIP_MODALITY_VISION; } bool clip_has_audio_encoder(const struct clip_ctx * ctx) { - return ctx->vision_model.hparams.has_audio; + return ctx->model.modality == CLIP_MODALITY_AUDIO; } bool clip_has_whisper_encoder(const struct clip_ctx * ctx) { - return ctx->proj_type == PROJECTOR_TYPE_ULTRAVOX || ctx->proj_type == PROJECTOR_TYPE_QWEN2A; + return ctx->proj_type() == PROJECTOR_TYPE_ULTRAVOX + || ctx->proj_type() == PROJECTOR_TYPE_QWEN2A; } bool clip_encode_float_image (struct clip_ctx * ctx, int n_threads, float * img, int h, int w, float * vec) { @@ -4055,7 +4138,7 @@ bool clip_encode_float_image (struct clip_ctx * ctx, int n_threads, float * img, // projector_type clip_get_projector_type(const struct clip_ctx * ctx) { - return ctx->proj_type; + return ctx->proj_type(); } void clip_image_f32_batch_add_mel(struct clip_image_f32_batch * batch, int n_mel, int n_frames, float * mel) { diff --git a/tools/mtmd/clip.h b/tools/mtmd/clip.h index 5abfcd1a3c418..cb2eb261fe2e8 100644 --- a/tools/mtmd/clip.h +++ b/tools/mtmd/clip.h @@ -17,12 +17,22 @@ struct clip_image_f32; struct clip_image_u8_batch; struct clip_image_f32_batch; +enum clip_modality { + CLIP_MODALITY_VISION, + CLIP_MODALITY_AUDIO, +}; + struct clip_context_params { bool use_gpu; enum ggml_log_level verbosity; }; -struct clip_ctx * clip_init(const char * fname, struct clip_context_params ctx_params); +struct clip_init_result { + struct clip_ctx * ctx_v; // vision context + struct clip_ctx * ctx_a; // audio context +}; + +struct clip_init_result clip_init(const char * fname, struct clip_context_params ctx_params); void clip_free(struct clip_ctx * ctx); diff --git a/tools/mtmd/mtmd-cli.cpp b/tools/mtmd/mtmd-cli.cpp index 0f8bb0cdc42dc..a70f11ca9d718 100644 --- a/tools/mtmd/mtmd-cli.cpp +++ b/tools/mtmd/mtmd-cli.cpp @@ -284,7 +284,9 @@ int main(int argc, char ** argv) { if (is_single_turn) { g_is_generating = true; if (params.prompt.find(mtmd_default_marker()) == std::string::npos) { - params.prompt += mtmd_default_marker(); + for (size_t i = 0; i < params.image.size(); i++) { + params.prompt += mtmd_default_marker(); + } } common_chat_msg msg; msg.role = "user"; diff --git a/tools/mtmd/mtmd-helper.cpp b/tools/mtmd/mtmd-helper.cpp index b79094c0a48b6..e6c926080cde3 100644 --- a/tools/mtmd/mtmd-helper.cpp +++ b/tools/mtmd/mtmd-helper.cpp @@ -66,7 +66,8 @@ struct decode_embd_batch { } } - void set_position_mrope(llama_pos pos_0, int nx, int ny, llama_seq_id seq_id) { + // M-RoPE for image + void set_position_mrope_2d(llama_pos pos_0, int nx, int ny, llama_seq_id seq_id) { GGML_ASSERT(n_pos_per_embd == 4); seq_id_0[0] = seq_id; for (int y = 0; y < ny; y++) { @@ -85,6 +86,23 @@ struct decode_embd_batch { } } + // M-RoPE for audio + void set_position_mrope_1d(llama_pos pos_0, llama_seq_id seq_id) { + GGML_ASSERT(n_pos_per_embd == 4); + seq_id_0[0] = seq_id; + for (int i = 0; i < batch.n_tokens; i++) { + pos[i ] = pos_0 + i; + pos[i + batch.n_tokens ] = pos_0 + i; + pos[i + batch.n_tokens * 2] = pos_0 + i; + pos[i + batch.n_tokens * 3] = 0; // last pos dim is unused + } + for (int i = 0; i < batch.n_tokens; i++) { + batch.n_seq_id[i] = 1; + batch.seq_id [i] = seq_id_0.data(); + batch.logits [i] = false; + } + } + llama_batch get_view(int offset, int n_tokens) { llama_pos * pos_ptr; pos_view.clear(); @@ -146,18 +164,20 @@ int32_t mtmd_helper_decode_image_chunk( decode_embd_batch batch_embd(encoded_embd, n_tokens, n_pos_per_embd, n_mmproj_embd); if (mtmd_decode_use_mrope(ctx)) { - const auto image_tokens = mtmd_input_chunk_get_tokens_image(chunk); - if (chunk_type != MTMD_INPUT_CHUNK_TYPE_IMAGE) { - LOG_ERR("failed to decode chunk: M-RoPE only accepts image chunk\n"); - return -1; - } - if (!image_tokens) { - LOG_ERR("failed to decode chunk: image tokens are null\n"); - return -1; + if (chunk_type == MTMD_INPUT_CHUNK_TYPE_IMAGE) { + const auto image_tokens = mtmd_input_chunk_get_tokens_image(chunk); + if (!image_tokens) { + LOG_ERR("failed to decode chunk: image tokens are null\n"); + return -1; + } + const int nx = mtmd_image_tokens_get_nx(image_tokens); + const int ny = mtmd_image_tokens_get_ny(image_tokens); + batch_embd.set_position_mrope_2d(n_past, nx, ny, seq_id); + } else if (chunk_type == MTMD_INPUT_CHUNK_TYPE_AUDIO) { + batch_embd.set_position_mrope_1d(n_past, seq_id); + } else { + GGML_ABORT("invalid chunk type for M-RoPE"); } - const int nx = mtmd_image_tokens_get_nx(image_tokens); - const int ny = mtmd_image_tokens_get_ny(image_tokens); - batch_embd.set_position_mrope(n_past, nx, ny, seq_id); } else { batch_embd.set_position_normal(n_past, seq_id); } diff --git a/tools/mtmd/mtmd.cpp b/tools/mtmd/mtmd.cpp index c3be91265f331..52bf71e2c9dc0 100644 --- a/tools/mtmd/mtmd.cpp +++ b/tools/mtmd/mtmd.cpp @@ -95,15 +95,21 @@ mtmd_context_params mtmd_context_params_default() { } struct mtmd_context { - struct clip_ctx * ctx_clip; + struct clip_ctx * ctx_v; // vision + struct clip_ctx * ctx_a; // audio const struct llama_model * text_model; std::vector image_embd_v; // image embedding vector bool print_timings; int n_threads; std::string media_marker; - bool has_vision; - bool has_audio; + const int n_embd_text; + + // these are not token, but strings used to mark the beginning and end of image/audio embeddings + std::string img_beg; + std::string img_end; + std::string aud_beg; + std::string aud_end; // for llava-uhd style models, we need special tokens in-between slices // minicpmv calls them "slices", llama 4 calls them "tiles" @@ -132,33 +138,61 @@ struct mtmd_context { text_model (text_model), print_timings(ctx_params.print_timings), n_threads (ctx_params.n_threads), - media_marker (ctx_params.media_marker) + media_marker (ctx_params.media_marker), + n_embd_text (llama_model_n_embd(text_model)) { if (std::string(ctx_params.image_marker) != MTMD_DEFAULT_IMAGE_MARKER) { throw std::runtime_error("custom image_marker is not supported anymore, use media_marker instead"); } + if (media_marker.empty()) { + throw std::runtime_error("media_marker must not be empty"); + } + clip_context_params ctx_clip_params; ctx_clip_params.use_gpu = ctx_params.use_gpu; ctx_clip_params.verbosity = ctx_params.verbosity; - ctx_clip = clip_init(mmproj_fname, ctx_clip_params); - if (!ctx_clip) { + auto res = clip_init(mmproj_fname, ctx_clip_params); + ctx_v = res.ctx_v; + ctx_a = res.ctx_a; + if (!ctx_v && !ctx_a) { throw std::runtime_error(string_format("Failed to load CLIP model from %s\n", mmproj_fname)); } - if (llama_model_n_embd(text_model) != clip_n_mmproj_embd(ctx_clip)) { + // if both vision and audio mmproj are present, we need to validate their n_embd + if (ctx_v && ctx_a) { + int n_embd_v = clip_n_mmproj_embd(ctx_v); + int n_embd_a = clip_n_mmproj_embd(ctx_a); + if (n_embd_v != n_embd_a) { + throw std::runtime_error(string_format( + "mismatch between vision and audio mmproj (n_embd_v = %d, n_embd_a = %d)\n", + n_embd_v, n_embd_a)); + } + } + + // since we already validate n_embd of vision and audio mmproj, + // we can safely assume that they are the same + int n_embd_clip = clip_n_mmproj_embd(ctx_v ? ctx_v : ctx_a); + if (n_embd_text != n_embd_clip) { throw std::runtime_error(string_format( "mismatch between text model (n_embd = %d) and mmproj (n_embd = %d)\n" "hint: you may be using wrong mmproj\n", - llama_model_n_embd(text_model), clip_n_mmproj_embd(ctx_clip))); + n_embd_text, n_embd_clip)); + } + if (ctx_v) { + init_vision(); } + if (ctx_a) { + init_audio(); + } + } - has_vision = clip_has_vision_encoder(ctx_clip); - has_audio = clip_has_audio_encoder(ctx_clip); - use_mrope = clip_is_qwen2vl(ctx_clip); + void init_vision() { + GGML_ASSERT(ctx_v != nullptr); + use_mrope = clip_is_qwen2vl(ctx_v); - projector_type proj = clip_get_projector_type(ctx_clip); - int minicpmv_version = clip_is_minicpmv(ctx_clip); + projector_type proj = clip_get_projector_type(ctx_v); + int minicpmv_version = clip_is_minicpmv(ctx_v); if (minicpmv_version == 2) { // minicpmv 2.5 format: // (overview) (slice) (slice) \n ... @@ -203,24 +237,82 @@ struct mtmd_context { ov_img_first = false; // overview image is last } - if (clip_has_whisper_encoder(ctx_clip)) { + // set boi/eoi + if (proj == PROJECTOR_TYPE_GEMMA3) { + // ... (image embeddings) ... + img_beg = ""; + img_end = ""; + + } else if (proj == PROJECTOR_TYPE_IDEFICS3) { + // https://github.com/huggingface/transformers/blob/a42ba80fa520c784c8f11a973ca9034e5f859b79/src/transformers/models/idefics3/processing_idefics3.py#L192-L215 + img_beg = ""; + img_end = ""; + + } else if (proj == PROJECTOR_TYPE_PIXTRAL) { + // https://github.com/huggingface/transformers/blob/1cd110c6cb6a6237614130c470e9a902dbc1a4bd/docs/source/en/model_doc/pixtral.md + img_end = "[IMG_END]"; + + } else if (proj == PROJECTOR_TYPE_QWEN2VL || proj == PROJECTOR_TYPE_QWEN25VL) { + // <|vision_start|> ... (image embeddings) ... <|vision_end|> + img_beg = "<|vision_start|>"; + img_end = "<|vision_end|>"; + + } else if (proj == PROJECTOR_TYPE_LLAMA4) { + // (more details in mtmd_context constructor) + img_beg = "<|image_start|>"; + img_end = "<|image_end|>"; + LOG_WRN("%s: llama 4 vision is known to have degraded quality:\n" + " https://github.com/ggml-org/llama.cpp/pull/13282\n", __func__); + + } else if (proj == PROJECTOR_TYPE_INTERNVL) { + // ... (image embeddings) ... + img_beg = ""; + img_end = ""; + + } + } + + void init_audio() { + GGML_ASSERT(ctx_a != nullptr); + projector_type proj = clip_get_projector_type(ctx_a); + + if (clip_has_whisper_encoder(ctx_a)) { // TODO @ngxson : check if model n_mel is 128 or 80 w_filters = whisper_precalc_filters::get_128_bins(); } - // warning messages - if (proj == PROJECTOR_TYPE_LLAMA4) { - LOG_WRN("%s: llama 4 vision is known to have degraded quality:\n" - " https://github.com/ggml-org/llama.cpp/pull/13282\n", __func__); + LOG_WRN("%s: audio input is in experimental stage and may have reduced quality:\n" + " https://github.com/ggml-org/llama.cpp/discussions/13759\n", __func__); + + if (proj == PROJECTOR_TYPE_QWEN2A) { + // <|audio_bos|> ... (embeddings) ... <|audio_eos|> + aud_beg = "<|audio_bos|>"; + aud_end = "<|audio_eos|>"; + } - if (has_audio) { - LOG_WRN("%s: audio input is in experimental stage and may have reduced quality:\n" - " https://github.com/ggml-org/llama.cpp/discussions/13759\n", __func__); + } + + // get clip ctx based on chunk type + clip_ctx * get_clip_ctx(const mtmd_input_chunk * chunk) const { + if (chunk->type == MTMD_INPUT_CHUNK_TYPE_IMAGE) { + return ctx_v; + } else if (chunk->type == MTMD_INPUT_CHUNK_TYPE_AUDIO) { + return ctx_a; } + GGML_ABORT("unknown chunk type"); + } + + projector_type proj_type_v() const { + return ctx_v ? clip_get_projector_type(ctx_v) : PROJECTOR_TYPE_UNKNOWN; + } + + projector_type proj_type_a() const { + return ctx_a ? clip_get_projector_type(ctx_a) : PROJECTOR_TYPE_UNKNOWN; } ~mtmd_context() { - clip_free(ctx_clip); + clip_free(ctx_a); + clip_free(ctx_v); } private: @@ -267,167 +359,137 @@ void mtmd_free(mtmd_context * ctx) { } } -// copied from common_tokenize -static std::vector mtmd_tokenize_text_internal( - const struct llama_vocab * vocab, - const std::string & text, - bool add_special, - bool parse_special) { - // upper limit for the number of tokens - int n_tokens = text.length() + 2 * add_special; - std::vector result(n_tokens); - n_tokens = llama_tokenize(vocab, text.data(), text.length(), result.data(), result.size(), add_special, parse_special); - if (n_tokens < 0) { - result.resize(-n_tokens); - int check = llama_tokenize(vocab, text.data(), text.length(), result.data(), result.size(), add_special, parse_special); - GGML_ASSERT(check == -n_tokens); - } else { - result.resize(n_tokens); - } - return result; -} +struct mtmd_tokenizer { + mtmd_context * ctx; + std::vector bitmaps; -int32_t mtmd_tokenize(mtmd_context * ctx, - mtmd_input_chunks * output, + std::string input_text; + bool add_special; + bool parse_special; + const llama_vocab * vocab; + + mtmd_input_chunks cur; + + mtmd_tokenizer(mtmd_context * ctx, const mtmd_input_text * text, const mtmd_bitmap ** bitmaps, - size_t n_bitmaps) { - auto vocab = llama_model_get_vocab(ctx->text_model); - - std::string prompt_modified(text->text); - std::string marker_modified(ctx->media_marker); - projector_type proj_type = clip_get_projector_type(ctx->ctx_clip); - - // for compatibility, we convert image marker to media marker - string_replace_all(prompt_modified, MTMD_DEFAULT_IMAGE_MARKER, ctx->media_marker); - - // a bit hacky here, but works for now - // for some models, we need to add prefix and suffix to the image embeddings - if (clip_is_gemma3(ctx->ctx_clip)) { - // gemma 3 - // ... (image embeddings) ... - marker_modified = "" + ctx->media_marker + ""; - string_replace_all(prompt_modified, ctx->media_marker, marker_modified); - - } else if (proj_type == PROJECTOR_TYPE_IDEFICS3) { - // https://github.com/huggingface/transformers/blob/a42ba80fa520c784c8f11a973ca9034e5f859b79/src/transformers/models/idefics3/processing_idefics3.py#L192-L215 - marker_modified = "" + ctx->media_marker + ""; - string_replace_all(prompt_modified, ctx->media_marker, marker_modified); - - } else if (proj_type == PROJECTOR_TYPE_PIXTRAL) { - // https://github.com/huggingface/transformers/blob/1cd110c6cb6a6237614130c470e9a902dbc1a4bd/docs/source/en/model_doc/pixtral.md - marker_modified = ctx->media_marker + "[IMG_END]"; - string_replace_all(prompt_modified, ctx->media_marker, marker_modified); - - } else if (proj_type == PROJECTOR_TYPE_QWEN2VL || proj_type == PROJECTOR_TYPE_QWEN25VL) { - // <|vision_start|> ... (image embeddings) ... <|vision_end|> - marker_modified = "<|vision_start|>" + ctx->media_marker + "<|vision_end|>"; - string_replace_all(prompt_modified, ctx->media_marker, marker_modified); - - } else if (proj_type == PROJECTOR_TYPE_LLAMA4) { - // (more details in mtmd_context constructor) - marker_modified = "<|image_start|>" + ctx->media_marker + "<|image_end|>"; - string_replace_all(prompt_modified, ctx->media_marker, marker_modified); - - } else if (proj_type == PROJECTOR_TYPE_INTERNVL) { - // ... (image embeddings) ... - marker_modified = "" + ctx->media_marker + ""; - string_replace_all(prompt_modified, ctx->media_marker, marker_modified); - - } else if (proj_type == PROJECTOR_TYPE_QWEN2A) { - // <|audio_bos|> ... (embeddings) ... <|audio_eos|> - marker_modified = "<|audio_bos|>" + ctx->media_marker + "<|audio_eos|>"; - string_replace_all(prompt_modified, ctx->media_marker, marker_modified); - - } - - // llava-1.5, llava-1.6, Yi-VL, Yi-34B, granite: don't need to add prefix and suffix - // for glm-edge, BOI and EOI token's embeddings are not present in the text model - - std::vector parts = string_split_str(prompt_modified, ctx->media_marker); - output->entries.clear(); - output->entries.reserve(parts.size()); - - size_t i_bm = 0; - - // utility for adding raw tokens - auto add_text_chunk = [&output](std::vector && tokens) { - mtmd_input_chunk chunk{ - MTMD_INPUT_CHUNK_TYPE_TEXT, - std::move(tokens), - nullptr, // image tokens - nullptr, // audio tokens - }; - output->entries.emplace_back(std::move(chunk)); - }; + size_t n_bitmaps) : ctx(ctx), bitmaps(bitmaps, bitmaps + n_bitmaps) { + add_special = text->add_special; + parse_special = text->parse_special; + input_text = text->text; + vocab = llama_model_get_vocab(ctx->text_model); + + // for compatibility, we convert image marker to media marker + string_replace_all(input_text, MTMD_DEFAULT_IMAGE_MARKER, ctx->media_marker); + } - // utility for splitting batch of multiple images into chunks of batch having single images - auto split_batch_to_chunk = [&ctx](clip_image_f32_batch && batch_f32, const std::string & id) { - std::vector chunks; + int32_t tokenize(mtmd_input_chunks * output) { + cur.entries.clear(); + std::vector parts = split_text(input_text, ctx->media_marker); + size_t i_bm = 0; // index of the current bitmap + for (auto & part : parts) { + if (part == ctx->media_marker) { + // this is a marker, we should add the next bitmap + if (i_bm >= bitmaps.size()) { + LOG_ERR("%s: error: number of bitmaps (%zu) does not match number of markers (%zu)\n", + __func__, bitmaps.size(), parts.size() - 1); + return 1; + } + const mtmd_bitmap * bitmap = bitmaps[i_bm++]; + int32_t res = add_media(bitmap); + if (res != 0) { + return res; + } + } else { + // this is a text part, we should add it as text + add_text(part, parse_special); + } + } - for (auto & entry : batch_f32.entries) { - mtmd_image_tokens_ptr image_tokens(new mtmd_image_tokens); - image_tokens->nx = clip_n_output_tokens(ctx->ctx_clip, entry.get()); - image_tokens->ny = 1; - image_tokens->batch_f32.entries.push_back(std::move(entry)); - image_tokens->id = id; + if (add_special && llama_vocab_get_add_bos(vocab)) { + // if first chunk is text, we add BOS token to first text chunk + // otherwise, create a new text chunk with BOS token + if (!cur.entries.empty() && cur.entries[0].type == MTMD_INPUT_CHUNK_TYPE_TEXT) { + // add BOS token to the beginning of first text chunk + cur.entries[0].tokens_text.insert(cur.entries[0].tokens_text.begin(), llama_vocab_bos(vocab)); + } else { + // create a new text chunk with BOS token at the beginning + mtmd_input_chunk bos_chunk{ + MTMD_INPUT_CHUNK_TYPE_TEXT, + {llama_vocab_bos(vocab)}, + nullptr, // image tokens + nullptr, // audio tokens + }; + cur.entries.insert(cur.entries.begin(), std::move(bos_chunk)); + } + } - mtmd_input_chunk chunk{ - MTMD_INPUT_CHUNK_TYPE_IMAGE, - {}, // text tokens - std::move(image_tokens), - nullptr, // audio tokens - }; - chunks.emplace_back(std::move(chunk)); + if (add_special && llama_vocab_get_add_eos(vocab)) { + // if last chunk is text, we add EOS token to it + add_text({llama_vocab_eos(vocab)}); } - return chunks; - }; + if (i_bm != bitmaps.size()) { + LOG_ERR("%s: error: number of bitmaps (%zu) does not match number of markers (%zu)\n", + __func__, bitmaps.size(), parts.size() - 1); + return 1; + } + + *output = std::move(cur); + + return 0; + } + + void add_text(const std::string & txt, bool parse_special) { + LOG_DBG("%s: %s\n", __func__, txt.c_str()); + auto tokens = mtmd_tokenize_text_internal(vocab, txt, /* add_special */ false, parse_special); + add_text(tokens); + } - for (const auto & part : parts) { - // printf("tokenizing part: %s\n", part.c_str()); - bool add_bos = &parts.front() == ∂ - auto tokens = mtmd_tokenize_text_internal(vocab, part, text->add_special && add_bos, text->parse_special); + void add_text(const std::vector & tokens) { if (tokens.empty()) { - continue; + return; } - mtmd_input_chunk chunk{ - MTMD_INPUT_CHUNK_TYPE_TEXT, - std::move(tokens), - nullptr, // image tokens - nullptr, // audio tokens - }; - output->entries.emplace_back(std::move(chunk)); - - // only add image/audio tokens to middle of 2 parts - // therefore, we skip handling image/audio if this is the last part - if (&parts.back() == &part) { - continue; + // if last entry is also a text chunk, add tokens to it instead of creating new chunk + if (!cur.entries.empty() && cur.entries.back().type == MTMD_INPUT_CHUNK_TYPE_TEXT) { + cur.entries.back().tokens_text.insert( + cur.entries.back().tokens_text.end(), + tokens.begin(), + tokens.end()); + } else { + mtmd_input_chunk chunk{ + MTMD_INPUT_CHUNK_TYPE_TEXT, + tokens, + nullptr, // image tokens + nullptr, // audio tokens + }; + cur.entries.emplace_back(std::move(chunk)); } + } - if (!bitmaps[i_bm]->is_audio) { + int32_t add_media(const mtmd_bitmap * bitmap) { + if (!bitmap->is_audio) { // handle image - if (i_bm >= n_bitmaps) { - LOG_ERR("%s: error: not enough images for %d parts\n", __func__, (int)parts.size()); - return 1; - } - - if (!ctx->has_vision) { + if (!ctx->ctx_v) { LOG_ERR("%s: error: model does not support vision input\n", __func__); return 2; } + if (!ctx->img_beg.empty()) { + add_text(ctx->img_beg, true); // add image begin token + } + // convert mtmd_bitmap to clip_image_u8 clip_image_u8_ptr img_u8(clip_image_u8_init()); - img_u8->nx = bitmaps[i_bm]->nx; - img_u8->ny = bitmaps[i_bm]->ny; - img_u8->buf.resize(bitmaps[i_bm]->data.size()); - std::memcpy(img_u8->buf.data(), bitmaps[i_bm]->data.data(), img_u8->nx * img_u8->ny * 3); + img_u8->nx = bitmap->nx; + img_u8->ny = bitmap->ny; + img_u8->buf.resize(bitmap->data.size()); + std::memcpy(img_u8->buf.data(), bitmap->data.data(), img_u8->nx * img_u8->ny * 3); // preprocess image clip_image_f32_batch batch_f32; - bool ok = clip_image_preprocess(ctx->ctx_clip, img_u8.get(), &batch_f32); + bool ok = clip_image_preprocess(ctx->ctx_v, img_u8.get(), &batch_f32); if (!ok) { LOG_ERR("Unable to preprocess image\n"); return 2; @@ -440,7 +502,7 @@ int32_t mtmd_tokenize(mtmd_context * ctx, || ctx->slice_tmpl == MTMD_SLICE_TMPL_LLAMA4 ) { // split batch into chunks of single images - auto chunks = split_batch_to_chunk(std::move(batch_f32), bitmaps[i_bm]->id); + auto chunks = split_batch_to_chunk(std::move(batch_f32), bitmap->id); GGML_ASSERT(chunks.size() > 0); auto ov_chunk = std::move(chunks.front()); @@ -449,11 +511,11 @@ int32_t mtmd_tokenize(mtmd_context * ctx, // add overview image (first) if (ctx->ov_img_first) { if (ctx->tok_ov_img_start != LLAMA_TOKEN_NULL) { - add_text_chunk({ctx->tok_ov_img_start}); + add_text({ctx->tok_ov_img_start}); } - output->entries.emplace_back(std::move(ov_chunk)); + cur.entries.emplace_back(std::move(ov_chunk)); if (ctx->tok_ov_img_end != LLAMA_TOKEN_NULL) { - add_text_chunk({ctx->tok_ov_img_end}); + add_text({ctx->tok_ov_img_end}); } } @@ -462,53 +524,53 @@ int32_t mtmd_tokenize(mtmd_context * ctx, const int n_col = batch_f32.grid_x; const int n_row = batch_f32.grid_y; if (ctx->tok_slices_start != LLAMA_TOKEN_NULL) { - add_text_chunk({ctx->tok_slices_start}); + add_text({ctx->tok_slices_start}); } for (int y = 0; y < n_row; y++) { for (int x = 0; x < n_col; x++) { const bool is_last_in_row = (x == n_col - 1); if (ctx->tok_sli_img_start != LLAMA_TOKEN_NULL) { - add_text_chunk({ctx->tok_sli_img_start}); + add_text({ctx->tok_sli_img_start}); } - output->entries.emplace_back(std::move(chunks[y * n_col + x])); + cur.entries.emplace_back(std::move(chunks[y * n_col + x])); if (ctx->tok_sli_img_end != LLAMA_TOKEN_NULL) { - add_text_chunk({ctx->tok_sli_img_end}); + add_text({ctx->tok_sli_img_end}); } if (!is_last_in_row && ctx->tok_sli_img_mid != LLAMA_TOKEN_NULL) { - add_text_chunk({ctx->tok_sli_img_mid}); + add_text({ctx->tok_sli_img_mid}); } } if ((y != n_row - 1 || ctx->tok_row_end_trail) && ctx->tok_row_end != LLAMA_TOKEN_NULL) { - add_text_chunk({ctx->tok_row_end}); + add_text({ctx->tok_row_end}); } } if (ctx->tok_slices_end != LLAMA_TOKEN_NULL) { - add_text_chunk({ctx->tok_slices_end}); + add_text({ctx->tok_slices_end}); } } // add overview image (last) if (!ctx->ov_img_first) { if (ctx->tok_ov_img_start != LLAMA_TOKEN_NULL) { - add_text_chunk({ctx->tok_ov_img_start}); + add_text({ctx->tok_ov_img_start}); } - output->entries.emplace_back(std::move(ov_chunk)); + cur.entries.emplace_back(std::move(ov_chunk)); if (ctx->tok_ov_img_end != LLAMA_TOKEN_NULL) { - add_text_chunk({ctx->tok_ov_img_end}); + add_text({ctx->tok_ov_img_end}); } } } else { size_t n_tokens = 0; for (const auto & entry : batch_f32.entries) { - n_tokens += clip_n_output_tokens(ctx->ctx_clip, entry.get()); + n_tokens += clip_n_output_tokens(ctx->ctx_v, entry.get()); } mtmd_image_tokens_ptr image_tokens(new mtmd_image_tokens); if (ctx->use_mrope) { // for Qwen2VL, we need this information for M-RoPE decoding positions - image_tokens->nx = clip_n_output_tokens_x(ctx->ctx_clip, batch_f32.entries[0].get()); - image_tokens->ny = clip_n_output_tokens_y(ctx->ctx_clip, batch_f32.entries[0].get()); + image_tokens->nx = clip_n_output_tokens_x(ctx->ctx_v, batch_f32.entries[0].get()); + image_tokens->ny = clip_n_output_tokens_y(ctx->ctx_v, batch_f32.entries[0].get()); image_tokens->use_mrope_pos = true; } else { // other models, we only need the total number of tokens @@ -516,7 +578,7 @@ int32_t mtmd_tokenize(mtmd_context * ctx, image_tokens->ny = 1; } image_tokens->batch_f32 = std::move(batch_f32); - image_tokens->id = bitmaps[i_bm]->id; // optional + image_tokens->id = bitmap->id; // optional LOG_DBG("image_tokens->nx = %d\n", image_tokens->nx); LOG_DBG("image_tokens->ny = %d\n", image_tokens->ny); @@ -528,35 +590,35 @@ int32_t mtmd_tokenize(mtmd_context * ctx, std::move(image_tokens), nullptr, // audio tokens }; - output->entries.emplace_back(std::move(chunk)); + cur.entries.emplace_back(std::move(chunk)); } - i_bm++; // move to next image - continue; + if (!ctx->img_end.empty()) { + add_text(ctx->img_end, true); // add image end token + } } else { // handle audio - if (i_bm >= n_bitmaps) { - LOG_ERR("%s: error: not enough images for %d parts\n", __func__, (int)parts.size()); - return 1; - } - - if (!ctx->has_audio) { + if (!ctx->ctx_a) { LOG_ERR("%s: error: model does not support audio input\n", __func__); return 2; } - if (bitmaps[i_bm]->data.size() == 0) { + if (bitmap->data.size() == 0) { LOG_ERR("%s: error: empty audio data\n", __func__); return 2; } + if (!ctx->aud_beg.empty()) { + add_text(ctx->aud_beg, true); // add audio begin token + } + // preprocess audio GGML_ASSERT(ctx->w_filters.n_mel); // make sure we have filter preloaded std::vector mel_spec_chunks; - const float * samples = (const float *)bitmaps[i_bm]->data.data(); - size_t n_samples = bitmaps[i_bm]->data.size() / sizeof(float); + const float * samples = (const float *)bitmap->data.data(); + size_t n_samples = bitmap->data.size() / sizeof(float); bool ok = whisper_preprocessor::preprocess_audio(samples, n_samples, ctx->w_filters, mel_spec_chunks); if (!ok) { LOG_ERR("Unable to preprocess audio\n"); @@ -570,7 +632,7 @@ int32_t mtmd_tokenize(mtmd_context * ctx, mel_f32->nx = mel_spec.n_len; mel_f32->ny = mel_spec.n_mel; mel_f32->buf = std::move(mel_spec.data); - size_t n_tokens = clip_n_output_tokens(ctx->ctx_clip, mel_f32.get()); + size_t n_tokens = clip_n_output_tokens(ctx->ctx_a, mel_f32.get()); clip_image_f32_batch batch_f32; batch_f32.is_audio = true; @@ -579,7 +641,7 @@ int32_t mtmd_tokenize(mtmd_context * ctx, mtmd_audio_tokens_ptr audio_tokens(new mtmd_audio_tokens); audio_tokens->n_tokens = n_tokens; audio_tokens->batch_f32 = std::move(batch_f32); - audio_tokens->id = bitmaps[i_bm]->id; // optional + audio_tokens->id = bitmap->id; // optional LOG_DBG("audio_tokens->n_tokens = %d\n", audio_tokens->n_tokens); @@ -589,15 +651,88 @@ int32_t mtmd_tokenize(mtmd_context * ctx, nullptr, // image tokens std::move(audio_tokens), }; - output->entries.emplace_back(std::move(chunk)); + cur.entries.emplace_back(std::move(chunk)); } - i_bm++; - continue; + if (!ctx->aud_end.empty()) { + add_text(ctx->aud_end, true); // add audio end token + } } + + return 0; } - return 0; + std::vector split_batch_to_chunk(clip_image_f32_batch && batch_f32, const std::string & id) { + std::vector chunks; + + for (auto & entry : batch_f32.entries) { + mtmd_image_tokens_ptr image_tokens(new mtmd_image_tokens); + image_tokens->nx = clip_n_output_tokens(ctx->ctx_v, entry.get()); + image_tokens->ny = 1; + image_tokens->batch_f32.entries.push_back(std::move(entry)); + image_tokens->id = id; + + mtmd_input_chunk chunk{ + MTMD_INPUT_CHUNK_TYPE_IMAGE, + {}, // text tokens + std::move(image_tokens), + nullptr, // audio tokens + }; + chunks.emplace_back(std::move(chunk)); + } + + return chunks; + } + + // for example: "a <__media__> b <__media__> c" --> "a", "<__media__>", "b", "<__media__>", "c" + static std::vector split_text(const std::string & input, const std::string & delimiter) { + std::vector result; + if (input.empty()) { + return result; + } + size_t start = 0; + size_t pos = 0; + while ((pos = input.find(delimiter, start)) != std::string::npos) { + if (pos > start) { + result.push_back(input.substr(start, pos - start)); + } + result.push_back(delimiter); + start = pos + delimiter.length(); + } + if (start < input.length()) { + result.push_back(input.substr(start)); + } + return result; + } + + // copied from common_tokenize + static std::vector mtmd_tokenize_text_internal( + const struct llama_vocab * vocab, + const std::string & text, + bool add_special, + bool parse_special) { + // upper limit for the number of tokens + int n_tokens = text.length() + 2 * add_special; + std::vector result(n_tokens); + n_tokens = llama_tokenize(vocab, text.data(), text.length(), result.data(), result.size(), add_special, parse_special); + if (n_tokens < 0) { + result.resize(-n_tokens); + int check = llama_tokenize(vocab, text.data(), text.length(), result.data(), result.size(), add_special, parse_special); + GGML_ASSERT(check == -n_tokens); + } else { + result.resize(n_tokens); + } + return result; + } +}; + +int32_t mtmd_tokenize(mtmd_context * ctx, + mtmd_input_chunks * output, + const mtmd_input_text * text, + const mtmd_bitmap ** bitmaps, + size_t n_bitmaps) { + mtmd_tokenizer tokenizer(ctx, text, bitmaps, n_bitmaps); + return tokenizer.tokenize(output); } int32_t mtmd_encode_chunk(mtmd_context * ctx, const mtmd_input_chunk * chunk) { @@ -605,41 +740,54 @@ int32_t mtmd_encode_chunk(mtmd_context * ctx, const mtmd_input_chunk * chunk) { LOG_WRN("mtmd_encode_chunk has no effect for text chunks\n"); return 0; } else if (chunk->type == MTMD_INPUT_CHUNK_TYPE_IMAGE) { + if (!ctx->ctx_v) { + LOG_ERR("%s: model does not support vision input\n", __func__); + return 1; + } return mtmd_encode(ctx, chunk->tokens_image.get()); } else if (chunk->type == MTMD_INPUT_CHUNK_TYPE_AUDIO) { - int n_mmproj_embd = clip_n_mmproj_embd(ctx->ctx_clip); + if (!ctx->ctx_a) { + LOG_ERR("%s: model does not support audio input\n", __func__); + return 1; + } + int n_mmproj_embd = ctx->n_embd_text; ctx->image_embd_v.resize(chunk->tokens_audio->n_tokens * n_mmproj_embd); bool ok = clip_image_batch_encode( - ctx->ctx_clip, + ctx->ctx_a, ctx->n_threads, &chunk->tokens_audio->batch_f32, ctx->image_embd_v.data()); return ok ? 0 : 1; } - LOG_ERR("mtmd_encode_chunk: unknown chunk type %d\n", (int)chunk->type); + LOG_ERR("%s: unknown chunk type %d\n", __func__, (int)chunk->type); return 1; } int32_t mtmd_encode(mtmd_context * ctx, const mtmd_image_tokens * image_tokens) { - int n_mmproj_embd = clip_n_mmproj_embd(ctx->ctx_clip); + clip_ctx * ctx_clip = ctx->ctx_v; + if (!ctx_clip) { + LOG_ERR("%s: this API does not support non-vision input, please use mtmd_encode_chunk instead\n", __func__); + return 1; + } + int n_mmproj_embd = clip_n_mmproj_embd(ctx_clip); ctx->image_embd_v.resize(image_tokens->n_tokens() * n_mmproj_embd); bool ok = false; - if (clip_is_llava(ctx->ctx_clip) || clip_is_minicpmv(ctx->ctx_clip) || clip_is_glm(ctx->ctx_clip)) { + if (clip_is_llava(ctx_clip) || clip_is_minicpmv(ctx_clip) || clip_is_glm(ctx_clip)) { // TODO @ngxson : llava does not support batched encoding ; this should be fixed inside clip_image_batch_encode() const auto & entries = image_tokens->batch_f32.entries; for (size_t i = 0; i < entries.size(); i++) { - int n_tokens_per_image = clip_n_output_tokens(ctx->ctx_clip, entries[i].get()); + int n_tokens_per_image = clip_n_output_tokens(ctx_clip, entries[i].get()); ok = clip_image_encode( - ctx->ctx_clip, + ctx_clip, ctx->n_threads, entries[i].get(), ctx->image_embd_v.data() + i*n_mmproj_embd*n_tokens_per_image); } } else { ok = clip_image_batch_encode( - ctx->ctx_clip, + ctx_clip, ctx->n_threads, &image_tokens->batch_f32, ctx->image_embd_v.data()); @@ -653,8 +801,7 @@ float * mtmd_get_output_embd(mtmd_context * ctx) { } bool mtmd_decode_use_non_causal(mtmd_context * ctx) { - projector_type proj_type = clip_get_projector_type(ctx->ctx_clip); - if (proj_type == PROJECTOR_TYPE_GEMMA3) { + if (ctx->ctx_v && clip_get_projector_type(ctx->ctx_v) == PROJECTOR_TYPE_GEMMA3) { return true; } return false; @@ -665,11 +812,11 @@ bool mtmd_decode_use_mrope(mtmd_context * ctx) { } bool mtmd_support_vision(mtmd_context * ctx) { - return ctx->has_vision; + return ctx->ctx_v != nullptr; } bool mtmd_support_audio(mtmd_context * ctx) { - return ctx->has_audio; + return ctx->ctx_a != nullptr; } // these 2 helpers below use internal clip_image_u8_ptr, diff --git a/tools/mtmd/test-2.mp3 b/tools/mtmd/test-2.mp3 new file mode 100644 index 0000000000000..aa9d7ec2c1dde Binary files /dev/null and b/tools/mtmd/test-2.mp3 differ diff --git a/tools/mtmd/tests.sh b/tools/mtmd/tests.sh index 15a37b0d22bb4..aa0019893283e 100755 --- a/tools/mtmd/tests.sh +++ b/tools/mtmd/tests.sh @@ -25,80 +25,99 @@ RUN_HUGE_TESTS=false if [ "${1:-}" = "huge" ]; then RUN_HUGE_TESTS=true RUN_BIG_TESTS=true - echo "Include BIG models..." + echo "Include BIG and HUGE models..." fi ############### -arr_bin=() +arr_prefix=() arr_hf=() arr_tmpl=() # chat template +arr_file=() -add_test() { - local bin=$1 - local hf=$2 - local tmpl=${3:-""} # default to empty string if not provided - arr_bin+=("$bin") +add_test_vision() { + local hf=$1 + local tmpl=${2:-""} # default to empty string if not provided + arr_prefix+=("[vision]") arr_hf+=("$hf") arr_tmpl+=("$tmpl") + arr_file+=("test-1.jpeg") +} + +add_test_audio() { + local hf=$1 + arr_prefix+=("[audio] ") + arr_hf+=("$hf") + arr_tmpl+=("") # no need for chat tmpl + arr_file+=("test-2.mp3") } -add_test "llama-mtmd-cli" "ggml-org/SmolVLM-500M-Instruct-GGUF:Q8_0" -add_test "llama-mtmd-cli" "ggml-org/SmolVLM2-2.2B-Instruct-GGUF:Q4_K_M" -add_test "llama-mtmd-cli" "ggml-org/SmolVLM2-500M-Video-Instruct-GGUF:Q8_0" -add_test "llama-mtmd-cli" "ggml-org/gemma-3-4b-it-GGUF:Q4_K_M" -add_test "llama-mtmd-cli" "THUDM/glm-edge-v-5b-gguf:Q4_K_M" -add_test "llama-mtmd-cli" "second-state/Llava-v1.5-7B-GGUF:Q2_K" "vicuna" -add_test "llama-mtmd-cli" "cjpais/llava-1.6-mistral-7b-gguf:Q3_K_M" "vicuna" -add_test "llama-mtmd-cli" "ibm-research/granite-vision-3.2-2b-GGUF:Q4_K_M" -add_test "llama-mtmd-cli" "second-state/MiniCPM-Llama3-V-2_5-GGUF:Q2_K" # model from openbmb is corrupted -add_test "llama-mtmd-cli" "openbmb/MiniCPM-V-2_6-gguf:Q2_K" -add_test "llama-mtmd-cli" "openbmb/MiniCPM-o-2_6-gguf:Q4_0" -add_test "llama-mtmd-cli" "bartowski/Qwen2-VL-2B-Instruct-GGUF:Q4_K_M" -add_test "llama-mtmd-cli" "ggml-org/Qwen2.5-VL-3B-Instruct-GGUF:Q4_K_M" -add_test "llama-mtmd-cli" "ggml-org/InternVL2_5-1B-GGUF:Q8_0" -add_test "llama-mtmd-cli" "ggml-org/InternVL3-1B-Instruct-GGUF:Q8_0" +add_test_vision "ggml-org/SmolVLM-500M-Instruct-GGUF:Q8_0" +add_test_vision "ggml-org/SmolVLM2-2.2B-Instruct-GGUF:Q4_K_M" +add_test_vision "ggml-org/SmolVLM2-500M-Video-Instruct-GGUF:Q8_0" +add_test_vision "ggml-org/gemma-3-4b-it-GGUF:Q4_K_M" +add_test_vision "THUDM/glm-edge-v-5b-gguf:Q4_K_M" +add_test_vision "second-state/Llava-v1.5-7B-GGUF:Q2_K" "vicuna" +add_test_vision "cjpais/llava-1.6-mistral-7b-gguf:Q3_K_M" "vicuna" +add_test_vision "ibm-research/granite-vision-3.2-2b-GGUF:Q4_K_M" +add_test_vision "second-state/MiniCPM-Llama3-V-2_5-GGUF:Q2_K" # model from openbmb is corrupted +add_test_vision "openbmb/MiniCPM-V-2_6-gguf:Q2_K" +add_test_vision "openbmb/MiniCPM-o-2_6-gguf:Q4_0" +add_test_vision "bartowski/Qwen2-VL-2B-Instruct-GGUF:Q4_K_M" +add_test_vision "ggml-org/Qwen2.5-VL-3B-Instruct-GGUF:Q4_K_M" +add_test_vision "ggml-org/InternVL2_5-1B-GGUF:Q8_0" +add_test_vision "ggml-org/InternVL3-1B-Instruct-GGUF:Q8_0" +add_test_vision "ggml-org/Qwen2.5-Omni-3B-GGUF:Q4_K_M" + +add_test_audio "ggml-org/ultravox-v0_5-llama-3_2-1b-GGUF:Q8_0" +add_test_audio "ggml-org/Qwen2.5-Omni-3B-GGUF:Q4_K_M" # to test the big models, run: ./tests.sh big if [ "$RUN_BIG_TESTS" = true ]; then - add_test "llama-mtmd-cli" "ggml-org/pixtral-12b-GGUF:Q4_K_M" - add_test "llama-mtmd-cli" "ggml-org/Mistral-Small-3.1-24B-Instruct-2503-GGUF" "mistral-v7" - add_test "llama-mtmd-cli" "ggml-org/Qwen2-VL-2B-Instruct-GGUF:Q4_K_M" - add_test "llama-mtmd-cli" "ggml-org/Qwen2-VL-7B-Instruct-GGUF:Q4_K_M" - add_test "llama-mtmd-cli" "ggml-org/Qwen2.5-VL-3B-Instruct-GGUF:Q4_K_M" - add_test "llama-mtmd-cli" "ggml-org/Qwen2.5-VL-7B-Instruct-GGUF:Q4_K_M" - add_test "llama-mtmd-cli" "ggml-org/InternVL3-8B-Instruct-GGUF:Q4_K_M" - add_test "llama-mtmd-cli" "ggml-org/InternVL3-14B-Instruct-GGUF:Q4_K_M" - # add_test "llama-mtmd-cli" "ggml-org/Qwen2.5-VL-32B-Instruct-GGUF:Q4_K_M" # does not work on my mac M3 Ultra + add_test_vision "ggml-org/pixtral-12b-GGUF:Q4_K_M" + add_test_vision "ggml-org/Mistral-Small-3.1-24B-Instruct-2503-GGUF" "mistral-v7" + add_test_vision "ggml-org/Qwen2-VL-2B-Instruct-GGUF:Q4_K_M" + add_test_vision "ggml-org/Qwen2-VL-7B-Instruct-GGUF:Q4_K_M" + add_test_vision "ggml-org/Qwen2.5-VL-3B-Instruct-GGUF:Q4_K_M" + add_test_vision "ggml-org/Qwen2.5-VL-7B-Instruct-GGUF:Q4_K_M" + add_test_vision "ggml-org/InternVL3-8B-Instruct-GGUF:Q4_K_M" + add_test_vision "ggml-org/InternVL3-14B-Instruct-GGUF:Q4_K_M" + add_test_vision "ggml-org/Qwen2.5-Omni-7B-GGUF:Q4_K_M" + # add_test_vision "ggml-org/Qwen2.5-VL-32B-Instruct-GGUF:Q4_K_M" # does not work on my mac M3 Ultra + + add_test_audio "ggml-org/ultravox-v0_5-llama-3_1-8b-GGUF:Q4_K_M" + add_test_audio "ggml-org/Qwen2.5-Omni-7B-GGUF:Q4_K_M" fi # to test the huge models, run: ./tests.sh huge # this will run both the big and huge models # huge models are > 32B parameters if [ "$RUN_HUGE_TESTS" = true ]; then - add_test "llama-mtmd-cli" "ggml-org/Qwen2.5-VL-72B-Instruct-GGUF:Q4_K_M" - add_test "llama-mtmd-cli" "ggml-org/Llama-4-Scout-17B-16E-Instruct-GGUF:IQ1_S" + add_test_vision "ggml-org/Qwen2.5-VL-72B-Instruct-GGUF:Q4_K_M" + add_test_vision "ggml-org/Llama-4-Scout-17B-16E-Instruct-GGUF:IQ1_S" fi # these models always give the wrong answer, not sure why -# add_test "llama-mtmd-cli" "ggml-org/SmolVLM-Instruct-GGUF:Q4_K_M" -# add_test "llama-mtmd-cli" "ggml-org/SmolVLM-256M-Instruct-GGUF:Q8_0" -# add_test "llama-mtmd-cli" "ggml-org/SmolVLM2-256M-Video-Instruct-GGUF:Q8_0" +# add_test_vision "ggml-org/SmolVLM-Instruct-GGUF:Q4_K_M" +# add_test_vision "ggml-org/SmolVLM-256M-Instruct-GGUF:Q8_0" +# add_test_vision "ggml-org/SmolVLM2-256M-Video-Instruct-GGUF:Q8_0" # this model has broken chat template, not usable -# add_test "llama-mtmd-cli" "cmp-nct/Yi-VL-6B-GGUF:Q5_K" -# add_test "llama-mtmd-cli" "guinmoon/MobileVLM-3B-GGUF:Q4_K_M" "deepseek" +# add_test_vision "cmp-nct/Yi-VL-6B-GGUF:Q5_K" +# add_test_vision "guinmoon/MobileVLM-3B-GGUF:Q4_K_M" "deepseek" ############### -cmake --build build -j --target "${arr_bin[@]}" +cmake --build build -j --target llama-mtmd-cli arr_res=() -for i in "${!arr_bin[@]}"; do - bin="${arr_bin[$i]}" +for i in "${!arr_hf[@]}"; do + bin="llama-mtmd-cli" + prefix="${arr_prefix[$i]}" hf="${arr_hf[$i]}" tmpl="${arr_tmpl[$i]}" + inp_file="${arr_file[$i]}" echo "Running test with binary: $bin and HF model: $hf" echo "" @@ -107,7 +126,7 @@ for i in "${!arr_bin[@]}"; do output=$(\ "$PROJ_ROOT/build/bin/$bin" \ -hf "$hf" \ - --image $SCRIPT_DIR/test-1.jpeg \ + --image $SCRIPT_DIR/$inp_file \ -p "what is the publisher name of the newspaper?" \ --temp 0 -n 128 \ ${tmpl:+--chat-template "$tmpl"} \ @@ -116,9 +135,9 @@ for i in "${!arr_bin[@]}"; do echo "$output" > $SCRIPT_DIR/output/$bin-$(echo "$hf" | tr '/' '-').log if echo "$output" | grep -iq "new york"; then - result="\033[32mOK\033[0m: $bin $hf" + result="$prefix \033[32mOK\033[0m: $bin $hf" else - result="\033[31mFAIL\033[0m: $bin $hf" + result="$prefix \033[31mFAIL\033[0m: $bin $hf" fi echo -e "$result" arr_res+=("$result")