diff --git a/src/diffusers/pipelines/unclip/pipeline_unclip.py b/src/diffusers/pipelines/unclip/pipeline_unclip.py index 52f150569364..0f35d004a09a 100644 --- a/src/diffusers/pipelines/unclip/pipeline_unclip.py +++ b/src/diffusers/pipelines/unclip/pipeline_unclip.py @@ -13,12 +13,13 @@ # limitations under the License. import inspect -from typing import List, Optional, Union +from typing import List, Optional, Tuple, Union import torch from torch.nn import functional as F from transformers import CLIPTextModelWithProjection, CLIPTokenizer +from transformers.models.clip.modeling_clip import CLIPTextModelOutput from ...models import PriorTransformer, UNet2DConditionModel, UNet2DModel from ...pipelines import DiffusionPipeline, ImagePipelineOutput @@ -117,31 +118,44 @@ def prepare_latents(self, shape, dtype, device, generator, latents, scheduler): latents = latents * scheduler.init_noise_sigma return latents - def _encode_prompt(self, prompt, device, num_images_per_prompt, do_classifier_free_guidance): - batch_size = len(prompt) if isinstance(prompt, list) else 1 - - # get prompt text embeddings - text_inputs = self.tokenizer( - prompt, - padding="max_length", - max_length=self.tokenizer.model_max_length, - return_tensors="pt", - ) - text_input_ids = text_inputs.input_ids - text_mask = text_inputs.attention_mask.bool().to(device) - - if text_input_ids.shape[-1] > self.tokenizer.model_max_length: - removed_text = self.tokenizer.batch_decode(text_input_ids[:, self.tokenizer.model_max_length :]) - logger.warning( - "The following part of your input was truncated because CLIP can only handle sequences up to" - f" {self.tokenizer.model_max_length} tokens: {removed_text}" + def _encode_prompt( + self, + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + text_model_output: Optional[Union[CLIPTextModelOutput, Tuple]] = None, + text_attention_mask: Optional[torch.Tensor] = None, + ): + if text_model_output is None: + batch_size = len(prompt) if isinstance(prompt, list) else 1 + # get prompt text embeddings + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer.model_max_length, + return_tensors="pt", ) - text_input_ids = text_input_ids[:, : self.tokenizer.model_max_length] + text_input_ids = text_inputs.input_ids + text_mask = text_inputs.attention_mask.bool().to(device) + + if text_input_ids.shape[-1] > self.tokenizer.model_max_length: + removed_text = self.tokenizer.batch_decode(text_input_ids[:, self.tokenizer.model_max_length :]) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {self.tokenizer.model_max_length} tokens: {removed_text}" + ) + text_input_ids = text_input_ids[:, : self.tokenizer.model_max_length] - text_encoder_output = self.text_encoder(text_input_ids.to(device)) + text_encoder_output = self.text_encoder(text_input_ids.to(device)) - text_embeddings = text_encoder_output.text_embeds - text_encoder_hidden_states = text_encoder_output.last_hidden_state + text_embeddings = text_encoder_output.text_embeds + text_encoder_hidden_states = text_encoder_output.last_hidden_state + + else: + batch_size = text_model_output[0].shape[0] + text_embeddings, text_encoder_hidden_states = text_model_output[0], text_model_output[1] + text_mask = text_attention_mask text_embeddings = text_embeddings.repeat_interleave(num_images_per_prompt, dim=0) text_encoder_hidden_states = text_encoder_hidden_states.repeat_interleave(num_images_per_prompt, dim=0) @@ -150,11 +164,10 @@ def _encode_prompt(self, prompt, device, num_images_per_prompt, do_classifier_fr if do_classifier_free_guidance: uncond_tokens = [""] * batch_size - max_length = text_input_ids.shape[-1] uncond_input = self.tokenizer( uncond_tokens, padding="max_length", - max_length=max_length, + max_length=self.tokenizer.model_max_length, truncation=True, return_tensors="pt", ) @@ -235,7 +248,7 @@ def _execution_device(self): @torch.no_grad() def __call__( self, - prompt: Union[str, List[str]], + prompt: Optional[Union[str, List[str]]] = None, num_images_per_prompt: int = 1, prior_num_inference_steps: int = 25, decoder_num_inference_steps: int = 25, @@ -244,6 +257,8 @@ def __call__( prior_latents: Optional[torch.FloatTensor] = None, decoder_latents: Optional[torch.FloatTensor] = None, super_res_latents: Optional[torch.FloatTensor] = None, + text_model_output: Optional[Union[CLIPTextModelOutput, Tuple]] = None, + text_attention_mask: Optional[torch.Tensor] = None, prior_guidance_scale: float = 4.0, decoder_guidance_scale: float = 8.0, output_type: Optional[str] = "pil", @@ -254,7 +269,8 @@ def __call__( Args: prompt (`str` or `List[str]`): - The prompt or prompts to guide the image generation. + The prompt or prompts to guide the image generation. This can only be left undefined if + `text_model_output` and `text_attention_mask` is passed. num_images_per_prompt (`int`, *optional*, defaults to 1): The number of images to generate per prompt. prior_num_inference_steps (`int`, *optional*, defaults to 25): @@ -287,18 +303,29 @@ def __call__( Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, usually at the expense of lower image quality. + text_model_output (`CLIPTextModelOutput`, *optional*): + Pre-defined CLIPTextModel outputs that can be derived from the text encoder. Pre-defined text outputs + can be passed for tasks like text embedding interpolations. Make sure to also pass + `text_attention_mask` in this case. `prompt` can the be left to `None`. + text_attention_mask (`torch.Tensor`, *optional*): + Pre-defined CLIP text attention mask that can be derived from the tokenizer. Pre-defined text attention + masks are necessary when passing `text_model_output`. output_type (`str`, *optional*, defaults to `"pil"`): The output format of the generated image. Choose between [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple. """ - if isinstance(prompt, str): - batch_size = 1 - elif isinstance(prompt, list): - batch_size = len(prompt) + if prompt is not None: + if isinstance(prompt, str): + batch_size = 1 + elif isinstance(prompt, list): + batch_size = len(prompt) + else: + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") else: - raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + batch_size = text_model_output[0].shape[0] + device = self._execution_device batch_size = batch_size * num_images_per_prompt @@ -306,7 +333,7 @@ def __call__( do_classifier_free_guidance = prior_guidance_scale > 1.0 or decoder_guidance_scale > 1.0 text_embeddings, text_encoder_hidden_states, text_mask = self._encode_prompt( - prompt, device, num_images_per_prompt, do_classifier_free_guidance + prompt, device, num_images_per_prompt, do_classifier_free_guidance, text_model_output, text_attention_mask ) # prior @@ -315,6 +342,7 @@ def __call__( prior_timesteps_tensor = self.prior_scheduler.timesteps embedding_dim = self.prior.config.embedding_dim + prior_latents = self.prepare_latents( (batch_size, embedding_dim), text_embeddings.dtype, @@ -378,6 +406,7 @@ def __call__( num_channels_latents = self.decoder.in_channels height = self.decoder.sample_size width = self.decoder.sample_size + decoder_latents = self.prepare_latents( (batch_size, num_channels_latents, height, width), text_encoder_hidden_states.dtype, @@ -430,6 +459,7 @@ def __call__( channels = self.super_res_first.in_channels // 2 height = self.super_res_first.sample_size width = self.super_res_first.sample_size + super_res_latents = self.prepare_latents( (batch_size, channels, height, width), image_small.dtype, diff --git a/src/diffusers/pipelines/unclip/pipeline_unclip_image_variation.py b/src/diffusers/pipelines/unclip/pipeline_unclip_image_variation.py index 1d0320d38874..0b83407d8ccd 100644 --- a/src/diffusers/pipelines/unclip/pipeline_unclip_image_variation.py +++ b/src/diffusers/pipelines/unclip/pipeline_unclip_image_variation.py @@ -126,7 +126,6 @@ def prepare_latents(self, shape, dtype, device, generator, latents, scheduler): latents = latents * scheduler.init_noise_sigma return latents - # Copied from diffusers.pipelines.unclip.pipeline_unclip.UnCLIPPipeline._encode_prompt def _encode_prompt(self, prompt, device, num_images_per_prompt, do_classifier_free_guidance): batch_size = len(prompt) if isinstance(prompt, list) else 1 @@ -139,15 +138,6 @@ def _encode_prompt(self, prompt, device, num_images_per_prompt, do_classifier_fr ) text_input_ids = text_inputs.input_ids text_mask = text_inputs.attention_mask.bool().to(device) - - if text_input_ids.shape[-1] > self.tokenizer.model_max_length: - removed_text = self.tokenizer.batch_decode(text_input_ids[:, self.tokenizer.model_max_length :]) - logger.warning( - "The following part of your input was truncated because CLIP can only handle sequences up to" - f" {self.tokenizer.model_max_length} tokens: {removed_text}" - ) - text_input_ids = text_input_ids[:, : self.tokenizer.model_max_length] - text_encoder_output = self.text_encoder(text_input_ids.to(device)) text_embeddings = text_encoder_output.text_embeds @@ -199,14 +189,15 @@ def _encode_prompt(self, prompt, device, num_images_per_prompt, do_classifier_fr return text_embeddings, text_encoder_hidden_states, text_mask - def _encode_image(self, image, device, num_images_per_prompt): + def _encode_image(self, image, device, num_images_per_prompt, image_embeddings: Optional[torch.Tensor] = None): dtype = next(self.image_encoder.parameters()).dtype - if not isinstance(image, torch.Tensor): - image = self.feature_extractor(images=image, return_tensors="pt").pixel_values + if image_embeddings is None: + if not isinstance(image, torch.Tensor): + image = self.feature_extractor(images=image, return_tensors="pt").pixel_values - image = image.to(device=device, dtype=dtype) - image_embeddings = self.image_encoder(image).image_embeds + image = image.to(device=device, dtype=dtype) + image_embeddings = self.image_encoder(image).image_embeds image_embeddings = image_embeddings.repeat_interleave(num_images_per_prompt, dim=0) @@ -258,13 +249,14 @@ def _execution_device(self): @torch.no_grad() def __call__( self, - image: Union[PIL.Image.Image, List[PIL.Image.Image], torch.FloatTensor], + image: Optional[Union[PIL.Image.Image, List[PIL.Image.Image], torch.FloatTensor]] = None, num_images_per_prompt: int = 1, decoder_num_inference_steps: int = 25, super_res_num_inference_steps: int = 7, generator: Optional[torch.Generator] = None, decoder_latents: Optional[torch.FloatTensor] = None, super_res_latents: Optional[torch.FloatTensor] = None, + image_embeddings: Optional[torch.Tensor] = None, decoder_guidance_scale: float = 8.0, output_type: Optional[str] = "pil", return_dict: bool = True, @@ -277,7 +269,7 @@ def __call__( The image or images to guide the image generation. If you provide a tensor, it needs to comply with the configuration of [this](https://huggingface.co/fusing/karlo-image-variations-diffusers/blob/main/feature_extractor/preprocessor_config.json) - `CLIPFeatureExtractor`. + `CLIPFeatureExtractor`. Can be left to `None` only when `image_embeddings` are passed. num_images_per_prompt (`int`, *optional*, defaults to 1): The number of images to generate per prompt. decoder_num_inference_steps (`int`, *optional*, defaults to 25): @@ -299,18 +291,24 @@ def __call__( Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, usually at the expense of lower image quality. + image_embeddings (`torch.Tensor`, *optional*): + Pre-defined image embeddings that can be derived from the image encoder. Pre-defined image embeddings + can be passed for tasks like image interpolations. `image` can the be left to `None`. output_type (`str`, *optional*, defaults to `"pil"`): The output format of the generated image. Choose between [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple. """ - if isinstance(image, PIL.Image.Image): - batch_size = 1 - elif isinstance(image, list): - batch_size = len(image) + if image is not None: + if isinstance(image, PIL.Image.Image): + batch_size = 1 + elif isinstance(image, list): + batch_size = len(image) + else: + batch_size = image.shape[0] else: - batch_size = image.shape[0] + batch_size = image_embeddings.shape[0] prompt = [""] * batch_size @@ -324,10 +322,9 @@ def __call__( prompt, device, num_images_per_prompt, do_classifier_free_guidance ) - image_embeddings = self._encode_image(image, device, num_images_per_prompt) + image_embeddings = self._encode_image(image, device, num_images_per_prompt, image_embeddings) # decoder - text_encoder_hidden_states, additive_clip_time_embeddings = self.text_proj( image_embeddings=image_embeddings, text_embeddings=text_embeddings, @@ -343,14 +340,16 @@ def __call__( num_channels_latents = self.decoder.in_channels height = self.decoder.sample_size width = self.decoder.sample_size - decoder_latents = self.prepare_latents( - (batch_size, num_channels_latents, height, width), - text_encoder_hidden_states.dtype, - device, - generator, - decoder_latents, - self.decoder_scheduler, - ) + + if decoder_latents is None: + decoder_latents = self.prepare_latents( + (batch_size, num_channels_latents, height, width), + text_encoder_hidden_states.dtype, + device, + generator, + decoder_latents, + self.decoder_scheduler, + ) for i, t in enumerate(self.progress_bar(decoder_timesteps_tensor)): # expand the latents if we are doing classifier free guidance @@ -395,14 +394,16 @@ def __call__( channels = self.super_res_first.in_channels // 2 height = self.super_res_first.sample_size width = self.super_res_first.sample_size - super_res_latents = self.prepare_latents( - (batch_size, channels, height, width), - image_small.dtype, - device, - generator, - super_res_latents, - self.super_res_scheduler, - ) + + if super_res_latents is None: + super_res_latents = self.prepare_latents( + (batch_size, channels, height, width), + image_small.dtype, + device, + generator, + super_res_latents, + self.super_res_scheduler, + ) interpolate_antialias = {} if "antialias" in inspect.signature(F.interpolate).parameters: diff --git a/tests/pipelines/unclip/test_unclip.py b/tests/pipelines/unclip/test_unclip.py index fb0cb75ea703..670082c20c24 100644 --- a/tests/pipelines/unclip/test_unclip.py +++ b/tests/pipelines/unclip/test_unclip.py @@ -248,6 +248,120 @@ def test_unclip(self): assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2 + def test_unclip_passed_text_embed(self): + device = torch.device("cpu") + + class DummyScheduler: + init_noise_sigma = 1 + + prior = self.dummy_prior + decoder = self.dummy_decoder + text_proj = self.dummy_text_proj + text_encoder = self.dummy_text_encoder + tokenizer = self.dummy_tokenizer + super_res_first = self.dummy_super_res_first + super_res_last = self.dummy_super_res_last + + prior_scheduler = UnCLIPScheduler( + variance_type="fixed_small_log", + prediction_type="sample", + num_train_timesteps=1000, + clip_sample_range=5.0, + ) + + decoder_scheduler = UnCLIPScheduler( + variance_type="learned_range", + prediction_type="epsilon", + num_train_timesteps=1000, + ) + + super_res_scheduler = UnCLIPScheduler( + variance_type="fixed_small_log", + prediction_type="epsilon", + num_train_timesteps=1000, + ) + + pipe = UnCLIPPipeline( + prior=prior, + decoder=decoder, + text_proj=text_proj, + text_encoder=text_encoder, + tokenizer=tokenizer, + super_res_first=super_res_first, + super_res_last=super_res_last, + prior_scheduler=prior_scheduler, + decoder_scheduler=decoder_scheduler, + super_res_scheduler=super_res_scheduler, + ) + pipe = pipe.to(device) + + generator = torch.Generator(device=device).manual_seed(0) + dtype = prior.dtype + batch_size = 1 + + shape = (batch_size, prior.config.embedding_dim) + prior_latents = pipe.prepare_latents( + shape, dtype=dtype, device=device, generator=generator, latents=None, scheduler=DummyScheduler() + ) + shape = (batch_size, decoder.in_channels, decoder.sample_size, decoder.sample_size) + decoder_latents = pipe.prepare_latents( + shape, dtype=dtype, device=device, generator=generator, latents=None, scheduler=DummyScheduler() + ) + + shape = ( + batch_size, + super_res_first.in_channels // 2, + super_res_first.sample_size, + super_res_first.sample_size, + ) + super_res_latents = pipe.prepare_latents( + shape, dtype=dtype, device=device, generator=generator, latents=None, scheduler=DummyScheduler() + ) + + pipe.set_progress_bar_config(disable=None) + + prompt = "this is a prompt example" + + generator = torch.Generator(device=device).manual_seed(0) + output = pipe( + [prompt], + generator=generator, + prior_num_inference_steps=2, + decoder_num_inference_steps=2, + super_res_num_inference_steps=2, + prior_latents=prior_latents, + decoder_latents=decoder_latents, + super_res_latents=super_res_latents, + output_type="np", + ) + image = output.images + + text_inputs = tokenizer( + prompt, + padding="max_length", + max_length=tokenizer.model_max_length, + return_tensors="pt", + ) + text_model_output = text_encoder(text_inputs.input_ids) + text_attention_mask = text_inputs.attention_mask + + generator = torch.Generator(device=device).manual_seed(0) + image_from_text = pipe( + generator=generator, + prior_num_inference_steps=2, + decoder_num_inference_steps=2, + super_res_num_inference_steps=2, + prior_latents=prior_latents, + decoder_latents=decoder_latents, + super_res_latents=super_res_latents, + text_model_output=text_model_output, + text_attention_mask=text_attention_mask, + output_type="np", + )[0] + + # make sure passing text embeddings manually is identical + assert np.abs(image - image_from_text).max() < 1e-4 + @slow @require_torch_gpu diff --git a/tests/pipelines/unclip/test_unclip_image_variation.py b/tests/pipelines/unclip/test_unclip_image_variation.py index 5835a2169aa4..87ad14146a11 100644 --- a/tests/pipelines/unclip/test_unclip_image_variation.py +++ b/tests/pipelines/unclip/test_unclip_image_variation.py @@ -407,6 +407,55 @@ def test_unclip_image_variation_input_num_images_per_prompt(self): assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2 + def test_unclip_passed_image_embed(self): + device = torch.device("cpu") + seed = 0 + + class DummyScheduler: + init_noise_sigma = 1 + + pipe = self.get_pipeline(device) + + generator = torch.Generator(device=device).manual_seed(0) + dtype = pipe.decoder.dtype + batch_size = 1 + + shape = (batch_size, pipe.decoder.in_channels, pipe.decoder.sample_size, pipe.decoder.sample_size) + decoder_latents = pipe.prepare_latents( + shape, dtype=dtype, device=device, generator=generator, latents=None, scheduler=DummyScheduler() + ) + + shape = ( + batch_size, + pipe.super_res_first.in_channels // 2, + pipe.super_res_first.sample_size, + pipe.super_res_first.sample_size, + ) + super_res_latents = pipe.prepare_latents( + shape, dtype=dtype, device=device, generator=generator, latents=None, scheduler=DummyScheduler() + ) + + pipeline_inputs = self.get_pipeline_inputs(device, seed) + + img_out_1 = pipe( + **pipeline_inputs, decoder_latents=decoder_latents, super_res_latents=super_res_latents + ).images + + pipeline_inputs = self.get_pipeline_inputs(device, seed) + # Don't pass image, instead pass embedding + image = pipeline_inputs.pop("image") + image_embeddings = pipe.image_encoder(image).image_embeds + + img_out_2 = pipe( + **pipeline_inputs, + decoder_latents=decoder_latents, + super_res_latents=super_res_latents, + image_embeddings=image_embeddings, + ).images + + # make sure passing text embeddings manually is identical + assert np.abs(img_out_1 - img_out_2).max() < 1e-4 + @slow @require_torch_gpu @@ -426,11 +475,10 @@ def test_unclip_image_variation_karlo(self): "/unclip/karlo_v1_alpha_cat_variation_fp16.npy" ) - pipeline = UnCLIPImageVariationPipeline.from_pretrained( - "fusing/karlo-image-variations-diffusers", torch_dtype=torch.float16 - ) + pipeline = UnCLIPImageVariationPipeline.from_pretrained("fusing/karlo-image-variations-diffusers") pipeline = pipeline.to(torch_device) pipeline.set_progress_bar_config(disable=None) + pipeline.enable_sequential_cpu_offload() generator = torch.Generator(device=torch_device).manual_seed(0) output = pipeline( @@ -442,7 +490,5 @@ def test_unclip_image_variation_karlo(self): image = output.images[0] - np.save("./karlo_v1_alpha_cat_variation_fp16.npy", image) - assert image.shape == (256, 256, 3) - assert np.abs(expected_image - image).max() < 1e-2 + assert np.abs(expected_image - image).max() < 5e-2