From bc26350e17b0ab0bfdf3bb740920aa0ed81cb6ae Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Thu, 5 Dec 2024 09:24:26 +0530 Subject: [PATCH] depcrecate save_attn_procs(). --- src/diffusers/loaders/unet.py | 3 +++ .../unets/test_models_unet_2d_condition.py | 18 ++++++++++++++++++ 2 files changed, 21 insertions(+) diff --git a/src/diffusers/loaders/unet.py b/src/diffusers/loaders/unet.py index 201526937b4e..7050968b6de5 100644 --- a/src/diffusers/loaders/unet.py +++ b/src/diffusers/loaders/unet.py @@ -492,6 +492,9 @@ def save_attn_procs( ) state_dict = {k: v for k, v in state_dict.items() if isinstance(v, torch.Tensor)} else: + deprecation_message = "Using the `save_attn_procs()` method has been deprecated and will be removed in a future version. Please use `save_lora_adapter()`." + deprecate("save_attn_procs", "0.40.0", deprecation_message) + if not USE_PEFT_BACKEND: raise ValueError("PEFT backend is required for saving LoRAs using the `save_attn_procs()` method.") diff --git a/tests/models/unets/test_models_unet_2d_condition.py b/tests/models/unets/test_models_unet_2d_condition.py index 84bc9695fc59..8ec5b6e9a5e4 100644 --- a/tests/models/unets/test_models_unet_2d_condition.py +++ b/tests/models/unets/test_models_unet_2d_condition.py @@ -1119,6 +1119,24 @@ def test_load_attn_procs_raise_warning(self): lora_sample_1, lora_sample_2, atol=1e-4, rtol=1e-4 ), "Loading from a saved checkpoint should produce identical results." + @require_peft_backend + def test_save_attn_procs_raise_warning(self): + init_dict, _ = self.prepare_init_args_and_inputs_for_common() + model = self.model_class(**init_dict) + model.to(torch_device) + + unet_lora_config = get_unet_lora_config() + model.add_adapter(unet_lora_config) + + assert check_if_lora_correctly_set(model), "Lora not correctly set in UNet." + + with tempfile.TemporaryDirectory() as tmpdirname: + with self.assertWarns(FutureWarning) as warning: + model.save_attn_procs(tmpdirname) + + warning_message = str(warning.warnings[0].message) + assert "Using the `save_attn_procs()` method has been deprecated" in warning_message + @slow class UNet2DConditionModelIntegrationTests(unittest.TestCase):