@@ -45,7 +45,13 @@ def from_name(cls, name: str):
45
45
return cls (** transformer_configs [name ])
46
46
# fuzzy search
47
47
config = [config for config in transformer_configs if config in str (name ).upper () or config in str (name )]
48
- assert len (config ) == 1 , name
48
+
49
+ # We may have two or more configs matched (e.g. "7B" and "Mistral-7B"). Find the best config match,
50
+ # take longer name (as it have more symbols matched)
51
+ if len (config ) > 1 :
52
+ config .sort (key = len , reverse = True )
53
+ assert len (config [0 ]) != len (config [1 ]), name # make sure only one 'best' match
54
+
49
55
return cls (** transformer_configs [config [0 ]])
50
56
51
57
@@ -56,6 +62,7 @@ def from_name(cls, name: str):
56
62
"30B" : dict (n_layer = 60 , n_head = 52 , dim = 6656 ),
57
63
"34B" : dict (n_layer = 48 , n_head = 64 , dim = 8192 , vocab_size = 32000 , n_local_heads = 8 , intermediate_size = 22016 , rope_base = 1000000 ), # CodeLlama-34B-Python-hf
58
64
"70B" : dict (n_layer = 80 , n_head = 64 , dim = 8192 , n_local_heads = 8 , intermediate_size = 28672 ),
65
+ "Mistral-7B" : dict (n_layer = 32 , n_head = 32 , n_local_heads = 8 , dim = 4096 , intermediate_size = 14336 , vocab_size = 32000 ),
59
66
}
60
67
61
68
class KVCache (nn .Module ):
0 commit comments