Skip to content

[LoRA] parse metadata from LoRA and save metadata #11324

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 84 commits into from
Jun 13, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
84 commits
Select commit Hold shift + click to select a range
5139de1
feat: parse metadata from lora state dicts.
sayakpaul Apr 15, 2025
d8a305e
tests
sayakpaul Apr 15, 2025
ba546bc
fix tests
sayakpaul Apr 15, 2025
25f826e
Merge branch 'main' into metadata-lora
sayakpaul Apr 15, 2025
61d3708
key renaming
sayakpaul Apr 15, 2025
e98fb84
fix
sayakpaul Apr 15, 2025
2f1c326
Merge branch 'main' into metadata-lora
sayakpaul Apr 15, 2025
d390d4d
Merge branch 'main' into metadata-lora
sayakpaul Apr 16, 2025
201bd7b
resolve conflicts.
sayakpaul Apr 21, 2025
a771982
Merge branch 'main' into metadata-lora
sayakpaul May 2, 2025
42bb6bc
smol update
sayakpaul May 2, 2025
7ec4ef4
smol updates
sayakpaul May 2, 2025
7f59ca0
load metadata.
sayakpaul May 2, 2025
ded2fd6
automatically save metadata in save_lora_adapter.
sayakpaul May 2, 2025
d5b3037
propagate changes.
sayakpaul May 2, 2025
bee9e00
changes
sayakpaul May 2, 2025
a9f5088
add test to models too.
sayakpaul May 2, 2025
7716303
tigher tests.
sayakpaul May 2, 2025
0ac1a39
updates
sayakpaul May 2, 2025
4b51bbf
fixes
sayakpaul May 2, 2025
e2ca95a
rename tests.
sayakpaul May 2, 2025
7a2ba69
Merge branch 'main' into metadata-lora
sayakpaul May 3, 2025
e0449c2
sorted.
sayakpaul May 3, 2025
918aef1
Update src/diffusers/loaders/lora_base.py
sayakpaul May 3, 2025
4bd325c
review suggestions.
sayakpaul May 3, 2025
e8bec86
removeprefix.
sayakpaul May 5, 2025
aa5cb3c
Merge branch 'main' into metadata-lora
sayakpaul May 5, 2025
7bb6c9f
propagate changes.
sayakpaul May 8, 2025
116306e
fix-copies
sayakpaul May 8, 2025
ae0580a
sd
sayakpaul May 8, 2025
f6fde6f
docs.
sayakpaul May 8, 2025
cbb4071
resolve conflicts.
sayakpaul May 8, 2025
87417b2
fixes
sayakpaul May 8, 2025
55a41bf
Merge branch 'main' into metadata-lora
sayakpaul May 9, 2025
16dba2d
get review ready.
sayakpaul May 9, 2025
023c0fe
Merge branch 'main' into metadata-lora
sayakpaul May 9, 2025
67bceda
one more test to catch error.
sayakpaul May 9, 2025
83a8995
merge conflicts.
sayakpaul May 9, 2025
d336486
Merge branch 'main' into metadata-lora
sayakpaul May 11, 2025
4f2d90c
Merge branch 'main' into metadata-lora
sayakpaul May 12, 2025
42a0d1c
Merge branch 'main' into metadata-lora
sayakpaul May 15, 2025
9c32dc2
Merge branch 'main' into metadata-lora
linoytsaban May 18, 2025
5d578c9
Merge branch 'main' into metadata-lora
sayakpaul May 19, 2025
1c37845
Merge branch 'main' into metadata-lora
linoytsaban May 20, 2025
2bf7fde
Merge branch 'main' into metadata-lora
sayakpaul May 21, 2025
4304a6d
change to a different approach.
sayakpaul May 22, 2025
425ea95
fix-copies.
sayakpaul May 22, 2025
e08830e
todo
sayakpaul May 22, 2025
40f5c97
sd3
sayakpaul May 22, 2025
5a2a023
update
sayakpaul May 22, 2025
0ae3408
revert changes in get_peft_kwargs.
sayakpaul May 22, 2025
99fe09c
update
sayakpaul May 22, 2025
f4d4179
fixes
sayakpaul May 22, 2025
46f4726
fixes
sayakpaul May 22, 2025
c4bd1c7
Merge branch 'main' into metadata-lora
sayakpaul May 22, 2025
1348463
simplify _load_sft_state_dict_metadata
sayakpaul May 22, 2025
ef16bce
update
sayakpaul May 22, 2025
9cba78e
style fix
sayakpaul May 22, 2025
28d634f
uipdate
sayakpaul May 22, 2025
e07ace0
update
sayakpaul May 22, 2025
c762b7c
update
sayakpaul May 22, 2025
c8c33d3
empty commit
sayakpaul May 22, 2025
aabfb5f
resolve conflicts.
sayakpaul May 22, 2025
a4f78c8
Merge branch 'main' into metadata-lora
sayakpaul May 27, 2025
e3e8b20
Merge branch 'main' into metadata-lora
sayakpaul Jun 2, 2025
72b489d
resolve conflicts.
sayakpaul Jun 4, 2025
d952267
_pack_dict_with_prefix
sayakpaul Jun 5, 2025
9bbc6dc
update
sayakpaul Jun 5, 2025
eb52469
TODO 1.
sayakpaul Jun 5, 2025
461d2bd
todo: 2.
sayakpaul Jun 5, 2025
f78c6f9
todo: 3.
sayakpaul Jun 5, 2025
0eba7e7
update
sayakpaul Jun 5, 2025
a4a15b5
update
sayakpaul Jun 5, 2025
4588d83
Merge branch 'main' into metadata-lora
sayakpaul Jun 5, 2025
47cad58
Merge branch 'main' into metadata-lora
sayakpaul Jun 6, 2025
252fd21
Apply suggestions from code review
sayakpaul Jun 6, 2025
29ff6f1
reraise.
sayakpaul Jun 6, 2025
0007969
Merge branch 'main' into metadata-lora
sayakpaul Jun 6, 2025
cbc01a3
Merge branch 'main' into metadata-lora
sayakpaul Jun 9, 2025
2cb9e46
Merge branch 'main' into metadata-lora
sayakpaul Jun 10, 2025
c0d5156
Merge branch 'main' into metadata-lora
sayakpaul Jun 12, 2025
603462e
Merge branch 'main' into metadata-lora
sayakpaul Jun 13, 2025
37a225a
move argument.
sayakpaul Jun 13, 2025
1c03709
Merge branch 'main' into metadata-lora
sayakpaul Jun 13, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 1 addition & 4 deletions examples/community/ip_adapter_face_id.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,10 +282,7 @@ def load_ip_adapter_face_id(self, pretrained_model_name_or_path_or_dict, weight_
revision = kwargs.pop("revision", None)
subfolder = kwargs.pop("subfolder", None)

user_agent = {
"file_type": "attn_procs_weights",
"framework": "pytorch",
}
user_agent = {"file_type": "attn_procs_weights", "framework": "pytorch"}
model_file = _get_model_file(
pretrained_model_name_or_path_or_dict,
weights_name=weight_name,
Expand Down
15 changes: 3 additions & 12 deletions src/diffusers/loaders/ip_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,10 +159,7 @@ def load_ip_adapter(
" `low_cpu_mem_usage=False`."
)

user_agent = {
"file_type": "attn_procs_weights",
"framework": "pytorch",
}
user_agent = {"file_type": "attn_procs_weights", "framework": "pytorch"}
state_dicts = []
for pretrained_model_name_or_path_or_dict, weight_name, subfolder in zip(
pretrained_model_name_or_path_or_dict, weight_name, subfolder
Expand Down Expand Up @@ -465,10 +462,7 @@ def load_ip_adapter(
" `low_cpu_mem_usage=False`."
)

user_agent = {
"file_type": "attn_procs_weights",
"framework": "pytorch",
}
user_agent = {"file_type": "attn_procs_weights", "framework": "pytorch"}
state_dicts = []
for pretrained_model_name_or_path_or_dict, weight_name, subfolder in zip(
pretrained_model_name_or_path_or_dict, weight_name, subfolder
Expand Down Expand Up @@ -750,10 +744,7 @@ def load_ip_adapter(
" `low_cpu_mem_usage=False`."
)

user_agent = {
"file_type": "attn_procs_weights",
"framework": "pytorch",
}
user_agent = {"file_type": "attn_procs_weights", "framework": "pytorch"}

if not isinstance(pretrained_model_name_or_path_or_dict, dict):
model_file = _get_model_file(
Expand Down
52 changes: 46 additions & 6 deletions src/diffusers/loaders/lora_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

import copy
import inspect
import json
import os
from pathlib import Path
from typing import Callable, Dict, List, Optional, Union
Expand Down Expand Up @@ -45,6 +46,7 @@
set_adapter_layers,
set_weights_and_activate_adapters,
)
from ..utils.state_dict_utils import _load_sft_state_dict_metadata


if is_transformers_available():
Expand All @@ -62,6 +64,7 @@

LORA_WEIGHT_NAME = "pytorch_lora_weights.bin"
LORA_WEIGHT_NAME_SAFE = "pytorch_lora_weights.safetensors"
LORA_ADAPTER_METADATA_KEY = "lora_adapter_metadata"


def fuse_text_encoder_lora(text_encoder, lora_scale=1.0, safe_fusing=False, adapter_names=None):
Expand Down Expand Up @@ -206,6 +209,7 @@ def _fetch_state_dict(
subfolder,
user_agent,
allow_pickle,
metadata=None,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is it passed here?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Because we pass metadata while loading the adapter:

which is, in turn, passed to this method if you step through it.

):
model_file = None
if not isinstance(pretrained_model_name_or_path_or_dict, dict):
Expand Down Expand Up @@ -236,11 +240,14 @@ def _fetch_state_dict(
user_agent=user_agent,
)
state_dict = safetensors.torch.load_file(model_file, device="cpu")
metadata = _load_sft_state_dict_metadata(model_file)

except (IOError, safetensors.SafetensorError) as e:
if not allow_pickle:
raise e
# try loading non-safetensors weights
model_file = None
metadata = None
pass

if model_file is None:
Expand All @@ -261,10 +268,11 @@ def _fetch_state_dict(
user_agent=user_agent,
)
state_dict = load_state_dict(model_file)
metadata = None
else:
state_dict = pretrained_model_name_or_path_or_dict

return state_dict
return state_dict, metadata


def _best_guess_weight_name(
Expand Down Expand Up @@ -306,6 +314,11 @@ def _best_guess_weight_name(
return weight_name


def _pack_dict_with_prefix(state_dict, prefix):
sd_with_prefix = {f"{prefix}.{key}": value for key, value in state_dict.items()}
return sd_with_prefix


def _load_lora_into_text_encoder(
state_dict,
network_alphas,
Expand All @@ -317,10 +330,14 @@ def _load_lora_into_text_encoder(
_pipeline=None,
low_cpu_mem_usage=False,
hotswap: bool = False,
metadata=None,
):
if not USE_PEFT_BACKEND:
raise ValueError("PEFT backend is required for this method.")

if network_alphas and metadata:
raise ValueError("`network_alphas` and `metadata` cannot be specified both at the same time.")

peft_kwargs = {}
if low_cpu_mem_usage:
if not is_peft_version(">=", "0.13.1"):
Expand Down Expand Up @@ -349,6 +366,8 @@ def _load_lora_into_text_encoder(
# Load the layers corresponding to text encoder and make necessary adjustments.
if prefix is not None:
state_dict = {k.removeprefix(f"{prefix}."): v for k, v in state_dict.items() if k.startswith(f"{prefix}.")}
if metadata is not None:
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For future: _load_lora_into_text_encoder() and load_lora_adapter() share a bunch of common logic which I will refactor in a future PR.

metadata = {k.removeprefix(f"{prefix}."): v for k, v in metadata.items() if k.startswith(f"{prefix}.")}

if len(state_dict) > 0:
logger.info(f"Loading {prefix}.")
Expand Down Expand Up @@ -376,7 +395,10 @@ def _load_lora_into_text_encoder(
alpha_keys = [k for k in network_alphas.keys() if k.startswith(prefix) and k.split(".")[0] == prefix]
network_alphas = {k.removeprefix(f"{prefix}."): v for k, v in network_alphas.items() if k in alpha_keys}

lora_config_kwargs = get_peft_kwargs(rank, network_alphas, state_dict, is_unet=False)
if metadata is not None:
lora_config_kwargs = metadata
else:
lora_config_kwargs = get_peft_kwargs(rank, network_alphas, state_dict, is_unet=False)

if "use_dora" in lora_config_kwargs:
if lora_config_kwargs["use_dora"]:
Expand All @@ -398,7 +420,10 @@ def _load_lora_into_text_encoder(
if is_peft_version("<=", "0.13.2"):
lora_config_kwargs.pop("lora_bias")

lora_config = LoraConfig(**lora_config_kwargs)
try:
lora_config = LoraConfig(**lora_config_kwargs)
except TypeError as e:
raise TypeError("`LoraConfig` class could not be instantiated.") from e

# adapter_name
if adapter_name is None:
Expand Down Expand Up @@ -889,8 +914,7 @@ def set_lora_device(self, adapter_names: List[str], device: Union[torch.device,
@staticmethod
def pack_weights(layers, prefix):
layers_weights = layers.state_dict() if isinstance(layers, torch.nn.Module) else layers
layers_state_dict = {f"{prefix}.{module_name}": param for module_name, param in layers_weights.items()}
return layers_state_dict
return _pack_dict_with_prefix(layers_weights, prefix)

@staticmethod
def write_lora_layers(
Expand All @@ -900,16 +924,32 @@ def write_lora_layers(
weight_name: str,
save_function: Callable,
safe_serialization: bool,
lora_adapter_metadata: Optional[dict] = None,
):
if os.path.isfile(save_directory):
logger.error(f"Provided path ({save_directory}) should be a directory, not a file")
return

if lora_adapter_metadata and not safe_serialization:
raise ValueError("`lora_adapter_metadata` cannot be specified when not using `safe_serialization`.")
if lora_adapter_metadata and not isinstance(lora_adapter_metadata, dict):
raise TypeError("`lora_adapter_metadata` must be of type `dict`.")

if save_function is None:
if safe_serialization:

def save_function(weights, filename):
return safetensors.torch.save_file(weights, filename, metadata={"format": "pt"})
# Inject framework format.
metadata = {"format": "pt"}
if lora_adapter_metadata:
for key, value in lora_adapter_metadata.items():
if isinstance(value, set):
lora_adapter_metadata[key] = list(value)
metadata[LORA_ADAPTER_METADATA_KEY] = json.dumps(
lora_adapter_metadata, indent=2, sort_keys=True
)

return safetensors.torch.save_file(weights, filename, metadata=metadata)

else:
save_function = torch.save
Expand Down
Loading