Skip to content

Commit 4840d2f

Browse files
committed
use AutoConfig
1 parent e8b00ed commit 4840d2f

File tree

1 file changed

+14
-26
lines changed

1 file changed

+14
-26
lines changed

convert_hf_to_gguf.py

Lines changed: 14 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
from hashlib import sha256
1717
from typing import TYPE_CHECKING, Any, Callable, ContextManager, Iterable, Iterator, Literal, Sequence, TypeVar, cast
1818
from itertools import chain
19+
from transformers import AutoConfig
1920

2021
import math
2122
import numpy as np
@@ -417,8 +418,13 @@ def get_model_part_names(dir_model: Path, prefix: str, suffix: str) -> list[str]
417418

418419
@staticmethod
419420
def load_hparams(dir_model: Path):
420-
with open(dir_model / "config.json", "r", encoding="utf-8") as f:
421-
return json.load(f)
421+
try:
422+
return AutoConfig.from_pretrained(dir_model, trust_remote_code=True).to_dict()
423+
except Exception as e:
424+
logger.warning(f"Failed to load model config from {dir_model}: {e}")
425+
logger.warning("Trying to load config.json instead")
426+
with open(dir_model / "config.json", "r", encoding="utf-8") as f:
427+
return json.load(f)
422428

423429
@classmethod
424430
def register(cls, *names: str) -> Callable[[AnyModel], AnyModel]:
@@ -1080,10 +1086,6 @@ def __init__(self, *args, **kwargs):
10801086
if self.model_arch != gguf.MODEL_ARCH.CLIP_VISION:
10811087
raise TypeError("VisionModel must be subclassed with model_arch = gguf.MODEL_ARCH.CLIP_VISION")
10821088

1083-
# small hack to correct the number of layers
1084-
self.block_count = 512 # vision models are small, this "ought to be enough for anybody"
1085-
self.tensor_map = gguf.get_tensor_name_map(gguf.MODEL_ARCH.CLIP_VISION, self.block_count)
1086-
10871089
# get n_embd of the text model
10881090
text_config = {**self.hparams, **self.hparams["text_config"]}
10891091
self.n_embd_text = text_config.get("hidden_size", text_config.get("n_embd", 0))
@@ -1095,6 +1097,9 @@ def __init__(self, *args, **kwargs):
10951097
self.global_config = self.hparams
10961098
self.hparams = self.hparams["vision_config"]
10971099

1100+
self.block_count = self.find_hparam(["n_layers", "num_hidden_layers", "n_layer", "num_layers"])
1101+
self.tensor_map = gguf.get_tensor_name_map(gguf.MODEL_ARCH.CLIP_VISION, self.block_count)
1102+
10981103
# load preprocessor config
10991104
with open(self.dir_model / "preprocessor_config.json", "r", encoding="utf-8") as f:
11001105
self.preprocessor_config = json.load(f)
@@ -1739,17 +1744,6 @@ class LlamaModel(TextModel):
17391744
model_arch = gguf.MODEL_ARCH.LLAMA
17401745
undo_permute = True
17411746

1742-
def __init__(self, *args, **kwargs):
1743-
super().__init__(*args, **kwargs)
1744-
arch = get_model_architecture(self.dir_model, ModelType.TEXT, self.hparams)
1745-
# fix for SmolVLM2, missing `num_attention_heads` in config.json
1746-
if arch == "VLlama3ForCausalLM":
1747-
self.hparams["num_attention_heads"] = self.hparams.get("num_attention_heads", 32)
1748-
# fix for Pixtral, missing `num_attention_heads` in config.json
1749-
if arch == "LlavaForConditionalGeneration" \
1750-
and self.hparams.get("model_type") == "mistral":
1751-
self.hparams["num_attention_heads"] = self.hparams.get("num_attention_heads", 32)
1752-
17531747
def set_vocab(self):
17541748
try:
17551749
self._set_vocab_sentencepiece()
@@ -1912,11 +1906,7 @@ class LlavaVisionModel(VisionModel):
19121906
def __init__(self, *args, **kwargs):
19131907
super().__init__(*args, **kwargs)
19141908
if self.hparams["model_type"] == "pixtral":
1915-
# fix missing config.json values
1916-
self.hparams["num_attention_heads"] = self.hparams.get("num_attention_heads", 16)
1917-
self.hparams["num_hidden_layers"] = self.hparams.get("num_hidden_layers", 24)
1918-
self.hparams["intermediate_size"] = self.hparams.get("intermediate_size", 4096)
1919-
self.hparams["hidden_size"] = self.hparams.get("hidden_size", 1024)
1909+
# layer_norm_eps is not in config.json, it is hard-coded in modeling_pixtral.py
19201910
self.hparams["layer_norm_eps"] = self.hparams.get("layer_norm_eps", 1e-5)
19211911
self.img_break_tok_id = 12 # see tokenizer_config.json
19221912
else:
@@ -1927,7 +1917,6 @@ def set_gguf_parameters(self):
19271917
hparams = self.hparams
19281918
if hparams["model_type"] == "pixtral":
19291919
self.gguf_writer.add_vision_projector_type(gguf.VisionProjectorType.PIXTRAL)
1930-
# default values below are taken from HF tranformers code
19311920
self.gguf_writer.add_vision_attention_layernorm_eps(hparams["layer_norm_eps"])
19321921
self.gguf_writer.add_vision_use_silu(True)
19331922

@@ -1958,13 +1947,12 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter
19581947
class SmolVLMModel(VisionModel):
19591948
def __init__(self, *args, **kwargs):
19601949
super().__init__(*args, **kwargs)
1961-
# fix for SmolVLM2, missing some keys in config.json
1962-
# default values are taken from transformers code
19631950
if self.hparams["model_type"] == "smolvlm_vision":
1951+
# fix for SmolVLM2, missing some keys in config.json
1952+
# default values are taken from transformers code
19641953
self.hparams["hidden_size"] = self.hparams.get("hidden_size", 1152)
19651954
self.hparams["num_attention_heads"] = self.hparams.get("num_attention_heads", 16)
19661955
self.hparams["intermediate_size"] = self.hparams.get("intermediate_size", 3072)
1967-
self.hparams["num_hidden_layers"] = self.hparams.get("num_hidden_layers", 12)
19681956

19691957
def set_gguf_parameters(self):
19701958
super().set_gguf_parameters()

0 commit comments

Comments
 (0)