16
16
from hashlib import sha256
17
17
from typing import TYPE_CHECKING , Any , Callable , ContextManager , Iterable , Iterator , Literal , Sequence , TypeVar , cast
18
18
from itertools import chain
19
+ from transformers import AutoConfig
19
20
20
21
import math
21
22
import numpy as np
@@ -66,8 +67,6 @@ class ModelBase:
66
67
part_names : list [str ]
67
68
is_safetensors : bool
68
69
hparams : dict [str , Any ]
69
- block_count : int
70
- tensor_map : gguf .TensorNameMap
71
70
tensor_names : set [str ] | None
72
71
gguf_writer : gguf .GGUFWriter
73
72
model_name : str | None
@@ -78,6 +77,10 @@ class ModelBase:
78
77
# subclasses should define this!
79
78
model_arch : gguf .MODEL_ARCH
80
79
80
+ # subclasses should initialize this!
81
+ block_count : int
82
+ tensor_map : gguf .TensorNameMap
83
+
81
84
def __init__ (self , dir_model : Path , ftype : gguf .LlamaFileType , fname_out : Path , * , is_big_endian : bool = False ,
82
85
use_temp_file : bool = False , eager : bool = False ,
83
86
metadata_override : Path | None = None , model_name : str | None = None ,
@@ -113,8 +116,6 @@ def get_remote_tensors() -> Iterator[tuple[str, Tensor]]:
113
116
if not self .is_safetensors :
114
117
self .part_names = ModelBase .get_model_part_names (self .dir_model , "pytorch_model" , ".bin" )
115
118
self .hparams = ModelBase .load_hparams (self .dir_model ) if hparams is None else hparams
116
- self .block_count = self .find_hparam (["n_layers" , "num_hidden_layers" , "n_layer" , "num_layers" ])
117
- self .tensor_map = gguf .get_tensor_name_map (self .model_arch , self .block_count )
118
119
self .tensor_names = None
119
120
self .metadata_override = metadata_override
120
121
self .model_name = model_name
@@ -417,15 +418,13 @@ def get_model_part_names(dir_model: Path, prefix: str, suffix: str) -> list[str]
417
418
418
419
@staticmethod
419
420
def load_hparams (dir_model : Path ):
420
- with open (dir_model / "config.json" , "r" , encoding = "utf-8" ) as f :
421
- hparams = json .load (f )
422
- architectures = hparams .get ("architectures" )
423
- if "text_config" in hparams :
424
- hparams = {** hparams , ** hparams ["text_config" ]}
425
- if architectures is not None :
426
- # preserve "architectures" from root level config
427
- hparams ["architectures" ] = architectures
428
- return hparams
421
+ try :
422
+ return AutoConfig .from_pretrained (dir_model ).to_dict ()
423
+ except Exception as e :
424
+ logger .warning (f"Failed to load model config from { dir_model } : { e } " )
425
+ logger .warning ("Trying to load config.json instead" )
426
+ with open (dir_model / "config.json" , "r" , encoding = "utf-8" ) as f :
427
+ return json .load (f )
429
428
430
429
@classmethod
431
430
def register (cls , * names : str ) -> Callable [[AnyModel ], AnyModel ]:
@@ -454,6 +453,23 @@ def from_model_architecture(cls, arch: str, model_type = ModelType.TEXT) -> type
454
453
455
454
456
455
class TextModel (ModelBase ):
456
+ def __init__ (self , * args , ** kwargs ):
457
+ super ().__init__ (* args , ** kwargs )
458
+
459
+ if "text_config" in self .hparams :
460
+ # move the text_config to the root level
461
+ self .hparams = {** self .hparams , ** self .hparams ["text_config" ]}
462
+
463
+ self .block_count = self .find_hparam (["n_layers" , "num_hidden_layers" , "n_layer" , "num_layers" ])
464
+ self .tensor_map = gguf .get_tensor_name_map (self .model_arch , self .block_count )
465
+
466
+ @classmethod
467
+ def __init_subclass__ (cls ):
468
+ # can't use an abstract property, because overriding it without type errors
469
+ # would require using decorated functions instead of simply defining the property
470
+ if "model_arch" not in cls .__dict__ :
471
+ raise TypeError (f"Missing property 'model_arch' for { cls .__name__ !r} " )
472
+
457
473
def set_vocab (self ):
458
474
self ._set_vocab_gpt2 ()
459
475
@@ -1070,9 +1086,9 @@ def __init__(self, *args, **kwargs):
1070
1086
if self .model_arch != gguf .MODEL_ARCH .CLIP_VISION :
1071
1087
raise TypeError ("VisionModel must be subclassed with model_arch = gguf.MODEL_ARCH.CLIP_VISION" )
1072
1088
1073
- # small hack to correct the number of layers
1074
- self . tensor_map = gguf . get_tensor_name_map ( gguf . MODEL_ARCH . CLIP_VISION , 128 )
1075
- self .n_embd_text = self . find_hparam ([ "hidden_size" , "n_embd" ] )
1089
+ # get n_embd of the text model
1090
+ text_config = { ** self . hparams , ** self . hparams [ "text_config" ]}
1091
+ self .n_embd_text = text_config . get ( "hidden_size" , text_config . get ( "n_embd" , 0 ) )
1076
1092
assert self .n_embd_text > 0 , "n_embd not found in hparams"
1077
1093
1078
1094
if "vision_config" not in self .hparams :
@@ -1081,6 +1097,9 @@ def __init__(self, *args, **kwargs):
1081
1097
self .global_config = self .hparams
1082
1098
self .hparams = self .hparams ["vision_config" ]
1083
1099
1100
+ self .block_count = self .find_hparam (["n_layers" , "num_hidden_layers" , "n_layer" , "num_layers" , "depth" ])
1101
+ self .tensor_map = gguf .get_tensor_name_map (gguf .MODEL_ARCH .CLIP_VISION , self .block_count )
1102
+
1084
1103
# load preprocessor config
1085
1104
with open (self .dir_model / "preprocessor_config.json" , "r" , encoding = "utf-8" ) as f :
1086
1105
self .preprocessor_config = json .load (f )
@@ -1098,7 +1117,7 @@ def set_gguf_parameters(self):
1098
1117
self .gguf_writer .add_vision_patch_size (self .find_hparam (["patch_size" ]))
1099
1118
self .gguf_writer .add_vision_embedding_length (self .find_hparam (["hidden_size" ]))
1100
1119
self .gguf_writer .add_vision_feed_forward_length (self .find_hparam (["intermediate_size" ]))
1101
- self .gguf_writer .add_vision_block_count (self .find_hparam ([ "num_hidden_layers" ]) )
1120
+ self .gguf_writer .add_vision_block_count (self .block_count )
1102
1121
self .gguf_writer .add_vision_head_count (self .find_hparam (["num_attention_heads" ]))
1103
1122
1104
1123
# preprocessor config
@@ -1719,23 +1738,12 @@ def prepare_tensors(self):
1719
1738
"LlamaForCausalLM" ,
1720
1739
"MistralForCausalLM" ,
1721
1740
"MixtralForCausalLM" ,
1722
- "Idefics3ForConditionalGeneration" ,
1723
- "SmolVLMForConditionalGeneration" ,
1741
+ "VLlama3ForCausalLM" ,
1724
1742
"LlavaForConditionalGeneration" )
1725
1743
class LlamaModel (TextModel ):
1726
1744
model_arch = gguf .MODEL_ARCH .LLAMA
1727
1745
undo_permute = True
1728
1746
1729
- def __init__ (self , * args , ** kwargs ):
1730
- super ().__init__ (* args , ** kwargs )
1731
- # fix for SmolVLM2, missing `num_attention_heads` in config.json
1732
- if self .hparams ["architectures" ][0 ] == "SmolVLMForConditionalGeneration" :
1733
- self .hparams ["num_attention_heads" ] = self .hparams .get ("num_attention_heads" , 32 )
1734
- # fix for Pixtral, missing `num_attention_heads` in config.json
1735
- if self .hparams ["architectures" ][0 ] == "LlavaForConditionalGeneration" \
1736
- and self .hparams .get ("model_type" ) == "mistral" :
1737
- self .hparams ["num_attention_heads" ] = self .hparams .get ("num_attention_heads" , 32 )
1738
-
1739
1747
def set_vocab (self ):
1740
1748
try :
1741
1749
self ._set_vocab_sentencepiece ()
@@ -1898,11 +1906,7 @@ class LlavaVisionModel(VisionModel):
1898
1906
def __init__ (self , * args , ** kwargs ):
1899
1907
super ().__init__ (* args , ** kwargs )
1900
1908
if self .hparams ["model_type" ] == "pixtral" :
1901
- # fix missing config.json values
1902
- self .hparams ["num_attention_heads" ] = self .hparams .get ("num_attention_heads" , 16 )
1903
- self .hparams ["num_hidden_layers" ] = self .hparams .get ("num_hidden_layers" , 24 )
1904
- self .hparams ["intermediate_size" ] = self .hparams .get ("intermediate_size" , 4096 )
1905
- self .hparams ["hidden_size" ] = self .hparams .get ("hidden_size" , 1024 )
1909
+ # layer_norm_eps is not in config.json, it is hard-coded in modeling_pixtral.py
1906
1910
self .hparams ["layer_norm_eps" ] = self .hparams .get ("layer_norm_eps" , 1e-5 )
1907
1911
self .img_break_tok_id = 12 # see tokenizer_config.json
1908
1912
else :
@@ -1913,7 +1917,6 @@ def set_gguf_parameters(self):
1913
1917
hparams = self .hparams
1914
1918
if hparams ["model_type" ] == "pixtral" :
1915
1919
self .gguf_writer .add_vision_projector_type (gguf .VisionProjectorType .PIXTRAL )
1916
- # default values below are taken from HF tranformers code
1917
1920
self .gguf_writer .add_vision_attention_layernorm_eps (hparams ["layer_norm_eps" ])
1918
1921
self .gguf_writer .add_vision_use_silu (True )
1919
1922
@@ -1944,13 +1947,12 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter
1944
1947
class SmolVLMModel (VisionModel ):
1945
1948
def __init__ (self , * args , ** kwargs ):
1946
1949
super ().__init__ (* args , ** kwargs )
1947
- # fix for SmolVLM2, missing some keys in config.json
1948
- # default values are taken from transformers code
1949
1950
if self .hparams ["model_type" ] == "smolvlm_vision" :
1951
+ # fix for SmolVLM2, missing some keys in config.json
1952
+ # default values are taken from transformers code
1950
1953
self .hparams ["hidden_size" ] = self .hparams .get ("hidden_size" , 1152 )
1951
1954
self .hparams ["num_attention_heads" ] = self .hparams .get ("num_attention_heads" , 16 )
1952
1955
self .hparams ["intermediate_size" ] = self .hparams .get ("intermediate_size" , 3072 )
1953
- self .hparams ["num_hidden_layers" ] = self .hparams .get ("num_hidden_layers" , 12 )
1954
1956
1955
1957
def set_gguf_parameters (self ):
1956
1958
super ().set_gguf_parameters ()
@@ -3505,6 +3507,8 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter
3505
3507
3506
3508
@ModelBase .register ("NomicBertModel" )
3507
3509
class NomicBertModel (BertModel ):
3510
+ model_arch = gguf .MODEL_ARCH .BERT
3511
+
3508
3512
def __init__ (self , dir_model : Path , ftype : gguf .LlamaFileType , fname_out : Path , ** kwargs : Any ):
3509
3513
hparams = kwargs .pop ("hparams" , None )
3510
3514
if hparams is None :
@@ -5849,6 +5853,19 @@ def split_str_to_n_bytes(split_str: str) -> int:
5849
5853
return n
5850
5854
5851
5855
5856
+ def get_model_architecture (dir_model : Path , model_type : ModelType , hparams : Any = None ) -> str :
5857
+ hparams = ModelBase .load_hparams (dir_model ) if hparams is None else hparams
5858
+ text_config = hparams .get ("text_config" , {})
5859
+ vision_config = hparams .get ("vision_config" , {})
5860
+ arch = hparams ["architectures" ][0 ]
5861
+ # if "architectures" is found in the sub-config, use that instead
5862
+ if model_type == ModelType .TEXT and text_config .get ("architectures" ) is not None :
5863
+ arch = text_config ["architectures" ][0 ]
5864
+ elif model_type == ModelType .VISION and vision_config .get ("architectures" ) is not None :
5865
+ arch = vision_config ["architectures" ][0 ]
5866
+ return arch
5867
+
5868
+
5852
5869
def main () -> None :
5853
5870
args = parse_args ()
5854
5871
@@ -5901,16 +5918,15 @@ def main() -> None:
5901
5918
5902
5919
logger .info (f"Loading model: { dir_model .name } " )
5903
5920
5904
- hparams = ModelBase .load_hparams (dir_model )
5905
-
5906
5921
if args .mmproj :
5907
5922
if "mmproj" not in fname_out .name :
5908
5923
fname_out = ModelBase .add_prefix_to_filename (fname_out , "mmproj-" )
5909
5924
5910
5925
with torch .inference_mode ():
5911
5926
output_type = ftype_map [args .outtype ]
5912
- model_architecture = hparams ["architectures" ][0 ]
5913
5927
model_type = ModelType .VISION if args .mmproj else ModelType .TEXT
5928
+ model_architecture = get_model_architecture (dir_model , model_type )
5929
+ logger .info (f"Model architecture: { model_architecture } " )
5914
5930
try :
5915
5931
model_class = ModelBase .from_model_architecture (model_architecture , model_type = model_type )
5916
5932
except NotImplementedError :
0 commit comments