Skip to content

Commit 053367d

Browse files
authored
mtmd : support InternVL 2.5 and 3 (#13422)
* convert : internvl support * InternVL3-1B working * fix regression * rm mobilevlm from test * fix conversion * add test for internvl * add to list of pre-quant * restore boi/eoi check * add clarify comment for norm eps
1 parent d891942 commit 053367d

File tree

9 files changed

+243
-25
lines changed

9 files changed

+243
-25
lines changed

convert_hf_to_gguf.py

Lines changed: 75 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -426,7 +426,11 @@ def load_hparams(dir_model: Path):
426426
logger.warning(f"Failed to load model config from {dir_model}: {e}")
427427
logger.warning("Trying to load config.json instead")
428428
with open(dir_model / "config.json", "r", encoding="utf-8") as f:
429-
return json.load(f)
429+
config = json.load(f)
430+
if "llm_config" in config:
431+
# rename for InternVL
432+
config["text_config"] = config["llm_config"]
433+
return config
430434

431435
@classmethod
432436
def register(cls, *names: str) -> Callable[[AnyModel], AnyModel]:
@@ -2606,6 +2610,11 @@ def set_gguf_parameters(self):
26062610
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
26072611
if self.hf_arch == "Qwen2Model":
26082612
name = f"model.{name}" # map to Qwen2ForCausalLM tensors
2613+
if "language_model." in name:
2614+
name = name.replace("language_model.", "") # for InternVL
2615+
if name.startswith("mlp") or name.startswith("vision_model"):
2616+
# skip visual tensors
2617+
return []
26092618
yield from super().modify_tensors(data_torch, name, bid)
26102619

26112620

@@ -2709,6 +2718,62 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter
27092718
return [] # skip other tensors
27102719

27112720

2721+
@ModelBase.register("InternVisionModel")
2722+
class InternVisionModel(VisionModel):
2723+
def set_gguf_parameters(self):
2724+
super().set_gguf_parameters()
2725+
hparams = self.hparams
2726+
self.gguf_writer.add_vision_projector_type(gguf.VisionProjectorType.INTERNVL)
2727+
self.gguf_writer.add_vision_attention_layernorm_eps(hparams["layer_norm_eps"])
2728+
# hidden_act
2729+
if hparams["hidden_act"] == "silu":
2730+
self.gguf_writer.add_vision_use_silu(True)
2731+
elif hparams["hidden_act"] == "gelu":
2732+
self.gguf_writer.add_vision_use_gelu(True)
2733+
else:
2734+
raise ValueError(f"Unsupported hidden_act: {hparams['hidden_act']}")
2735+
# downsample_ratio
2736+
downsample_ratio = self.global_config.get("downsample_ratio")
2737+
assert downsample_ratio is not None
2738+
self.gguf_writer.add_vision_projector_scale_factor(int(1.0 / downsample_ratio))
2739+
2740+
def tensor_force_quant(self, name, new_name, bid, n_dims):
2741+
del bid, name, n_dims # unused
2742+
if ".patch_embd." in new_name:
2743+
return gguf.GGMLQuantizationType.F16
2744+
if ".position_embd." in new_name:
2745+
return gguf.GGMLQuantizationType.F32
2746+
return False
2747+
2748+
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
2749+
del bid # unused
2750+
if name.startswith("vision_model") or name.startswith("mlp"):
2751+
# process visual tensors
2752+
# correct name
2753+
if name.startswith("vision_model"):
2754+
name = "vision_tower." + name
2755+
if (".ls" in name or "position_embedding" in name) and not name.endswith(".weight"):
2756+
name += ".weight"
2757+
# split QKV tensors if needed
2758+
if ".qkv." in name:
2759+
if data_torch.ndim == 2: # weight
2760+
c3, _ = data_torch.shape
2761+
else: # bias
2762+
c3 = data_torch.shape[0]
2763+
assert c3 % 3 == 0
2764+
c = c3 // 3
2765+
wq = data_torch[:c]
2766+
wk = data_torch[c: c * 2]
2767+
wv = data_torch[c * 2:]
2768+
return [
2769+
(self.map_tensor_name(name.replace("attn.qkv", "self_attn.q_proj")), wq),
2770+
(self.map_tensor_name(name.replace("attn.qkv", "self_attn.k_proj")), wk),
2771+
(self.map_tensor_name(name.replace("attn.qkv", "self_attn.v_proj")), wv),
2772+
]
2773+
return [(self.map_tensor_name(name), data_torch)]
2774+
return [] # skip other tensors
2775+
2776+
27122777
@ModelBase.register("WavTokenizerDec")
27132778
class WavTokenizerDecModel(TextModel):
27142779
model_arch = gguf.MODEL_ARCH.WAVTOKENIZER_DEC
@@ -3360,6 +3425,11 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter
33603425
head_dim = n_embd // num_heads
33613426
num_groups = num_heads // q_per_kv
33623427

3428+
name = name.replace("language_model.", "") # InternVL
3429+
if name.startswith("mlp") or name.startswith("vision_model"):
3430+
# skip visual tensors
3431+
return []
3432+
33633433
if bid is not None and f"model.layers.{bid}.attention.wqkv" in name:
33643434
qkv = data_torch
33653435

@@ -3433,6 +3503,10 @@ def set_gguf_parameters(self):
34333503
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
34343504
n_head = self.hparams["num_attention_heads"]
34353505
n_kv_head = self.hparams.get("num_key_value_heads")
3506+
name = name.replace("language_model.", "") # InternVL
3507+
if name.startswith("mlp") or name.startswith("vision_model"):
3508+
# skip visual tensors
3509+
return []
34363510
if name.endswith(("q_proj.weight", "q_proj.bias")):
34373511
data_torch = LlamaModel.permute(data_torch, n_head, n_head)
34383512
if name.endswith(("k_proj.weight", "k_proj.bias")):

docs/multimodal.md

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,4 +66,12 @@ NOTE: some models may require large context window, for example: `-c 8192`
6666

6767
# Mistral Small 3.1 24B (IQ2_M quantization)
6868
(tool_name) -hf ggml-org/Mistral-Small-3.1-24B-Instruct-2503-GGUF
69+
70+
# InternVL 2.5 and 3
71+
(tool_name) -hf ggml-org/InternVL2_5-1B-GGUF
72+
(tool_name) -hf ggml-org/InternVL2_5-2B-GGUF
73+
(tool_name) -hf ggml-org/InternVL3-1B-Instruct-GGUF
74+
(tool_name) -hf ggml-org/InternVL3-2B-Instruct-GGUF
75+
(tool_name) -hf ggml-org/InternVL3-4B-Instruct-GGUF
76+
(tool_name) -hf ggml-org/InternVL3-14B-Instruct-GGUF
6977
```

gguf-py/gguf/constants.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -491,6 +491,8 @@ class MODEL_TENSOR(IntEnum):
491491
V_ENC_FFN_UP = auto()
492492
V_ENC_FFN_GATE = auto()
493493
V_ENC_FFN_DOWN = auto()
494+
V_LAYER_SCALE_1 = auto()
495+
V_LAYER_SCALE_2 = auto()
494496
V_PRE_NORM = auto()
495497
V_POST_NORM = auto()
496498
V_MM_INP_NORM = auto()
@@ -748,6 +750,8 @@ class MODEL_TENSOR(IntEnum):
748750
MODEL_TENSOR.V_ENC_FFN_UP: "v.blk.{bid}.ffn_up",
749751
MODEL_TENSOR.V_ENC_FFN_GATE: "v.blk.{bid}.ffn_gate",
750752
MODEL_TENSOR.V_ENC_FFN_DOWN: "v.blk.{bid}.ffn_down",
753+
MODEL_TENSOR.V_LAYER_SCALE_1: "v.blk.{bid}.ls1",
754+
MODEL_TENSOR.V_LAYER_SCALE_2: "v.blk.{bid}.ls2",
751755
MODEL_TENSOR.V_PRE_NORM: "v.pre_ln",
752756
MODEL_TENSOR.V_POST_NORM: "v.post_ln",
753757
MODEL_TENSOR.V_MM_INP_PROJ: "mm.input_projection",
@@ -786,6 +790,8 @@ class MODEL_TENSOR(IntEnum):
786790
MODEL_TENSOR.V_ENC_FFN_UP,
787791
MODEL_TENSOR.V_ENC_FFN_GATE,
788792
MODEL_TENSOR.V_ENC_FFN_DOWN,
793+
MODEL_TENSOR.V_LAYER_SCALE_1,
794+
MODEL_TENSOR.V_LAYER_SCALE_2,
789795
MODEL_TENSOR.V_PRE_NORM,
790796
MODEL_TENSOR.V_POST_NORM,
791797
MODEL_TENSOR.V_MM_INP_PROJ,
@@ -2167,6 +2173,7 @@ class VisionProjectorType:
21672173
PIXTRAL = "pixtral"
21682174
QWEN2VL = "qwen2vl_merger"
21692175
QWEN25VL = "qwen2.5vl_merger"
2176+
INTERNVL = "internvl"
21702177

21712178

21722179
# Items here are (block size, type size)

gguf-py/gguf/tensor_mapping.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -905,6 +905,7 @@ class TensorNameMap:
905905

906906
MODEL_TENSOR.V_MMPROJ_MLP: (
907907
"model.mm_projector.mlp.mlp.{bid}",
908+
"mlp1.{bid}", # InternVL
908909
),
909910

910911
MODEL_TENSOR.V_MMPROJ_PEG: (
@@ -955,6 +956,7 @@ class TensorNameMap:
955956

956957
MODEL_TENSOR.V_ENC_INPUT_NORM: (
957958
"vision_tower.vision_model.encoder.layers.{bid}.layer_norm1",
959+
"vision_tower.vision_model.encoder.layers.{bid}.norm1", # InternVL
958960
"vpm.encoder.layers.{bid}.layer_norm1",
959961
"model.vision_model.encoder.layers.{bid}.layer_norm1", # SmolVLM
960962
"vision_tower.transformer.layers.{bid}.attention_norm", # pixtral
@@ -963,6 +965,7 @@ class TensorNameMap:
963965

964966
MODEL_TENSOR.V_ENC_OUTPUT: (
965967
"vision_tower.vision_model.encoder.layers.{bid}.self_attn.out_proj",
968+
"vision_tower.vision_model.encoder.layers.{bid}.attn.proj", # InternVL
966969
"vpm.encoder.layers.{bid}.self_attn.out_proj",
967970
"model.vision_model.encoder.layers.{bid}.self_attn.out_proj", # SmolVLM
968971
"vision_tower.transformer.layers.{bid}.attention.o_proj", # pixtral
@@ -971,6 +974,7 @@ class TensorNameMap:
971974

972975
MODEL_TENSOR.V_ENC_OUTPUT_NORM: (
973976
"vision_tower.vision_model.encoder.layers.{bid}.layer_norm2",
977+
"vision_tower.vision_model.encoder.layers.{bid}.norm2", # InternVL
974978
"vpm.encoder.layers.{bid}.layer_norm2",
975979
"model.vision_model.encoder.layers.{bid}.layer_norm2", # SmolVLM
976980
"vision_tower.transformer.layers.{bid}.ffn_norm", # pixtral
@@ -1000,6 +1004,14 @@ class TensorNameMap:
10001004
"visual.blocks.{bid}.mlp.down_proj", # qwen2.5vl
10011005
),
10021006

1007+
MODEL_TENSOR.V_LAYER_SCALE_1: (
1008+
"vision_tower.vision_model.encoder.layers.{bid}.ls1", # InternVL
1009+
),
1010+
1011+
MODEL_TENSOR.V_LAYER_SCALE_2: (
1012+
"vision_tower.vision_model.encoder.layers.{bid}.ls2", # InternVL
1013+
),
1014+
10031015
MODEL_TENSOR.V_PRE_NORM: (
10041016
"vision_tower.vision_model.pre_layrnorm",
10051017
"vision_tower.ln_pre", # pixtral

tools/mtmd/README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@ For the following models, you can use `convert_hf_to_gguf.py`with `--mmproj` fla
4848
- [Pixtral 12B](https://huggingface.co/mistral-community/pixtral-12b) - only works with `transformers`-compatible checkpoint
4949
- Qwen 2 VL and Qwen 2.5 VL (from [Qwen](https://huggingface.co/Qwen))
5050
- [Mistral Small 3.1 24B](https://huggingface.co/mistralai/Mistral-Small-3.1-24B-Instruct-2503)
51+
- InternVL 2.5 and InternVL 3 from [OpenGVLab](https://huggingface.co/OpenGVLab) (note: we don't support conversion of `InternVL3-*-hf` model, only non-HF version is supported ; `InternLM2Model` **text** model is not supported)
5152

5253
For older models, please refer to the relevant guide for instructions on how to obtain or create them:
5354

tools/mtmd/clip-impl.h

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -33,9 +33,6 @@
3333
#define KEY_PROJ_TYPE "clip.projector_type"
3434
#define KEY_SPATIAL_MERGE_SIZE "clip.vision.spatial_merge_size"
3535

36-
#define KEY_USE_GLU_MLP "clip.use_glu_mlp" // for qwen2.5vl
37-
#define KEY_USE_RMS_NORM "clip.use_rms_norm" // for qwen2.5vl
38-
3936
#define KEY_MM_PATCH_MERGE_TYPE "clip.vision.mm_patch_merge_type"
4037
#define KEY_IMAGE_GRID_PINPOINTS "clip.vision.image_grid_pinpoints"
4138
#define KEY_IMAGE_CROP_RESOLUTION "clip.vision.image_crop_resolution"
@@ -60,8 +57,10 @@
6057
#define TN_FFN_GATE "%s.blk.%d.ffn_gate.%s"
6158
#define TN_FFN_UP "%s.blk.%d.ffn_up.%s"
6259
#define TN_FFN_GATE "%s.blk.%d.ffn_gate.%s"
63-
#define TN_LN_1 "%s.blk.%d.ln1.%s"
64-
#define TN_LN_2 "%s.blk.%d.ln2.%s"
60+
#define TN_LN_1 "%s.blk.%d.ln1.%s" // layer norm
61+
#define TN_LN_2 "%s.blk.%d.ln2.%s" // layer norm
62+
#define TN_LS_1 "%s.blk.%d.ls1.%s" // layer scale
63+
#define TN_LS_2 "%s.blk.%d.ls2.%s" // layer scale
6564
#define TN_LN_PRE "%s.pre_ln.%s"
6665
#define TN_LN_POST "%s.post_ln.%s"
6766
#define TN_LLAVA_PROJ "mm.%d.%s"
@@ -105,6 +104,7 @@ enum projector_type {
105104
PROJECTOR_TYPE_IDEFICS3,
106105
PROJECTOR_TYPE_PIXTRAL,
107106
PROJECTOR_TYPE_QWEN25VL,
107+
PROJECTOR_TYPE_INTERNVL,
108108
PROJECTOR_TYPE_UNKNOWN,
109109
};
110110

@@ -119,6 +119,7 @@ static std::map<projector_type, std::string> PROJECTOR_TYPE_NAMES = {
119119
{ PROJECTOR_TYPE_GEMMA3, "gemma3"},
120120
{ PROJECTOR_TYPE_IDEFICS3, "idefics3"},
121121
{ PROJECTOR_TYPE_PIXTRAL, "pixtral"},
122+
{ PROJECTOR_TYPE_INTERNVL, "internvl"},
122123
};
123124

124125
static projector_type clip_projector_type_from_string(const std::string & str) {

0 commit comments

Comments
 (0)