|
15 | 15 | import sys
|
16 | 16 | import unittest
|
17 | 17 |
|
| 18 | +import numpy as np |
| 19 | +import pytest |
18 | 20 | import torch
|
19 | 21 | from transformers import AutoTokenizer, GemmaForCausalLM
|
20 | 22 |
|
|
24 | 26 | Lumina2Text2ImgPipeline,
|
25 | 27 | Lumina2Transformer2DModel,
|
26 | 28 | )
|
27 |
| -from diffusers.utils.testing_utils import floats_tensor, require_peft_backend |
| 29 | +from diffusers.utils.testing_utils import floats_tensor, is_torch_version, require_peft_backend, skip_mps, torch_device |
28 | 30 |
|
29 | 31 |
|
30 | 32 | sys.path.append(".")
|
31 | 33 |
|
32 |
| -from utils import PeftLoraLoaderMixinTests # noqa: E402 |
| 34 | +from utils import PeftLoraLoaderMixinTests, check_if_lora_correctly_set # noqa: E402 |
33 | 35 |
|
34 | 36 |
|
35 | 37 | @require_peft_backend
|
@@ -130,3 +132,41 @@ def test_simple_inference_with_text_lora_fused(self):
|
130 | 132 | @unittest.skip("Text encoder LoRA is not supported in Lumina2.")
|
131 | 133 | def test_simple_inference_with_text_lora_save_load(self):
|
132 | 134 | pass
|
| 135 | + |
| 136 | + @skip_mps |
| 137 | + @pytest.mark.xfail( |
| 138 | + condition=torch.device(torch_device).type == "cpu" and is_torch_version(">=", "2.5"), |
| 139 | + reason="Test currently fails on CPU and PyTorch 2.5.1 but not on PyTorch 2.4.1.", |
| 140 | + strict=False, |
| 141 | + ) |
| 142 | + def test_lora_fuse_nan(self): |
| 143 | + for scheduler_cls in self.scheduler_classes: |
| 144 | + components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls) |
| 145 | + pipe = self.pipeline_class(**components) |
| 146 | + pipe = pipe.to(torch_device) |
| 147 | + pipe.set_progress_bar_config(disable=None) |
| 148 | + _, _, inputs = self.get_dummy_inputs(with_generator=False) |
| 149 | + |
| 150 | + if "text_encoder" in self.pipeline_class._lora_loadable_modules: |
| 151 | + pipe.text_encoder.add_adapter(text_lora_config, "adapter-1") |
| 152 | + self.assertTrue( |
| 153 | + check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder" |
| 154 | + ) |
| 155 | + |
| 156 | + denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet |
| 157 | + denoiser.add_adapter(denoiser_lora_config, "adapter-1") |
| 158 | + self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.") |
| 159 | + |
| 160 | + # corrupt one LoRA weight with `inf` values |
| 161 | + with torch.no_grad(): |
| 162 | + pipe.transformer.layers[0].attn.to_q.lora_A["adapter-1"].weight += float("inf") |
| 163 | + |
| 164 | + # with `safe_fusing=True` we should see an Error |
| 165 | + with self.assertRaises(ValueError): |
| 166 | + pipe.fuse_lora(components=self.pipeline_class._lora_loadable_modules, safe_fusing=True) |
| 167 | + |
| 168 | + # without we should not see an error, but every image will be black |
| 169 | + pipe.fuse_lora(components=self.pipeline_class._lora_loadable_modules, safe_fusing=False) |
| 170 | + out = pipe(**inputs)[0] |
| 171 | + |
| 172 | + self.assertTrue(np.isnan(out).all()) |
0 commit comments