Skip to content

Commit edddf3d

Browse files
asomozasayakpaul
authored andcommitted
[Kolors] Add IP Adapter (#8901)
* initial draft * apply suggestions * fix failing test * added ipa to img2img * add docs * apply suggestions
1 parent a9de5cf commit edddf3d

File tree

7 files changed

+362
-16
lines changed

7 files changed

+362
-16
lines changed

docs/source/en/api/pipelines/kolors.md

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,64 @@ image = pipe(
4141
image.save("kolors_sample.png")
4242
```
4343

44+
### IP Adapter
45+
46+
Kolors needs a different IP Adapter to work, and it uses [Openai-CLIP-336](https://huggingface.co/openai/clip-vit-large-patch14-336) as an image encoder.
47+
48+
<Tip>
49+
50+
Using an IP Adapter with Kolors requires more than 24GB of VRAM. To use it, we recommend using [`~DiffusionPipeline.enable_model_cpu_offload`] on consumer GPUs.
51+
52+
</Tip>
53+
54+
<Tip>
55+
56+
While Kolors is integrated in Diffusers, you need to load the image encoder from a revision to use the safetensor files. You can still use the main branch of the original repository if you're comfortable loading pickle checkpoints.
57+
58+
</Tip>
59+
60+
```python
61+
import torch
62+
from transformers import CLIPVisionModelWithProjection
63+
64+
from diffusers import DPMSolverMultistepScheduler, KolorsPipeline
65+
from diffusers.utils import load_image
66+
67+
image_encoder = CLIPVisionModelWithProjection.from_pretrained(
68+
"Kwai-Kolors/Kolors-IP-Adapter-Plus",
69+
subfolder="image_encoder",
70+
low_cpu_mem_usage=True,
71+
torch_dtype=torch.float16,
72+
revision="refs/pr/4",
73+
)
74+
75+
pipe = KolorsPipeline.from_pretrained(
76+
"Kwai-Kolors/Kolors-diffusers", image_encoder=image_encoder, torch_dtype=torch.float16, variant="fp16"
77+
).to("cuda")
78+
pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config, use_karras_sigmas=True)
79+
80+
pipe.load_ip_adapter(
81+
"Kwai-Kolors/Kolors-IP-Adapter-Plus",
82+
subfolder="",
83+
weight_name="ip_adapter_plus_general.safetensors",
84+
revision="refs/pr/4",
85+
image_encoder_folder=None,
86+
)
87+
pipe.enable_model_cpu_offload()
88+
89+
ipa_image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/kolors/cat_square.png")
90+
91+
image = pipe(
92+
prompt="best quality, high quality",
93+
negative_prompt="",
94+
guidance_scale=6.5,
95+
num_inference_steps=25,
96+
ip_adapter_image=ipa_image,
97+
).images[0]
98+
99+
image.save("kolors_ipa_sample.png")
100+
```
101+
44102
## KolorsPipeline
45103

46104
[[autodoc]] KolorsPipeline

src/diffusers/loaders/ip_adapter.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -222,7 +222,8 @@ def load_ip_adapter(
222222

223223
# create feature extractor if it has not been registered to the pipeline yet
224224
if hasattr(self, "feature_extractor") and getattr(self, "feature_extractor", None) is None:
225-
feature_extractor = CLIPImageProcessor()
225+
clip_image_size = self.image_encoder.config.image_size
226+
feature_extractor = CLIPImageProcessor(size=clip_image_size, crop_size=clip_image_size)
226227
self.register_modules(feature_extractor=feature_extractor)
227228

228229
# load ip-adapter into unet
@@ -319,7 +320,13 @@ def unload_ip_adapter(self):
319320

320321
# remove hidden encoder
321322
self.unet.encoder_hid_proj = None
322-
self.config.encoder_hid_dim_type = None
323+
self.unet.config.encoder_hid_dim_type = None
324+
325+
# Kolors: restore `encoder_hid_proj` with `text_encoder_hid_proj`
326+
if hasattr(self.unet, "text_encoder_hid_proj") and self.unet.text_encoder_hid_proj is not None:
327+
self.unet.encoder_hid_proj = self.unet.text_encoder_hid_proj
328+
self.unet.text_encoder_hid_proj = None
329+
self.unet.config.encoder_hid_dim_type = "text_proj"
323330

324331
# restore original Unet attention processors layers
325332
attn_procs = {}

src/diffusers/loaders/unet.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -823,6 +823,15 @@ def _convert_ip_adapter_attn_to_diffusers(self, state_dicts, low_cpu_mem_usage=F
823823
def _load_ip_adapter_weights(self, state_dicts, low_cpu_mem_usage=False):
824824
if not isinstance(state_dicts, list):
825825
state_dicts = [state_dicts]
826+
827+
# Kolors Unet already has a `encoder_hid_proj`
828+
if (
829+
self.encoder_hid_proj is not None
830+
and self.config.encoder_hid_dim_type == "text_proj"
831+
and not hasattr(self, "text_encoder_hid_proj")
832+
):
833+
self.text_encoder_hid_proj = self.encoder_hid_proj
834+
826835
# Set encoder_hid_proj after loading ip_adapter weights,
827836
# because `IPAdapterPlusImageProjection` also has `attn_processors`.
828837
self.encoder_hid_proj = None

src/diffusers/models/unets/unet_2d_condition.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1027,6 +1027,10 @@ def process_encoder_hidden_states(
10271027
raise ValueError(
10281028
f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'ip_image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`"
10291029
)
1030+
1031+
if hasattr(self, "text_encoder_hid_proj") and self.text_encoder_hid_proj is not None:
1032+
encoder_hidden_states = self.text_encoder_hid_proj(encoder_hidden_states)
1033+
10301034
image_embeds = added_cond_kwargs.get("image_embeds")
10311035
image_embeds = self.encoder_hid_proj(image_embeds)
10321036
encoder_hidden_states = (encoder_hidden_states, image_embeds)

