12
12
from math import prod
13
13
from pathlib import Path
14
14
from typing import TYPE_CHECKING , Any , Callable , Iterable , Iterator , Sequence , SupportsIndex , cast
15
+ from transformers import AutoConfig
15
16
16
17
import torch
17
18
@@ -256,8 +257,8 @@ def parse_args() -> argparse.Namespace:
256
257
help = "only print out what will be done, without writing any new files" ,
257
258
)
258
259
parser .add_argument (
259
- "--base" , type = Path , required = True ,
260
- help = "directory containing Hugging Face model config files (config.json, tokenizer.json) for the base model that the adapter is based on - only config is needed, actual model weights are not required" ,
260
+ "--base" , type = Path ,
261
+ help = "directory containing Hugging Face model config files (config.json, tokenizer.json) for the base model that the adapter is based on - only config is needed, actual model weights are not required. If base model is unspecified, it will be loaded from Hugging Face hub based on the adapter config " ,
261
262
)
262
263
parser .add_argument (
263
264
"lora_path" , type = Path ,
@@ -267,6 +268,12 @@ def parse_args() -> argparse.Namespace:
267
268
return parser .parse_args ()
268
269
269
270
271
+ def load_hparams_from_hf (hf_model_id : str ) -> dict [str , Any ]:
272
+ # normally, adapter does not come with base model config, we need to load it from AutoConfig
273
+ config = AutoConfig .from_pretrained (hf_model_id )
274
+ return config .to_dict ()
275
+
276
+
270
277
if __name__ == '__main__' :
271
278
args = parse_args ()
272
279
logging .basicConfig (level = logging .DEBUG if args .verbose else logging .INFO )
@@ -281,7 +288,7 @@ def parse_args() -> argparse.Namespace:
281
288
282
289
ftype = ftype_map [args .outtype ]
283
290
284
- dir_base_model : Path = args .base
291
+ dir_base_model : Path | None = args .base
285
292
dir_lora : Path = args .lora_path
286
293
lora_config = dir_lora / "adapter_config.json"
287
294
input_model = dir_lora / "adapter_model.safetensors"
@@ -301,9 +308,29 @@ def parse_args() -> argparse.Namespace:
301
308
input_model = os .path .join (dir_lora , "adapter_model.bin" )
302
309
lora_model = torch .load (input_model , map_location = "cpu" , weights_only = True )
303
310
311
+ # load LoRA config
312
+ with open (lora_config , "r" ) as f :
313
+ lparams : dict [str , Any ] = json .load (f )
314
+
304
315
# load base model
305
- logger .info (f"Loading base model: { dir_base_model .name } " )
306
- hparams = Model .load_hparams (dir_base_model )
316
+ if dir_base_model is None :
317
+ if "base_model_name_or_path" in lparams :
318
+ model_id = lparams ["base_model_name_or_path" ]
319
+ logger .info (f"Loading base model from Hugging Face: { model_id } " )
320
+ try :
321
+ hparams = load_hparams_from_hf (model_id )
322
+ except OSError as e :
323
+ logger .error (f"Failed to load base model config: { e } " )
324
+ logger .error ("Please try downloading the base model and add its path to --base" )
325
+ sys .exit (1 )
326
+ else :
327
+ logger .error ("'base_model_name_or_path' is not found in adapter_config.json" )
328
+ logger .error ("Base model config is required. Please download the base model and add its path to --base" )
329
+ sys .exit (1 )
330
+ else :
331
+ logger .info (f"Loading base model: { dir_base_model .name } " )
332
+ hparams = Model .load_hparams (dir_base_model )
333
+
307
334
with torch .inference_mode ():
308
335
try :
309
336
model_class = Model .from_model_architecture (hparams ["architectures" ][0 ])
@@ -323,13 +350,15 @@ def __init__(self, *args, dir_lora_model: Path, lora_alpha: float, **kwargs):
323
350
self .dir_model_card = dir_lora_model
324
351
self .lora_alpha = float (lora_alpha )
325
352
353
+ def set_vocab (self ):
354
+ pass
355
+
326
356
def set_type (self ):
327
357
self .gguf_writer .add_type (gguf .GGUFType .ADAPTER )
328
358
self .gguf_writer .add_string (gguf .Keys .Adapter .TYPE , "lora" )
329
359
330
360
def set_gguf_parameters (self ):
331
361
self .gguf_writer .add_float32 (gguf .Keys .Adapter .LORA_ALPHA , self .lora_alpha )
332
- super ().set_gguf_parameters ()
333
362
334
363
def generate_extra_tensors (self ) -> Iterable [tuple [str , Tensor ]]:
335
364
# Never add extra tensors (e.g. rope_freqs) for LoRA adapters
@@ -350,7 +379,7 @@ def get_tensors(self) -> Iterator[tuple[str, Tensor]]:
350
379
logger .error (f"Unexpected name '{ name } ': Not a lora_A or lora_B tensor" )
351
380
if ".embed_tokens.weight" in name or ".lm_head.weight" in name :
352
381
logger .error ("Embeddings is present in the adapter. This can be due to new tokens added during fine tuning" )
353
- logger .error ("Hint: if you are using TRL, make sure not to call setup_chat_format() " )
382
+ logger .error ("Please refer to https://github.com/ggerganov/llama.cpp/pull/9948 " )
354
383
sys .exit (1 )
355
384
356
385
if base_name in tensor_map :
@@ -384,9 +413,6 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter
384
413
yield (dest_name + ".lora_a" , lora_a )
385
414
yield (dest_name + ".lora_b" , lora_b )
386
415
387
- with open (lora_config , "r" ) as f :
388
- lparams : dict [str , Any ] = json .load (f )
389
-
390
416
alpha : float = lparams ["lora_alpha" ]
391
417
392
418
model_instance = LoraModel (
@@ -399,6 +425,7 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter
399
425
dry_run = args .dry_run ,
400
426
dir_lora_model = dir_lora ,
401
427
lora_alpha = alpha ,
428
+ hparams = hparams ,
402
429
)
403
430
404
431
logger .info ("Exporting model..." )
0 commit comments