diff --git a/tests/lora/test_lora_layers_lumina2.py b/tests/lora/test_lora_layers_lumina2.py index 1d253f9afad9..07b1cda2f79f 100644 --- a/tests/lora/test_lora_layers_lumina2.py +++ b/tests/lora/test_lora_layers_lumina2.py @@ -15,6 +15,8 @@ import sys import unittest +import numpy as np +import pytest import torch from transformers import AutoTokenizer, GemmaForCausalLM @@ -24,12 +26,12 @@ Lumina2Text2ImgPipeline, Lumina2Transformer2DModel, ) -from diffusers.utils.testing_utils import floats_tensor, require_peft_backend +from diffusers.utils.testing_utils import floats_tensor, is_torch_version, require_peft_backend, skip_mps, torch_device sys.path.append(".") -from utils import PeftLoraLoaderMixinTests # noqa: E402 +from utils import PeftLoraLoaderMixinTests, check_if_lora_correctly_set # noqa: E402 @require_peft_backend @@ -130,3 +132,41 @@ def test_simple_inference_with_text_lora_fused(self): @unittest.skip("Text encoder LoRA is not supported in Lumina2.") def test_simple_inference_with_text_lora_save_load(self): pass + + @skip_mps + @pytest.mark.xfail( + condition=torch.device(torch_device).type == "cpu" and is_torch_version(">=", "2.5"), + reason="Test currently fails on CPU and PyTorch 2.5.1 but not on PyTorch 2.4.1.", + strict=False, + ) + def test_lora_fuse_nan(self): + for scheduler_cls in self.scheduler_classes: + components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls) + pipe = self.pipeline_class(**components) + pipe = pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + _, _, inputs = self.get_dummy_inputs(with_generator=False) + + if "text_encoder" in self.pipeline_class._lora_loadable_modules: + pipe.text_encoder.add_adapter(text_lora_config, "adapter-1") + self.assertTrue( + check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder" + ) + + denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet + denoiser.add_adapter(denoiser_lora_config, "adapter-1") + self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.") + + # corrupt one LoRA weight with `inf` values + with torch.no_grad(): + pipe.transformer.layers[0].attn.to_q.lora_A["adapter-1"].weight += float("inf") + + # with `safe_fusing=True` we should see an Error + with self.assertRaises(ValueError): + pipe.fuse_lora(components=self.pipeline_class._lora_loadable_modules, safe_fusing=True) + + # without we should not see an error, but every image will be black + pipe.fuse_lora(components=self.pipeline_class._lora_loadable_modules, safe_fusing=False) + out = pipe(**inputs)[0] + + self.assertTrue(np.isnan(out).all())