-
Notifications
You must be signed in to change notification settings - Fork 6k
[LoRA] parse metadata from LoRA and save metadata #11324
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
5139de1
d8a305e
ba546bc
25f826e
61d3708
e98fb84
2f1c326
d390d4d
201bd7b
a771982
42bb6bc
7ec4ef4
7f59ca0
ded2fd6
d5b3037
bee9e00
a9f5088
7716303
0ac1a39
4b51bbf
e2ca95a
7a2ba69
e0449c2
918aef1
4bd325c
e8bec86
aa5cb3c
7bb6c9f
116306e
ae0580a
f6fde6f
cbb4071
87417b2
55a41bf
16dba2d
023c0fe
67bceda
83a8995
d336486
4f2d90c
42a0d1c
9c32dc2
5d578c9
1c37845
2bf7fde
4304a6d
425ea95
e08830e
40f5c97
5a2a023
0ae3408
99fe09c
f4d4179
46f4726
c4bd1c7
1348463
ef16bce
9cba78e
28d634f
e07ace0
c762b7c
c8c33d3
aabfb5f
a4f78c8
e3e8b20
72b489d
d952267
9bbc6dc
eb52469
461d2bd
f78c6f9
0eba7e7
a4a15b5
4588d83
47cad58
252fd21
29ff6f1
0007969
cbc01a3
2cb9e46
c0d5156
603462e
37a225a
1c03709
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||
---|---|---|---|---|
|
@@ -14,6 +14,7 @@ | |||
|
||||
import copy | ||||
import inspect | ||||
import json | ||||
import os | ||||
from pathlib import Path | ||||
from typing import Callable, Dict, List, Optional, Union | ||||
|
@@ -45,6 +46,7 @@ | |||
set_adapter_layers, | ||||
set_weights_and_activate_adapters, | ||||
) | ||||
from ..utils.state_dict_utils import _load_sft_state_dict_metadata | ||||
|
||||
|
||||
if is_transformers_available(): | ||||
|
@@ -62,6 +64,7 @@ | |||
|
||||
LORA_WEIGHT_NAME = "pytorch_lora_weights.bin" | ||||
LORA_WEIGHT_NAME_SAFE = "pytorch_lora_weights.safetensors" | ||||
LORA_ADAPTER_METADATA_KEY = "lora_adapter_metadata" | ||||
|
||||
|
||||
def fuse_text_encoder_lora(text_encoder, lora_scale=1.0, safe_fusing=False, adapter_names=None): | ||||
|
@@ -206,6 +209,7 @@ def _fetch_state_dict( | |||
subfolder, | ||||
user_agent, | ||||
allow_pickle, | ||||
metadata=None, | ||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why is it passed here? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Because we pass
which is, in turn, passed to this method if you step through it. |
||||
): | ||||
model_file = None | ||||
if not isinstance(pretrained_model_name_or_path_or_dict, dict): | ||||
|
@@ -236,11 +240,14 @@ def _fetch_state_dict( | |||
user_agent=user_agent, | ||||
) | ||||
state_dict = safetensors.torch.load_file(model_file, device="cpu") | ||||
metadata = _load_sft_state_dict_metadata(model_file) | ||||
|
||||
except (IOError, safetensors.SafetensorError) as e: | ||||
if not allow_pickle: | ||||
raise e | ||||
# try loading non-safetensors weights | ||||
model_file = None | ||||
metadata = None | ||||
pass | ||||
|
||||
if model_file is None: | ||||
|
@@ -261,10 +268,11 @@ def _fetch_state_dict( | |||
user_agent=user_agent, | ||||
) | ||||
state_dict = load_state_dict(model_file) | ||||
metadata = None | ||||
else: | ||||
state_dict = pretrained_model_name_or_path_or_dict | ||||
|
||||
return state_dict | ||||
return state_dict, metadata | ||||
sayakpaul marked this conversation as resolved.
Show resolved
Hide resolved
|
||||
|
||||
|
||||
def _best_guess_weight_name( | ||||
|
@@ -306,6 +314,11 @@ def _best_guess_weight_name( | |||
return weight_name | ||||
|
||||
|
||||
def _pack_dict_with_prefix(state_dict, prefix): | ||||
sd_with_prefix = {f"{prefix}.{key}": value for key, value in state_dict.items()} | ||||
return sd_with_prefix | ||||
|
||||
|
||||
def _load_lora_into_text_encoder( | ||||
state_dict, | ||||
network_alphas, | ||||
|
@@ -317,10 +330,14 @@ def _load_lora_into_text_encoder( | |||
_pipeline=None, | ||||
low_cpu_mem_usage=False, | ||||
hotswap: bool = False, | ||||
metadata=None, | ||||
): | ||||
if not USE_PEFT_BACKEND: | ||||
raise ValueError("PEFT backend is required for this method.") | ||||
|
||||
if network_alphas and metadata: | ||||
raise ValueError("`network_alphas` and `metadata` cannot be specified both at the same time.") | ||||
|
||||
peft_kwargs = {} | ||||
if low_cpu_mem_usage: | ||||
if not is_peft_version(">=", "0.13.1"): | ||||
|
@@ -349,6 +366,8 @@ def _load_lora_into_text_encoder( | |||
# Load the layers corresponding to text encoder and make necessary adjustments. | ||||
if prefix is not None: | ||||
state_dict = {k.removeprefix(f"{prefix}."): v for k, v in state_dict.items() if k.startswith(f"{prefix}.")} | ||||
if metadata is not None: | ||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. For future: |
||||
metadata = {k.removeprefix(f"{prefix}."): v for k, v in metadata.items() if k.startswith(f"{prefix}.")} | ||||
|
||||
if len(state_dict) > 0: | ||||
logger.info(f"Loading {prefix}.") | ||||
|
@@ -376,7 +395,10 @@ def _load_lora_into_text_encoder( | |||
alpha_keys = [k for k in network_alphas.keys() if k.startswith(prefix) and k.split(".")[0] == prefix] | ||||
network_alphas = {k.removeprefix(f"{prefix}."): v for k, v in network_alphas.items() if k in alpha_keys} | ||||
|
||||
lora_config_kwargs = get_peft_kwargs(rank, network_alphas, state_dict, is_unet=False) | ||||
if metadata is not None: | ||||
lora_config_kwargs = metadata | ||||
else: | ||||
lora_config_kwargs = get_peft_kwargs(rank, network_alphas, state_dict, is_unet=False) | ||||
|
||||
if "use_dora" in lora_config_kwargs: | ||||
if lora_config_kwargs["use_dora"]: | ||||
|
@@ -398,7 +420,10 @@ def _load_lora_into_text_encoder( | |||
if is_peft_version("<=", "0.13.2"): | ||||
lora_config_kwargs.pop("lora_bias") | ||||
|
||||
lora_config = LoraConfig(**lora_config_kwargs) | ||||
try: | ||||
lora_config = LoraConfig(**lora_config_kwargs) | ||||
except TypeError as e: | ||||
raise TypeError("`LoraConfig` class could not be instantiated.") from e | ||||
|
||||
# adapter_name | ||||
if adapter_name is None: | ||||
|
@@ -889,8 +914,7 @@ def set_lora_device(self, adapter_names: List[str], device: Union[torch.device, | |||
@staticmethod | ||||
def pack_weights(layers, prefix): | ||||
layers_weights = layers.state_dict() if isinstance(layers, torch.nn.Module) else layers | ||||
layers_state_dict = {f"{prefix}.{module_name}": param for module_name, param in layers_weights.items()} | ||||
return layers_state_dict | ||||
return _pack_dict_with_prefix(layers_weights, prefix) | ||||
|
||||
@staticmethod | ||||
def write_lora_layers( | ||||
|
@@ -900,16 +924,32 @@ def write_lora_layers( | |||
weight_name: str, | ||||
save_function: Callable, | ||||
safe_serialization: bool, | ||||
lora_adapter_metadata: Optional[dict] = None, | ||||
): | ||||
if os.path.isfile(save_directory): | ||||
logger.error(f"Provided path ({save_directory}) should be a directory, not a file") | ||||
return | ||||
|
||||
if lora_adapter_metadata and not safe_serialization: | ||||
raise ValueError("`lora_adapter_metadata` cannot be specified when not using `safe_serialization`.") | ||||
if lora_adapter_metadata and not isinstance(lora_adapter_metadata, dict): | ||||
raise TypeError("`lora_adapter_metadata` must be of type `dict`.") | ||||
|
||||
if save_function is None: | ||||
if safe_serialization: | ||||
|
||||
def save_function(weights, filename): | ||||
return safetensors.torch.save_file(weights, filename, metadata={"format": "pt"}) | ||||
# Inject framework format. | ||||
metadata = {"format": "pt"} | ||||
sayakpaul marked this conversation as resolved.
Show resolved
Hide resolved
|
||||
if lora_adapter_metadata: | ||||
for key, value in lora_adapter_metadata.items(): | ||||
if isinstance(value, set): | ||||
lora_adapter_metadata[key] = list(value) | ||||
metadata[LORA_ADAPTER_METADATA_KEY] = json.dumps( | ||||
lora_adapter_metadata, indent=2, sort_keys=True | ||||
) | ||||
|
||||
return safetensors.torch.save_file(weights, filename, metadata=metadata) | ||||
|
||||
else: | ||||
save_function = torch.save | ||||
|
Uh oh!
There was an error while loading. Please reload this page.