14
14
15
15
import copy
16
16
import inspect
17
+ import json
17
18
import os
18
19
from pathlib import Path
19
20
from typing import Callable , Dict , List , Optional , Union
45
46
set_adapter_layers ,
46
47
set_weights_and_activate_adapters ,
47
48
)
49
+ from ..utils .state_dict_utils import _load_sft_state_dict_metadata
48
50
49
51
50
52
if is_transformers_available ():
62
64
63
65
LORA_WEIGHT_NAME = "pytorch_lora_weights.bin"
64
66
LORA_WEIGHT_NAME_SAFE = "pytorch_lora_weights.safetensors"
67
+ LORA_ADAPTER_METADATA_KEY = "lora_adapter_metadata"
65
68
66
69
67
70
def fuse_text_encoder_lora (text_encoder , lora_scale = 1.0 , safe_fusing = False , adapter_names = None ):
@@ -206,6 +209,7 @@ def _fetch_state_dict(
206
209
subfolder ,
207
210
user_agent ,
208
211
allow_pickle ,
212
+ metadata = None ,
209
213
):
210
214
model_file = None
211
215
if not isinstance (pretrained_model_name_or_path_or_dict , dict ):
@@ -236,11 +240,14 @@ def _fetch_state_dict(
236
240
user_agent = user_agent ,
237
241
)
238
242
state_dict = safetensors .torch .load_file (model_file , device = "cpu" )
243
+ metadata = _load_sft_state_dict_metadata (model_file )
244
+
239
245
except (IOError , safetensors .SafetensorError ) as e :
240
246
if not allow_pickle :
241
247
raise e
242
248
# try loading non-safetensors weights
243
249
model_file = None
250
+ metadata = None
244
251
pass
245
252
246
253
if model_file is None :
@@ -261,10 +268,11 @@ def _fetch_state_dict(
261
268
user_agent = user_agent ,
262
269
)
263
270
state_dict = load_state_dict (model_file )
271
+ metadata = None
264
272
else :
265
273
state_dict = pretrained_model_name_or_path_or_dict
266
274
267
- return state_dict
275
+ return state_dict , metadata
268
276
269
277
270
278
def _best_guess_weight_name (
@@ -306,6 +314,11 @@ def _best_guess_weight_name(
306
314
return weight_name
307
315
308
316
317
+ def _pack_dict_with_prefix (state_dict , prefix ):
318
+ sd_with_prefix = {f"{ prefix } .{ key } " : value for key , value in state_dict .items ()}
319
+ return sd_with_prefix
320
+
321
+
309
322
def _load_lora_into_text_encoder (
310
323
state_dict ,
311
324
network_alphas ,
@@ -317,10 +330,14 @@ def _load_lora_into_text_encoder(
317
330
_pipeline = None ,
318
331
low_cpu_mem_usage = False ,
319
332
hotswap : bool = False ,
333
+ metadata = None ,
320
334
):
321
335
if not USE_PEFT_BACKEND :
322
336
raise ValueError ("PEFT backend is required for this method." )
323
337
338
+ if network_alphas and metadata :
339
+ raise ValueError ("`network_alphas` and `metadata` cannot be specified both at the same time." )
340
+
324
341
peft_kwargs = {}
325
342
if low_cpu_mem_usage :
326
343
if not is_peft_version (">=" , "0.13.1" ):
@@ -349,6 +366,8 @@ def _load_lora_into_text_encoder(
349
366
# Load the layers corresponding to text encoder and make necessary adjustments.
350
367
if prefix is not None :
351
368
state_dict = {k .removeprefix (f"{ prefix } ." ): v for k , v in state_dict .items () if k .startswith (f"{ prefix } ." )}
369
+ if metadata is not None :
370
+ metadata = {k .removeprefix (f"{ prefix } ." ): v for k , v in metadata .items () if k .startswith (f"{ prefix } ." )}
352
371
353
372
if len (state_dict ) > 0 :
354
373
logger .info (f"Loading { prefix } ." )
@@ -376,7 +395,10 @@ def _load_lora_into_text_encoder(
376
395
alpha_keys = [k for k in network_alphas .keys () if k .startswith (prefix ) and k .split ("." )[0 ] == prefix ]
377
396
network_alphas = {k .removeprefix (f"{ prefix } ." ): v for k , v in network_alphas .items () if k in alpha_keys }
378
397
379
- lora_config_kwargs = get_peft_kwargs (rank , network_alphas , state_dict , is_unet = False )
398
+ if metadata is not None :
399
+ lora_config_kwargs = metadata
400
+ else :
401
+ lora_config_kwargs = get_peft_kwargs (rank , network_alphas , state_dict , is_unet = False )
380
402
381
403
if "use_dora" in lora_config_kwargs :
382
404
if lora_config_kwargs ["use_dora" ]:
@@ -398,7 +420,10 @@ def _load_lora_into_text_encoder(
398
420
if is_peft_version ("<=" , "0.13.2" ):
399
421
lora_config_kwargs .pop ("lora_bias" )
400
422
401
- lora_config = LoraConfig (** lora_config_kwargs )
423
+ try :
424
+ lora_config = LoraConfig (** lora_config_kwargs )
425
+ except TypeError as e :
426
+ raise TypeError ("`LoraConfig` class could not be instantiated." ) from e
402
427
403
428
# adapter_name
404
429
if adapter_name is None :
@@ -889,8 +914,7 @@ def set_lora_device(self, adapter_names: List[str], device: Union[torch.device,
889
914
@staticmethod
890
915
def pack_weights (layers , prefix ):
891
916
layers_weights = layers .state_dict () if isinstance (layers , torch .nn .Module ) else layers
892
- layers_state_dict = {f"{ prefix } .{ module_name } " : param for module_name , param in layers_weights .items ()}
893
- return layers_state_dict
917
+ return _pack_dict_with_prefix (layers_weights , prefix )
894
918
895
919
@staticmethod
896
920
def write_lora_layers (
@@ -900,16 +924,32 @@ def write_lora_layers(
900
924
weight_name : str ,
901
925
save_function : Callable ,
902
926
safe_serialization : bool ,
927
+ lora_adapter_metadata : Optional [dict ] = None ,
903
928
):
904
929
if os .path .isfile (save_directory ):
905
930
logger .error (f"Provided path ({ save_directory } ) should be a directory, not a file" )
906
931
return
907
932
933
+ if lora_adapter_metadata and not safe_serialization :
934
+ raise ValueError ("`lora_adapter_metadata` cannot be specified when not using `safe_serialization`." )
935
+ if lora_adapter_metadata and not isinstance (lora_adapter_metadata , dict ):
936
+ raise TypeError ("`lora_adapter_metadata` must be of type `dict`." )
937
+
908
938
if save_function is None :
909
939
if safe_serialization :
910
940
911
941
def save_function (weights , filename ):
912
- return safetensors .torch .save_file (weights , filename , metadata = {"format" : "pt" })
942
+ # Inject framework format.
943
+ metadata = {"format" : "pt" }
944
+ if lora_adapter_metadata :
945
+ for key , value in lora_adapter_metadata .items ():
946
+ if isinstance (value , set ):
947
+ lora_adapter_metadata [key ] = list (value )
948
+ metadata [LORA_ADAPTER_METADATA_KEY ] = json .dumps (
949
+ lora_adapter_metadata , indent = 2 , sort_keys = True
950
+ )
951
+
952
+ return safetensors .torch .save_file (weights , filename , metadata = metadata )
913
953
914
954
else :
915
955
save_function = torch .save
0 commit comments