13
13
from diffusers .image_processor import VaeImageProcessor
14
14
from diffusers .loaders import FromSingleFileMixin , StableDiffusionLoraLoaderMixin , TextualInversionLoaderMixin
15
15
from diffusers .models import AutoencoderKL , UNet2DConditionModel
16
+ from diffusers .models .lora import adjust_lora_scale_text_encoder
16
17
from diffusers .pipelines .pipeline_utils import StableDiffusionMixin
17
18
from diffusers .pipelines .stable_diffusion import StableDiffusionPipelineOutput , StableDiffusionSafetyChecker
18
19
from diffusers .schedulers import KarrasDiffusionSchedulers
19
20
from diffusers .utils import (
20
21
PIL_INTERPOLATION ,
22
+ USE_PEFT_BACKEND ,
21
23
deprecate ,
22
24
logging ,
25
+ scale_lora_layers ,
26
+ unscale_lora_layers ,
23
27
)
24
28
from diffusers .utils .torch_utils import randn_tensor
25
29
@@ -199,6 +203,7 @@ def get_unweighted_text_embeddings(
199
203
text_input : torch .Tensor ,
200
204
chunk_length : int ,
201
205
no_boseos_middle : Optional [bool ] = True ,
206
+ clip_skip : Optional [int ] = None ,
202
207
):
203
208
"""
204
209
When the length of tokens is a multiple of the capacity of the text encoder,
@@ -214,7 +219,20 @@ def get_unweighted_text_embeddings(
214
219
# cover the head and the tail by the starting and the ending tokens
215
220
text_input_chunk [:, 0 ] = text_input [0 , 0 ]
216
221
text_input_chunk [:, - 1 ] = text_input [0 , - 1 ]
217
- text_embedding = pipe .text_encoder (text_input_chunk )[0 ]
222
+ if clip_skip is None :
223
+ prompt_embeds = pipe .text_encoder (text_input_chunk .to (pipe .device ))
224
+ text_embedding = prompt_embeds [0 ]
225
+ else :
226
+ prompt_embeds = pipe .text_encoder (text_input_chunk .to (pipe .device ), output_hidden_states = True )
227
+ # Access the `hidden_states` first, that contains a tuple of
228
+ # all the hidden states from the encoder layers. Then index into
229
+ # the tuple to access the hidden states from the desired layer.
230
+ prompt_embeds = prompt_embeds [- 1 ][- (clip_skip + 1 )]
231
+ # We also need to apply the final LayerNorm here to not mess with the
232
+ # representations. The `last_hidden_states` that we typically use for
233
+ # obtaining the final prompt representations passes through the LayerNorm
234
+ # layer.
235
+ text_embedding = pipe .text_encoder .text_model .final_layer_norm (prompt_embeds )
218
236
219
237
if no_boseos_middle :
220
238
if i == 0 :
@@ -230,7 +248,10 @@ def get_unweighted_text_embeddings(
230
248
text_embeddings .append (text_embedding )
231
249
text_embeddings = torch .concat (text_embeddings , axis = 1 )
232
250
else :
233
- text_embeddings = pipe .text_encoder (text_input )[0 ]
251
+ if clip_skip is None :
252
+ clip_skip = 0
253
+ prompt_embeds = pipe .text_encoder (text_input , output_hidden_states = True )[- 1 ][- (clip_skip + 1 )]
254
+ text_embeddings = pipe .text_encoder .text_model .final_layer_norm (prompt_embeds )
234
255
return text_embeddings
235
256
236
257
@@ -242,6 +263,8 @@ def get_weighted_text_embeddings(
242
263
no_boseos_middle : Optional [bool ] = False ,
243
264
skip_parsing : Optional [bool ] = False ,
244
265
skip_weighting : Optional [bool ] = False ,
266
+ clip_skip = None ,
267
+ lora_scale = None ,
245
268
):
246
269
r"""
247
270
Prompts can be assigned with local weights using brackets. For example,
@@ -268,6 +291,16 @@ def get_weighted_text_embeddings(
268
291
skip_weighting (`bool`, *optional*, defaults to `False`):
269
292
Skip the weighting. When the parsing is skipped, it is forced True.
270
293
"""
294
+ # set lora scale so that monkey patched LoRA
295
+ # function of text encoder can correctly access it
296
+ if lora_scale is not None and isinstance (pipe , StableDiffusionLoraLoaderMixin ):
297
+ pipe ._lora_scale = lora_scale
298
+
299
+ # dynamically adjust the LoRA scale
300
+ if not USE_PEFT_BACKEND :
301
+ adjust_lora_scale_text_encoder (pipe .text_encoder , lora_scale )
302
+ else :
303
+ scale_lora_layers (pipe .text_encoder , lora_scale )
271
304
max_length = (pipe .tokenizer .model_max_length - 2 ) * max_embeddings_multiples + 2
272
305
if isinstance (prompt , str ):
273
306
prompt = [prompt ]
@@ -334,10 +367,7 @@ def get_weighted_text_embeddings(
334
367
335
368
# get the embeddings
336
369
text_embeddings = get_unweighted_text_embeddings (
337
- pipe ,
338
- prompt_tokens ,
339
- pipe .tokenizer .model_max_length ,
340
- no_boseos_middle = no_boseos_middle ,
370
+ pipe , prompt_tokens , pipe .tokenizer .model_max_length , no_boseos_middle = no_boseos_middle , clip_skip = clip_skip
341
371
)
342
372
prompt_weights = torch .tensor (prompt_weights , dtype = text_embeddings .dtype , device = text_embeddings .device )
343
373
if uncond_prompt is not None :
@@ -346,6 +376,7 @@ def get_weighted_text_embeddings(
346
376
uncond_tokens ,
347
377
pipe .tokenizer .model_max_length ,
348
378
no_boseos_middle = no_boseos_middle ,
379
+ clip_skip = clip_skip ,
349
380
)
350
381
uncond_weights = torch .tensor (uncond_weights , dtype = uncond_embeddings .dtype , device = uncond_embeddings .device )
351
382
@@ -362,6 +393,11 @@ def get_weighted_text_embeddings(
362
393
current_mean = uncond_embeddings .float ().mean (axis = [- 2 , - 1 ]).to (uncond_embeddings .dtype )
363
394
uncond_embeddings *= (previous_mean / current_mean ).unsqueeze (- 1 ).unsqueeze (- 1 )
364
395
396
+ if pipe .text_encoder is not None :
397
+ if isinstance (pipe , StableDiffusionLoraLoaderMixin ) and USE_PEFT_BACKEND :
398
+ # Retrieve the original scale by scaling back the LoRA layers
399
+ unscale_lora_layers (pipe .text_encoder , lora_scale )
400
+
365
401
if uncond_prompt is not None :
366
402
return text_embeddings , uncond_embeddings
367
403
return text_embeddings , None
@@ -549,6 +585,8 @@ def _encode_prompt(
549
585
max_embeddings_multiples = 3 ,
550
586
prompt_embeds : Optional [torch .Tensor ] = None ,
551
587
negative_prompt_embeds : Optional [torch .Tensor ] = None ,
588
+ clip_skip : Optional [int ] = None ,
589
+ lora_scale : Optional [float ] = None ,
552
590
):
553
591
r"""
554
592
Encodes the prompt into text encoder hidden states.
@@ -597,6 +635,8 @@ def _encode_prompt(
597
635
prompt = prompt ,
598
636
uncond_prompt = negative_prompt if do_classifier_free_guidance else None ,
599
637
max_embeddings_multiples = max_embeddings_multiples ,
638
+ clip_skip = clip_skip ,
639
+ lora_scale = lora_scale ,
600
640
)
601
641
if prompt_embeds is None :
602
642
prompt_embeds = prompt_embeds1
@@ -790,6 +830,7 @@ def __call__(
790
830
return_dict : bool = True ,
791
831
callback : Optional [Callable [[int , int , torch .Tensor ], None ]] = None ,
792
832
is_cancelled_callback : Optional [Callable [[], bool ]] = None ,
833
+ clip_skip : Optional [int ] = None ,
793
834
callback_steps : int = 1 ,
794
835
cross_attention_kwargs : Optional [Dict [str , Any ]] = None ,
795
836
):
@@ -865,6 +906,9 @@ def __call__(
865
906
is_cancelled_callback (`Callable`, *optional*):
866
907
A function that will be called every `callback_steps` steps during inference. If the function returns
867
908
`True`, the inference will be cancelled.
909
+ clip_skip (`int`, *optional*):
910
+ Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
911
+ the output of the pre-final layer will be used for computing the prompt embeddings.
868
912
callback_steps (`int`, *optional*, defaults to 1):
869
913
The frequency at which the `callback` function will be called. If not specified, the callback will be
870
914
called at every step.
@@ -903,6 +947,7 @@ def __call__(
903
947
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
904
948
# corresponds to doing no classifier free guidance.
905
949
do_classifier_free_guidance = guidance_scale > 1.0
950
+ lora_scale = cross_attention_kwargs .get ("scale" , None ) if cross_attention_kwargs is not None else None
906
951
907
952
# 3. Encode input prompt
908
953
prompt_embeds = self ._encode_prompt (
@@ -914,6 +959,8 @@ def __call__(
914
959
max_embeddings_multiples ,
915
960
prompt_embeds = prompt_embeds ,
916
961
negative_prompt_embeds = negative_prompt_embeds ,
962
+ clip_skip = clip_skip ,
963
+ lora_scale = lora_scale ,
917
964
)
918
965
dtype = prompt_embeds .dtype
919
966
@@ -1044,6 +1091,7 @@ def text2img(
1044
1091
return_dict : bool = True ,
1045
1092
callback : Optional [Callable [[int , int , torch .Tensor ], None ]] = None ,
1046
1093
is_cancelled_callback : Optional [Callable [[], bool ]] = None ,
1094
+ clip_skip = None ,
1047
1095
callback_steps : int = 1 ,
1048
1096
cross_attention_kwargs : Optional [Dict [str , Any ]] = None ,
1049
1097
):
@@ -1101,6 +1149,9 @@ def text2img(
1101
1149
is_cancelled_callback (`Callable`, *optional*):
1102
1150
A function that will be called every `callback_steps` steps during inference. If the function returns
1103
1151
`True`, the inference will be cancelled.
1152
+ clip_skip (`int`, *optional*):
1153
+ Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
1154
+ the output of the pre-final layer will be used for computing the prompt embeddings.
1104
1155
callback_steps (`int`, *optional*, defaults to 1):
1105
1156
The frequency at which the `callback` function will be called. If not specified, the callback will be
1106
1157
called at every step.
@@ -1135,6 +1186,7 @@ def text2img(
1135
1186
return_dict = return_dict ,
1136
1187
callback = callback ,
1137
1188
is_cancelled_callback = is_cancelled_callback ,
1189
+ clip_skip = clip_skip ,
1138
1190
callback_steps = callback_steps ,
1139
1191
cross_attention_kwargs = cross_attention_kwargs ,
1140
1192
)
0 commit comments