Skip to content

Commit e8b00ed

Browse files
committed
convert : improve model arch handling
1 parent d5fe4e8 commit e8b00ed

File tree

1 file changed

+40
-21
lines changed

1 file changed

+40
-21
lines changed

convert_hf_to_gguf.py

Lines changed: 40 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -66,8 +66,6 @@ class ModelBase:
6666
part_names: list[str]
6767
is_safetensors: bool
6868
hparams: dict[str, Any]
69-
block_count: int
70-
tensor_map: gguf.TensorNameMap
7169
tensor_names: set[str] | None
7270
gguf_writer: gguf.GGUFWriter
7371
model_name: str | None
@@ -78,6 +76,10 @@ class ModelBase:
7876
# subclasses should define this!
7977
model_arch: gguf.MODEL_ARCH
8078

79+
# subclasses should initialize this!
80+
block_count: int
81+
tensor_map: gguf.TensorNameMap
82+
8183
def __init__(self, dir_model: Path, ftype: gguf.LlamaFileType, fname_out: Path, is_big_endian: bool = False,
8284
use_temp_file: bool = False, eager: bool = False,
8385
metadata_override: Path | None = None, model_name: str | None = None,
@@ -113,8 +115,6 @@ def get_remote_tensors() -> Iterator[tuple[str, Tensor]]:
113115
if not self.is_safetensors:
114116
self.part_names = ModelBase.get_model_part_names(self.dir_model, "pytorch_model", ".bin")
115117
self.hparams = ModelBase.load_hparams(self.dir_model) if hparams is None else hparams
116-
self.block_count = self.find_hparam(["n_layers", "num_hidden_layers", "n_layer", "num_layers"])
117-
self.tensor_map = gguf.get_tensor_name_map(self.model_arch, self.block_count)
118118
self.tensor_names = None
119119
self.metadata_override = metadata_override
120120
self.model_name = model_name
@@ -418,14 +418,7 @@ def get_model_part_names(dir_model: Path, prefix: str, suffix: str) -> list[str]
418418
@staticmethod
419419
def load_hparams(dir_model: Path):
420420
with open(dir_model / "config.json", "r", encoding="utf-8") as f:
421-
hparams = json.load(f)
422-
architectures = hparams.get("architectures")
423-
if "text_config" in hparams:
424-
hparams = {**hparams, **hparams["text_config"]}
425-
if architectures is not None:
426-
# preserve "architectures" from root level config
427-
hparams["architectures"] = architectures
428-
return hparams
421+
return json.load(f)
429422

430423
@classmethod
431424
def register(cls, *names: str) -> Callable[[AnyModel], AnyModel]:
@@ -454,6 +447,16 @@ def from_model_architecture(cls, arch: str, model_type = ModelType.TEXT) -> type
454447

455448

456449
class TextModel(ModelBase):
450+
def __init__(self, *args, **kwargs):
451+
super().__init__(*args, **kwargs)
452+
453+
if "text_config" in self.hparams:
454+
# move the text_config to the root level
455+
self.hparams = {**self.hparams, **self.hparams["text_config"]}
456+
457+
self.block_count = self.find_hparam(["n_layers", "num_hidden_layers", "n_layer", "num_layers"])
458+
self.tensor_map = gguf.get_tensor_name_map(self.model_arch, self.block_count)
459+
457460
@classmethod
458461
def __init_subclass__(cls):
459462
# can't use an abstract property, because overriding it without type errors
@@ -1078,8 +1081,12 @@ def __init__(self, *args, **kwargs):
10781081
raise TypeError("VisionModel must be subclassed with model_arch = gguf.MODEL_ARCH.CLIP_VISION")
10791082

10801083
# small hack to correct the number of layers
1081-
self.tensor_map = gguf.get_tensor_name_map(gguf.MODEL_ARCH.CLIP_VISION, 128)
1082-
self.n_embd_text = self.find_hparam(["hidden_size", "n_embd"])
1084+
self.block_count = 512 # vision models are small, this "ought to be enough for anybody"
1085+
self.tensor_map = gguf.get_tensor_name_map(gguf.MODEL_ARCH.CLIP_VISION, self.block_count)
1086+
1087+
# get n_embd of the text model
1088+
text_config = {**self.hparams, **self.hparams["text_config"]}
1089+
self.n_embd_text = text_config.get("hidden_size", text_config.get("n_embd", 0))
10831090
assert self.n_embd_text > 0, "n_embd not found in hparams"
10841091

10851092
if "vision_config" not in self.hparams:
@@ -1726,20 +1733,20 @@ def prepare_tensors(self):
17261733
"LlamaForCausalLM",
17271734
"MistralForCausalLM",
17281735
"MixtralForCausalLM",
1729-
"Idefics3ForConditionalGeneration",
1730-
"SmolVLMForConditionalGeneration",
1736+
"VLlama3ForCausalLM",
17311737
"LlavaForConditionalGeneration")
17321738
class LlamaModel(TextModel):
17331739
model_arch = gguf.MODEL_ARCH.LLAMA
17341740
undo_permute = True
17351741