src/diffusers/pipelines/kolors/pipeline_kolors.py

Lines changed: 140 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -15,11 +15,12 @@
1515
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
1616

1717
import torch
18+
from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection
1819

1920
from ...callbacks import MultiPipelineCallbacks, PipelineCallback
20-
from ...image_processor import VaeImageProcessor
21-
from ...loaders import StableDiffusionXLLoraLoaderMixin
22-
from ...models import AutoencoderKL, UNet2DConditionModel
21+
from ...image_processor import PipelineImageInput, VaeImageProcessor
22+
from ...loaders import IPAdapterMixin, StableDiffusionXLLoraLoaderMixin
23+
from ...models import AutoencoderKL, ImageProjection, UNet2DConditionModel
2324
from ...models.attention_processor import AttnProcessor2_0, FusedAttnProcessor2_0, XFormersAttnProcessor
2425
from ...schedulers import KarrasDiffusionSchedulers
2526
from ...utils import is_torch_xla_available, logging, replace_example_docstring
@@ -120,7 +121,7 @@ def retrieve_timesteps(
120121
return timesteps, num_inference_steps
121122

122123

123-
class KolorsPipeline(DiffusionPipeline, StableDiffusionMixin, StableDiffusionXLLoraLoaderMixin):
124+
class KolorsPipeline(DiffusionPipeline, StableDiffusionMixin, StableDiffusionXLLoraLoaderMixin, IPAdapterMixin):
124125
r"""
125126
Pipeline for text-to-image generation using Kolors.
126127
@@ -130,6 +131,7 @@ class KolorsPipeline(DiffusionPipeline, StableDiffusionMixin, StableDiffusionXLL
130131
The pipeline also inherits the following loading methods:
131132
- [`~loaders.StableDiffusionXLLoraLoaderMixin.load_lora_weights`] for loading LoRA weights
132133
- [`~loaders.StableDiffusionXLLoraLoaderMixin.save_lora_weights`] for saving LoRA weights
134+
- [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters
133135
134136
Args:
135137
vae ([`AutoencoderKL`]):
@@ -148,7 +150,11 @@ class KolorsPipeline(DiffusionPipeline, StableDiffusionMixin, StableDiffusionXLL
148150
`Kwai-Kolors/Kolors-diffusers`.
149151
"""
150152

151-
model_cpu_offload_seq = "text_encoder->unet->vae"
153+
model_cpu_offload_seq = "text_encoder->image_encoder->unet->vae"
154+
_optional_components = [
155+
"image_encoder",
156+
"feature_extractor",
157+
]
152158
_callback_tensor_inputs = [
153159
"latents",
154160
"prompt_embeds",
@@ -166,11 +172,21 @@ def __init__(
166172
tokenizer: ChatGLMTokenizer,
167173
unet: UNet2DConditionModel,
168174
scheduler: KarrasDiffusionSchedulers,
175+
image_encoder: CLIPVisionModelWithProjection = None,
176+
feature_extractor: CLIPImageProcessor = None,
169177
force_zeros_for_empty_prompt: bool = False,
170178
):
171179
super().__init__()
172180

173-
self.register_modules(vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, unet=unet, scheduler=scheduler)
181+
self.register_modules(
182+
vae=vae,
183+
text_encoder=text_encoder,
184+
tokenizer=tokenizer,
185+
unet=unet,
186+
scheduler=scheduler,
187+
image_encoder=image_encoder,
188+
feature_extractor=feature_extractor,
189+
)
174190
self.register_to_config(force_zeros_for_empty_prompt=force_zeros_for_empty_prompt)
175191
self.vae_scale_factor = (
176192
2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8
@@ -343,6 +359,77 @@ def encode_prompt(
343359

344360
return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds
345361

362+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image
363+
def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None):
364+
dtype = next(self.image_encoder.parameters()).dtype
365+
366+
if not isinstance(image, torch.Tensor):
367+
image = self.feature_extractor(image, return_tensors="pt").pixel_values
368+
369+
image = image.to(device=device, dtype=dtype)
370+
if output_hidden_states:
371+
image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2]
372+
image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0)
373+
uncond_image_enc_hidden_states = self.image_encoder(
374+
torch.zeros_like(image), output_hidden_states=True
375+
).hidden_states[-2]
376+
uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave(
377+
num_images_per_prompt, dim=0
378+
)
379+
return image_enc_hidden_states, uncond_image_enc_hidden_states
380+
else:
381+
image_embeds = self.image_encoder(image).image_embeds
382+
image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0)
383+
uncond_image_embeds = torch.zeros_like(image_embeds)
384+
385+
return image_embeds, uncond_image_embeds
386+
387+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_ip_adapter_image_embeds
388+
def prepare_ip_adapter_image_embeds(
389+
self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt, do_classifier_free_guidance
390+
):
391+
image_embeds = []
392+
if do_classifier_free_guidance:
393+
negative_image_embeds = []
394+
if ip_adapter_image_embeds is None:
395+
if not isinstance(ip_adapter_image, list):
396+
ip_adapter_image = [ip_adapter_image]
397+
398+
if len(ip_adapter_image) != len(self.unet.encoder_hid_proj.image_projection_layers):
399+
raise ValueError(
400+
f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.unet.encoder_hid_proj.image_projection_layers)} IP Adapters."
401+
)
402+
403+
for single_ip_adapter_image, image_proj_layer in zip(
404+
ip_adapter_image, self.unet.encoder_hid_proj.image_projection_layers
405+
):
406+
output_hidden_state = not isinstance(image_proj_layer, ImageProjection)
407+
single_image_embeds, single_negative_image_embeds = self.encode_image(
408+
single_ip_adapter_image, device, 1, output_hidden_state
409+
)
410+
411+
image_embeds.append(single_image_embeds[None, :])
412+
if do_classifier_free_guidance:
413+
negative_image_embeds.append(single_negative_image_embeds[None, :])
414+
else:
415+
for single_image_embeds in ip_adapter_image_embeds:
416+
if do_classifier_free_guidance:
417+
single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2)
418+
negative_image_embeds.append(single_negative_image_embeds)
419+
image_embeds.append(single_image_embeds)
420+
421+
ip_adapter_image_embeds = []
422+
for i, single_image_embeds in enumerate(image_embeds):
423+
single_image_embeds = torch.cat([single_image_embeds] * num_images_per_prompt, dim=0)
424+
if do_classifier_free_guidance:
425+
single_negative_image_embeds = torch.cat([negative_image_embeds[i]] * num_images_per_prompt, dim=0)
426+
single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds], dim=0)
427+
428+
single_image_embeds = single_image_embeds.to(device=device)
429+
ip_adapter_image_embeds.append(single_image_embeds)
430+
431+
return ip_adapter_image_embeds
432+
346433
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
347434
def prepare_extra_step_kwargs(self, generator, eta):
348435
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
@@ -364,16 +451,25 @@ def prepare_extra_step_kwargs(self, generator, eta):
364451
def check_inputs(
365452
self,
366453
prompt,
454+
num_inference_steps,
367455
height,
368456
width,
369457
negative_prompt=None,
370458
prompt_embeds=None,
371459
pooled_prompt_embeds=None,
372460
negative_prompt_embeds=None,
373461
negative_pooled_prompt_embeds=None,
462+
ip_adapter_image=None,
463+
ip_adapter_image_embeds=None,
374464
callback_on_step_end_tensor_inputs=None,
375465
max_sequence_length=None,
376466
):
467+
if not isinstance(num_inference_steps, int) or num_inference_steps <= 0:
468+
raise ValueError(
469+
f"`num_inference_steps` has to be a positive integer but is {num_inference_steps} of type"
470+
f" {type(num_inference_steps)}."
471+
)
472+
377473
if height % 8 != 0 or width % 8 != 0:
378474
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
379475

