@@ -66,8 +66,6 @@ class ModelBase:
66
66
part_names : list [str ]
67
67
is_safetensors : bool
68
68
hparams : dict [str , Any ]
69
- block_count : int
70
- tensor_map : gguf .TensorNameMap
71
69
tensor_names : set [str ] | None
72
70
gguf_writer : gguf .GGUFWriter
73
71
model_name : str | None
@@ -78,6 +76,10 @@ class ModelBase:
78
76
# subclasses should define this!
79
77
model_arch : gguf .MODEL_ARCH
80
78
79
+ # subclasses should initialize this!
80
+ block_count : int
81
+ tensor_map : gguf .TensorNameMap
82
+
81
83
def __init__ (self , dir_model : Path , ftype : gguf .LlamaFileType , fname_out : Path , is_big_endian : bool = False ,
82
84
use_temp_file : bool = False , eager : bool = False ,
83
85
metadata_override : Path | None = None , model_name : str | None = None ,
@@ -113,8 +115,6 @@ def get_remote_tensors() -> Iterator[tuple[str, Tensor]]:
113
115
if not self .is_safetensors :
114
116
self .part_names = ModelBase .get_model_part_names (self .dir_model , "pytorch_model" , ".bin" )
115
117
self .hparams = ModelBase .load_hparams (self .dir_model ) if hparams is None else hparams
116
- self .block_count = self .find_hparam (["n_layers" , "num_hidden_layers" , "n_layer" , "num_layers" ])
117
- self .tensor_map = gguf .get_tensor_name_map (self .model_arch , self .block_count )
118
118
self .tensor_names = None
119
119
self .metadata_override = metadata_override
120
120
self .model_name = model_name
@@ -418,14 +418,7 @@ def get_model_part_names(dir_model: Path, prefix: str, suffix: str) -> list[str]
418
418
@staticmethod
419
419
def load_hparams (dir_model : Path ):
420
420
with open (dir_model / "config.json" , "r" , encoding = "utf-8" ) as f :
421
- hparams = json .load (f )
422
- architectures = hparams .get ("architectures" )
423
- if "text_config" in hparams :
424
- hparams = {** hparams , ** hparams ["text_config" ]}
425
- if architectures is not None :
426
- # preserve "architectures" from root level config
427
- hparams ["architectures" ] = architectures
428
- return hparams
421
+ return json .load (f )
429
422
430
423
@classmethod
431
424
def register (cls , * names : str ) -> Callable [[AnyModel ], AnyModel ]:
@@ -454,6 +447,16 @@ def from_model_architecture(cls, arch: str, model_type = ModelType.TEXT) -> type
454
447
455
448
456
449
class TextModel (ModelBase ):
450
+ def __init__ (self , * args , ** kwargs ):
451
+ super ().__init__ (* args , ** kwargs )
452
+
453
+ if "text_config" in self .hparams :
454
+ # move the text_config to the root level
455
+ self .hparams = {** self .hparams , ** self .hparams ["text_config" ]}
456
+
457
+ self .block_count = self .find_hparam (["n_layers" , "num_hidden_layers" , "n_layer" , "num_layers" ])
458
+ self .tensor_map = gguf .get_tensor_name_map (self .model_arch , self .block_count )
459
+
457
460
@classmethod
458
461
def __init_subclass__ (cls ):
459
462
# can't use an abstract property, because overriding it without type errors
@@ -1078,8 +1081,12 @@ def __init__(self, *args, **kwargs):
1078
1081
raise TypeError ("VisionModel must be subclassed with model_arch = gguf.MODEL_ARCH.CLIP_VISION" )
1079
1082
1080
1083
# small hack to correct the number of layers
1081
- self .tensor_map = gguf .get_tensor_name_map (gguf .MODEL_ARCH .CLIP_VISION , 128 )
1082
- self .n_embd_text = self .find_hparam (["hidden_size" , "n_embd" ])
1084
+ self .block_count = 512 # vision models are small, this "ought to be enough for anybody"
1085
+ self .tensor_map = gguf .get_tensor_name_map (gguf .MODEL_ARCH .CLIP_VISION , self .block_count )
1086
+
1087
+ # get n_embd of the text model
1088
+ text_config = {** self .hparams , ** self .hparams ["text_config" ]}
1089
+ self .n_embd_text = text_config .get ("hidden_size" , text_config .get ("n_embd" , 0 ))
1083
1090
assert self .n_embd_text > 0 , "n_embd not found in hparams"
1084
1091
1085
1092
if "vision_config" not in self .hparams :
@@ -1726,20 +1733,20 @@ def prepare_tensors(self):
1726
1733
"LlamaForCausalLM" ,
1727
1734
"MistralForCausalLM" ,
1728
1735
"MixtralForCausalLM" ,
1729
- "Idefics3ForConditionalGeneration" ,
1730
- "SmolVLMForConditionalGeneration" ,
1736
+ "VLlama3ForCausalLM" ,
1731
1737
"LlavaForConditionalGeneration" )
1732
1738
class LlamaModel (TextModel ):
1733
1739
model_arch = gguf .MODEL_ARCH .LLAMA
1734
1740
undo_permute = True
1735
1741
1736
1742
def __init__ (self , * args , ** kwargs ):
1737
1743
super ().__init__ (* args , ** kwargs )
1744
+ arch = get_model_architecture (self .dir_model , ModelType .TEXT , self .hparams )
1738
1745
# fix for SmolVLM2, missing `num_attention_heads` in config.json
1739
- if self . hparams [ "architectures" ][ 0 ] == "SmolVLMForConditionalGeneration " :
1746
+ if arch == "VLlama3ForCausalLM " :
1740
1747
self .hparams ["num_attention_heads" ] = self .hparams .get ("num_attention_heads" , 32 )
1741
1748
# fix for Pixtral, missing `num_attention_heads` in config.json
1742
- if self . hparams [ "architectures" ][ 0 ] == "LlavaForConditionalGeneration" \
1749
+ if arch == "LlavaForConditionalGeneration" \
1743
1750
and self .hparams .get ("model_type" ) == "mistral" :
1744
1751
self .hparams ["num_attention_heads" ] = self .hparams .get ("num_attention_heads" , 32 )
1745
1752
@@ -5805,6 +5812,19 @@ def split_str_to_n_bytes(split_str: str) -> int:
5805
5812
return n
5806
5813
5807
5814
5815
+ def get_model_architecture (dir_model : Path , model_type : ModelType , hparams : Any = None ) -> str :
5816
+ hparams = ModelBase .load_hparams (dir_model ) if hparams is None else hparams
5817
+ text_config = hparams .get ("text_config" , {})
5818
+ vision_config = hparams .get ("vision_config" , {})
5819
+ arch = hparams ["architectures" ][0 ]
5820
+ # if "architectures" is found in the sub-config, use that instead
5821
+ if model_type == ModelType .TEXT and text_config .get ("architectures" ) is not None :
5822
+ arch = text_config ["architectures" ][0 ]
5823
+ elif model_type == ModelType .VISION and vision_config .get ("architectures" ) is not None :
5824
+ arch = vision_config ["architectures" ][0 ]
5825
+ return arch
5826
+
5827
+
5808
5828
def main () -> None :
5809
5829
args = parse_args ()
5810
5830
@@ -5857,16 +5877,15 @@ def main() -> None:
5857
5877
5858
5878
logger .info (f"Loading model: { dir_model .name } " )
5859
5879
5860
- hparams = ModelBase .load_hparams (dir_model )
5861
-
5862
5880
if args .mmproj :
5863
5881
if "mmproj" not in fname_out .name :
5864
5882
fname_out = ModelBase .add_prefix_to_filename (fname_out , "mmproj-" )
5865
5883
5866
5884
with torch .inference_mode ():
5867
5885
output_type = ftype_map [args .outtype ]
5868
- model_architecture = hparams ["architectures" ][0 ]
5869
5886
model_type = ModelType .VISION if args .mmproj else ModelType .TEXT
5887
+ model_architecture = get_model_architecture (dir_model , model_type )
5888
+ logger .info (f"Model architecture: { model_architecture } " )
5870
5889
try :
5871
5890
model_class = ModelBase .from_model_architecture (model_architecture , model_type = model_type )
5872
5891
except NotImplementedError :
0 commit comments