22
22
import numpy as np
23
23
import torch
24
24
import torch .nn as nn
25
- import torch .nn .functional as F
26
25
from huggingface_hub import hf_hub_download
27
26
from huggingface_hub .repocard import RepoCard
28
27
from packaging import version
41
40
StableDiffusionXLPipeline ,
42
41
UNet2DConditionModel ,
43
42
)
44
- from diffusers .loaders import AttnProcsLayers
45
- from diffusers .models .attention_processor import LoRAAttnProcessor , LoRAAttnProcessor2_0
46
43
from diffusers .utils .import_utils import is_accelerate_available , is_peft_available
47
44
from diffusers .utils .testing_utils import (
48
45
floats_tensor ,
@@ -78,28 +75,6 @@ def state_dicts_almost_equal(sd1, sd2):
78
75
return models_are_equal
79
76
80
77
81
- def create_unet_lora_layers (unet : nn .Module ):
82
- lora_attn_procs = {}
83
- for name in unet .attn_processors .keys ():
84
- cross_attention_dim = None if name .endswith ("attn1.processor" ) else unet .config .cross_attention_dim
85
- if name .startswith ("mid_block" ):
86
- hidden_size = unet .config .block_out_channels [- 1 ]
87
- elif name .startswith ("up_blocks" ):
88
- block_id = int (name [len ("up_blocks." )])
89
- hidden_size = list (reversed (unet .config .block_out_channels ))[block_id ]
90
- elif name .startswith ("down_blocks" ):
91
- block_id = int (name [len ("down_blocks." )])
92
- hidden_size = unet .config .block_out_channels [block_id ]
93
- lora_attn_processor_class = (
94
- LoRAAttnProcessor2_0 if hasattr (F , "scaled_dot_product_attention" ) else LoRAAttnProcessor
95
- )
96
- lora_attn_procs [name ] = lora_attn_processor_class (
97
- hidden_size = hidden_size , cross_attention_dim = cross_attention_dim
98
- )
99
- unet_lora_layers = AttnProcsLayers (lora_attn_procs )
100
- return lora_attn_procs , unet_lora_layers
101
-
102
-
103
78
@require_peft_backend
104
79
class PeftLoraLoaderMixinTests :
105
80
torch_device = "cuda" if torch .cuda .is_available () else "cpu"
@@ -140,8 +115,6 @@ def get_dummy_components(self, scheduler_cls=None):
140
115
r = rank , lora_alpha = rank , target_modules = ["to_q" , "to_k" , "to_v" , "to_out.0" ], init_lora_weights = False
141
116
)
142
117
143
- unet_lora_attn_procs , unet_lora_layers = create_unet_lora_layers (unet )
144
-
145
118
if self .has_two_text_encoders :
146
119
pipeline_components = {
147
120
"unet" : unet ,
@@ -165,11 +138,8 @@ def get_dummy_components(self, scheduler_cls=None):
165
138
"feature_extractor" : None ,
166
139
"image_encoder" : None ,
167
140
}
168
- lora_components = {
169
- "unet_lora_layers" : unet_lora_layers ,
170
- "unet_lora_attn_procs" : unet_lora_attn_procs ,
171
- }
172
- return pipeline_components , lora_components , text_lora_config , unet_lora_config
141
+
142
+ return pipeline_components , text_lora_config , unet_lora_config
173
143
174
144
def get_dummy_inputs (self , with_generator = True ):
175
145
batch_size = 1
@@ -216,7 +186,7 @@ def test_simple_inference(self):
216
186
Tests a simple inference and makes sure it works as expected
217
187
"""
218
188
for scheduler_cls in [DDIMScheduler , LCMScheduler ]:
219
- components , _ , text_lora_config , _ = self .get_dummy_components (scheduler_cls )
189
+ components , text_lora_config , _ = self .get_dummy_components (scheduler_cls )
220
190
pipe = self .pipeline_class (** components )
221
191
pipe = pipe .to (self .torch_device )
222
192
pipe .set_progress_bar_config (disable = None )
@@ -231,7 +201,7 @@ def test_simple_inference_with_text_lora(self):
231
201
and makes sure it works as expected
232
202
"""
233
203
for scheduler_cls in [DDIMScheduler , LCMScheduler ]:
234
- components , _ , text_lora_config , _ = self .get_dummy_components (scheduler_cls )
204
+ components , text_lora_config , _ = self .get_dummy_components (scheduler_cls )
235
205
pipe = self .pipeline_class (** components )
236
206
pipe = pipe .to (self .torch_device )
237
207
pipe .set_progress_bar_config (disable = None )
@@ -262,7 +232,7 @@ def test_simple_inference_with_text_lora_and_scale(self):
262
232
and makes sure it works as expected
263
233
"""
264
234
for scheduler_cls in [DDIMScheduler , LCMScheduler ]:
265
- components , _ , text_lora_config , _ = self .get_dummy_components (scheduler_cls )
235
+ components , text_lora_config , _ = self .get_dummy_components (scheduler_cls )
266
236
pipe = self .pipeline_class (** components )
267
237
pipe = pipe .to (self .torch_device )
268
238
pipe .set_progress_bar_config (disable = None )
@@ -309,7 +279,7 @@ def test_simple_inference_with_text_lora_fused(self):
309
279
and makes sure it works as expected
310
280
"""
311
281
for scheduler_cls in [DDIMScheduler , LCMScheduler ]:
312
- components , _ , text_lora_config , _ = self .get_dummy_components (scheduler_cls )
282
+ components , text_lora_config , _ = self .get_dummy_components (scheduler_cls )
313
283
pipe = self .pipeline_class (** components )
314
284
pipe = pipe .to (self .torch_device )
315
285
pipe .set_progress_bar_config (disable = None )
@@ -351,7 +321,7 @@ def test_simple_inference_with_text_lora_unloaded(self):
351
321
and makes sure it works as expected
352
322
"""
353
323
for scheduler_cls in [DDIMScheduler , LCMScheduler ]:
354
- components , _ , text_lora_config , _ = self .get_dummy_components (scheduler_cls )
324
+ components , text_lora_config , _ = self .get_dummy_components (scheduler_cls )
355
325
pipe = self .pipeline_class (** components )
356
326
pipe = pipe .to (self .torch_device )
357
327
pipe .set_progress_bar_config (disable = None )
@@ -394,7 +364,7 @@ def test_simple_inference_with_text_lora_save_load(self):
394
364
Tests a simple usecase where users could use saving utilities for LoRA.
395
365
"""
396
366
for scheduler_cls in [DDIMScheduler , LCMScheduler ]:
397
- components , _ , text_lora_config , _ = self .get_dummy_components (scheduler_cls )
367
+ components , text_lora_config , _ = self .get_dummy_components (scheduler_cls )
398
368
pipe = self .pipeline_class (** components )
399
369
pipe = pipe .to (self .torch_device )
400
370
pipe .set_progress_bar_config (disable = None )
@@ -459,7 +429,7 @@ def test_simple_inference_save_pretrained(self):
459
429
Tests a simple usecase where users could use saving utilities for LoRA through save_pretrained
460
430
"""
461
431
for scheduler_cls in [DDIMScheduler , LCMScheduler ]:
462
- components , _ , text_lora_config , _ = self .get_dummy_components (scheduler_cls )
432
+ components , text_lora_config , _ = self .get_dummy_components (scheduler_cls )
463
433
pipe = self .pipeline_class (** components )
464
434
pipe = pipe .to (self .torch_device )
465
435
pipe .set_progress_bar_config (disable = None )
@@ -510,7 +480,7 @@ def test_simple_inference_with_text_unet_lora_save_load(self):
510
480
Tests a simple usecase where users could use saving utilities for LoRA for Unet + text encoder
511
481
"""
512
482
for scheduler_cls in [DDIMScheduler , LCMScheduler ]:
513
- components , _ , text_lora_config , unet_lora_config = self .get_dummy_components (scheduler_cls )
483
+ components , text_lora_config , unet_lora_config = self .get_dummy_components (scheduler_cls )
514
484
pipe = self .pipeline_class (** components )
515
485
pipe = pipe .to (self .torch_device )
516
486
pipe .set_progress_bar_config (disable = None )
@@ -583,7 +553,7 @@ def test_simple_inference_with_text_unet_lora_and_scale(self):
583
553
and makes sure it works as expected
584
554
"""
585
555
for scheduler_cls in [DDIMScheduler , LCMScheduler ]:
586
- components , _ , text_lora_config , unet_lora_config = self .get_dummy_components (scheduler_cls )
556
+ components , text_lora_config , unet_lora_config = self .get_dummy_components (scheduler_cls )
587
557
pipe = self .pipeline_class (** components )
588
558
pipe = pipe .to (self .torch_device )
589
559
pipe .set_progress_bar_config (disable = None )
@@ -637,7 +607,7 @@ def test_simple_inference_with_text_lora_unet_fused(self):
637
607
and makes sure it works as expected - with unet
638
608
"""
639
609
for scheduler_cls in [DDIMScheduler , LCMScheduler ]:
640
- components , _ , text_lora_config , unet_lora_config = self .get_dummy_components (scheduler_cls )
610
+ components , text_lora_config , unet_lora_config = self .get_dummy_components (scheduler_cls )
641
611
pipe = self .pipeline_class (** components )
642
612
pipe = pipe .to (self .torch_device )
643
613
pipe .set_progress_bar_config (disable = None )
@@ -683,7 +653,7 @@ def test_simple_inference_with_text_unet_lora_unloaded(self):
683
653
and makes sure it works as expected
684
654
"""
685
655
for scheduler_cls in [DDIMScheduler , LCMScheduler ]:
686
- components , _ , text_lora_config , unet_lora_config = self .get_dummy_components (scheduler_cls )
656
+ components , text_lora_config , unet_lora_config = self .get_dummy_components (scheduler_cls )
687
657
pipe = self .pipeline_class (** components )
688
658
pipe = pipe .to (self .torch_device )
689
659
pipe .set_progress_bar_config (disable = None )
@@ -730,7 +700,7 @@ def test_simple_inference_with_text_unet_lora_unfused(self):
730
700
and makes sure it works as expected
731
701
"""
732
702
for scheduler_cls in [DDIMScheduler , LCMScheduler ]:
733
- components , _ , text_lora_config , unet_lora_config = self .get_dummy_components (scheduler_cls )
703
+ components , text_lora_config , unet_lora_config = self .get_dummy_components (scheduler_cls )
734
704
pipe = self .pipeline_class (** components )
735
705
pipe = pipe .to (self .torch_device )
736
706
pipe .set_progress_bar_config (disable = None )
@@ -780,7 +750,7 @@ def test_simple_inference_with_text_unet_multi_adapter(self):
780
750
multiple adapters and set them
781
751
"""
782
752
for scheduler_cls in [DDIMScheduler , LCMScheduler ]:
783
- components , _ , text_lora_config , unet_lora_config = self .get_dummy_components (scheduler_cls )
753
+ components , text_lora_config , unet_lora_config = self .get_dummy_components (scheduler_cls )
784
754
pipe = self .pipeline_class (** components )
785
755
pipe = pipe .to (self .torch_device )
786
756
pipe .set_progress_bar_config (disable = None )
@@ -848,7 +818,7 @@ def test_simple_inference_with_text_unet_multi_adapter_delete_adapter(self):
848
818
multiple adapters and set/delete them
849
819
"""
850
820
for scheduler_cls in [DDIMScheduler , LCMScheduler ]:
851
- components , _ , text_lora_config , unet_lora_config = self .get_dummy_components (scheduler_cls )
821
+ components , text_lora_config , unet_lora_config = self .get_dummy_components (scheduler_cls )
852
822
pipe = self .pipeline_class (** components )
853
823
pipe = pipe .to (self .torch_device )
854
824
pipe .set_progress_bar_config (disable = None )
@@ -938,7 +908,7 @@ def test_simple_inference_with_text_unet_multi_adapter_weighted(self):
938
908
multiple adapters and set them
939
909
"""
940
910
for scheduler_cls in [DDIMScheduler , LCMScheduler ]:
941
- components , _ , text_lora_config , unet_lora_config = self .get_dummy_components (scheduler_cls )
911
+ components , text_lora_config , unet_lora_config = self .get_dummy_components (scheduler_cls )
942
912
pipe = self .pipeline_class (** components )
943
913
pipe = pipe .to (self .torch_device )
944
914
pipe .set_progress_bar_config (disable = None )
@@ -1010,7 +980,7 @@ def test_simple_inference_with_text_unet_multi_adapter_weighted(self):
1010
980
1011
981
def test_lora_fuse_nan (self ):
1012
982
for scheduler_cls in [DDIMScheduler , LCMScheduler ]:
1013
- components , _ , text_lora_config , unet_lora_config = self .get_dummy_components (scheduler_cls )
983
+ components , text_lora_config , unet_lora_config = self .get_dummy_components (scheduler_cls )
1014
984
pipe = self .pipeline_class (** components )
1015
985
pipe = pipe .to (self .torch_device )
1016
986
pipe .set_progress_bar_config (disable = None )
@@ -1048,7 +1018,7 @@ def test_get_adapters(self):
1048
1018
are the expected results
1049
1019
"""
1050
1020
for scheduler_cls in [DDIMScheduler , LCMScheduler ]:
1051
- components , _ , text_lora_config , unet_lora_config = self .get_dummy_components (scheduler_cls )
1021
+ components , text_lora_config , unet_lora_config = self .get_dummy_components (scheduler_cls )
1052
1022
pipe = self .pipeline_class (** components )
1053
1023
pipe = pipe .to (self .torch_device )
1054
1024
pipe .set_progress_bar_config (disable = None )
@@ -1075,7 +1045,7 @@ def test_get_list_adapters(self):
1075
1045
are the expected results
1076
1046
"""
1077
1047
for scheduler_cls in [DDIMScheduler , LCMScheduler ]:
1078
- components , _ , text_lora_config , unet_lora_config = self .get_dummy_components (scheduler_cls )
1048
+ components , text_lora_config , unet_lora_config = self .get_dummy_components (scheduler_cls )
1079
1049
pipe = self .pipeline_class (** components )
1080
1050
pipe = pipe .to (self .torch_device )
1081
1051
pipe .set_progress_bar_config (disable = None )
@@ -1113,7 +1083,7 @@ def test_simple_inference_with_text_lora_unet_fused_multi(self):
1113
1083
and makes sure it works as expected - with unet and multi-adapter case
1114
1084
"""
1115
1085
for scheduler_cls in [DDIMScheduler , LCMScheduler ]:
1116
- components , _ , text_lora_config , unet_lora_config = self .get_dummy_components (scheduler_cls )
1086
+ components , text_lora_config , unet_lora_config = self .get_dummy_components (scheduler_cls )
1117
1087
pipe = self .pipeline_class (** components )
1118
1088
pipe = pipe .to (self .torch_device )
1119
1089
pipe .set_progress_bar_config (disable = None )
@@ -1175,7 +1145,7 @@ def test_simple_inference_with_text_unet_lora_unfused_torch_compile(self):
1175
1145
and makes sure it works as expected
1176
1146
"""
1177
1147
for scheduler_cls in [DDIMScheduler , LCMScheduler ]:
1178
- components , _ , text_lora_config , unet_lora_config = self .get_dummy_components (scheduler_cls )
1148
+ components , text_lora_config , unet_lora_config = self .get_dummy_components (scheduler_cls )
1179
1149
pipe = self .pipeline_class (** components )
1180
1150
pipe = pipe .to (self .torch_device )
1181
1151
pipe .set_progress_bar_config (disable = None )
0 commit comments