Skip to content

Commit 368958d

Browse files
sayakpaulBenjaminBossanlinoytsaban
authored
[LoRA] parse metadata from LoRA and save metadata (#11324)
* feat: parse metadata from lora state dicts. * tests * fix tests * key renaming * fix * smol update * smol updates * load metadata. * automatically save metadata in save_lora_adapter. * propagate changes. * changes * add test to models too. * tigher tests. * updates * fixes * rename tests. * sorted. * Update src/diffusers/loaders/lora_base.py Co-authored-by: Benjamin Bossan <BenjaminBossan@users.noreply.github.com> * review suggestions. * removeprefix. * propagate changes. * fix-copies * sd * docs. * fixes * get review ready. * one more test to catch error. * change to a different approach. * fix-copies. * todo * sd3 * update * revert changes in get_peft_kwargs. * update * fixes * fixes * simplify _load_sft_state_dict_metadata * update * style fix * uipdate * update * update * empty commit * _pack_dict_with_prefix * update * TODO 1. * todo: 2. * todo: 3. * update * update * Apply suggestions from code review Co-authored-by: Benjamin Bossan <BenjaminBossan@users.noreply.github.com> * reraise. * move argument. --------- Co-authored-by: Benjamin Bossan <BenjaminBossan@users.noreply.github.com> Co-authored-by: Linoy Tsaban <57615435+linoytsaban@users.noreply.github.com>
1 parent e52ceae commit 368958d

File tree

11 files changed

+845
-199
lines changed

11 files changed

+845
-199
lines changed

examples/community/ip_adapter_face_id.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -282,10 +282,7 @@ def load_ip_adapter_face_id(self, pretrained_model_name_or_path_or_dict, weight_
282282
revision = kwargs.pop("revision", None)
283283
subfolder = kwargs.pop("subfolder", None)
284284

285-
user_agent = {
286-
"file_type": "attn_procs_weights",
287-
"framework": "pytorch",
288-
}
285+
user_agent = {"file_type": "attn_procs_weights", "framework": "pytorch"}
289286
model_file = _get_model_file(
290287
pretrained_model_name_or_path_or_dict,
291288
weights_name=weight_name,

src/diffusers/loaders/ip_adapter.py

Lines changed: 3 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -159,10 +159,7 @@ def load_ip_adapter(
159159
" `low_cpu_mem_usage=False`."
160160
)
161161

162-
user_agent = {
163-
"file_type": "attn_procs_weights",
164-
"framework": "pytorch",
165-
}
162+
user_agent = {"file_type": "attn_procs_weights", "framework": "pytorch"}
166163
state_dicts = []
167164
for pretrained_model_name_or_path_or_dict, weight_name, subfolder in zip(
168165
pretrained_model_name_or_path_or_dict, weight_name, subfolder
@@ -465,10 +462,7 @@ def load_ip_adapter(
465462
" `low_cpu_mem_usage=False`."
466463
)
467464

468-
user_agent = {
469-
"file_type": "attn_procs_weights",
470-
"framework": "pytorch",
471-
}
465+
user_agent = {"file_type": "attn_procs_weights", "framework": "pytorch"}
472466
state_dicts = []
473467
for pretrained_model_name_or_path_or_dict, weight_name, subfolder in zip(
474468
pretrained_model_name_or_path_or_dict, weight_name, subfolder
@@ -750,10 +744,7 @@ def load_ip_adapter(
750744
" `low_cpu_mem_usage=False`."
751745
)
752746

753-
user_agent = {
754-
"file_type": "attn_procs_weights",
755-
"framework": "pytorch",
756-
}
747+
user_agent = {"file_type": "attn_procs_weights", "framework": "pytorch"}
757748

758749
if not isinstance(pretrained_model_name_or_path_or_dict, dict):
759750
model_file = _get_model_file(

src/diffusers/loaders/lora_base.py

Lines changed: 46 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
import copy
1616
import inspect
17+
import json
1718
import os
1819
from pathlib import Path
1920
from typing import Callable, Dict, List, Optional, Union
@@ -45,6 +46,7 @@
4546
set_adapter_layers,
4647
set_weights_and_activate_adapters,
4748
)
49+
from ..utils.state_dict_utils import _load_sft_state_dict_metadata
4850

4951

5052
if is_transformers_available():
@@ -62,6 +64,7 @@
6264

6365
LORA_WEIGHT_NAME = "pytorch_lora_weights.bin"
6466
LORA_WEIGHT_NAME_SAFE = "pytorch_lora_weights.safetensors"
67+
LORA_ADAPTER_METADATA_KEY = "lora_adapter_metadata"
6568

6669

6770
def fuse_text_encoder_lora(text_encoder, lora_scale=1.0, safe_fusing=False, adapter_names=None):
@@ -206,6 +209,7 @@ def _fetch_state_dict(
206209
subfolder,
207210
user_agent,
208211
allow_pickle,
212+
metadata=None,
209213
):
210214
model_file = None
211215
if not isinstance(pretrained_model_name_or_path_or_dict, dict):
@@ -236,11 +240,14 @@ def _fetch_state_dict(
236240
user_agent=user_agent,
237241
)
238242
state_dict = safetensors.torch.load_file(model_file, device="cpu")
243+
metadata = _load_sft_state_dict_metadata(model_file)
244+
239245
except (IOError, safetensors.SafetensorError) as e:
240246
if not allow_pickle:
241247
raise e
242248
# try loading non-safetensors weights
243249
model_file = None
250+
metadata = None
244251
pass
245252

246253
if model_file is None:
@@ -261,10 +268,11 @@ def _fetch_state_dict(
261268
user_agent=user_agent,
262269
)
263270
state_dict = load_state_dict(model_file)
271+
metadata = None
264272
else:
265273
state_dict = pretrained_model_name_or_path_or_dict
266274

267-
return state_dict
275+
return state_dict, metadata
268276

269277

270278
def _best_guess_weight_name(
@@ -306,6 +314,11 @@ def _best_guess_weight_name(
306314
return weight_name
307315

308316

317+
def _pack_dict_with_prefix(state_dict, prefix):
318+
sd_with_prefix = {f"{prefix}.{key}": value for key, value in state_dict.items()}
319+
return sd_with_prefix
320+
321+
309322
def _load_lora_into_text_encoder(
310323
state_dict,
311324
network_alphas,
@@ -317,10 +330,14 @@ def _load_lora_into_text_encoder(
317330
_pipeline=None,
318331
low_cpu_mem_usage=False,
319332
hotswap: bool = False,
333+
metadata=None,
320334
):
321335
if not USE_PEFT_BACKEND:
322336
raise ValueError("PEFT backend is required for this method.")
323337

338+
if network_alphas and metadata:
339+
raise ValueError("`network_alphas` and `metadata` cannot be specified both at the same time.")
340+
324341
peft_kwargs = {}
325342
if low_cpu_mem_usage:
326343
if not is_peft_version(">=", "0.13.1"):
@@ -349,6 +366,8 @@ def _load_lora_into_text_encoder(
349366
# Load the layers corresponding to text encoder and make necessary adjustments.
350367
if prefix is not None:
351368
state_dict = {k.removeprefix(f"{prefix}."): v for k, v in state_dict.items() if k.startswith(f"{prefix}.")}
369+
if metadata is not None:
370+
metadata = {k.removeprefix(f"{prefix}."): v for k, v in metadata.items() if k.startswith(f"{prefix}.")}
352371

353372
if len(state_dict) > 0:
354373
logger.info(f"Loading {prefix}.")
@@ -376,7 +395,10 @@ def _load_lora_into_text_encoder(
376395
alpha_keys = [k for k in network_alphas.keys() if k.startswith(prefix) and k.split(".")[0] == prefix]
377396
network_alphas = {k.removeprefix(f"{prefix}."): v for k, v in network_alphas.items() if k in alpha_keys}
378397

379-
lora_config_kwargs = get_peft_kwargs(rank, network_alphas, state_dict, is_unet=False)
398+
if metadata is not None:
399+
lora_config_kwargs = metadata
400+
else:
401+
lora_config_kwargs = get_peft_kwargs(rank, network_alphas, state_dict, is_unet=False)
380402

381403
if "use_dora" in lora_config_kwargs:
382404
if lora_config_kwargs["use_dora"]:
@@ -398,7 +420,10 @@ def _load_lora_into_text_encoder(
398420
if is_peft_version("<=", "0.13.2"):
399421
lora_config_kwargs.pop("lora_bias")
400422

401-
lora_config = LoraConfig(**lora_config_kwargs)
423+
try:
424+
lora_config = LoraConfig(**lora_config_kwargs)
425+
except TypeError as e:
426+
raise TypeError("`LoraConfig` class could not be instantiated.") from e
402427

403428
# adapter_name
404429
if adapter_name is None:
@@ -889,8 +914,7 @@ def set_lora_device(self, adapter_names: List[str], device: Union[torch.device,
889914
@staticmethod
890915
def pack_weights(layers, prefix):
891916
layers_weights = layers.state_dict() if isinstance(layers, torch.nn.Module) else layers
892-
layers_state_dict = {f"{prefix}.{module_name}": param for module_name, param in layers_weights.items()}
893-
return layers_state_dict
917+
return _pack_dict_with_prefix(layers_weights, prefix)
894918

895919
@staticmethod
896920
def write_lora_layers(
@@ -900,16 +924,32 @@ def write_lora_layers(
900924
weight_name: str,
901925
save_function: Callable,
902926
safe_serialization: bool,
927+
lora_adapter_metadata: Optional[dict] = None,
903928
):
904929
if os.path.isfile(save_directory):
905930
logger.error(f"Provided path ({save_directory}) should be a directory, not a file")
906931
return
907932

933+
if lora_adapter_metadata and not safe_serialization:
934+
raise ValueError("`lora_adapter_metadata` cannot be specified when not using `safe_serialization`.")
935+
if lora_adapter_metadata and not isinstance(lora_adapter_metadata, dict):
936+
raise TypeError("`lora_adapter_metadata` must be of type `dict`.")
937+
908938
if save_function is None:
909939
if safe_serialization:
910940

911941
def save_function(weights, filename):
912-
return safetensors.torch.save_file(weights, filename, metadata={"format": "pt"})
942+
# Inject framework format.
943+
metadata = {"format": "pt"}
944+
if lora_adapter_metadata:
945+
for key, value in lora_adapter_metadata.items():
946+
if isinstance(value, set):
947+
lora_adapter_metadata[key] = list(value)
948+
metadata[LORA_ADAPTER_METADATA_KEY] = json.dumps(
949+
lora_adapter_metadata, indent=2, sort_keys=True
950+
)
951+
952+
return safetensors.torch.save_file(weights, filename, metadata=metadata)
913953

914954
else:
915955
save_function = torch.save

0 commit comments

Comments
 (0)