Skip to content

Commit 6a376ce

Browse files
authored
[LoRA] remove unnecessary components from lora peft test suite (#6401)
remove unnecessary components from lora peft suite/
1 parent 9f283b0 commit 6a376ce

File tree

1 file changed

+22
-52
lines changed

1 file changed

+22
-52
lines changed

tests/lora/test_lora_layers_peft.py

Lines changed: 22 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,6 @@
2222
import numpy as np
2323
import torch
2424
import torch.nn as nn
25-
import torch.nn.functional as F
2625
from huggingface_hub import hf_hub_download
2726
from huggingface_hub.repocard import RepoCard
2827
from packaging import version
@@ -41,8 +40,6 @@
4140
StableDiffusionXLPipeline,
4241
UNet2DConditionModel,
4342
)
44-
from diffusers.loaders import AttnProcsLayers
45-
from diffusers.models.attention_processor import LoRAAttnProcessor, LoRAAttnProcessor2_0
4643
from diffusers.utils.import_utils import is_accelerate_available, is_peft_available
4744
from diffusers.utils.testing_utils import (
4845
floats_tensor,
@@ -78,28 +75,6 @@ def state_dicts_almost_equal(sd1, sd2):
7875
return models_are_equal
7976

8077

81-
def create_unet_lora_layers(unet: nn.Module):
82-
lora_attn_procs = {}
83-
for name in unet.attn_processors.keys():
84-
cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim
85-
if name.startswith("mid_block"):
86-
hidden_size = unet.config.block_out_channels[-1]
87-
elif name.startswith("up_blocks"):
88-
block_id = int(name[len("up_blocks.")])
89-
hidden_size = list(reversed(unet.config.block_out_channels))[block_id]
90-
elif name.startswith("down_blocks"):
91-
block_id = int(name[len("down_blocks.")])
92-
hidden_size = unet.config.block_out_channels[block_id]
93-
lora_attn_processor_class = (
94-
LoRAAttnProcessor2_0 if hasattr(F, "scaled_dot_product_attention") else LoRAAttnProcessor
95-
)
96-
lora_attn_procs[name] = lora_attn_processor_class(
97-
hidden_size=hidden_size, cross_attention_dim=cross_attention_dim
98-
)
99-
unet_lora_layers = AttnProcsLayers(lora_attn_procs)
100-
return lora_attn_procs, unet_lora_layers
101-
102-
10378
@require_peft_backend
10479
class PeftLoraLoaderMixinTests:
10580
torch_device = "cuda" if torch.cuda.is_available() else "cpu"
@@ -140,8 +115,6 @@ def get_dummy_components(self, scheduler_cls=None):
140115
r=rank, lora_alpha=rank, target_modules=["to_q", "to_k", "to_v", "to_out.0"], init_lora_weights=False
141116
)
142117

143-
unet_lora_attn_procs, unet_lora_layers = create_unet_lora_layers(unet)
144-
145118
if self.has_two_text_encoders:
146119
pipeline_components = {
147120
"unet": unet,
@@ -165,11 +138,8 @@ def get_dummy_components(self, scheduler_cls=None):
165138
"feature_extractor": None,
166139
"image_encoder": None,
167140
}
168-
lora_components = {
169-
"unet_lora_layers": unet_lora_layers,
170-
"unet_lora_attn_procs": unet_lora_attn_procs,
171-
}
172-
return pipeline_components, lora_components, text_lora_config, unet_lora_config
141+
142+
return pipeline_components, text_lora_config, unet_lora_config
173143

