From 5b4548ae256624e4350335b4f0e25ff0d25d5ddb Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Mon, 4 Nov 2024 17:40:33 +0100 Subject: [PATCH 01/14] feat: save_lora_adapter. --- src/diffusers/loaders/peft.py | 52 +++++++++++++++++- tests/models/test_modeling_common.py | 82 ++++++++++++++++++++++++++++ 2 files changed, 133 insertions(+), 1 deletion(-) diff --git a/src/diffusers/loaders/peft.py b/src/diffusers/loaders/peft.py index cf361e88a670..8f61f91f19f2 100644 --- a/src/diffusers/loaders/peft.py +++ b/src/diffusers/loaders/peft.py @@ -13,9 +13,13 @@ # See the License for the specific language governing permissions and # limitations under the License. import inspect +import os from functools import partial +from pathlib import Path from typing import Dict, List, Optional, Union +import safetensors +import torch import torch.nn as nn from ..utils import ( @@ -203,7 +207,7 @@ def load_lora_adapter(self, pretrained_model_name_or_path_or_dict, prefix="trans if adapter_name in getattr(self, "peft_config", {}): raise ValueError( - f"Adapter name {adapter_name} already in use in the transformer - please select a new adapter name." + f"Adapter name {adapter_name} already in use in the model - please select a new adapter name." ) rank = {} @@ -276,6 +280,52 @@ def load_lora_adapter(self, pretrained_model_name_or_path_or_dict, prefix="trans _pipeline.enable_sequential_cpu_offload() # Unsafe code /> + def save_lora_adapter( + self, + save_directory, + adapter_name: str = "default", + upcast_before_saving: bool = False, + safe_serialization: bool = True, + weight_name: str = None, + ): + """TODO""" + from peft.utils import get_peft_model_state_dict + + from .lora_base import LORA_WEIGHT_NAME, LORA_WEIGHT_NAME_SAFE + + if adapter_name is None: + adapter_name = get_adapter_name(self) + + if adapter_name not in getattr(self, "peft_config", {}): + raise ValueError(f"Adapter name {adapter_name} not found in the model.") + + lora_layers_to_save = get_peft_model_state_dict( + self.to(dtype=torch.float32 if upcast_before_saving else None), adapter_name=adapter_name + ) + if os.path.isfile(save_directory): + logger.error(f"Provided path ({save_directory}) should be a directory, not a file") + return + + if safe_serialization: + + def save_function(weights, filename): + return safetensors.torch.save_file(weights, filename, metadata={"format": "pt"}) + + else: + save_function = torch.save + + os.makedirs(save_directory, exist_ok=True) + + if weight_name is None: + if safe_serialization: + weight_name = LORA_WEIGHT_NAME_SAFE + else: + weight_name = LORA_WEIGHT_NAME + + save_path = Path(save_directory, weight_name).as_posix() + save_function(lora_layers_to_save, save_path) + logger.info(f"Model weights saved in {save_path}") + def set_adapters( self, adapter_names: Union[List[str], str], diff --git a/tests/models/test_modeling_common.py b/tests/models/test_modeling_common.py index 7f8dc63e00ac..c4ae546ac1ae 100644 --- a/tests/models/test_modeling_common.py +++ b/tests/models/test_modeling_common.py @@ -44,6 +44,7 @@ from diffusers.utils import ( SAFE_WEIGHTS_INDEX_NAME, WEIGHTS_INDEX_NAME, + is_peft_available, is_torch_npu_available, is_xformers_available, logging, @@ -53,6 +54,7 @@ CaptureLogger, get_python_version, is_torch_compile, + require_peft_backend, require_torch_2, require_torch_accelerator_with_training, require_torch_gpu, @@ -65,6 +67,13 @@ from ..others.test_utils import TOKEN, USER, is_staging_test +if is_peft_available(): + from peft import LoraConfig + from peft.tuners.tuners_utils import BaseTunerLayer + + from diffusers.loaders import PeftAdapterMixin + + def caculate_expected_num_shards(index_map_path): with open(index_map_path) as f: weight_map_dict = json.load(f)["weight_map"] @@ -74,6 +83,16 @@ def caculate_expected_num_shards(index_map_path): return expected_num_shards +def check_if_lora_correctly_set(model) -> bool: + """ + Checks if the LoRA layers are correctly set with peft + """ + for module in model.modules(): + if isinstance(module, BaseTunerLayer): + return True + return False + + # Will be run via run_test_in_subprocess def _test_from_save_pretrained_dynamo(in_queue, out_queue, timeout): error = None @@ -902,6 +921,69 @@ def test_deprecated_kwargs(self): " from `_deprecated_kwargs = []`" ) + @require_peft_backend + @parameterized.expand([True, False]) + def test_load_save_lora_adapter(self, use_dora=False): + init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() + model = self.model_class(**init_dict).to(torch_device) + + if not issubclass(model.__class__, PeftAdapterMixin): + return + + torch.manual_seed(0) + output_no_lora = model(**inputs_dict).sample + + denoiser_lora_config = LoraConfig( + r=4, + lora_alpha=4, + target_modules=["to_q", "to_k", "to_v", "to_out.0"], + init_lora_weights=False, + use_dora=use_dora, + ) + model.add_adapter(denoiser_lora_config) + self.assertTrue(check_if_lora_correctly_set(model), "LoRA layers not set correctly") + + torch.manual_seed(0) + outputs_with_lora = model(**inputs_dict).sample + + self.assertFalse(torch.allclose(output_no_lora, outputs_with_lora, atol=1e-4, rtol=1e-4)) + + with tempfile.TemporaryDirectory() as tmpdir: + model.save_lora_adapter(tmpdir) + model.unload_lora() + model.load_lora_adapter(tmpdir, use_safetensors=True) + self.assertTrue(check_if_lora_correctly_set(model), "LoRA layers not set correctly") + + torch.manual_seed(0) + outputs_with_lora_2 = model(**inputs_dict).sample + + self.assertFalse(torch.allclose(output_no_lora, outputs_with_lora_2, atol=1e-4, rtol=1e-4)) + self.assertTrue(torch.allclose(outputs_with_lora, outputs_with_lora_2, atol=1e-4, rtol=1e-4)) + + def test_wrong_adapter_name_raises_error(self): + init_dict, _ = self.prepare_init_args_and_inputs_for_common() + model = self.model_class(**init_dict).to(torch_device) + + if not issubclass(model.__class__, PeftAdapterMixin): + return + + denoiser_lora_config = LoraConfig( + r=4, + lora_alpha=4, + target_modules=["to_q", "to_k", "to_v", "to_out.0"], + init_lora_weights=False, + use_dora=False, + ) + model.add_adapter(denoiser_lora_config) + self.assertTrue(check_if_lora_correctly_set(model), "LoRA layers not set correctly") + + with tempfile.TemporaryDirectory() as tmpdir: + wrong_name = "foo" + with self.assertRaises(ValueError) as err_context: + model.save_lora_adapter(tmpdir, adapter_name=wrong_name) + + self.assertTrue(f"Adapter name {wrong_name} not found in the model." in str(err_context.exception)) + @require_torch_gpu def test_cpu_offload(self): config, inputs_dict = self.prepare_init_args_and_inputs_for_common() From 9fd157e73e76c0f76152b914d528c96e3e61817b Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Mon, 4 Nov 2024 18:08:46 +0100 Subject: [PATCH 02/14] add tests --- src/diffusers/loaders/lora_pipeline.py | 6 ++-- src/diffusers/loaders/peft.py | 10 +++--- src/diffusers/loaders/unet.py | 5 +++ .../unets/test_models_unet_2d_condition.py | 33 +++++-------------- 4 files changed, 22 insertions(+), 32 deletions(-) diff --git a/src/diffusers/loaders/lora_pipeline.py b/src/diffusers/loaders/lora_pipeline.py index 154aa2d8f9bb..59cbd5a7a960 100644 --- a/src/diffusers/loaders/lora_pipeline.py +++ b/src/diffusers/loaders/lora_pipeline.py @@ -298,8 +298,9 @@ def load_lora_into_unet( if not only_text_encoder: # Load the layers corresponding to UNet. logger.info(f"Loading {cls.unet_name}.") - unet.load_attn_procs( + unet.load_lora_adapter( state_dict, + prefix=cls.unet_name, network_alphas=network_alphas, adapter_name=adapter_name, _pipeline=_pipeline, @@ -827,8 +828,9 @@ def load_lora_into_unet( if not only_text_encoder: # Load the layers corresponding to UNet. logger.info(f"Loading {cls.unet_name}.") - unet.load_attn_procs( + unet.load_lora_adapter( state_dict, + prefix=cls.unet_name, network_alphas=network_alphas, adapter_name=adapter_name, _pipeline=_pipeline, diff --git a/src/diffusers/loaders/peft.py b/src/diffusers/loaders/peft.py index 8f61f91f19f2..31e1fe47a5fa 100644 --- a/src/diffusers/loaders/peft.py +++ b/src/diffusers/loaders/peft.py @@ -200,16 +200,16 @@ def load_lora_adapter(self, pretrained_model_name_or_path_or_dict, prefix="trans state_dict = {k.replace(f"{prefix}.", ""): v for k, v in state_dict.items() if k in transformer_keys} if len(state_dict.keys()) > 0: - # check with first key if is not in peft format - first_key = next(iter(state_dict.keys())) - if "lora_A" not in first_key: - state_dict = convert_unet_state_dict_to_peft(state_dict) - if adapter_name in getattr(self, "peft_config", {}): raise ValueError( f"Adapter name {adapter_name} already in use in the model - please select a new adapter name." ) + # check with first key if is not in peft format + first_key = next(iter(state_dict.keys())) + if "lora_A" not in first_key: + state_dict = convert_unet_state_dict_to_peft(state_dict) + rank = {} for key, val in state_dict.items(): if "lora_B" in key: diff --git a/src/diffusers/loaders/unet.py b/src/diffusers/loaders/unet.py index 2fa7732a6a3b..9a692d324be9 100644 --- a/src/diffusers/loaders/unet.py +++ b/src/diffusers/loaders/unet.py @@ -36,6 +36,7 @@ USE_PEFT_BACKEND, _get_model_file, convert_unet_state_dict_to_peft, + deprecate, get_adapter_name, get_peft_kwargs, is_accelerate_available, @@ -209,6 +210,10 @@ def load_attn_procs(self, pretrained_model_name_or_path_or_dict: Union[str, Dict is_model_cpu_offload = False is_sequential_cpu_offload = False + if is_lora: + deprecation_message = "Using the `load_attn_procs()` method has been deprecated and will be removed in a future version. Please use `load_lora_adapter()`." + deprecate("load_attn_procs", "0.40.0", deprecation_message) + if is_custom_diffusion: attn_processors = self._process_custom_diffusion(state_dict=state_dict) elif is_lora: diff --git a/tests/models/unets/test_models_unet_2d_condition.py b/tests/models/unets/test_models_unet_2d_condition.py index fec34822904c..84bc9695fc59 100644 --- a/tests/models/unets/test_models_unet_2d_condition.py +++ b/tests/models/unets/test_models_unet_2d_condition.py @@ -1078,30 +1078,7 @@ def test_load_sharded_checkpoint_device_map_from_hub_local_subfolder(self): assert new_output.sample.shape == (4, 4, 16, 16) @require_peft_backend - def test_lora(self): - init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() - model = self.model_class(**init_dict) - model.to(torch_device) - - # forward pass without LoRA - with torch.no_grad(): - non_lora_sample = model(**inputs_dict).sample - - unet_lora_config = get_unet_lora_config() - model.add_adapter(unet_lora_config) - - assert check_if_lora_correctly_set(model), "Lora not correctly set in UNet." - - # forward pass with LoRA - with torch.no_grad(): - lora_sample = model(**inputs_dict).sample - - assert not torch.allclose( - non_lora_sample, lora_sample, atol=1e-4, rtol=1e-4 - ), "LoRA injected UNet should produce different results." - - @require_peft_backend - def test_lora_serialization(self): + def test_load_attn_procs_raise_warning(self): init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() model = self.model_class(**init_dict) model.to(torch_device) @@ -1122,8 +1099,14 @@ def test_lora_serialization(self): with tempfile.TemporaryDirectory() as tmpdirname: model.save_attn_procs(tmpdirname) model.unload_lora() - model.load_attn_procs(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors")) + with self.assertWarns(FutureWarning) as warning: + model.load_attn_procs(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors")) + + warning_message = str(warning.warnings[0].message) + assert "Using the `load_attn_procs()` method has been deprecated" in warning_message + + # import to still check for the rest of the stuff. assert check_if_lora_correctly_set(model), "Lora not correctly set in UNet." with torch.no_grad(): From 5694f112e78e125ea60b226b505bfcd8596b59a9 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Tue, 5 Nov 2024 09:54:50 +0100 Subject: [PATCH 03/14] decorate --- tests/models/test_modeling_common.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/models/test_modeling_common.py b/tests/models/test_modeling_common.py index c4ae546ac1ae..a058c3c60e78 100644 --- a/tests/models/test_modeling_common.py +++ b/tests/models/test_modeling_common.py @@ -960,6 +960,7 @@ def test_load_save_lora_adapter(self, use_dora=False): self.assertFalse(torch.allclose(output_no_lora, outputs_with_lora_2, atol=1e-4, rtol=1e-4)) self.assertTrue(torch.allclose(outputs_with_lora, outputs_with_lora_2, atol=1e-4, rtol=1e-4)) + @require_peft_backend def test_wrong_adapter_name_raises_error(self): init_dict, _ = self.prepare_init_args_and_inputs_for_common() model = self.model_class(**init_dict).to(torch_device) From e70f26511538022713f7eff458399f55cc73ca48 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Tue, 5 Nov 2024 09:59:50 +0100 Subject: [PATCH 04/14] fixes --- tests/lora/utils.py | 13 +++---------- 1 file changed, 3 insertions(+), 10 deletions(-) diff --git a/tests/lora/utils.py b/tests/lora/utils.py index b711c8c9791e..288cdc48c901 100644 --- a/tests/lora/utils.py +++ b/tests/lora/utils.py @@ -1784,11 +1784,7 @@ def test_missing_keys_warning(self): missing_key = [k for k in state_dict if "lora_A" in k][0] del state_dict[missing_key] - logger = ( - logging.get_logger("diffusers.loaders.unet") - if self.unet_kwargs is not None - else logging.get_logger("diffusers.loaders.peft") - ) + logger = logging.get_logger("diffusers.loaders.peft") logger.setLevel(30) with CaptureLogger(logger) as cap_logger: pipe.load_lora_weights(state_dict) @@ -1796,6 +1792,7 @@ def test_missing_keys_warning(self): # Since the missing key won't contain the adapter name ("default_0"). # Also strip out the component prefix (such as "unet." from `missing_key`). component = list({k.split(".")[0] for k in state_dict})[0] + print(f"{cap_logger.out=}") self.assertTrue(missing_key.replace(f"{component}.", "") in cap_logger.out.replace("default_0.", "")) def test_unexpected_keys_warning(self): @@ -1823,11 +1820,7 @@ def test_unexpected_keys_warning(self): unexpected_key = [k for k in state_dict if "lora_A" in k][0] + ".diffusers_cat" state_dict[unexpected_key] = torch.tensor(1.0, device=torch_device) - logger = ( - logging.get_logger("diffusers.loaders.unet") - if self.unet_kwargs is not None - else logging.get_logger("diffusers.loaders.peft") - ) + logger = logging.get_logger("diffusers.loaders.peft") logger.setLevel(30) with CaptureLogger(logger) as cap_logger: pipe.load_lora_weights(state_dict) From 4cb583c4766b9fe368647620daf5a12f7e2c70fb Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Tue, 5 Nov 2024 11:19:32 +0100 Subject: [PATCH 05/14] fixes --- src/diffusers/loaders/peft.py | 19 +++++++++++-------- tests/models/test_modeling_common.py | 3 ++- 2 files changed, 13 insertions(+), 9 deletions(-) diff --git a/src/diffusers/loaders/peft.py b/src/diffusers/loaders/peft.py index 31e1fe47a5fa..c622c91f0875 100644 --- a/src/diffusers/loaders/peft.py +++ b/src/diffusers/loaders/peft.py @@ -195,11 +195,12 @@ def load_lora_adapter(self, pretrained_model_name_or_path_or_dict, prefix="trans ) keys = list(state_dict.keys()) - transformer_keys = [k for k in keys if k.startswith(prefix)] - if len(transformer_keys) > 0: - state_dict = {k.replace(f"{prefix}.", ""): v for k, v in state_dict.items() if k in transformer_keys} + model_keys = [k for k in keys if k.startswith(prefix)] + if len(model_keys) > 0: + state_dict = {k.replace(f"{prefix}.", ""): v for k, v in state_dict.items() if k in model_keys} if len(state_dict.keys()) > 0: + print("Within actual.") if adapter_name in getattr(self, "peft_config", {}): raise ValueError( f"Adapter name {adapter_name} already in use in the model - please select a new adapter name." @@ -221,12 +222,14 @@ def load_lora_adapter(self, pretrained_model_name_or_path_or_dict, prefix="trans lora_config_kwargs = get_peft_kwargs(rank, network_alpha_dict=network_alphas, peft_state_dict=state_dict) if "use_dora" in lora_config_kwargs: - if lora_config_kwargs["use_dora"] and is_peft_version("<", "0.9.0"): - raise ValueError( - "You need `peft` 0.9.0 at least to use DoRA-enabled LoRAs. Please upgrade your installation of `peft`." - ) + if lora_config_kwargs["use_dora"]: + if is_peft_version("<", "0.9.0"): + raise ValueError( + "You need `peft` 0.9.0 at least to use DoRA-enabled LoRAs. Please upgrade your installation of `peft`." + ) else: - lora_config_kwargs.pop("use_dora") + if is_peft_version("<", "0.9.0"): + lora_config_kwargs.pop("use_dora") lora_config = LoraConfig(**lora_config_kwargs) # adapter_name diff --git a/tests/models/test_modeling_common.py b/tests/models/test_modeling_common.py index a058c3c60e78..4a1faf940fe3 100644 --- a/tests/models/test_modeling_common.py +++ b/tests/models/test_modeling_common.py @@ -922,8 +922,9 @@ def test_deprecated_kwargs(self): ) @require_peft_backend + @torch.no_grad() @parameterized.expand([True, False]) - def test_load_save_lora_adapter(self, use_dora=False): + def test_save_load_lora_adapter(self, use_dora=False): init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() model = self.model_class(**init_dict).to(torch_device) From 1d27595790700d756212f1972bbab8f2b2c3a3e5 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Tue, 5 Nov 2024 11:20:28 +0100 Subject: [PATCH 06/14] fixes --- tests/models/test_modeling_common.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/tests/models/test_modeling_common.py b/tests/models/test_modeling_common.py index 4a1faf940fe3..c8fc2ca2964d 100644 --- a/tests/models/test_modeling_common.py +++ b/tests/models/test_modeling_common.py @@ -71,8 +71,6 @@ from peft import LoraConfig from peft.tuners.tuners_utils import BaseTunerLayer - from diffusers.loaders import PeftAdapterMixin - def caculate_expected_num_shards(index_map_path): with open(index_map_path) as f: @@ -925,6 +923,8 @@ def test_deprecated_kwargs(self): @torch.no_grad() @parameterized.expand([True, False]) def test_save_load_lora_adapter(self, use_dora=False): + from diffusers.loaders.peft import PeftAdapterMixin + init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() model = self.model_class(**init_dict).to(torch_device) @@ -963,6 +963,8 @@ def test_save_load_lora_adapter(self, use_dora=False): @require_peft_backend def test_wrong_adapter_name_raises_error(self): + from diffusers.loaders.peft import PeftAdapterMixin + init_dict, _ = self.prepare_init_args_and_inputs_for_common() model = self.model_class(**init_dict).to(torch_device) From 63916fad5920d07e3ce7d760707798e726dfe07f Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Tue, 5 Nov 2024 11:50:42 +0100 Subject: [PATCH 07/14] fixes --- tests/models/test_modeling_common.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/tests/models/test_modeling_common.py b/tests/models/test_modeling_common.py index c8fc2ca2964d..21df4b4e1712 100644 --- a/tests/models/test_modeling_common.py +++ b/tests/models/test_modeling_common.py @@ -68,7 +68,6 @@ if is_peft_available(): - from peft import LoraConfig from peft.tuners.tuners_utils import BaseTunerLayer @@ -923,6 +922,8 @@ def test_deprecated_kwargs(self): @torch.no_grad() @parameterized.expand([True, False]) def test_save_load_lora_adapter(self, use_dora=False): + from peft import LoraConfig + from diffusers.loaders.peft import PeftAdapterMixin init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() @@ -963,6 +964,8 @@ def test_save_load_lora_adapter(self, use_dora=False): @require_peft_backend def test_wrong_adapter_name_raises_error(self): + from peft import LoraConfig + from diffusers.loaders.peft import PeftAdapterMixin init_dict, _ = self.prepare_init_args_and_inputs_for_common() From 1ed05a4c3ef3a372cebd610baea72e3c37f28f80 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Tue, 5 Nov 2024 12:04:33 +0100 Subject: [PATCH 08/14] fixes --- check_flux.py | 0 tests/models/test_modeling_common.py | 5 ++--- 2 files changed, 2 insertions(+), 3 deletions(-) create mode 100644 check_flux.py diff --git a/check_flux.py b/check_flux.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/tests/models/test_modeling_common.py b/tests/models/test_modeling_common.py index 21df4b4e1712..929fbd38dddc 100644 --- a/tests/models/test_modeling_common.py +++ b/tests/models/test_modeling_common.py @@ -54,7 +54,6 @@ CaptureLogger, get_python_version, is_torch_compile, - require_peft_backend, require_torch_2, require_torch_accelerator_with_training, require_torch_gpu, @@ -918,9 +917,9 @@ def test_deprecated_kwargs(self): " from `_deprecated_kwargs = []`" ) - @require_peft_backend @torch.no_grad() @parameterized.expand([True, False]) + @unittest.skipIf(not is_peft_available(), "Only with PEFT") def test_save_load_lora_adapter(self, use_dora=False): from peft import LoraConfig @@ -962,7 +961,7 @@ def test_save_load_lora_adapter(self, use_dora=False): self.assertFalse(torch.allclose(output_no_lora, outputs_with_lora_2, atol=1e-4, rtol=1e-4)) self.assertTrue(torch.allclose(outputs_with_lora, outputs_with_lora_2, atol=1e-4, rtol=1e-4)) - @require_peft_backend + @unittest.skipIf(not is_peft_available(), "Only with PEFT") def test_wrong_adapter_name_raises_error(self): from peft import LoraConfig From 8e9f683e72aa2a605a784d9f3d3773cec065bc3b Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Tue, 5 Nov 2024 13:55:09 +0100 Subject: [PATCH 09/14] fixes --- check_flux.py | 98 ++++++++++++++++++++++++++++ tests/models/test_modeling_common.py | 8 +-- 2 files changed, 102 insertions(+), 4 deletions(-) diff --git a/check_flux.py b/check_flux.py index e69de29bb2d1..6cfb426ce436 100644 --- a/check_flux.py +++ b/check_flux.py @@ -0,0 +1,98 @@ +from diffusers import AutoencoderKL, FluxTransformer2DModel, FluxPipeline, FlowMatchEulerDiscreteScheduler +from transformers import CLIPTokenizer, AutoTokenizer, CLIPTextModel, T5EncoderModel +from peft import LoraConfig +from peft.utils import get_peft_model_state_dict +import numpy as np +import tempfile +import os +import torch + +def _get_lora_state_dicts(modules_to_save): + state_dicts = {} + for module_name, module in modules_to_save.items(): + if module is not None: + state_dicts[f"{module_name}_lora_layers"] = get_peft_model_state_dict(module) + return state_dicts + +transformer_kwargs = { + "patch_size": 1, + "in_channels": 4, + "num_layers": 1, + "num_single_layers": 1, + "attention_head_dim": 16, + "num_attention_heads": 2, + "joint_attention_dim": 32, + "pooled_projection_dim": 32, + "axes_dims_rope": [4, 4, 8], +} +transformer = FluxTransformer2DModel(**transformer_kwargs) + +vae_kwargs = { + "sample_size": 32, + "in_channels": 3, + "out_channels": 3, + "block_out_channels": (4,), + "layers_per_block": 1, + "latent_channels": 1, + "norm_num_groups": 1, + "use_quant_conv": False, + "use_post_quant_conv": False, + "shift_factor": 0.0609, + "scaling_factor": 1.5035, +} +vae = AutoencoderKL(**vae_kwargs) + +tokenizer = CLIPTokenizer.from_pretrained("peft-internal-testing/tiny-clip-text-2") +tokenizer_2 = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-t5") +text_encoder = CLIPTextModel.from_pretrained("peft-internal-testing/tiny-clip-text-2") +text_encoder_2 = T5EncoderModel.from_pretrained("hf-internal-testing/tiny-random-t5") + +pipeline = FluxPipeline( + transformer=transformer, + scheduler=FlowMatchEulerDiscreteScheduler(), + vae=vae, + tokenizer=tokenizer, + text_encoder=text_encoder, + tokenizer_2=tokenizer_2, + text_encoder_2=text_encoder_2 +) + +pipeline_inputs = { + "prompt": "A painting of a squirrel eating a burger", + "num_inference_steps": 4, + "guidance_scale": 0.0, + "height": 8, + "width": 8, + "output_type": "np", +} + +output_no_lora = pipeline(**pipeline_inputs, generator=torch.manual_seed(0))[0] + +denoiser_lora_config = LoraConfig( + r=4, + lora_alpha=4, + target_modules=["to_q", "to_k", "to_v", "to_out.0"], + init_lora_weights=False, + use_dora=False, +) +transformer.add_adapter(denoiser_lora_config) +print(pipeline.transformer.peft_config) + +output_lora = pipeline(**pipeline_inputs, generator=torch.manual_seed(0))[0] + +assert not np.allclose(output_lora, output_no_lora, atol=1e-4, rtol=1e-4) + +with tempfile.TemporaryDirectory() as tmpdir: + lora_state_dicts = _get_lora_state_dicts({"transformer": pipeline.transformer}) + FluxPipeline.save_lora_weights( + save_directory=tmpdir, safe_serialization=True, **lora_state_dicts + ) + + assert os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.safetensors")) + pipeline.unload_lora_weights() + pipeline.load_lora_weights(os.path.join(tmpdir, "pytorch_lora_weights.safetensors"), low_cpu_mem_usage=True) + + images_lora_from_pretrained = pipeline(**pipeline_inputs, generator=torch.manual_seed(0))[0] + + assert not np.allclose(images_lora_from_pretrained, output_no_lora, atol=1e-4, rtol=1e-4) + assert np.allclose(output_lora, output_no_lora, atol=1e-4, rtol=1e-4) \ No newline at end of file diff --git a/tests/models/test_modeling_common.py b/tests/models/test_modeling_common.py index 929fbd38dddc..160f529f20ff 100644 --- a/tests/models/test_modeling_common.py +++ b/tests/models/test_modeling_common.py @@ -917,8 +917,8 @@ def test_deprecated_kwargs(self): " from `_deprecated_kwargs = []`" ) - @torch.no_grad() @parameterized.expand([True, False]) + @torch.no_grad() @unittest.skipIf(not is_peft_available(), "Only with PEFT") def test_save_load_lora_adapter(self, use_dora=False): from peft import LoraConfig @@ -932,7 +932,7 @@ def test_save_load_lora_adapter(self, use_dora=False): return torch.manual_seed(0) - output_no_lora = model(**inputs_dict).sample + output_no_lora = model(**inputs_dict, return_dict=False)[0] denoiser_lora_config = LoraConfig( r=4, @@ -945,7 +945,7 @@ def test_save_load_lora_adapter(self, use_dora=False): self.assertTrue(check_if_lora_correctly_set(model), "LoRA layers not set correctly") torch.manual_seed(0) - outputs_with_lora = model(**inputs_dict).sample + outputs_with_lora = model(**inputs_dict, return_dict=False)[0] self.assertFalse(torch.allclose(output_no_lora, outputs_with_lora, atol=1e-4, rtol=1e-4)) @@ -956,7 +956,7 @@ def test_save_load_lora_adapter(self, use_dora=False): self.assertTrue(check_if_lora_correctly_set(model), "LoRA layers not set correctly") torch.manual_seed(0) - outputs_with_lora_2 = model(**inputs_dict).sample + outputs_with_lora_2 = model(**inputs_dict, return_dict=False)[0] self.assertFalse(torch.allclose(output_no_lora, outputs_with_lora_2, atol=1e-4, rtol=1e-4)) self.assertTrue(torch.allclose(outputs_with_lora, outputs_with_lora_2, atol=1e-4, rtol=1e-4)) From 45bf5b318b9c956059abcd2185bc9615c473278d Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Tue, 5 Nov 2024 15:09:22 +0100 Subject: [PATCH 10/14] fixes --- check_flux.py | 4 +++- src/diffusers/loaders/peft.py | 3 +-- tests/models/test_modeling_common.py | 20 ++++++++++++++++++-- 3 files changed, 22 insertions(+), 5 deletions(-) diff --git a/check_flux.py b/check_flux.py index 6cfb426ce436..659c6514868a 100644 --- a/check_flux.py +++ b/check_flux.py @@ -90,9 +90,11 @@ def _get_lora_state_dicts(modules_to_save): assert os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.safetensors")) pipeline.unload_lora_weights() + assert not hasattr(pipeline.transformer, "peft_config") + pipeline.load_lora_weights(os.path.join(tmpdir, "pytorch_lora_weights.safetensors"), low_cpu_mem_usage=True) images_lora_from_pretrained = pipeline(**pipeline_inputs, generator=torch.manual_seed(0))[0] assert not np.allclose(images_lora_from_pretrained, output_no_lora, atol=1e-4, rtol=1e-4) - assert np.allclose(output_lora, output_no_lora, atol=1e-4, rtol=1e-4) \ No newline at end of file + assert np.allclose(output_lora, images_lora_from_pretrained, atol=1e-4, rtol=1e-4) \ No newline at end of file diff --git a/src/diffusers/loaders/peft.py b/src/diffusers/loaders/peft.py index c622c91f0875..9cd6227d28cc 100644 --- a/src/diffusers/loaders/peft.py +++ b/src/diffusers/loaders/peft.py @@ -199,8 +199,7 @@ def load_lora_adapter(self, pretrained_model_name_or_path_or_dict, prefix="trans if len(model_keys) > 0: state_dict = {k.replace(f"{prefix}.", ""): v for k, v in state_dict.items() if k in model_keys} - if len(state_dict.keys()) > 0: - print("Within actual.") + if len(state_dict) > 0: if adapter_name in getattr(self, "peft_config", {}): raise ValueError( f"Adapter name {adapter_name} already in use in the model - please select a new adapter name." diff --git a/tests/models/test_modeling_common.py b/tests/models/test_modeling_common.py index 160f529f20ff..fedb400a2e72 100644 --- a/tests/models/test_modeling_common.py +++ b/tests/models/test_modeling_common.py @@ -892,8 +892,6 @@ def _set_gradient_checkpointing_new(self, module, value=False): model = model_class_copy(**init_dict) model.enable_gradient_checkpointing() - print(f"{set(modules_with_gc_enabled.keys())=}, {expected_set=}") - assert set(modules_with_gc_enabled.keys()) == expected_set assert all(modules_with_gc_enabled.values()), "All modules should be enabled" @@ -921,7 +919,9 @@ def test_deprecated_kwargs(self): @torch.no_grad() @unittest.skipIf(not is_peft_available(), "Only with PEFT") def test_save_load_lora_adapter(self, use_dora=False): + import safetensors from peft import LoraConfig + from peft.utils import get_peft_model_state_dict from diffusers.loaders.peft import PeftAdapterMixin @@ -951,8 +951,24 @@ def test_save_load_lora_adapter(self, use_dora=False): with tempfile.TemporaryDirectory() as tmpdir: model.save_lora_adapter(tmpdir) + self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.safetensors"))) + + state_dict_loaded = safetensors.torch.load_file(os.path.join(tmpdir, "pytorch_lora_weights.safetensors")) + model.unload_lora() + self.assertFalse(check_if_lora_correctly_set(model), "LoRA layers not set correctly") + model.load_lora_adapter(tmpdir, use_safetensors=True) + state_dict_retrieved = get_peft_model_state_dict(model, adapter_name="default_0") + + print(f"{state_dict_loaded.keys()=}, {len(state_dict_loaded)}") + print(f"{state_dict_retrieved.keys()=} {len(state_dict_retrieved)}") + + for k in state_dict_loaded: + loaded_v = state_dict_loaded[k] + retrieved_v = state_dict_retrieved[k] + self.assertTrue(torch.allclose(loaded_v, retrieved_v)) + self.assertTrue(check_if_lora_correctly_set(model), "LoRA layers not set correctly") torch.manual_seed(0) From 01e87ccfece6aad7c8153ae2d10d3d4780ffaab8 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Tue, 5 Nov 2024 16:47:12 +0100 Subject: [PATCH 11/14] fixes, thanks to @ BenjaminBossan. --- check_flux.py | 100 --------------------------- src/diffusers/loaders/peft.py | 15 ++-- tests/models/test_modeling_common.py | 7 +- 3 files changed, 12 insertions(+), 110 deletions(-) delete mode 100644 check_flux.py diff --git a/check_flux.py b/check_flux.py deleted file mode 100644 index 659c6514868a..000000000000 --- a/check_flux.py +++ /dev/null @@ -1,100 +0,0 @@ -from diffusers import AutoencoderKL, FluxTransformer2DModel, FluxPipeline, FlowMatchEulerDiscreteScheduler -from transformers import CLIPTokenizer, AutoTokenizer, CLIPTextModel, T5EncoderModel -from peft import LoraConfig -from peft.utils import get_peft_model_state_dict -import numpy as np -import tempfile -import os -import torch - -def _get_lora_state_dicts(modules_to_save): - state_dicts = {} - for module_name, module in modules_to_save.items(): - if module is not None: - state_dicts[f"{module_name}_lora_layers"] = get_peft_model_state_dict(module) - return state_dicts - -transformer_kwargs = { - "patch_size": 1, - "in_channels": 4, - "num_layers": 1, - "num_single_layers": 1, - "attention_head_dim": 16, - "num_attention_heads": 2, - "joint_attention_dim": 32, - "pooled_projection_dim": 32, - "axes_dims_rope": [4, 4, 8], -} -transformer = FluxTransformer2DModel(**transformer_kwargs) - -vae_kwargs = { - "sample_size": 32, - "in_channels": 3, - "out_channels": 3, - "block_out_channels": (4,), - "layers_per_block": 1, - "latent_channels": 1, - "norm_num_groups": 1, - "use_quant_conv": False, - "use_post_quant_conv": False, - "shift_factor": 0.0609, - "scaling_factor": 1.5035, -} -vae = AutoencoderKL(**vae_kwargs) - -tokenizer = CLIPTokenizer.from_pretrained("peft-internal-testing/tiny-clip-text-2") -tokenizer_2 = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-t5") -text_encoder = CLIPTextModel.from_pretrained("peft-internal-testing/tiny-clip-text-2") -text_encoder_2 = T5EncoderModel.from_pretrained("hf-internal-testing/tiny-random-t5") - -pipeline = FluxPipeline( - transformer=transformer, - scheduler=FlowMatchEulerDiscreteScheduler(), - vae=vae, - tokenizer=tokenizer, - text_encoder=text_encoder, - tokenizer_2=tokenizer_2, - text_encoder_2=text_encoder_2 -) - -pipeline_inputs = { - "prompt": "A painting of a squirrel eating a burger", - "num_inference_steps": 4, - "guidance_scale": 0.0, - "height": 8, - "width": 8, - "output_type": "np", -} - -output_no_lora = pipeline(**pipeline_inputs, generator=torch.manual_seed(0))[0] - -denoiser_lora_config = LoraConfig( - r=4, - lora_alpha=4, - target_modules=["to_q", "to_k", "to_v", "to_out.0"], - init_lora_weights=False, - use_dora=False, -) -transformer.add_adapter(denoiser_lora_config) -print(pipeline.transformer.peft_config) - -output_lora = pipeline(**pipeline_inputs, generator=torch.manual_seed(0))[0] - -assert not np.allclose(output_lora, output_no_lora, atol=1e-4, rtol=1e-4) - -with tempfile.TemporaryDirectory() as tmpdir: - lora_state_dicts = _get_lora_state_dicts({"transformer": pipeline.transformer}) - FluxPipeline.save_lora_weights( - save_directory=tmpdir, safe_serialization=True, **lora_state_dicts - ) - - assert os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.safetensors")) - pipeline.unload_lora_weights() - assert not hasattr(pipeline.transformer, "peft_config") - - pipeline.load_lora_weights(os.path.join(tmpdir, "pytorch_lora_weights.safetensors"), low_cpu_mem_usage=True) - - images_lora_from_pretrained = pipeline(**pipeline_inputs, generator=torch.manual_seed(0))[0] - - assert not np.allclose(images_lora_from_pretrained, output_no_lora, atol=1e-4, rtol=1e-4) - assert np.allclose(output_lora, images_lora_from_pretrained, atol=1e-4, rtol=1e-4) \ No newline at end of file diff --git a/src/diffusers/loaders/peft.py b/src/diffusers/loaders/peft.py index 9cd6227d28cc..2e24a5ab69ff 100644 --- a/src/diffusers/loaders/peft.py +++ b/src/diffusers/loaders/peft.py @@ -193,11 +193,14 @@ def load_lora_adapter(self, pretrained_model_name_or_path_or_dict, prefix="trans user_agent=user_agent, allow_pickle=allow_pickle, ) + if network_alphas is not None and prefix is None: + raise ValueError("`network_alphas` cannot be None when `prefix` is None.") - keys = list(state_dict.keys()) - model_keys = [k for k in keys if k.startswith(prefix)] - if len(model_keys) > 0: - state_dict = {k.replace(f"{prefix}.", ""): v for k, v in state_dict.items() if k in model_keys} + if prefix is not None: + keys = list(state_dict.keys()) + model_keys = [k for k in keys if k.startswith(f"{prefix}.")] + if len(model_keys) > 0: + state_dict = {k.replace(f"{prefix}.", ""): v for k, v in state_dict.items() if k in model_keys} if len(state_dict) > 0: if adapter_name in getattr(self, "peft_config", {}): @@ -216,7 +219,9 @@ def load_lora_adapter(self, pretrained_model_name_or_path_or_dict, prefix="trans rank[key] = val.shape[1] if network_alphas is not None and len(network_alphas) >= 1: - alpha_keys = [k for k in network_alphas.keys() if k.startswith(prefix) and k.split(".")[0] == prefix] + alpha_keys = [ + k for k in network_alphas.keys() if k.startswith(f"{prefix}.") and k.split(".")[0] == prefix + ] network_alphas = {k.replace(f"{prefix}.", ""): v for k, v in network_alphas.items() if k in alpha_keys} lora_config_kwargs = get_peft_kwargs(rank, network_alpha_dict=network_alphas, peft_state_dict=state_dict) diff --git a/tests/models/test_modeling_common.py b/tests/models/test_modeling_common.py index fedb400a2e72..f6ce6bda7381 100644 --- a/tests/models/test_modeling_common.py +++ b/tests/models/test_modeling_common.py @@ -958,15 +958,12 @@ def test_save_load_lora_adapter(self, use_dora=False): model.unload_lora() self.assertFalse(check_if_lora_correctly_set(model), "LoRA layers not set correctly") - model.load_lora_adapter(tmpdir, use_safetensors=True) + model.load_lora_adapter(tmpdir, prefix=None, use_safetensors=True) state_dict_retrieved = get_peft_model_state_dict(model, adapter_name="default_0") - print(f"{state_dict_loaded.keys()=}, {len(state_dict_loaded)}") - print(f"{state_dict_retrieved.keys()=} {len(state_dict_retrieved)}") - for k in state_dict_loaded: loaded_v = state_dict_loaded[k] - retrieved_v = state_dict_retrieved[k] + retrieved_v = state_dict_retrieved[k].to(loaded_v.device) self.assertTrue(torch.allclose(loaded_v, retrieved_v)) self.assertTrue(check_if_lora_correctly_set(model), "LoRA layers not set correctly") From 628ad098ebd982a5d7f0bb55a06b631337b76899 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Tue, 5 Nov 2024 17:45:54 +0100 Subject: [PATCH 12/14] fixes --- src/diffusers/loaders/peft.py | 6 +++--- tests/lora/utils.py | 1 - 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/src/diffusers/loaders/peft.py b/src/diffusers/loaders/peft.py index 2e24a5ab69ff..16bb711c54ed 100644 --- a/src/diffusers/loaders/peft.py +++ b/src/diffusers/loaders/peft.py @@ -293,7 +293,7 @@ def save_lora_adapter( adapter_name: str = "default", upcast_before_saving: bool = False, safe_serialization: bool = True, - weight_name: str = None, + weight_name: Optional[str] = None, ): """TODO""" from peft.utils import get_peft_model_state_dict @@ -310,8 +310,7 @@ def save_lora_adapter( self.to(dtype=torch.float32 if upcast_before_saving else None), adapter_name=adapter_name ) if os.path.isfile(save_directory): - logger.error(f"Provided path ({save_directory}) should be a directory, not a file") - return + raise ValueError(f"Provided path ({save_directory}) should be a directory, not a file") if safe_serialization: @@ -329,6 +328,7 @@ def save_function(weights, filename): else: weight_name = LORA_WEIGHT_NAME + # TODO: we could consider saving the `peft_config` as well. save_path = Path(save_directory, weight_name).as_posix() save_function(lora_layers_to_save, save_path) logger.info(f"Model weights saved in {save_path}") diff --git a/tests/lora/utils.py b/tests/lora/utils.py index 288cdc48c901..7cdb2d6f51d7 100644 --- a/tests/lora/utils.py +++ b/tests/lora/utils.py @@ -1792,7 +1792,6 @@ def test_missing_keys_warning(self): # Since the missing key won't contain the adapter name ("default_0"). # Also strip out the component prefix (such as "unet." from `missing_key`). component = list({k.split(".")[0] for k in state_dict})[0] - print(f"{cap_logger.out=}") self.assertTrue(missing_key.replace(f"{component}.", "") in cap_logger.out.replace("default_0.", "")) def test_unexpected_keys_warning(self): From 1f3c9ff7d7afe5ef5835f626ea3e1bfac1b4ba4a Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Tue, 5 Nov 2024 17:56:40 +0100 Subject: [PATCH 13/14] remove redundant conditions. --- src/diffusers/loaders/peft.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/src/diffusers/loaders/peft.py b/src/diffusers/loaders/peft.py index 16bb711c54ed..1b0a38029dff 100644 --- a/src/diffusers/loaders/peft.py +++ b/src/diffusers/loaders/peft.py @@ -219,9 +219,7 @@ def load_lora_adapter(self, pretrained_model_name_or_path_or_dict, prefix="trans rank[key] = val.shape[1] if network_alphas is not None and len(network_alphas) >= 1: - alpha_keys = [ - k for k in network_alphas.keys() if k.startswith(f"{prefix}.") and k.split(".")[0] == prefix - ] + alpha_keys = [k for k in network_alphas.keys() if k.startswith(f"{prefix}.")] network_alphas = {k.replace(f"{prefix}.", ""): v for k, v in network_alphas.items() if k in alpha_keys} lora_config_kwargs = get_peft_kwargs(rank, network_alpha_dict=network_alphas, peft_state_dict=state_dict) From 7f3d5e30b29b9b0d36889817eae5caead898d5c0 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Sat, 9 Nov 2024 11:35:11 -0400 Subject: [PATCH 14/14] add documentation for save_lora_adapter(). --- src/diffusers/loaders/peft.py | 19 ++++++++++++++++++- 1 file changed, 18 insertions(+), 1 deletion(-) diff --git a/src/diffusers/loaders/peft.py b/src/diffusers/loaders/peft.py index 1b0a38029dff..a1bce35813a5 100644 --- a/src/diffusers/loaders/peft.py +++ b/src/diffusers/loaders/peft.py @@ -293,7 +293,24 @@ def save_lora_adapter( safe_serialization: bool = True, weight_name: Optional[str] = None, ): - """TODO""" + """ + Save the LoRA parameters corresponding to the underlying model. + + Arguments: + save_directory (`str` or `os.PathLike`): + Directory to save LoRA parameters to. Will be created if it doesn't exist. + adapter_name: (`str`, defaults to "default"): The name of the adapter to serialize. Useful when the + underlying model has multiple adapters loaded. + upcast_before_saving (`bool`, defaults to `False`): + Whether to cast the underlying model to `torch.float32` before serialization. + save_function (`Callable`): + The function to use to save the state dictionary. Useful during distributed training when you need to + replace `torch.save` with another method. Can be configured with the environment variable + `DIFFUSERS_SAVE_MODE`. + safe_serialization (`bool`, *optional*, defaults to `True`): + Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`. + weight_name: (`str`, *optional*, defaults to `None`): Name of the file to serialize the state dict with. + """ from peft.utils import get_peft_model_state_dict from .lora_base import LORA_WEIGHT_NAME, LORA_WEIGHT_NAME_SAFE