@@ -420,6 +516,21 @@ def check_inputs(
420516
"If `negative_prompt_embeds` are provided, `negative_pooled_prompt_embeds` also have to be passed. Make sure to generate `negative_pooled_prompt_embeds` from the same text encoder that was used to generate `negative_prompt_embeds`."
421517
)
422518

519+
if ip_adapter_image is not None and ip_adapter_image_embeds is not None:
520+
raise ValueError(
521+
"Provide either `ip_adapter_image` or `ip_adapter_image_embeds`. Cannot leave both `ip_adapter_image` and `ip_adapter_image_embeds` defined."
522+
)
523+
524+
if ip_adapter_image_embeds is not None:
525+
if not isinstance(ip_adapter_image_embeds, list):
526+
raise ValueError(
527+
f"`ip_adapter_image_embeds` has to be of type `list` but is {type(ip_adapter_image_embeds)}"
528+
)
529+
elif ip_adapter_image_embeds[0].ndim not in [3, 4]:
530+
raise ValueError(
531+
f"`ip_adapter_image_embeds` has to be a list of 3D or 4D tensors but is {ip_adapter_image_embeds[0].ndim}D"
532+
)
533+
423534
if max_sequence_length is not None and max_sequence_length > 256:
424535
raise ValueError(f"`max_sequence_length` cannot be greater than 256 but is {max_sequence_length}")
425536