174144
def get_dummy_inputs(self, with_generator=True):
175145
batch_size = 1
@@ -216,7 +186,7 @@ def test_simple_inference(self):
216186
Tests a simple inference and makes sure it works as expected
217187
"""
218188
for scheduler_cls in [DDIMScheduler, LCMScheduler]:
219-
components, _, text_lora_config, _ = self.get_dummy_components(scheduler_cls)
189+
components, text_lora_config, _ = self.get_dummy_components(scheduler_cls)
220190
pipe = self.pipeline_class(**components)
221191
pipe = pipe.to(self.torch_device)
222192
pipe.set_progress_bar_config(disable=None)
@@ -231,7 +201,7 @@ def test_simple_inference_with_text_lora(self):
231201
and makes sure it works as expected
232202
"""
233203
for scheduler_cls in [DDIMScheduler, LCMScheduler]:
234-
components, _, text_lora_config, _ = self.get_dummy_components(scheduler_cls)
204+
components, text_lora_config, _ = self.get_dummy_components(scheduler_cls)
235205
pipe = self.pipeline_class(**components)
236206
pipe = pipe.to(self.torch_device)
237207
pipe.set_progress_bar_config(disable=None)
@@ -262,7 +232,7 @@ def test_simple_inference_with_text_lora_and_scale(self):
262232
and makes sure it works as expected
263233
"""
264234
for scheduler_cls in [DDIMScheduler, LCMScheduler]:
265-
components, _, text_lora_config, _ = self.get_dummy_components(scheduler_cls)
235+
components, text_lora_config, _ = self.get_dummy_components(scheduler_cls)
266236
pipe = self.pipeline_class(**components)
267237
pipe = pipe.to(self.torch_device)
268238
pipe.set_progress_bar_config(disable=None)
@@ -309,7 +279,7 @@ def test_simple_inference_with_text_lora_fused(self):
309279
and makes sure it works as expected
310280
"""
311281
for scheduler_cls in [DDIMScheduler, LCMScheduler]:
312-
components, _, text_lora_config, _ = self.get_dummy_components(scheduler_cls)
282+
components, text_lora_config, _ = self.get_dummy_components(scheduler_cls)
313283
pipe = self.pipeline_class(**components)
314284
pipe = pipe.to(self.torch_device)
315285
pipe.set_progress_bar_config(disable=None)
@@ -351,7 +321,7 @@ def test_simple_inference_with_text_lora_unloaded(self):
351321
and makes sure it works as expected
352322
"""
353323
for scheduler_cls in [DDIMScheduler, LCMScheduler]:
354-
components, _, text_lora_config, _ = self.get_dummy_components(scheduler_cls)
324+
components, text_lora_config, _ = self.get_dummy_components(scheduler_cls)
355325
pipe = self.pipeline_class(**components)
356326
pipe = pipe.to(self.torch_device)
357327
pipe.set_progress_bar_config(disable=None)
@@ -394,7 +364,7 @@ def test_simple_inference_with_text_lora_save_load(self):
394364
Tests a simple usecase where users could use saving utilities for LoRA.
395365
"""
396366
for scheduler_cls in [DDIMScheduler, LCMScheduler]:
397-
components, _, text_lora_config, _ = self.get_dummy_components(scheduler_cls)
367+
components, text_lora_config, _ = self.get_dummy_components(scheduler_cls)
398368
pipe = self.pipeline_class(**components)
399369
pipe = pipe.to(self.torch_device)
400370
pipe.set_progress_bar_config(disable=None)
@@ -459,7 +429,7 @@ def test_simple_inference_save_pretrained(self):
459429
Tests a simple usecase where users could use saving utilities for LoRA through save_pretrained
460430
"""
461431
for scheduler_cls in [DDIMScheduler, LCMScheduler]:
462-
components, _, text_lora_config, _ = self.get_dummy_components(scheduler_cls)
432+
components, text_lora_config, _ = self.get_dummy_components(scheduler_cls)
463433
pipe = self.pipeline_class(**components)
464434
pipe = pipe.to(self.torch_device)
465435
pipe.set_progress_bar_config(disable=None)
@@ -510,7 +480,7 @@ def test_simple_inference_with_text_unet_lora_save_load(self):
510480
Tests a simple usecase where users could use saving utilities for LoRA for Unet + text encoder
511481
"""
512482
for scheduler_cls in [DDIMScheduler, LCMScheduler]:
513-
components, _, text_lora_config, unet_lora_config = self.get_dummy_components(scheduler_cls)
483+
components, text_lora_config, unet_lora_config = self.get_dummy_components(scheduler_cls)
514484
pipe = self.pipeline_class(**components)
515485
pipe = pipe.to(self.torch_device)
516486
pipe.set_progress_bar_config(disable=None)
@@ -583,7 +553,7 @@ def test_simple_inference_with_text_unet_lora_and_scale(self):
583553
and makes sure it works as expected
584554
"""
585555
for scheduler_cls in [DDIMScheduler, LCMScheduler]:
586-
components, _, text_lora_config, unet_lora_config = self.get_dummy_components(scheduler_cls)
556+
components, text_lora_config, unet_lora_config = self.get_dummy_components(scheduler_cls)
587557
pipe = self.pipeline_class(**components)
588558
pipe = pipe.to(self.torch_device)
589559
pipe.set_progress_bar_config(disable=None)
@@ -637,7 +607,7 @@ def test_simple_inference_with_text_lora_unet_fused(self):
637607
and makes sure it works as expected - with unet
638608
"""
639609
for scheduler_cls in [DDIMScheduler, LCMScheduler]:
640-
components, _, text_lora_config, unet_lora_config = self.get_dummy_components(scheduler_cls)
610+
components, text_lora_config, unet_lora_config = self.get_dummy_components(scheduler_cls)
641611
pipe = self.pipeline_class(**components)
642612
pipe = pipe.to(self.torch_device)
643613
pipe.set_progress_bar_config(disable=None)
@@ -683,7 +653,7 @@ def test_simple_inference_with_text_unet_lora_unloaded(self):
683653
and makes sure it works as expected
684654
"""
685655
for scheduler_cls in [DDIMScheduler, LCMScheduler]:
686-
components, _, text_lora_config, unet_lora_config = self.get_dummy_components(scheduler_cls)
656+
components, text_lora_config, unet_lora_config = self.get_dummy_components(scheduler_cls)
687657
pipe = self.pipeline_class(**components)
688658
pipe = pipe.to(self.torch_device)
689659
pipe.set_progress_bar_config(disable=None)
@@ -730,7 +700,7 @@ def test_simple_inference_with_text_unet_lora_unfused(self):
730700
and makes sure it works as expected
731701
"""
732702
for scheduler_cls in [DDIMScheduler, LCMScheduler]:
733-
components, _, text_lora_config, unet_lora_config = self.get_dummy_components(scheduler_cls)
703+
components, text_lora_config, unet_lora_config = self.get_dummy_components(scheduler_cls)
734704
pipe = self.pipeline_class(**components)
735705
pipe = pipe.to(self.torch_device)
736706
pipe.set_progress_bar_config(disable=None)
@@ -780,7 +750,7 @@ def test_simple_inference_with_text_unet_multi_adapter(self):
780750
multiple adapters and set them
781751
"""
782752
for scheduler_cls in [DDIMScheduler, LCMScheduler]:
783-
components, _, text_lora_config, unet_lora_config = self.get_dummy_components(scheduler_cls)
753+
components, text_lora_config, unet_lora_config = self.get_dummy_components(scheduler_cls)
784754
pipe = self.pipeline_class(**components)
785755
pipe = pipe.to(self.torch_device)
786756
pipe.set_progress_bar_config(disable=None)
@@ -848,7 +818,7 @@ def test_simple_inference_with_text_unet_multi_adapter_delete_adapter(self):
848818
multiple adapters and set/delete them
849819
"""
850820
for scheduler_cls in [DDIMScheduler, LCMScheduler]:
851-
components, _, text_lora_config, unet_lora_config = self.get_dummy_components(scheduler_cls)
821+
components, text_lora_config, unet_lora_config = self.get_dummy_components(scheduler_cls)
852822
pipe = self.pipeline_class(**components)
853823
pipe = pipe.to(self.torch_device)
854824
pipe.set_progress_bar_config(disable=None)
@@ -938,7 +908,7 @@ def test_simple_inference_with_text_unet_multi_adapter_weighted(self):
938908
multiple adapters and set them
939909
"""
940910
for scheduler_cls in [DDIMScheduler, LCMScheduler]:
941-
components, _, text_lora_config, unet_lora_config = self.get_dummy_components(scheduler_cls)
911+
components, text_lora_config, unet_lora_config = self.get_dummy_components(scheduler_cls)
942912
pipe = self.pipeline_class(**components)
943913
pipe = pipe.to(self.torch_device)
944914
pipe.set_progress_bar_config(disable=None)
@@ -1010,7 +980,7 @@ def test_simple_inference_with_text_unet_multi_adapter_weighted(self):
1010980

1011981
def test_lora_fuse_nan(self):
1012982
for scheduler_cls in [DDIMScheduler, LCMScheduler]:
1013-
components, _, text_lora_config, unet_lora_config = self.get_dummy_components(scheduler_cls)
983+
components, text_lora_config, unet_lora_config = self.get_dummy_components(scheduler_cls)
1014984
pipe = self.pipeline_class(**components)
1015985
pipe = pipe.to(self.torch_device)
1016986
pipe.set_progress_bar_config(disable=None)
@@ -1048,7 +1018,7 @@ def test_get_adapters(self):
10481018
are the expected results
10491019
"""
10501020
for scheduler_cls in [DDIMScheduler, LCMScheduler]:
1051-
components, _, text_lora_config, unet_lora_config = self.get_dummy_components(scheduler_cls)
1021+
components, text_lora_config, unet_lora_config = self.get_dummy_components(scheduler_cls)
10521022
pipe = self.pipeline_class(**components)
10531023
pipe = pipe.to(self.torch_device)
10541024
pipe.set_progress_bar_config(disable=None)
@@ -1075,7 +1045,7 @@ def test_get_list_adapters(self):
10751045
are the expected results
10761046
"""
10771047
for scheduler_cls in [DDIMScheduler, LCMScheduler]:
1078-
components, _, text_lora_config, unet_lora_config = self.get_dummy_components(scheduler_cls)
1048+
components, text_lora_config, unet_lora_config = self.get_dummy_components(scheduler_cls)
10791049
pipe = self.pipeline_class(**components)
10801050
pipe = pipe.to(self.torch_device)
10811051
pipe.set_progress_bar_config(disable=None)
@@ -1113,7 +1083,7 @@ def test_simple_inference_with_text_lora_unet_fused_multi(self):
11131083
and makes sure it works as expected - with unet and multi-adapter case
11141084
"""
11151085
for scheduler_cls in [DDIMScheduler, LCMScheduler]:
1116-
components, _, text_lora_config, unet_lora_config = self.get_dummy_components(scheduler_cls)
1086+
components, text_lora_config, unet_lora_config = self.get_dummy_components(scheduler_cls)
11171087
pipe = self.pipeline_class(**components)
11181088
pipe = pipe.to(self.torch_device)
11191089
pipe.set_progress_bar_config(disable=None)
@@ -1175,7 +1145,7 @@ def test_simple_inference_with_text_unet_lora_unfused_torch_compile(self):
11751145
and makes sure it works as expected
11761146
"""
11771147
for scheduler_cls in [DDIMScheduler, LCMScheduler]:
1178-
components, _, text_lora_config, unet_lora_config = self.get_dummy_components(scheduler_cls)
1148+
components, text_lora_config, unet_lora_config = self.get_dummy_components(scheduler_cls)
11791149
pipe = self.pipeline_class(**components)
11801150
pipe = pipe.to(self.torch_device)
11811151
pipe.set_progress_bar_config(disable=None)

0 commit comments

Comments
 (0)