15
15
from typing import Any , Callable , Dict , List , Optional , Tuple , Union
16
16
17
17
import torch
18
+ from transformers import CLIPImageProcessor , CLIPVisionModelWithProjection
18
19
19
20
from ...callbacks import MultiPipelineCallbacks , PipelineCallback
20
- from ...image_processor import VaeImageProcessor
21
- from ...loaders import StableDiffusionXLLoraLoaderMixin
22
- from ...models import AutoencoderKL , UNet2DConditionModel
21
+ from ...image_processor import PipelineImageInput , VaeImageProcessor
22
+ from ...loaders import IPAdapterMixin , StableDiffusionXLLoraLoaderMixin
23
+ from ...models import AutoencoderKL , ImageProjection , UNet2DConditionModel
23
24
from ...models .attention_processor import AttnProcessor2_0 , FusedAttnProcessor2_0 , XFormersAttnProcessor
24
25
from ...schedulers import KarrasDiffusionSchedulers
25
26
from ...utils import is_torch_xla_available , logging , replace_example_docstring
@@ -120,7 +121,7 @@ def retrieve_timesteps(
120
121
return timesteps , num_inference_steps
121
122
122
123
123
- class KolorsPipeline (DiffusionPipeline , StableDiffusionMixin , StableDiffusionXLLoraLoaderMixin ):
124
+ class KolorsPipeline (DiffusionPipeline , StableDiffusionMixin , StableDiffusionXLLoraLoaderMixin , IPAdapterMixin ):
124
125
r"""
125
126
Pipeline for text-to-image generation using Kolors.
126
127
@@ -130,6 +131,7 @@ class KolorsPipeline(DiffusionPipeline, StableDiffusionMixin, StableDiffusionXLL
130
131
The pipeline also inherits the following loading methods:
131
132
- [`~loaders.StableDiffusionXLLoraLoaderMixin.load_lora_weights`] for loading LoRA weights
132
133
- [`~loaders.StableDiffusionXLLoraLoaderMixin.save_lora_weights`] for saving LoRA weights
134
+ - [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters
133
135
134
136
Args:
135
137
vae ([`AutoencoderKL`]):
@@ -148,7 +150,11 @@ class KolorsPipeline(DiffusionPipeline, StableDiffusionMixin, StableDiffusionXLL
148
150
`Kwai-Kolors/Kolors-diffusers`.
149
151
"""
150
152
151
- model_cpu_offload_seq = "text_encoder->unet->vae"
153
+ model_cpu_offload_seq = "text_encoder->image_encoder->unet->vae"
154
+ _optional_components = [
155
+ "image_encoder" ,
156
+ "feature_extractor" ,
157
+ ]
152
158
_callback_tensor_inputs = [
153
159
"latents" ,
154
160
"prompt_embeds" ,
@@ -166,11 +172,21 @@ def __init__(
166
172
tokenizer : ChatGLMTokenizer ,
167
173
unet : UNet2DConditionModel ,
168
174
scheduler : KarrasDiffusionSchedulers ,
175
+ image_encoder : CLIPVisionModelWithProjection = None ,
176
+ feature_extractor : CLIPImageProcessor = None ,
169
177
force_zeros_for_empty_prompt : bool = False ,
170
178
):
171
179
super ().__init__ ()
172
180
173
- self .register_modules (vae = vae , text_encoder = text_encoder , tokenizer = tokenizer , unet = unet , scheduler = scheduler )
181
+ self .register_modules (
182
+ vae = vae ,
183
+ text_encoder = text_encoder ,
184
+ tokenizer = tokenizer ,
185
+ unet = unet ,
186
+ scheduler = scheduler ,
187
+ image_encoder = image_encoder ,
188
+ feature_extractor = feature_extractor ,
189
+ )
174
190
self .register_to_config (force_zeros_for_empty_prompt = force_zeros_for_empty_prompt )
175
191
self .vae_scale_factor = (
176
192
2 ** (len (self .vae .config .block_out_channels ) - 1 ) if hasattr (self , "vae" ) and self .vae is not None else 8
@@ -343,6 +359,77 @@ def encode_prompt(
343
359
344
360
return prompt_embeds , negative_prompt_embeds , pooled_prompt_embeds , negative_pooled_prompt_embeds
345
361
362
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image
363
+ def encode_image (self , image , device , num_images_per_prompt , output_hidden_states = None ):
364
+ dtype = next (self .image_encoder .parameters ()).dtype
365
+
366
+ if not isinstance (image , torch .Tensor ):
367
+ image = self .feature_extractor (image , return_tensors = "pt" ).pixel_values
368
+
369
+ image = image .to (device = device , dtype = dtype )
370
+ if output_hidden_states :
371
+ image_enc_hidden_states = self .image_encoder (image , output_hidden_states = True ).hidden_states [- 2 ]
372
+ image_enc_hidden_states = image_enc_hidden_states .repeat_interleave (num_images_per_prompt , dim = 0 )
373
+ uncond_image_enc_hidden_states = self .image_encoder (
374
+ torch .zeros_like (image ), output_hidden_states = True
375
+ ).hidden_states [- 2 ]
376
+ uncond_image_enc_hidden_states = uncond_image_enc_hidden_states .repeat_interleave (
377
+ num_images_per_prompt , dim = 0
378
+ )
379
+ return image_enc_hidden_states , uncond_image_enc_hidden_states
380
+ else :
381
+ image_embeds = self .image_encoder (image ).image_embeds
382
+ image_embeds = image_embeds .repeat_interleave (num_images_per_prompt , dim = 0 )
383
+ uncond_image_embeds = torch .zeros_like (image_embeds )
384
+
385
+ return image_embeds , uncond_image_embeds
386
+
387
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_ip_adapter_image_embeds
388
+ def prepare_ip_adapter_image_embeds (
389
+ self , ip_adapter_image , ip_adapter_image_embeds , device , num_images_per_prompt , do_classifier_free_guidance
390
+ ):
391
+ image_embeds = []
392
+ if do_classifier_free_guidance :
393
+ negative_image_embeds = []
394
+ if ip_adapter_image_embeds is None :
395
+ if not isinstance (ip_adapter_image , list ):
396
+ ip_adapter_image = [ip_adapter_image ]
397
+
398
+ if len (ip_adapter_image ) != len (self .unet .encoder_hid_proj .image_projection_layers ):
399
+ raise ValueError (
400
+ f"`ip_adapter_image` must have same length as the number of IP Adapters. Got { len (ip_adapter_image )} images and { len (self .unet .encoder_hid_proj .image_projection_layers )} IP Adapters."
401
+ )
402
+
403
+ for single_ip_adapter_image , image_proj_layer in zip (
404
+ ip_adapter_image , self .unet .encoder_hid_proj .image_projection_layers
405
+ ):
406
+ output_hidden_state = not isinstance (image_proj_layer , ImageProjection )
407
+ single_image_embeds , single_negative_image_embeds = self .encode_image (
408
+ single_ip_adapter_image , device , 1 , output_hidden_state
409
+ )
410
+
411
+ image_embeds .append (single_image_embeds [None , :])
412
+ if do_classifier_free_guidance :
413
+ negative_image_embeds .append (single_negative_image_embeds [None , :])
414
+ else :
415
+ for single_image_embeds in ip_adapter_image_embeds :
416
+ if do_classifier_free_guidance :
417
+ single_negative_image_embeds , single_image_embeds = single_image_embeds .chunk (2 )
418
+ negative_image_embeds .append (single_negative_image_embeds )
419
+ image_embeds .append (single_image_embeds )
420
+
421
+ ip_adapter_image_embeds = []
422
+ for i , single_image_embeds in enumerate (image_embeds ):
423
+ single_image_embeds = torch .cat ([single_image_embeds ] * num_images_per_prompt , dim = 0 )
424
+ if do_classifier_free_guidance :
425
+ single_negative_image_embeds = torch .cat ([negative_image_embeds [i ]] * num_images_per_prompt , dim = 0 )
426
+ single_image_embeds = torch .cat ([single_negative_image_embeds , single_image_embeds ], dim = 0 )
427
+
428
+ single_image_embeds = single_image_embeds .to (device = device )
429
+ ip_adapter_image_embeds .append (single_image_embeds )
430
+
431
+ return ip_adapter_image_embeds
432
+
346
433
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
347
434
def prepare_extra_step_kwargs (self , generator , eta ):
348
435
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
@@ -364,16 +451,25 @@ def prepare_extra_step_kwargs(self, generator, eta):
364
451
def check_inputs (
365
452
self ,
366
453
prompt ,
454
+ num_inference_steps ,
367
455
height ,
368
456
width ,
369
457
negative_prompt = None ,
370
458
prompt_embeds = None ,
371
459
pooled_prompt_embeds = None ,
372
460
negative_prompt_embeds = None ,
373
461
negative_pooled_prompt_embeds = None ,
462
+ ip_adapter_image = None ,
463
+ ip_adapter_image_embeds = None ,
374
464
callback_on_step_end_tensor_inputs = None ,
375
465
max_sequence_length = None ,
376
466
):
467
+ if not isinstance (num_inference_steps , int ) or num_inference_steps <= 0 :
468
+ raise ValueError (
469
+ f"`num_inference_steps` has to be a positive integer but is { num_inference_steps } of type"
470
+ f" { type (num_inference_steps )} ."
471
+ )
472
+
377
473
if height % 8 != 0 or width % 8 != 0 :
378
474
raise ValueError (f"`height` and `width` have to be divisible by 8 but are { height } and { width } ." )
379
475
@@ -420,6 +516,21 @@ def check_inputs(
420
516
"If `negative_prompt_embeds` are provided, `negative_pooled_prompt_embeds` also have to be passed. Make sure to generate `negative_pooled_prompt_embeds` from the same text encoder that was used to generate `negative_prompt_embeds`."
421
517
)
422
518
519
+ if ip_adapter_image is not None and ip_adapter_image_embeds is not None :
520
+ raise ValueError (
521
+ "Provide either `ip_adapter_image` or `ip_adapter_image_embeds`. Cannot leave both `ip_adapter_image` and `ip_adapter_image_embeds` defined."
522
+ )
523
+
524
+ if ip_adapter_image_embeds is not None :
525
+ if not isinstance (ip_adapter_image_embeds , list ):
526
+ raise ValueError (
527
+ f"`ip_adapter_image_embeds` has to be of type `list` but is { type (ip_adapter_image_embeds )} "
528
+ )
529
+ elif ip_adapter_image_embeds [0 ].ndim not in [3 , 4 ]:
530
+ raise ValueError (
531
+ f"`ip_adapter_image_embeds` has to be a list of 3D or 4D tensors but is { ip_adapter_image_embeds [0 ].ndim } D"
532
+ )
533
+
423
534
if max_sequence_length is not None and max_sequence_length > 256 :
424
535
raise ValueError (f"`max_sequence_length` cannot be greater than 256 but is { max_sequence_length } " )
425
536
@@ -563,6 +674,8 @@ def __call__(
563
674
pooled_prompt_embeds : Optional [torch .Tensor ] = None ,
564
675
negative_prompt_embeds : Optional [torch .Tensor ] = None ,
565
676
negative_pooled_prompt_embeds : Optional [torch .Tensor ] = None ,
677
+ ip_adapter_image : Optional [PipelineImageInput ] = None ,
678
+ ip_adapter_image_embeds : Optional [List [torch .Tensor ]] = None ,
566
679
output_type : Optional [str ] = "pil" ,
567
680
return_dict : bool = True ,
568
681
cross_attention_kwargs : Optional [Dict [str , Any ]] = None ,
@@ -649,6 +762,12 @@ def __call__(
649
762
Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
650
763
weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt`
651
764
input argument.
765
+ ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters.
766
+ ip_adapter_image_embeds (`List[torch.Tensor]`, *optional*):
767
+ Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of
768
+ IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. It should
769
+ contain the negative image embedding if `do_classifier_free_guidance` is set to `True`. If not
770
+ provided, embeddings are computed from the `ip_adapter_image` input argument.
652
771
output_type (`str`, *optional*, defaults to `"pil"`):
653
772
The output format of the generate image. Choose between
654
773
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
@@ -719,13 +838,16 @@ def __call__(
719
838
# 1. Check inputs. Raise error if not correct
720
839
self .check_inputs (
721
840
prompt ,
841
+ num_inference_steps ,
722
842
height ,
723
843
width ,
724
844
negative_prompt ,
725
845
prompt_embeds ,
726
846
pooled_prompt_embeds ,
727
847
negative_prompt_embeds ,
728
848
negative_pooled_prompt_embeds ,
849
+ ip_adapter_image ,
850
+ ip_adapter_image_embeds ,
729
851
callback_on_step_end_tensor_inputs ,
730
852
max_sequence_length = max_sequence_length ,
731
853
)
@@ -815,6 +937,15 @@ def __call__(
815
937
add_text_embeds = add_text_embeds .to (device )
816
938
add_time_ids = add_time_ids .to (device ).repeat (batch_size * num_images_per_prompt , 1 )
817
939
940
+ if ip_adapter_image is not None or ip_adapter_image_embeds is not None :
941
+ image_embeds = self .prepare_ip_adapter_image_embeds (
942
+ ip_adapter_image ,
943
+ ip_adapter_image_embeds ,
944
+ device ,
945
+ batch_size * num_images_per_prompt ,
946
+ self .do_classifier_free_guidance ,
947
+ )
948
+
818
949
# 8. Denoising loop
819
950
num_warmup_steps = max (len (timesteps ) - num_inference_steps * self .scheduler .order , 0 )
820
951
@@ -856,6 +987,9 @@ def __call__(
856
987
# predict the noise residual
857
988
added_cond_kwargs = {"text_embeds" : add_text_embeds , "time_ids" : add_time_ids }
858
989
990
+ if ip_adapter_image is not None or ip_adapter_image_embeds is not None :
991
+ added_cond_kwargs ["image_embeds" ] = image_embeds
992
+
859
993
noise_pred = self .unet (
860
994
latent_model_input ,
861
995
t ,
0 commit comments