@@ -563,6 +674,8 @@ def __call__(
563674
pooled_prompt_embeds: Optional[torch.Tensor] = None,
564675
negative_prompt_embeds: Optional[torch.Tensor] = None,
565676
negative_pooled_prompt_embeds: Optional[torch.Tensor] = None,
677+
ip_adapter_image: Optional[PipelineImageInput] = None,
678+
ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None,
566679
output_type: Optional[str] = "pil",
567680
return_dict: bool = True,
568681
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
@@ -649,6 +762,12 @@ def __call__(
649762
Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
650763
weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt`
651764
input argument.
765+
ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters.
766+
ip_adapter_image_embeds (`List[torch.Tensor]`, *optional*):
767+
Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of
768+
IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. It should
769+
contain the negative image embedding if `do_classifier_free_guidance` is set to `True`. If not
770+
provided, embeddings are computed from the `ip_adapter_image` input argument.
652771
output_type (`str`, *optional*, defaults to `"pil"`):
653772
The output format of the generate image. Choose between
654773
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
@@ -719,13 +838,16 @@ def __call__(
719838
# 1. Check inputs. Raise error if not correct
720839
self.check_inputs(
721840
prompt,
841+
num_inference_steps,
722842
height,
723843
width,
724844
negative_prompt,
725845
prompt_embeds,
726846
pooled_prompt_embeds,
727847
negative_prompt_embeds,
728848
negative_pooled_prompt_embeds,
849+
ip_adapter_image,
850+
ip_adapter_image_embeds,
729851
callback_on_step_end_tensor_inputs,
730852
max_sequence_length=max_sequence_length,
731853
)
@@ -815,6 +937,15 @@ def __call__(
815937
add_text_embeds = add_text_embeds.to(device)
816938
add_time_ids = add_time_ids.to(device).repeat(batch_size * num_images_per_prompt, 1)
817939

940+
if ip_adapter_image is not None or ip_adapter_image_embeds is not None:
941+
image_embeds = self.prepare_ip_adapter_image_embeds(
942+
ip_adapter_image,
943+
ip_adapter_image_embeds,
944+
device,
945+
batch_size * num_images_per_prompt,
946+
self.do_classifier_free_guidance,
947+
)
948+
818949
# 8. Denoising loop
819950
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
820951

@@ -856,6 +987,9 @@ def __call__(
856987
# predict the noise residual
857988
added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids}
858989

990+
if ip_adapter_image is not None or ip_adapter_image_embeds is not None:
991+
added_cond_kwargs["image_embeds"] = image_embeds
992+
859993
noise_pred = self.unet(
860994
latent_model_input,
861995
t,

0 commit comments

Comments
 (0)