Skip to content

Commit bc583e3

Browse files
authored
mtmd : support Qwen 2.5 Omni (input audio+vision, no audio output) (#13784)
* mtmd : allow multiple modalities at the same time * refactor mtmd tokenizer * fix compile * ok, missing SinusoidsPositionEmbedding * first working version * fix style * more strict validate of n_embd * refactor if..else to switch * fix regression * add test for 3B * update docs * fix tokenizing with add_special * add more tests * fix test case "huge" * rm redundant code * set_position_mrope_1d rm n_tokens
1 parent 72b090d commit bc583e3

File tree

12 files changed

+1021
-617
lines changed

12 files changed

+1021
-617
lines changed

convert_hf_to_gguf.py

Lines changed: 140 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -432,6 +432,9 @@ def load_hparams(dir_model: Path):
432432
if "llm_config" in config:
433433
# rename for InternVL
434434
config["text_config"] = config["llm_config"]
435+
if "thinker_config" in config:
436+
# rename for Qwen2.5-Omni
437+
config["text_config"] = config["thinker_config"]["text_config"]
435438
return config
436439

437440
@classmethod
@@ -1121,18 +1124,21 @@ class MmprojModel(ModelBase):
11211124
preprocessor_config: dict[str, Any]
11221125
global_config: dict[str, Any]
11231126

1127+
n_block_keys = ["n_layers", "num_hidden_layers", "n_layer", "num_layers", "depth"]
1128+
11241129
has_vision_encoder: bool = True # by default
11251130
has_audio_encoder: bool = False
11261131

1132+
# for models having multiple encoders, we need to separate their hparams
1133+
hparams_vision: dict[str, Any] | None = None
1134+
hparams_audio: dict[str, Any] | None = None
1135+
11271136
def __init__(self, *args, **kwargs):
11281137
super().__init__(*args, **kwargs)
11291138

11301139
if self.model_arch != gguf.MODEL_ARCH.MMPROJ:
11311140
raise TypeError("MmprojModel must be subclassed with model_arch = gguf.MODEL_ARCH.MMPROJ")
11321141

1133-
if self.has_vision_encoder and self.has_audio_encoder:
1134-
raise NotImplementedError("both vision + audio not supported yet")
1135-
11361142
# get n_embd of the text model
11371143
if "text_config" not in self.hparams:
11381144
self.hparams["text_config"] = {}
@@ -1143,22 +1149,32 @@ def __init__(self, *args, **kwargs):
11431149
assert self.n_embd_text > 0, "n_embd not found in hparams"
11441150

11451151
# move vision config to the top level, while preserving the original hparams in global_config
1146-
self.global_config = self.hparams
1152+
import copy
1153+
self.global_config = copy.deepcopy(self.hparams)
1154+
self.hparams_vision = self.get_vision_config()
1155+
self.hparams_audio = self.get_audio_config()
11471156

1148-
if "vision_config" in self.hparams:
1149-
self.hparams = self.hparams["vision_config"]
1150-
elif "audio_config" in self.hparams:
1151-
self.hparams = self.hparams["audio_config"]
1152-
else:
1157+
if self.hparams_vision is None and self.hparams_audio is None:
11531158
raise ValueError("vision_config / audio_config not found in hparams")
11541159

1155-
self.block_count = self.find_hparam(["n_layers", "num_hidden_layers", "n_layer", "num_layers", "depth"])
1160+
# for compat with vision-only models
1161+
self.hparams = self.hparams_vision or self.hparams_audio or self.hparams
1162+
1163+
# TODO @ngxson : this is a hack to support both vision and audio encoders
1164+
have_multiple_encoders = self.has_audio_encoder and self.has_vision_encoder
1165+
self.block_count = 128 if have_multiple_encoders else self.find_hparam(self.n_block_keys, True)
11561166
self.tensor_map = gguf.get_tensor_name_map(gguf.MODEL_ARCH.MMPROJ, self.block_count)
11571167

11581168
# load preprocessor config
11591169
with open(self.dir_model / "preprocessor_config.json", "r", encoding="utf-8") as f:
11601170
self.preprocessor_config = json.load(f)
11611171

1172+
def get_vision_config(self) -> dict[str, Any] | None:
1173+
return self.global_config.get("vision_config")
1174+
1175+
def get_audio_config(self) -> dict[str, Any] | None:
1176+
return self.global_config.get("audio_config")
1177+
11621178
def set_type(self):
11631179
self.gguf_writer.add_type(gguf.GGUFType.MMPROJ)
11641180

@@ -1170,33 +1186,49 @@ def set_gguf_parameters(self):
11701186
self.gguf_writer.add_vision_projection_dim(self.n_embd_text)
11711187

11721188
# vision config
1173-
self.gguf_writer.add_vision_image_size(self.find_hparam(["image_size"]))
1174-
self.gguf_writer.add_vision_patch_size(self.find_hparam(["patch_size"]))
1175-
self.gguf_writer.add_vision_embedding_length(self.find_hparam(["hidden_size"]))
1176-
self.gguf_writer.add_vision_feed_forward_length(self.find_hparam(["intermediate_size"]))
1177-
self.gguf_writer.add_vision_block_count(self.block_count)
1178-
self.gguf_writer.add_vision_head_count(self.find_hparam(["num_attention_heads"]))
1189+
self.gguf_writer.add_vision_image_size(self.find_vparam(["image_size"]))
1190+
self.gguf_writer.add_vision_patch_size(self.find_vparam(["patch_size"]))
1191+
self.gguf_writer.add_vision_embedding_length(self.find_vparam(["hidden_size"]))
1192+
self.gguf_writer.add_vision_feed_forward_length(self.find_vparam(["intermediate_size"]))
1193+
self.gguf_writer.add_vision_block_count(self.find_vparam(self.n_block_keys))
1194+
self.gguf_writer.add_vision_head_count(self.find_vparam(["num_attention_heads"]))
11791195

11801196
# preprocessor config
11811197
self.gguf_writer.add_vision_image_mean(self.preprocessor_config["image_mean"])
11821198
self.gguf_writer.add_vision_image_std(self.preprocessor_config["image_std"])
11831199

1184-
elif self.has_audio_encoder:
1200+
if self.has_audio_encoder:
11851201
self.gguf_writer.add_clip_has_audio_encoder(True)
11861202
self.gguf_writer.add_audio_projection_dim(self.n_embd_text)
11871203

11881204
# audio config
1189-
self.gguf_writer.add_audio_embedding_length(self.find_hparam(["hidden_size"]))
1190-
self.gguf_writer.add_audio_feed_forward_length(self.find_hparam(["intermediate_size"]))
1191-
self.gguf_writer.add_audio_block_count(self.block_count)
1192-
self.gguf_writer.add_audio_head_count(self.find_hparam(["num_attention_heads"]))
1205+
self.gguf_writer.add_audio_embedding_length(self.find_aparam(["hidden_size"]))
1206+
self.gguf_writer.add_audio_feed_forward_length(self.find_aparam(["intermediate_size"]))
1207+
self.gguf_writer.add_audio_block_count(self.find_aparam(self.n_block_keys))
1208+
self.gguf_writer.add_audio_head_count(self.find_aparam(["num_attention_heads"]))
11931209

11941210
else:
11951211
raise ValueError("MmprojModel must have either vision or audio encoder")
11961212

11971213
def write_vocab(self):
11981214
raise ValueError("MmprojModel does not support vocab writing")
11991215

1216+
def find_vparam(self, keys: Iterable[str], optional: bool = False) -> Any:
1217+
assert self.hparams_vision is not None
1218+
return self._find_param(self.hparams_vision, keys, optional)
1219+
1220+
def find_aparam(self, keys: Iterable[str], optional: bool = False) -> Any:
1221+
assert self.hparams_audio is not None
1222+
return self._find_param(self.hparams_audio, keys, optional)
1223+
1224+
def _find_param(self, obj: dict[str, Any], keys: Iterable[str], optional: bool = False) -> Any:
1225+
key = next((k for k in keys if k in obj), None)
1226+
if key is not None:
1227+
return obj[key]
1228+
if optional:
1229+
return None
1230+
raise KeyError(f"could not find any of: {keys}")
1231+
12001232

12011233
@ModelBase.register("GPTNeoXForCausalLM")
12021234
class GPTNeoXModel(TextModel):
@@ -2674,7 +2706,12 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter
26742706
yield from super().modify_tensors(data_torch, name, bid)
26752707

26762708

2677-
@ModelBase.register("Qwen2VLModel", "Qwen2VLForConditionalGeneration", "Qwen2_5_VLForConditionalGeneration")
2709+
@ModelBase.register(
2710+
"Qwen2VLModel",
2711+
"Qwen2VLForConditionalGeneration",
2712+
"Qwen2_5_VLForConditionalGeneration",
2713+
"Qwen2_5OmniModel",
2714+
)
26782715
class Qwen2VLModel(TextModel):
26792716
model_arch = gguf.MODEL_ARCH.QWEN2VL
26802717

@@ -2692,8 +2729,11 @@ def set_vocab(self):
26922729

26932730
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
26942731
del bid # unused
2695-
if name.startswith("visual."):
2696-
# skip visual tensors
2732+
if name.startswith("thinker."):
2733+
name = name.replace("thinker.", "")
2734+
if name.startswith("visual") or name.startswith("audio") or \
2735+
name.startswith("talker") or name.startswith("token2wav"):
2736+
# skip multimodal tensors
26972737
return []
26982738
return [(self.map_tensor_name(name), data_torch)]
26992739

@@ -2702,21 +2742,27 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter
27022742
class Qwen2VLVisionModel(MmprojModel):
27032743
def __init__(self, *args, **kwargs):
27042744
super().__init__(*args, **kwargs)
2705-
self.hparams["image_size"] = self.hparams.get("image_size", 560)
2745+
assert self.hparams_vision is not None
2746+
self.hparams_vision["image_size"] = self.hparams_vision.get("image_size", 560)
27062747
# rename config.json values
2707-
self.hparams["num_attention_heads"] = self.hparams.get("num_heads")
2708-
self.hparams["num_hidden_layers"] = self.hparams.get("depth")
2709-
if "embed_dim" in self.hparams: # qwen2vl
2710-
self.hparams["intermediate_size"] = self.hparams.get("hidden_size")
2711-
self.hparams["hidden_size"] = self.hparams.get("embed_dim")
2748+
self.hparams_vision["num_attention_heads"] = self.hparams_vision.get("num_heads")
2749+
self.hparams_vision["num_hidden_layers"] = self.hparams_vision.get("depth")
2750+
if "embed_dim" in self.hparams_vision: # qwen2vl
2751+
self.hparams_vision["intermediate_size"] = self.hparams_vision.get("hidden_size")
2752+
self.hparams_vision["hidden_size"] = self.hparams_vision.get("embed_dim")
27122753

27132754
def set_gguf_parameters(self):
27142755
super().set_gguf_parameters()
2715-
hparams = self.hparams
2716-
if self.global_config['model_type'] == 'qwen2_vl':
2756+
assert self.hparams_vision is not None
2757+
hparams = self.hparams_vision
2758+
model_type = self.global_config['model_type']
2759+
if model_type == 'qwen2_vl':
27172760
self.gguf_writer.add_clip_projector_type(gguf.VisionProjectorType.QWEN2VL)
2718-
elif self.global_config['model_type'] == 'qwen2_5_vl':
2719-
self.gguf_writer.add_clip_projector_type(gguf.VisionProjectorType.QWEN25VL)
2761+
elif model_type == 'qwen2_5_vl' or model_type == 'qwen2_5_omni':
2762+
if model_type == 'qwen2_5_omni':
2763+
self.gguf_writer.add_clip_projector_type(gguf.VisionProjectorType.QWEN25O)
2764+
else:
2765+
self.gguf_writer.add_clip_projector_type(gguf.VisionProjectorType.QWEN25VL)
27202766
self.gguf_writer.add_vision_use_silu(True)
27212767
# find n_wa_pattern (window attention pattern)
27222768
fullatt_block_indexes = hparams.get("fullatt_block_indexes")
@@ -2774,6 +2820,66 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter
27742820
return [] # skip other tensors
27752821

27762822

2823+
@ModelBase.register("Qwen2_5OmniModel")
2824+
class Qwen25OmniModel(Qwen2VLVisionModel):
2825+
has_vision_encoder = True
2826+
has_audio_encoder = True
2827+
2828+
def __init__(self, *args, **kwargs):
2829+
super().__init__(*args, **kwargs)
2830+
assert self.hparams_audio is not None
2831+
self.hparams_audio["hidden_size"] = self.hparams_audio["d_model"]
2832+
self.hparams_audio["intermediate_size"] = self.hparams_audio["encoder_ffn_dim"]
2833+
self.hparams_audio["num_attention_heads"] = self.hparams_audio["encoder_attention_heads"]
2834+
2835+
def set_gguf_parameters(self):
2836+
super().set_gguf_parameters()
2837+
assert self.hparams_audio is not None
2838+
self.gguf_writer.add_audio_num_mel_bins(self.hparams_audio["num_mel_bins"])
2839+
self.gguf_writer.add_audio_attention_layernorm_eps(self.hparams_audio.get("layer_norm_eps", 1e-5))
2840+
2841+
def get_vision_config(self) -> dict[str, Any] | None:
2842+
return self.global_config["thinker_config"].get("vision_config")
2843+
2844+
def get_audio_config(self) -> dict[str, Any] | None:
2845+
return self.global_config["thinker_config"].get("audio_config")
2846+
2847+
def generate_extra_tensors(self) -> Iterable[tuple[str, Tensor]]:
2848+
# SinusoidsPositionEmbedding
2849+
assert self.hparams_audio is not None
2850+
max_timescale = 10000
2851+
length = 1500
2852+
channels = self.hparams_audio["hidden_size"]
2853+
log_timescale_increment = np.log(max_timescale) / (channels // 2 - 1)
2854+
inv_timescales = torch.exp(-log_timescale_increment * torch.arange(channels // 2).float())
2855+
scaled_time = torch.arange(length)[:, np.newaxis] * inv_timescales[np.newaxis, :]
2856+
pos_embd = torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], dim=1).to(dtype=torch.float32)
2857+
yield ("audio_tower.embed_positions.weight", pos_embd)
2858+
2859+
def tensor_force_quant(self, name, new_name, bid, n_dims):
2860+
del bid, new_name, n_dims # unused
2861+
if ".conv" in name and ".weight" in name:
2862+
return gguf.GGMLQuantizationType.F16
2863+
return False
2864+
2865+
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
2866+
if name.startswith("thinker."):
2867+
name = name.replace("thinker.", "")
2868+
2869+
if name.startswith("audio_tower"):
2870+
# process audio tensors
2871+
if "conv1.bias" in name or "conv2.bias" in name:
2872+
# transpose conv1 and conv2 bias
2873+
data_torch = data_torch.unsqueeze(-1)
2874+
if "audio_bos_eos_token" in name:
2875+
# this tensor is left unused in transformers code
2876+
# https://github.com/huggingface/transformers/blob/6e3063422c4b1c014aa60c32b9254fd2902f0f28/src/transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py#L1809
2877+
return []
2878+
return [(self.map_tensor_name(name), data_torch)]
2879+
2880+
return super().modify_tensors(data_torch, name, bid)
2881+
2882+
27772883
@ModelBase.register("InternVisionModel")
27782884
class InternVisionModel(MmprojModel):
27792885
def set_gguf_parameters(self):

docs/multimodal.md

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -98,3 +98,12 @@ NOTE: some models may require large context window, for example: `-c 8192`
9898
# note: no pre-quantized GGUF this model, as they have very poor result
9999
# ref: https://github.com/ggml-org/llama.cpp/pull/13760
100100
```
101+
102+
**Mixed modalities**:
103+
104+
```sh
105+
# Qwen2.5 Omni
106+
# Capabilities: audio input, vision input
107+
(tool_name) -hf ggml-org/Qwen2.5-Omni-3B-GGUF
108+
(tool_name) -hf ggml-org/Qwen2.5-Omni-7B-GGUF
109+
```

gguf-py/gguf/constants.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2260,6 +2260,7 @@ class VisionProjectorType:
22602260
ULTRAVOX = "ultravox"
22612261
INTERNVL = "internvl"
22622262
QWEN2A = "qwen2a" # audio
2263+
QWEN25O = "qwen2.5o" # omni
22632264

22642265

22652266
# Items here are (block size, type size)

gguf-py/gguf/tensor_mapping.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1125,6 +1125,7 @@ class TensorNameMap:
11251125

11261126
MODEL_TENSOR.A_POST_NORM: (
11271127
"audio_tower.layer_norm", # ultravox
1128+
"audio_tower.ln_post", # qwen2omni
11281129
),
11291130

11301131
MODEL_TENSOR.A_ENC_ATTN_Q: (
@@ -1161,12 +1162,16 @@ class TensorNameMap:
11611162
"audio_tower.layers.{bid}.fc2", # ultravox
11621163
),
11631164

1165+
# note: some tensors below has "audio." pseudo-prefix, to prevent conflicts with vision tensors
1166+
# this prefix is added in the conversion code in modify_tensors()
1167+
11641168
MODEL_TENSOR.A_MMPROJ: (
11651169
"audio.multi_modal_projector.linear_{bid}", # ultravox
11661170
),
11671171

11681172
MODEL_TENSOR.A_MMPROJ_FC: (
11691173
"audio.multi_modal_projector.linear", # qwen2audio
1174+
"audio_tower.proj", # qwen2omni
11701175
),
11711176

11721177
MODEL_TENSOR.A_MM_NORM_PRE: (

tools/mtmd/clip-impl.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -130,6 +130,7 @@ enum projector_type {
130130
PROJECTOR_TYPE_INTERNVL,
131131
PROJECTOR_TYPE_LLAMA4,
132132
PROJECTOR_TYPE_QWEN2A,
133+
PROJECTOR_TYPE_QWEN25O, // will be replaced by QWEN2A or QWEN25VL depending on clip_ctx
133134
PROJECTOR_TYPE_UNKNOWN,
134135
};
135136

@@ -148,6 +149,7 @@ static std::map<projector_type, std::string> PROJECTOR_TYPE_NAMES = {
148149
{ PROJECTOR_TYPE_INTERNVL, "internvl"},
149150
{ PROJECTOR_TYPE_LLAMA4, "llama4"},
150151
{ PROJECTOR_TYPE_QWEN2A, "qwen2a"},
152+
{ PROJECTOR_TYPE_QWEN25O, "qwen2.5o"},
151153
};
152154

153155
static projector_type clip_projector_type_from_string(const std::string & str) {

0 commit comments

Comments
 (0)