|
49 | 49 | from diffusers.utils.torch_utils import randn_tensor
|
50 | 50 |
|
51 | 51 | from ..test_modeling_common import ModelTesterMixin, UNetTesterMixin
|
52 |
| - |
| 52 | +from peft import LoraConfig |
53 | 53 |
|
54 | 54 | enable_full_determinism()
|
55 | 55 |
|
@@ -299,7 +299,38 @@ def test_output_pretrained(self):
|
299 | 299 |
|
300 | 300 | self.assertTrue(torch_all_close(output_slice, expected_output_slice, rtol=1e-2))
|
301 | 301 |
|
| 302 | + def test_lora_adapter(self): |
| 303 | + init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() |
| 304 | + vae = self.model_class(**init_dict) |
| 305 | + |
| 306 | + target_modules_vae = [ |
| 307 | + "conv1", |
| 308 | + "conv2", |
| 309 | + "conv_in", |
| 310 | + "conv_shortcut", |
| 311 | + "conv", |
| 312 | + "conv_out", |
| 313 | + "skip_conv_1", |
| 314 | + "skip_conv_2", |
| 315 | + "skip_conv_3", |
| 316 | + "skip_conv_4", |
| 317 | + "to_k", |
| 318 | + "to_q", |
| 319 | + "to_v", |
| 320 | + "to_out.0", |
| 321 | + ] |
| 322 | + vae_lora_config = LoraConfig( |
| 323 | + r=16, |
| 324 | + init_lora_weights="gaussian", |
| 325 | + target_modules=target_modules_vae, |
| 326 | + ) |
| 327 | + |
| 328 | + vae.add_adapter(vae_lora_config, adapter_name="vae_lora") |
| 329 | + active_lora = vae.active_adapters() |
| 330 | + self.assertTrue(len(active_lora) == 1) |
| 331 | + self.assertTrue(active_lora[0] == "vae_lora") |
302 | 332 |
|
| 333 | + |
303 | 334 | class AsymmetricAutoencoderKLTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
|
304 | 335 | model_class = AsymmetricAutoencoderKL
|
305 | 336 | main_input_name = "sample"
|
@@ -845,7 +876,7 @@ def test_stable_diffusion_encode_sample(self, seed, expected_slice):
|
845 | 876 | tolerance = 3e-3 if torch_device != "mps" else 1e-2
|
846 | 877 | assert torch_all_close(output_slice, expected_output_slice, atol=tolerance)
|
847 | 878 |
|
848 |
| - |
| 879 | + |
849 | 880 | @slow
|
850 | 881 | class AsymmetricAutoencoderKLIntegrationTests(unittest.TestCase):
|
851 | 882 | def get_file_format(self, seed, shape):
|
|
0 commit comments