diff --git a/src/diffusers/loaders/lora_pipeline.py b/src/diffusers/loaders/lora_pipeline.py index 154aa2d8f9bb..59cbd5a7a960 100644 --- a/src/diffusers/loaders/lora_pipeline.py +++ b/src/diffusers/loaders/lora_pipeline.py @@ -298,8 +298,9 @@ def load_lora_into_unet( if not only_text_encoder: # Load the layers corresponding to UNet. logger.info(f"Loading {cls.unet_name}.") - unet.load_attn_procs( + unet.load_lora_adapter( state_dict, + prefix=cls.unet_name, network_alphas=network_alphas, adapter_name=adapter_name, _pipeline=_pipeline, @@ -827,8 +828,9 @@ def load_lora_into_unet( if not only_text_encoder: # Load the layers corresponding to UNet. logger.info(f"Loading {cls.unet_name}.") - unet.load_attn_procs( + unet.load_lora_adapter( state_dict, + prefix=cls.unet_name, network_alphas=network_alphas, adapter_name=adapter_name, _pipeline=_pipeline, diff --git a/src/diffusers/loaders/peft.py b/src/diffusers/loaders/peft.py index cf361e88a670..a1bce35813a5 100644 --- a/src/diffusers/loaders/peft.py +++ b/src/diffusers/loaders/peft.py @@ -13,9 +13,13 @@ # See the License for the specific language governing permissions and # limitations under the License. import inspect +import os from functools import partial +from pathlib import Path from typing import Dict, List, Optional, Union +import safetensors +import torch import torch.nn as nn from ..utils import ( @@ -189,40 +193,45 @@ def load_lora_adapter(self, pretrained_model_name_or_path_or_dict, prefix="trans user_agent=user_agent, allow_pickle=allow_pickle, ) + if network_alphas is not None and prefix is None: + raise ValueError("`network_alphas` cannot be None when `prefix` is None.") - keys = list(state_dict.keys()) - transformer_keys = [k for k in keys if k.startswith(prefix)] - if len(transformer_keys) > 0: - state_dict = {k.replace(f"{prefix}.", ""): v for k, v in state_dict.items() if k in transformer_keys} + if prefix is not None: + keys = list(state_dict.keys()) + model_keys = [k for k in keys if k.startswith(f"{prefix}.")] + if len(model_keys) > 0: + state_dict = {k.replace(f"{prefix}.", ""): v for k, v in state_dict.items() if k in model_keys} + + if len(state_dict) > 0: + if adapter_name in getattr(self, "peft_config", {}): + raise ValueError( + f"Adapter name {adapter_name} already in use in the model - please select a new adapter name." + ) - if len(state_dict.keys()) > 0: # check with first key if is not in peft format first_key = next(iter(state_dict.keys())) if "lora_A" not in first_key: state_dict = convert_unet_state_dict_to_peft(state_dict) - if adapter_name in getattr(self, "peft_config", {}): - raise ValueError( - f"Adapter name {adapter_name} already in use in the transformer - please select a new adapter name." - ) - rank = {} for key, val in state_dict.items(): if "lora_B" in key: rank[key] = val.shape[1] if network_alphas is not None and len(network_alphas) >= 1: - alpha_keys = [k for k in network_alphas.keys() if k.startswith(prefix) and k.split(".")[0] == prefix] + alpha_keys = [k for k in network_alphas.keys() if k.startswith(f"{prefix}.")] network_alphas = {k.replace(f"{prefix}.", ""): v for k, v in network_alphas.items() if k in alpha_keys} lora_config_kwargs = get_peft_kwargs(rank, network_alpha_dict=network_alphas, peft_state_dict=state_dict) if "use_dora" in lora_config_kwargs: - if lora_config_kwargs["use_dora"] and is_peft_version("<", "0.9.0"): - raise ValueError( - "You need `peft` 0.9.0 at least to use DoRA-enabled LoRAs. Please upgrade your installation of `peft`." - ) + if lora_config_kwargs["use_dora"]: + if is_peft_version("<", "0.9.0"): + raise ValueError( + "You need `peft` 0.9.0 at least to use DoRA-enabled LoRAs. Please upgrade your installation of `peft`." + ) else: - lora_config_kwargs.pop("use_dora") + if is_peft_version("<", "0.9.0"): + lora_config_kwargs.pop("use_dora") lora_config = LoraConfig(**lora_config_kwargs) # adapter_name @@ -276,6 +285,69 @@ def load_lora_adapter(self, pretrained_model_name_or_path_or_dict, prefix="trans _pipeline.enable_sequential_cpu_offload() # Unsafe code /> + def save_lora_adapter( + self, + save_directory, + adapter_name: str = "default", + upcast_before_saving: bool = False, + safe_serialization: bool = True, + weight_name: Optional[str] = None, + ): + """ + Save the LoRA parameters corresponding to the underlying model. + + Arguments: + save_directory (`str` or `os.PathLike`): + Directory to save LoRA parameters to. Will be created if it doesn't exist. + adapter_name: (`str`, defaults to "default"): The name of the adapter to serialize. Useful when the + underlying model has multiple adapters loaded. + upcast_before_saving (`bool`, defaults to `False`): + Whether to cast the underlying model to `torch.float32` before serialization. + save_function (`Callable`): + The function to use to save the state dictionary. Useful during distributed training when you need to + replace `torch.save` with another method. Can be configured with the environment variable + `DIFFUSERS_SAVE_MODE`. + safe_serialization (`bool`, *optional*, defaults to `True`): + Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`. + weight_name: (`str`, *optional*, defaults to `None`): Name of the file to serialize the state dict with. + """ + from peft.utils import get_peft_model_state_dict + + from .lora_base import LORA_WEIGHT_NAME, LORA_WEIGHT_NAME_SAFE + + if adapter_name is None: + adapter_name = get_adapter_name(self) + + if adapter_name not in getattr(self, "peft_config", {}): + raise ValueError(f"Adapter name {adapter_name} not found in the model.") + + lora_layers_to_save = get_peft_model_state_dict( + self.to(dtype=torch.float32 if upcast_before_saving else None), adapter_name=adapter_name + ) + if os.path.isfile(save_directory): + raise ValueError(f"Provided path ({save_directory}) should be a directory, not a file") + + if safe_serialization: + + def save_function(weights, filename): + return safetensors.torch.save_file(weights, filename, metadata={"format": "pt"}) + + else: + save_function = torch.save + + os.makedirs(save_directory, exist_ok=True) + + if weight_name is None: + if safe_serialization: + weight_name = LORA_WEIGHT_NAME_SAFE + else: + weight_name = LORA_WEIGHT_NAME + + # TODO: we could consider saving the `peft_config` as well. + save_path = Path(save_directory, weight_name).as_posix() + save_function(lora_layers_to_save, save_path) + logger.info(f"Model weights saved in {save_path}") + def set_adapters( self, adapter_names: Union[List[str], str], diff --git a/src/diffusers/loaders/unet.py b/src/diffusers/loaders/unet.py index b37b681ae8fe..201526937b4e 100644 --- a/src/diffusers/loaders/unet.py +++ b/src/diffusers/loaders/unet.py @@ -36,6 +36,7 @@ USE_PEFT_BACKEND, _get_model_file, convert_unet_state_dict_to_peft, + deprecate, get_adapter_name, get_peft_kwargs, is_accelerate_available, @@ -209,6 +210,10 @@ def load_attn_procs(self, pretrained_model_name_or_path_or_dict: Union[str, Dict is_model_cpu_offload = False is_sequential_cpu_offload = False + if is_lora: + deprecation_message = "Using the `load_attn_procs()` method has been deprecated and will be removed in a future version. Please use `load_lora_adapter()`." + deprecate("load_attn_procs", "0.40.0", deprecation_message) + if is_custom_diffusion: attn_processors = self._process_custom_diffusion(state_dict=state_dict) elif is_lora: diff --git a/tests/lora/utils.py b/tests/lora/utils.py index b711c8c9791e..7cdb2d6f51d7 100644 --- a/tests/lora/utils.py +++ b/tests/lora/utils.py @@ -1784,11 +1784,7 @@ def test_missing_keys_warning(self): missing_key = [k for k in state_dict if "lora_A" in k][0] del state_dict[missing_key] - logger = ( - logging.get_logger("diffusers.loaders.unet") - if self.unet_kwargs is not None - else logging.get_logger("diffusers.loaders.peft") - ) + logger = logging.get_logger("diffusers.loaders.peft") logger.setLevel(30) with CaptureLogger(logger) as cap_logger: pipe.load_lora_weights(state_dict) @@ -1823,11 +1819,7 @@ def test_unexpected_keys_warning(self): unexpected_key = [k for k in state_dict if "lora_A" in k][0] + ".diffusers_cat" state_dict[unexpected_key] = torch.tensor(1.0, device=torch_device) - logger = ( - logging.get_logger("diffusers.loaders.unet") - if self.unet_kwargs is not None - else logging.get_logger("diffusers.loaders.peft") - ) + logger = logging.get_logger("diffusers.loaders.peft") logger.setLevel(30) with CaptureLogger(logger) as cap_logger: pipe.load_lora_weights(state_dict) diff --git a/tests/models/test_modeling_common.py b/tests/models/test_modeling_common.py index 7f8dc63e00ac..f6ce6bda7381 100644 --- a/tests/models/test_modeling_common.py +++ b/tests/models/test_modeling_common.py @@ -44,6 +44,7 @@ from diffusers.utils import ( SAFE_WEIGHTS_INDEX_NAME, WEIGHTS_INDEX_NAME, + is_peft_available, is_torch_npu_available, is_xformers_available, logging, @@ -65,6 +66,10 @@ from ..others.test_utils import TOKEN, USER, is_staging_test +if is_peft_available(): + from peft.tuners.tuners_utils import BaseTunerLayer + + def caculate_expected_num_shards(index_map_path): with open(index_map_path) as f: weight_map_dict = json.load(f)["weight_map"] @@ -74,6 +79,16 @@ def caculate_expected_num_shards(index_map_path): return expected_num_shards +def check_if_lora_correctly_set(model) -> bool: + """ + Checks if the LoRA layers are correctly set with peft + """ + for module in model.modules(): + if isinstance(module, BaseTunerLayer): + return True + return False + + # Will be run via run_test_in_subprocess def _test_from_save_pretrained_dynamo(in_queue, out_queue, timeout): error = None @@ -877,8 +892,6 @@ def _set_gradient_checkpointing_new(self, module, value=False): model = model_class_copy(**init_dict) model.enable_gradient_checkpointing() - print(f"{set(modules_with_gc_enabled.keys())=}, {expected_set=}") - assert set(modules_with_gc_enabled.keys()) == expected_set assert all(modules_with_gc_enabled.values()), "All modules should be enabled" @@ -902,6 +915,94 @@ def test_deprecated_kwargs(self): " from `_deprecated_kwargs = []`" ) + @parameterized.expand([True, False]) + @torch.no_grad() + @unittest.skipIf(not is_peft_available(), "Only with PEFT") + def test_save_load_lora_adapter(self, use_dora=False): + import safetensors + from peft import LoraConfig + from peft.utils import get_peft_model_state_dict + + from diffusers.loaders.peft import PeftAdapterMixin + + init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() + model = self.model_class(**init_dict).to(torch_device) + + if not issubclass(model.__class__, PeftAdapterMixin): + return + + torch.manual_seed(0) + output_no_lora = model(**inputs_dict, return_dict=False)[0] + + denoiser_lora_config = LoraConfig( + r=4, + lora_alpha=4, + target_modules=["to_q", "to_k", "to_v", "to_out.0"], + init_lora_weights=False, + use_dora=use_dora, + ) + model.add_adapter(denoiser_lora_config) + self.assertTrue(check_if_lora_correctly_set(model), "LoRA layers not set correctly") + + torch.manual_seed(0) + outputs_with_lora = model(**inputs_dict, return_dict=False)[0] + + self.assertFalse(torch.allclose(output_no_lora, outputs_with_lora, atol=1e-4, rtol=1e-4)) + + with tempfile.TemporaryDirectory() as tmpdir: + model.save_lora_adapter(tmpdir) + self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.safetensors"))) + + state_dict_loaded = safetensors.torch.load_file(os.path.join(tmpdir, "pytorch_lora_weights.safetensors")) + + model.unload_lora() + self.assertFalse(check_if_lora_correctly_set(model), "LoRA layers not set correctly") + + model.load_lora_adapter(tmpdir, prefix=None, use_safetensors=True) + state_dict_retrieved = get_peft_model_state_dict(model, adapter_name="default_0") + + for k in state_dict_loaded: + loaded_v = state_dict_loaded[k] + retrieved_v = state_dict_retrieved[k].to(loaded_v.device) + self.assertTrue(torch.allclose(loaded_v, retrieved_v)) + + self.assertTrue(check_if_lora_correctly_set(model), "LoRA layers not set correctly") + + torch.manual_seed(0) + outputs_with_lora_2 = model(**inputs_dict, return_dict=False)[0] + + self.assertFalse(torch.allclose(output_no_lora, outputs_with_lora_2, atol=1e-4, rtol=1e-4)) + self.assertTrue(torch.allclose(outputs_with_lora, outputs_with_lora_2, atol=1e-4, rtol=1e-4)) + + @unittest.skipIf(not is_peft_available(), "Only with PEFT") + def test_wrong_adapter_name_raises_error(self): + from peft import LoraConfig + + from diffusers.loaders.peft import PeftAdapterMixin + + init_dict, _ = self.prepare_init_args_and_inputs_for_common() + model = self.model_class(**init_dict).to(torch_device) + + if not issubclass(model.__class__, PeftAdapterMixin): + return + + denoiser_lora_config = LoraConfig( + r=4, + lora_alpha=4, + target_modules=["to_q", "to_k", "to_v", "to_out.0"], + init_lora_weights=False, + use_dora=False, + ) + model.add_adapter(denoiser_lora_config) + self.assertTrue(check_if_lora_correctly_set(model), "LoRA layers not set correctly") + + with tempfile.TemporaryDirectory() as tmpdir: + wrong_name = "foo" + with self.assertRaises(ValueError) as err_context: + model.save_lora_adapter(tmpdir, adapter_name=wrong_name) + + self.assertTrue(f"Adapter name {wrong_name} not found in the model." in str(err_context.exception)) + @require_torch_gpu def test_cpu_offload(self): config, inputs_dict = self.prepare_init_args_and_inputs_for_common() diff --git a/tests/models/unets/test_models_unet_2d_condition.py b/tests/models/unets/test_models_unet_2d_condition.py index fec34822904c..84bc9695fc59 100644 --- a/tests/models/unets/test_models_unet_2d_condition.py +++ b/tests/models/unets/test_models_unet_2d_condition.py @@ -1078,30 +1078,7 @@ def test_load_sharded_checkpoint_device_map_from_hub_local_subfolder(self): assert new_output.sample.shape == (4, 4, 16, 16) @require_peft_backend - def test_lora(self): - init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() - model = self.model_class(**init_dict) - model.to(torch_device) - - # forward pass without LoRA - with torch.no_grad(): - non_lora_sample = model(**inputs_dict).sample - - unet_lora_config = get_unet_lora_config() - model.add_adapter(unet_lora_config) - - assert check_if_lora_correctly_set(model), "Lora not correctly set in UNet." - - # forward pass with LoRA - with torch.no_grad(): - lora_sample = model(**inputs_dict).sample - - assert not torch.allclose( - non_lora_sample, lora_sample, atol=1e-4, rtol=1e-4 - ), "LoRA injected UNet should produce different results." - - @require_peft_backend - def test_lora_serialization(self): + def test_load_attn_procs_raise_warning(self): init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() model = self.model_class(**init_dict) model.to(torch_device) @@ -1122,8 +1099,14 @@ def test_lora_serialization(self): with tempfile.TemporaryDirectory() as tmpdirname: model.save_attn_procs(tmpdirname) model.unload_lora() - model.load_attn_procs(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors")) + with self.assertWarns(FutureWarning) as warning: + model.load_attn_procs(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors")) + + warning_message = str(warning.warnings[0].message) + assert "Using the `load_attn_procs()` method has been deprecated" in warning_message + + # import to still check for the rest of the stuff. assert check_if_lora_correctly_set(model), "Lora not correctly set in UNet." with torch.no_grad():