Skip to content

Commit c745230

Browse files
noskillsayakpaul
andcommitted
handle lora scale and clip skip in lpw sd and sdxl community pipelines (#8988)
* handle lora scale and clip skip in lpw sd and sdxl * use StableDiffusionLoraLoaderMixin * use StableDiffusionXLLoraLoaderMixin * style --------- Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
1 parent 3566f4b commit c745230

File tree

2 files changed

+100
-12
lines changed

2 files changed

+100
-12
lines changed

examples/community/lpw_stable_diffusion.py

Lines changed: 58 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -13,13 +13,17 @@
1313
from diffusers.image_processor import VaeImageProcessor
1414
from diffusers.loaders import FromSingleFileMixin, StableDiffusionLoraLoaderMixin, TextualInversionLoaderMixin
1515
from diffusers.models import AutoencoderKL, UNet2DConditionModel
16+
from diffusers.models.lora import adjust_lora_scale_text_encoder
1617
from diffusers.pipelines.pipeline_utils import StableDiffusionMixin
1718
from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput, StableDiffusionSafetyChecker
1819
from diffusers.schedulers import KarrasDiffusionSchedulers
1920
from diffusers.utils import (
2021
PIL_INTERPOLATION,
22+
USE_PEFT_BACKEND,
2123
deprecate,
2224
logging,
25+
scale_lora_layers,
26+
unscale_lora_layers,
2327
)
2428
from diffusers.utils.torch_utils import randn_tensor
2529

@@ -199,6 +203,7 @@ def get_unweighted_text_embeddings(
199203
text_input: torch.Tensor,
200204
chunk_length: int,
201205
no_boseos_middle: Optional[bool] = True,
206+
clip_skip: Optional[int] = None,
202207
):
203208
"""
204209
When the length of tokens is a multiple of the capacity of the text encoder,
@@ -214,7 +219,20 @@ def get_unweighted_text_embeddings(
214219
# cover the head and the tail by the starting and the ending tokens
215220
text_input_chunk[:, 0] = text_input[0, 0]
216221
text_input_chunk[:, -1] = text_input[0, -1]
217-
text_embedding = pipe.text_encoder(text_input_chunk)[0]
222+
if clip_skip is None:
223+
prompt_embeds = pipe.text_encoder(text_input_chunk.to(pipe.device))
224+
text_embedding = prompt_embeds[0]
225+
else:
226+
prompt_embeds = pipe.text_encoder(text_input_chunk.to(pipe.device), output_hidden_states=True)
227+
# Access the `hidden_states` first, that contains a tuple of
228+
# all the hidden states from the encoder layers. Then index into
229+
# the tuple to access the hidden states from the desired layer.
230+
prompt_embeds = prompt_embeds[-1][-(clip_skip + 1)]
231+
# We also need to apply the final LayerNorm here to not mess with the
232+
# representations. The `last_hidden_states` that we typically use for
233+
# obtaining the final prompt representations passes through the LayerNorm
234+
# layer.
235+
text_embedding = pipe.text_encoder.text_model.final_layer_norm(prompt_embeds)
218236

219237
if no_boseos_middle:
220238
if i == 0:
@@ -230,7 +248,10 @@ def get_unweighted_text_embeddings(
230248
text_embeddings.append(text_embedding)
231249
text_embeddings = torch.concat(text_embeddings, axis=1)
232250
else:
233-
text_embeddings = pipe.text_encoder(text_input)[0]
251+
if clip_skip is None:
252+
clip_skip = 0
253+
prompt_embeds = pipe.text_encoder(text_input, output_hidden_states=True)[-1][-(clip_skip + 1)]
254+
text_embeddings = pipe.text_encoder.text_model.final_layer_norm(prompt_embeds)
234255
return text_embeddings
235256

236257

@@ -242,6 +263,8 @@ def get_weighted_text_embeddings(
242263
no_boseos_middle: Optional[bool] = False,
243264
skip_parsing: Optional[bool] = False,
244265
skip_weighting: Optional[bool] = False,
266+
clip_skip=None,
267+
lora_scale=None,
245268
):
246269
r"""
247270
Prompts can be assigned with local weights using brackets. For example,
@@ -268,6 +291,16 @@ def get_weighted_text_embeddings(
268291
skip_weighting (`bool`, *optional*, defaults to `False`):
269292
Skip the weighting. When the parsing is skipped, it is forced True.
270293
"""
294+
# set lora scale so that monkey patched LoRA
295+
# function of text encoder can correctly access it
296+
if lora_scale is not None and isinstance(pipe, StableDiffusionLoraLoaderMixin):
297+
pipe._lora_scale = lora_scale
298+
299+
# dynamically adjust the LoRA scale
300+
if not USE_PEFT_BACKEND:
301+
adjust_lora_scale_text_encoder(pipe.text_encoder, lora_scale)
302+
else:
303+
scale_lora_layers(pipe.text_encoder, lora_scale)
271304
max_length = (pipe.tokenizer.model_max_length - 2) * max_embeddings_multiples + 2
272305
if isinstance(prompt, str):
273306
prompt = [prompt]
@@ -334,10 +367,7 @@ def get_weighted_text_embeddings(
334367

335368
# get the embeddings
336369
text_embeddings = get_unweighted_text_embeddings(
337-
pipe,
338-
prompt_tokens,
339-
pipe.tokenizer.model_max_length,
340-
no_boseos_middle=no_boseos_middle,
370+
pipe, prompt_tokens, pipe.tokenizer.model_max_length, no_boseos_middle=no_boseos_middle, clip_skip=clip_skip
341371
)
342372
prompt_weights = torch.tensor(prompt_weights, dtype=text_embeddings.dtype, device=text_embeddings.device)
343373
if uncond_prompt is not None:
@@ -346,6 +376,7 @@ def get_weighted_text_embeddings(
346376
uncond_tokens,
347377
pipe.tokenizer.model_max_length,
348378
no_boseos_middle=no_boseos_middle,
379+
clip_skip=clip_skip,
349380
)
350381
uncond_weights = torch.tensor(uncond_weights, dtype=uncond_embeddings.dtype, device=uncond_embeddings.device)
351382

@@ -362,6 +393,11 @@ def get_weighted_text_embeddings(
362393
current_mean = uncond_embeddings.float().mean(axis=[-2, -1]).to(uncond_embeddings.dtype)
363394
uncond_embeddings *= (previous_mean / current_mean).unsqueeze(-1).unsqueeze(-1)
364395

396+
if pipe.text_encoder is not None:
397+
if isinstance(pipe, StableDiffusionLoraLoaderMixin) and USE_PEFT_BACKEND:
398+
# Retrieve the original scale by scaling back the LoRA layers
399+
unscale_lora_layers(pipe.text_encoder, lora_scale)
400+
365401
if uncond_prompt is not None:
366402
return text_embeddings, uncond_embeddings
367403
return text_embeddings, None
@@ -549,6 +585,8 @@ def _encode_prompt(
549585
max_embeddings_multiples=3,
550586
prompt_embeds: Optional[torch.Tensor] = None,
551587
negative_prompt_embeds: Optional[torch.Tensor] = None,
588+
clip_skip: Optional[int] = None,
589+
lora_scale: Optional[float] = None,
552590
):
553591
r"""
554592
Encodes the prompt into text encoder hidden states.
@@ -597,6 +635,8 @@ def _encode_prompt(
597635
prompt=prompt,
598636
uncond_prompt=negative_prompt if do_classifier_free_guidance else None,
599637
max_embeddings_multiples=max_embeddings_multiples,
638+
clip_skip=clip_skip,
639+
lora_scale=lora_scale,
600640
)
601641
if prompt_embeds is None:
602642
prompt_embeds = prompt_embeds1
@@ -790,6 +830,7 @@ def __call__(
790830
return_dict: bool = True,
791831
callback: Optional[Callable[[int, int, torch.Tensor], None]] = None,
792832
is_cancelled_callback: Optional[Callable[[], bool]] = None,
833+
clip_skip: Optional[int] = None,
793834
callback_steps: int = 1,
794835
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
795836
):
@@ -865,6 +906,9 @@ def __call__(
865906
is_cancelled_callback (`Callable`, *optional*):
866907
A function that will be called every `callback_steps` steps during inference. If the function returns
867908
`True`, the inference will be cancelled.
909+
clip_skip (`int`, *optional*):
910+
Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
911+
the output of the pre-final layer will be used for computing the prompt embeddings.
868912
callback_steps (`int`, *optional*, defaults to 1):
869913
The frequency at which the `callback` function will be called. If not specified, the callback will be
870914
called at every step.
@@ -903,6 +947,7 @@ def __call__(
903947
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
904948
# corresponds to doing no classifier free guidance.
905949
do_classifier_free_guidance = guidance_scale > 1.0
950+
lora_scale = cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None
906951

907952
# 3. Encode input prompt
908953
prompt_embeds = self._encode_prompt(
@@ -914,6 +959,8 @@ def __call__(
914959
max_embeddings_multiples,
915960
prompt_embeds=prompt_embeds,
916961
negative_prompt_embeds=negative_prompt_embeds,
962+
clip_skip=clip_skip,
963+
lora_scale=lora_scale,
917964
)
918965
dtype = prompt_embeds.dtype
919966

@@ -1044,6 +1091,7 @@ def text2img(
10441091
return_dict: bool = True,
10451092
callback: Optional[Callable[[int, int, torch.Tensor], None]] = None,
10461093
is_cancelled_callback: Optional[Callable[[], bool]] = None,
1094+
clip_skip=None,
10471095
callback_steps: int = 1,
10481096
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
10491097
):
@@ -1101,6 +1149,9 @@ def text2img(
11011149
is_cancelled_callback (`Callable`, *optional*):
11021150
A function that will be called every `callback_steps` steps during inference. If the function returns
11031151
`True`, the inference will be cancelled.
1152+
clip_skip (`int`, *optional*):
1153+
Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
1154+
the output of the pre-final layer will be used for computing the prompt embeddings.
11041155
callback_steps (`int`, *optional*, defaults to 1):
11051156
The frequency at which the `callback` function will be called. If not specified, the callback will be
11061157
called at every step.
@@ -1135,6 +1186,7 @@ def text2img(
11351186
return_dict=return_dict,
11361187
callback=callback,
11371188
is_cancelled_callback=is_cancelled_callback,
1189+
clip_skip=clip_skip,
11381190
callback_steps=callback_steps,
11391191
cross_attention_kwargs=cross_attention_kwargs,
11401192
)

examples/community/lpw_stable_diffusion_xl.py

Lines changed: 42 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -25,21 +25,25 @@
2525
from diffusers.loaders import (
2626
FromSingleFileMixin,
2727
IPAdapterMixin,
28-
StableDiffusionLoraLoaderMixin,
28+
StableDiffusionXLLoraLoaderMixin,
2929
TextualInversionLoaderMixin,
3030
)
3131
from diffusers.models import AutoencoderKL, ImageProjection, UNet2DConditionModel
3232
from diffusers.models.attention_processor import AttnProcessor2_0, XFormersAttnProcessor
33+
from diffusers.models.lora import adjust_lora_scale_text_encoder
3334
from diffusers.pipelines.pipeline_utils import StableDiffusionMixin
3435
from diffusers.pipelines.stable_diffusion_xl.pipeline_output import StableDiffusionXLPipelineOutput
3536
from diffusers.schedulers import KarrasDiffusionSchedulers
3637
from diffusers.utils import (
38+
USE_PEFT_BACKEND,
3739
deprecate,
3840
is_accelerate_available,
3941
is_accelerate_version,
4042
is_invisible_watermark_available,
4143
logging,
4244
replace_example_docstring,
45+
scale_lora_layers,
46+
unscale_lora_layers,
4347
)
4448
from diffusers.utils.torch_utils import randn_tensor
4549

@@ -261,6 +265,7 @@ def get_weighted_text_embeddings_sdxl(
261265
num_images_per_prompt: int = 1,
262266
device: Optional[torch.device] = None,
263267
clip_skip: Optional[int] = None,
268+
lora_scale: Optional[int] = None,
264269
):
265270
"""
266271
This function can process long prompt with weights, no length limitation
@@ -281,6 +286,24 @@ def get_weighted_text_embeddings_sdxl(
281286
"""
282287
device = device or pipe._execution_device
283288

289+
# set lora scale so that monkey patched LoRA
290+
# function of text encoder can correctly access it
291+
if lora_scale is not None and isinstance(pipe, StableDiffusionXLLoraLoaderMixin):
292+
pipe._lora_scale = lora_scale
293+
294+
# dynamically adjust the LoRA scale
295+
if pipe.text_encoder is not None:
296+
if not USE_PEFT_BACKEND:
297+
adjust_lora_scale_text_encoder(pipe.text_encoder, lora_scale)
298+
else:
299+
scale_lora_layers(pipe.text_encoder, lora_scale)
300+
301+
if pipe.text_encoder_2 is not None:
302+
if not USE_PEFT_BACKEND:
303+
adjust_lora_scale_text_encoder(pipe.text_encoder_2, lora_scale)
304+
else:
305+
scale_lora_layers(pipe.text_encoder_2, lora_scale)
306+
284307
if prompt_2:
285308
prompt = f"{prompt} {prompt_2}"
286309

@@ -429,6 +452,16 @@ def get_weighted_text_embeddings_sdxl(
429452
bs_embed * num_images_per_prompt, -1
430453
)
431454

455+
if pipe.text_encoder is not None:
456+
if isinstance(pipe, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND:
457+
# Retrieve the original scale by scaling back the LoRA layers
458+
unscale_lora_layers(pipe.text_encoder, lora_scale)
459+
460+
if pipe.text_encoder_2 is not None:
461+
if isinstance(pipe, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND:
462+
# Retrieve the original scale by scaling back the LoRA layers
463+
unscale_lora_layers(pipe.text_encoder_2, lora_scale)
464+
432465
return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds
433466

434467

@@ -549,7 +582,7 @@ class SDXLLongPromptWeightingPipeline(
549582
StableDiffusionMixin,
550583
FromSingleFileMixin,
551584
IPAdapterMixin,
552-
StableDiffusionLoraLoaderMixin,
585+
StableDiffusionXLLoraLoaderMixin,
553586
TextualInversionLoaderMixin,
554587
):
555588
r"""
@@ -561,8 +594,8 @@ class SDXLLongPromptWeightingPipeline(
561594
The pipeline also inherits the following loading methods:
562595
- [`~loaders.FromSingleFileMixin.from_single_file`] for loading `.ckpt` files
563596
- [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters
564-
- [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`] for loading LoRA weights
565-
- [`~loaders.StableDiffusionLoraLoaderMixin.save_lora_weights`] for saving LoRA weights
597+
- [`~loaders.StableDiffusionXLLoraLoaderMixin.load_lora_weights`] for loading LoRA weights
598+
- [`~loaders.StableDiffusionXLLoraLoaderMixin.save_lora_weights`] for saving LoRA weights
566599
- [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings
567600
568601
Args:
@@ -743,7 +776,7 @@ def encode_prompt(
743776

744777
# set lora scale so that monkey patched LoRA
745778
# function of text encoder can correctly access it
746-
if lora_scale is not None and isinstance(self, StableDiffusionLoraLoaderMixin):
779+
if lora_scale is not None and isinstance(self, StableDiffusionXLLoraLoaderMixin):
747780
self._lora_scale = lora_scale
748781

749782
if prompt is not None and isinstance(prompt, str):
@@ -1612,7 +1645,9 @@ def __call__(
16121645
image_embeds = torch.cat([negative_image_embeds, image_embeds])
16131646

16141647
# 3. Encode input prompt
1615-
(self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None)
1648+
lora_scale = (
1649+
self._cross_attention_kwargs.get("scale", None) if self._cross_attention_kwargs is not None else None
1650+
)
16161651

16171652
negative_prompt = negative_prompt if negative_prompt is not None else ""
16181653

@@ -1627,6 +1662,7 @@ def __call__(
16271662
neg_prompt=negative_prompt,
16281663
num_images_per_prompt=num_images_per_prompt,
16291664
clip_skip=clip_skip,
1665+
lora_scale=lora_scale,
16301666
)
16311667
dtype = prompt_embeds.dtype
16321668

0 commit comments

Comments
 (0)