-
Notifications
You must be signed in to change notification settings - Fork 6k
[LoRA] feat: save_lora_adapter()
#9862
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
5b4548a
9fd157e
5ebc388
5694f11
e70f265
4cb583c
1d27595
63916fa
1ed05a4
8e9f683
45bf5b3
01e87cc
ed949c5
628ad09
1f3c9ff
a2bdebd
a2141b6
bc131b6
7f3d5e3
b843edf
497d12c
2843ed6
b23fd8d
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -13,9 +13,13 @@ | |
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
import inspect | ||
import os | ||
from functools import partial | ||
from pathlib import Path | ||
from typing import Dict, List, Optional, Union | ||
|
||
import safetensors | ||
import torch | ||
import torch.nn as nn | ||
|
||
from ..utils import ( | ||
|
@@ -189,40 +193,45 @@ def load_lora_adapter(self, pretrained_model_name_or_path_or_dict, prefix="trans | |
user_agent=user_agent, | ||
allow_pickle=allow_pickle, | ||
) | ||
if network_alphas is not None and prefix is None: | ||
raise ValueError("`network_alphas` cannot be None when `prefix` is None.") | ||
|
||
keys = list(state_dict.keys()) | ||
transformer_keys = [k for k in keys if k.startswith(prefix)] | ||
if len(transformer_keys) > 0: | ||
state_dict = {k.replace(f"{prefix}.", ""): v for k, v in state_dict.items() if k in transformer_keys} | ||
if prefix is not None: | ||
keys = list(state_dict.keys()) | ||
model_keys = [k for k in keys if k.startswith(f"{prefix}.")] | ||
Comment on lines
+199
to
+201
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Just a better and more robust way to filter out the state dict. |
||
if len(model_keys) > 0: | ||
state_dict = {k.replace(f"{prefix}.", ""): v for k, v in state_dict.items() if k in model_keys} | ||
|
||
if len(state_dict) > 0: | ||
if adapter_name in getattr(self, "peft_config", {}): | ||
raise ValueError( | ||
f"Adapter name {adapter_name} already in use in the model - please select a new adapter name." | ||
) | ||
Comment on lines
+205
to
+209
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Catching this error earlier than previous. |
||
|
||
if len(state_dict.keys()) > 0: | ||
# check with first key if is not in peft format | ||
first_key = next(iter(state_dict.keys())) | ||
if "lora_A" not in first_key: | ||
state_dict = convert_unet_state_dict_to_peft(state_dict) | ||
|
||
if adapter_name in getattr(self, "peft_config", {}): | ||
raise ValueError( | ||
f"Adapter name {adapter_name} already in use in the transformer - please select a new adapter name." | ||
) | ||
|
||
rank = {} | ||
for key, val in state_dict.items(): | ||
if "lora_B" in key: | ||
rank[key] = val.shape[1] | ||
|
||
if network_alphas is not None and len(network_alphas) >= 1: | ||
alpha_keys = [k for k in network_alphas.keys() if k.startswith(prefix) and k.split(".")[0] == prefix] | ||
alpha_keys = [k for k in network_alphas.keys() if k.startswith(f"{prefix}.")] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Removing redundant conditions. |
||
network_alphas = {k.replace(f"{prefix}.", ""): v for k, v in network_alphas.items() if k in alpha_keys} | ||
|
||
lora_config_kwargs = get_peft_kwargs(rank, network_alpha_dict=network_alphas, peft_state_dict=state_dict) | ||
if "use_dora" in lora_config_kwargs: | ||
if lora_config_kwargs["use_dora"] and is_peft_version("<", "0.9.0"): | ||
raise ValueError( | ||
"You need `peft` 0.9.0 at least to use DoRA-enabled LoRAs. Please upgrade your installation of `peft`." | ||
) | ||
if lora_config_kwargs["use_dora"]: | ||
if is_peft_version("<", "0.9.0"): | ||
BenjaminBossan marked this conversation as resolved.
Show resolved
Hide resolved
|
||
raise ValueError( | ||
"You need `peft` 0.9.0 at least to use DoRA-enabled LoRAs. Please upgrade your installation of `peft`." | ||
) | ||
else: | ||
lora_config_kwargs.pop("use_dora") | ||
if is_peft_version("<", "0.9.0"): | ||
lora_config_kwargs.pop("use_dora") | ||
Comment on lines
+227
to
+234
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Breaking the conditionals to be more explicit. |
||
lora_config = LoraConfig(**lora_config_kwargs) | ||
|
||
# adapter_name | ||
|
@@ -276,6 +285,69 @@ def load_lora_adapter(self, pretrained_model_name_or_path_or_dict, prefix="trans | |
_pipeline.enable_sequential_cpu_offload() | ||
# Unsafe code /> | ||
|
||
def save_lora_adapter( | ||
self, | ||
save_directory, | ||
adapter_name: str = "default", | ||
upcast_before_saving: bool = False, | ||
safe_serialization: bool = True, | ||
weight_name: Optional[str] = None, | ||
): | ||
""" | ||
Save the LoRA parameters corresponding to the underlying model. | ||
|
||
Arguments: | ||
save_directory (`str` or `os.PathLike`): | ||
Directory to save LoRA parameters to. Will be created if it doesn't exist. | ||
adapter_name: (`str`, defaults to "default"): The name of the adapter to serialize. Useful when the | ||
underlying model has multiple adapters loaded. | ||
upcast_before_saving (`bool`, defaults to `False`): | ||
Whether to cast the underlying model to `torch.float32` before serialization. | ||
save_function (`Callable`): | ||
The function to use to save the state dictionary. Useful during distributed training when you need to | ||
replace `torch.save` with another method. Can be configured with the environment variable | ||
`DIFFUSERS_SAVE_MODE`. | ||
safe_serialization (`bool`, *optional*, defaults to `True`): | ||
Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`. | ||
weight_name: (`str`, *optional*, defaults to `None`): Name of the file to serialize the state dict with. | ||
""" | ||
from peft.utils import get_peft_model_state_dict | ||
|
||
from .lora_base import LORA_WEIGHT_NAME, LORA_WEIGHT_NAME_SAFE | ||
|
||
if adapter_name is None: | ||
adapter_name = get_adapter_name(self) | ||
|
||
if adapter_name not in getattr(self, "peft_config", {}): | ||
raise ValueError(f"Adapter name {adapter_name} not found in the model.") | ||
|
||
lora_layers_to_save = get_peft_model_state_dict( | ||
self.to(dtype=torch.float32 if upcast_before_saving else None), adapter_name=adapter_name | ||
sayakpaul marked this conversation as resolved.
Show resolved
Hide resolved
|
||
) | ||
if os.path.isfile(save_directory): | ||
raise ValueError(f"Provided path ({save_directory}) should be a directory, not a file") | ||
|
||
if safe_serialization: | ||
|
||
def save_function(weights, filename): | ||
return safetensors.torch.save_file(weights, filename, metadata={"format": "pt"}) | ||
|
||
else: | ||
save_function = torch.save | ||
|
||
os.makedirs(save_directory, exist_ok=True) | ||
|
||
if weight_name is None: | ||
if safe_serialization: | ||
weight_name = LORA_WEIGHT_NAME_SAFE | ||
else: | ||
weight_name = LORA_WEIGHT_NAME | ||
|
||
# TODO: we could consider saving the `peft_config` as well. | ||
save_path = Path(save_directory, weight_name).as_posix() | ||
save_function(lora_layers_to_save, save_path) | ||
sayakpaul marked this conversation as resolved.
Show resolved
Hide resolved
|
||
logger.info(f"Model weights saved in {save_path}") | ||
|
||
def set_adapters( | ||
self, | ||
adapter_names: Union[List[str], str], | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -44,6 +44,7 @@ | |
from diffusers.utils import ( | ||
SAFE_WEIGHTS_INDEX_NAME, | ||
WEIGHTS_INDEX_NAME, | ||
is_peft_available, | ||
is_torch_npu_available, | ||
is_xformers_available, | ||
logging, | ||
|
@@ -65,6 +66,10 @@ | |
from ..others.test_utils import TOKEN, USER, is_staging_test | ||
|
||
|
||
if is_peft_available(): | ||
from peft.tuners.tuners_utils import BaseTunerLayer | ||
|
||
|
||
def caculate_expected_num_shards(index_map_path): | ||
with open(index_map_path) as f: | ||
weight_map_dict = json.load(f)["weight_map"] | ||
|
@@ -74,6 +79,16 @@ def caculate_expected_num_shards(index_map_path): | |
return expected_num_shards | ||
|
||
|
||
def check_if_lora_correctly_set(model) -> bool: | ||
""" | ||
Checks if the LoRA layers are correctly set with peft | ||
""" | ||
for module in model.modules(): | ||
if isinstance(module, BaseTunerLayer): | ||
return True | ||
return False | ||
|
||
|
||
# Will be run via run_test_in_subprocess | ||
def _test_from_save_pretrained_dynamo(in_queue, out_queue, timeout): | ||
error = None | ||
|
@@ -877,8 +892,6 @@ def _set_gradient_checkpointing_new(self, module, value=False): | |
model = model_class_copy(**init_dict) | ||
model.enable_gradient_checkpointing() | ||
|
||
print(f"{set(modules_with_gc_enabled.keys())=}, {expected_set=}") | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Unrelated but my hands were itching. |
||
|
||
assert set(modules_with_gc_enabled.keys()) == expected_set | ||
assert all(modules_with_gc_enabled.values()), "All modules should be enabled" | ||
|
||
|
@@ -902,6 +915,94 @@ def test_deprecated_kwargs(self): | |
" from `_deprecated_kwargs = [<deprecated_argument>]`" | ||
) | ||
|
||
@parameterized.expand([True, False]) | ||
@torch.no_grad() | ||
@unittest.skipIf(not is_peft_available(), "Only with PEFT") | ||
def test_save_load_lora_adapter(self, use_dora=False): | ||
import safetensors | ||
from peft import LoraConfig | ||
from peft.utils import get_peft_model_state_dict | ||
|
||
from diffusers.loaders.peft import PeftAdapterMixin | ||
|
||
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() | ||
model = self.model_class(**init_dict).to(torch_device) | ||
|
||
if not issubclass(model.__class__, PeftAdapterMixin): | ||
return | ||
|
||
torch.manual_seed(0) | ||
output_no_lora = model(**inputs_dict, return_dict=False)[0] | ||
|
||
denoiser_lora_config = LoraConfig( | ||
r=4, | ||
lora_alpha=4, | ||
target_modules=["to_q", "to_k", "to_v", "to_out.0"], | ||
init_lora_weights=False, | ||
use_dora=use_dora, | ||
) | ||
model.add_adapter(denoiser_lora_config) | ||
self.assertTrue(check_if_lora_correctly_set(model), "LoRA layers not set correctly") | ||
|
||
torch.manual_seed(0) | ||
outputs_with_lora = model(**inputs_dict, return_dict=False)[0] | ||
|
||
self.assertFalse(torch.allclose(output_no_lora, outputs_with_lora, atol=1e-4, rtol=1e-4)) | ||
|
||
with tempfile.TemporaryDirectory() as tmpdir: | ||
model.save_lora_adapter(tmpdir) | ||
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.safetensors"))) | ||
|
||
state_dict_loaded = safetensors.torch.load_file(os.path.join(tmpdir, "pytorch_lora_weights.safetensors")) | ||
|
||
model.unload_lora() | ||
self.assertFalse(check_if_lora_correctly_set(model), "LoRA layers not set correctly") | ||
|
||
model.load_lora_adapter(tmpdir, prefix=None, use_safetensors=True) | ||
state_dict_retrieved = get_peft_model_state_dict(model, adapter_name="default_0") | ||
|
||
for k in state_dict_loaded: | ||
loaded_v = state_dict_loaded[k] | ||
retrieved_v = state_dict_retrieved[k].to(loaded_v.device) | ||
self.assertTrue(torch.allclose(loaded_v, retrieved_v)) | ||
|
||
self.assertTrue(check_if_lora_correctly_set(model), "LoRA layers not set correctly") | ||
|
||
torch.manual_seed(0) | ||
outputs_with_lora_2 = model(**inputs_dict, return_dict=False)[0] | ||
|
||
self.assertFalse(torch.allclose(output_no_lora, outputs_with_lora_2, atol=1e-4, rtol=1e-4)) | ||
self.assertTrue(torch.allclose(outputs_with_lora, outputs_with_lora_2, atol=1e-4, rtol=1e-4)) | ||
|
||
@unittest.skipIf(not is_peft_available(), "Only with PEFT") | ||
def test_wrong_adapter_name_raises_error(self): | ||
from peft import LoraConfig | ||
|
||
from diffusers.loaders.peft import PeftAdapterMixin | ||
|
||
init_dict, _ = self.prepare_init_args_and_inputs_for_common() | ||
model = self.model_class(**init_dict).to(torch_device) | ||
|
||
if not issubclass(model.__class__, PeftAdapterMixin): | ||
return | ||
|
||
denoiser_lora_config = LoraConfig( | ||
r=4, | ||
lora_alpha=4, | ||
target_modules=["to_q", "to_k", "to_v", "to_out.0"], | ||
init_lora_weights=False, | ||
use_dora=False, | ||
) | ||
model.add_adapter(denoiser_lora_config) | ||
self.assertTrue(check_if_lora_correctly_set(model), "LoRA layers not set correctly") | ||
|
||
with tempfile.TemporaryDirectory() as tmpdir: | ||
wrong_name = "foo" | ||
with self.assertRaises(ValueError) as err_context: | ||
model.save_lora_adapter(tmpdir, adapter_name=wrong_name) | ||
|
||
self.assertTrue(f"Adapter name {wrong_name} not found in the model." in str(err_context.exception)) | ||
|
||
@require_torch_gpu | ||
def test_cpu_offload(self): | ||
config, inputs_dict = self.prepare_init_args_and_inputs_for_common() | ||
|
Uh oh!
There was an error while loading. Please reload this page.