From a7b915714f9a29db8356ede63bfb31c31e575ce4 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Wed, 21 May 2025 10:00:22 +0530 Subject: [PATCH 1/3] fix peft delete adapters for flux. --- src/diffusers/loaders/lora_pipeline.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/diffusers/loaders/lora_pipeline.py b/src/diffusers/loaders/lora_pipeline.py index 10b6a8f02710..6092eeff0a80 100644 --- a/src/diffusers/loaders/lora_pipeline.py +++ b/src/diffusers/loaders/lora_pipeline.py @@ -2545,14 +2545,13 @@ def _maybe_expand_lora_state_dict(cls, transformer, lora_state_dict): if unexpected_modules: logger.debug(f"Found unexpected modules: {unexpected_modules}. These will be ignored.") - is_peft_loaded = getattr(transformer, "peft_config", None) is not None for k in lora_module_names: if k in unexpected_modules: continue base_param_name = ( f"{k.replace(prefix, '')}.base_layer.weight" - if is_peft_loaded and f"{k.replace(prefix, '')}.base_layer.weight" in transformer_state_dict + if f"{k.replace(prefix, '')}.base_layer.weight" in transformer_state_dict else f"{k.replace(prefix, '')}.weight" ) base_weight_param = transformer_state_dict[base_param_name] From 1946baa70c42b51f9a069062ce78b5453074e8e1 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Wed, 21 May 2025 10:36:32 +0530 Subject: [PATCH 2/3] add test --- tests/lora/utils.py | 48 +++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 48 insertions(+) diff --git a/tests/lora/utils.py b/tests/lora/utils.py index 87a8fddfa583..cc760ea84cd0 100644 --- a/tests/lora/utils.py +++ b/tests/lora/utils.py @@ -2149,3 +2149,51 @@ def check_module(denoiser): _, _, inputs = self.get_dummy_inputs(with_generator=False) pipe(**inputs, generator=torch.manual_seed(0))[0] + + def test_inference_load_delete_load_adapters(self): + "Tests if `load_lora_weights()` -> `delete_adapters()` -> `load_lora_weights()` works." + for scheduler_cls in self.scheduler_classes: + components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls) + pipe = self.pipeline_class(**components) + pipe = pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + _, _, inputs = self.get_dummy_inputs(with_generator=False) + + output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0] + + if "text_encoder" in self.pipeline_class._lora_loadable_modules: + pipe.text_encoder.add_adapter(text_lora_config) + self.assertTrue( + check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder" + ) + + denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet + denoiser.add_adapter(denoiser_lora_config) + self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.") + + if self.has_two_text_encoders or self.has_three_text_encoders: + lora_loadable_components = self.pipeline_class._lora_loadable_modules + if "text_encoder_2" in lora_loadable_components: + pipe.text_encoder_2.add_adapter(text_lora_config) + self.assertTrue( + check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2" + ) + + output_adapter_1 = pipe(**inputs, generator=torch.manual_seed(0))[0] + + with tempfile.TemporaryDirectory() as tmpdirname: + modules_to_save = self._get_modules_to_save(pipe, has_denoiser=True) + lora_state_dicts = self._get_lora_state_dicts(modules_to_save) + self.pipeline_class.save_lora_weights(save_directory=tmpdirname, **lora_state_dicts) + self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors"))) + + # First, delete adapter and compare. + pipe.delete_adapters(pipe.get_active_adapters()[0]) + output_no_adapter = pipe(**inputs, generator=torch.manual_seed(0))[0] + self.assertFalse(np.allclose(output_adapter_1, output_no_adapter, atol=1e-3, rtol=1e-3)) + self.assertTrue(np.allclose(output_no_lora, output_no_adapter, atol=1e-3, rtol=1e-3)) + + # Then load adapter and compare. + pipe.load_lora_weights(tmpdirname) + output_lora_loaded = pipe(**inputs, generator=torch.manual_seed(0))[0] + self.assertTrue(np.allclose(output_adapter_1, output_lora_loaded, atol=1e-3, rtol=1e-3)) From e7ad7b0274f54c072aa73089999b410aeb3c0036 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Thu, 22 May 2025 14:56:32 +0530 Subject: [PATCH 3/3] empty commit