17361742
def __init__(self, *args, **kwargs):
17371743
super().__init__(*args, **kwargs)
1744+
arch = get_model_architecture(self.dir_model, ModelType.TEXT, self.hparams)
17381745
# fix for SmolVLM2, missing `num_attention_heads` in config.json
1739-
if self.hparams["architectures"][0] == "SmolVLMForConditionalGeneration":
1746+
if arch == "VLlama3ForCausalLM":
17401747
self.hparams["num_attention_heads"] = self.hparams.get("num_attention_heads", 32)
17411748
# fix for Pixtral, missing `num_attention_heads` in config.json
1742-
if self.hparams["architectures"][0] == "LlavaForConditionalGeneration" \
1749+
if arch == "LlavaForConditionalGeneration" \
17431750
and self.hparams.get("model_type") == "mistral":
17441751
self.hparams["num_attention_heads"] = self.hparams.get("num_attention_heads", 32)
17451752

@@ -5805,6 +5812,19 @@ def split_str_to_n_bytes(split_str: str) -> int:
58055812
return n
58065813

58075814

5815+
def get_model_architecture(dir_model: Path, model_type: ModelType, hparams: Any = None) -> str:
5816+
hparams = ModelBase.load_hparams(dir_model) if hparams is None else hparams
5817+
text_config = hparams.get("text_config", {})
5818+
vision_config = hparams.get("vision_config", {})
5819+
arch = hparams["architectures"][0]
5820+
# if "architectures" is found in the sub-config, use that instead
5821+
if model_type == ModelType.TEXT and text_config.get("architectures") is not None:
5822+
arch = text_config["architectures"][0]
5823+
elif model_type == ModelType.VISION and vision_config.get("architectures") is not None:
5824+
arch = vision_config["architectures"][0]
5825+
return arch
5826+
5827+
58085828
def main() -> None:
58095829
args = parse_args()
58105830

@@ -5857,16 +5877,15 @@ def main() -> None:
58575877

58585878
logger.info(f"Loading model: {dir_model.name}")
58595879

5860-
hparams = ModelBase.load_hparams(dir_model)
5861-
58625880
if args.mmproj:
58635881
if "mmproj" not in fname_out.name:
58645882
fname_out = ModelBase.add_prefix_to_filename(fname_out, "mmproj-")
58655883

58665884
with torch.inference_mode():
58675885
output_type = ftype_map[args.outtype]
5868-
model_architecture = hparams["architectures"][0]
58695886
model_type = ModelType.VISION if args.mmproj else ModelType.TEXT
5887+
model_architecture = get_model_architecture(dir_model, model_type)
5888+
logger.info(f"Model architecture: {model_architecture}")
58705889
try:
58715890
model_class = ModelBase.from_model_architecture(model_architecture, model_type=model_type)
58725891
except NotImplementedError:

0 commit comments

Comments
 (0)