From c59d38a0ad708573fd4b8ddf471aa64ce72c0749 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Thu, 29 Dec 2022 13:40:40 +0000 Subject: [PATCH 1/7] [Unclip] Make sure latents can be reused --- .../pipelines/unclip/pipeline_unclip.py | 54 ++++++++++--------- .../unclip/pipeline_unclip_image_variation.py | 37 +++++++------ .../unclip/test_unclip_image_variation.py | 4 +- 3 files changed, 51 insertions(+), 44 deletions(-) diff --git a/src/diffusers/pipelines/unclip/pipeline_unclip.py b/src/diffusers/pipelines/unclip/pipeline_unclip.py index a8e8cba8375d..3f16c55b3cda 100644 --- a/src/diffusers/pipelines/unclip/pipeline_unclip.py +++ b/src/diffusers/pipelines/unclip/pipeline_unclip.py @@ -315,14 +315,16 @@ def __call__( prior_timesteps_tensor = self.prior_scheduler.timesteps embedding_dim = self.prior.config.embedding_dim - prior_latents = self.prepare_latents( - (batch_size, embedding_dim), - text_embeddings.dtype, - device, - generator, - prior_latents, - self.prior_scheduler, - ) + + if prior_latents is None: + prior_latents = self.prepare_latents( + (batch_size, embedding_dim), + text_embeddings.dtype, + device, + generator, + prior_latents, + self.prior_scheduler, + ) for i, t in enumerate(self.progress_bar(prior_timesteps_tensor)): # expand the latents if we are doing classifier free guidance @@ -378,14 +380,16 @@ def __call__( num_channels_latents = self.decoder.in_channels height = self.decoder.sample_size width = self.decoder.sample_size - decoder_latents = self.prepare_latents( - (batch_size, num_channels_latents, height, width), - text_encoder_hidden_states.dtype, - device, - generator, - decoder_latents, - self.decoder_scheduler, - ) + + if decoder_latents is None: + decoder_latents = self.prepare_latents( + (batch_size, num_channels_latents, height, width), + text_encoder_hidden_states.dtype, + device, + generator, + decoder_latents, + self.decoder_scheduler, + ) for i, t in enumerate(self.progress_bar(decoder_timesteps_tensor)): # expand the latents if we are doing classifier free guidance @@ -430,14 +434,16 @@ def __call__( channels = self.super_res_first.in_channels // 2 height = self.super_res_first.sample_size width = self.super_res_first.sample_size - super_res_latents = self.prepare_latents( - (batch_size, channels, height, width), - image_small.dtype, - device, - generator, - super_res_latents, - self.super_res_scheduler, - ) + + if super_res_latents is None: + super_res_latents = self.prepare_latents( + (batch_size, channels, height, width), + image_small.dtype, + device, + generator, + super_res_latents, + self.super_res_scheduler, + ) interpolate_antialias = {} if "antialias" in inspect.signature(F.interpolate).parameters: diff --git a/src/diffusers/pipelines/unclip/pipeline_unclip_image_variation.py b/src/diffusers/pipelines/unclip/pipeline_unclip_image_variation.py index 32b950397a32..33a77fbf5d69 100644 --- a/src/diffusers/pipelines/unclip/pipeline_unclip_image_variation.py +++ b/src/diffusers/pipelines/unclip/pipeline_unclip_image_variation.py @@ -327,7 +327,6 @@ def __call__( image_embeddings = self._encode_image(image, device, num_images_per_prompt) # decoder - text_encoder_hidden_states, additive_clip_time_embeddings = self.text_proj( image_embeddings=image_embeddings, text_embeddings=text_embeddings, @@ -343,14 +342,16 @@ def __call__( num_channels_latents = self.decoder.in_channels height = self.decoder.sample_size width = self.decoder.sample_size - decoder_latents = self.prepare_latents( - (batch_size, num_channels_latents, height, width), - text_encoder_hidden_states.dtype, - device, - generator, - decoder_latents, - self.decoder_scheduler, - ) + + if decoder_latents is None: + decoder_latents = self.prepare_latents( + (batch_size, num_channels_latents, height, width), + text_encoder_hidden_states.dtype, + device, + generator, + decoder_latents, + self.decoder_scheduler, + ) for i, t in enumerate(self.progress_bar(decoder_timesteps_tensor)): # expand the latents if we are doing classifier free guidance @@ -395,14 +396,16 @@ def __call__( channels = self.super_res_first.in_channels // 2 height = self.super_res_first.sample_size width = self.super_res_first.sample_size - super_res_latents = self.prepare_latents( - (batch_size, channels, height, width), - image_small.dtype, - device, - generator, - super_res_latents, - self.super_res_scheduler, - ) + + if super_res_latents is None: + super_res_latents = self.prepare_latents( + (batch_size, channels, height, width), + image_small.dtype, + device, + generator, + super_res_latents, + self.super_res_scheduler, + ) interpolate_antialias = {} if "antialias" in inspect.signature(F.interpolate).parameters: diff --git a/tests/pipelines/unclip/test_unclip_image_variation.py b/tests/pipelines/unclip/test_unclip_image_variation.py index 64b51536084b..9b790e2059e4 100644 --- a/tests/pipelines/unclip/test_unclip_image_variation.py +++ b/tests/pipelines/unclip/test_unclip_image_variation.py @@ -437,7 +437,5 @@ def test_unclip_image_variation_karlo(self): image = output.images[0] - np.save("./karlo_v1_alpha_cat_variation_fp16.npy", image) - assert image.shape == (256, 256, 3) - assert np.abs(expected_image - image).max() < 1e-2 + assert np.abs(expected_image - image).max() < 5e-2 From a20ea60e4a5c699c7ad6ff0971d9753946faac89 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Thu, 29 Dec 2022 20:58:39 +0100 Subject: [PATCH 2/7] allow one to directly pass embeddings --- 1 | 502 ++++++++++++++++++ .../pipelines/unclip/pipeline_unclip.py | 75 +-- .../unclip/pipeline_unclip_image_variation.py | 14 +- 3 files changed, 554 insertions(+), 37 deletions(-) create mode 100644 1 diff --git a/1 b/1 new file mode 100644 index 000000000000..31edf7091714 --- /dev/null +++ b/1 @@ -0,0 +1,502 @@ +# Copyright 2022 Kakao Brain and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from typing import List, Optional, Union, Tuple + +import torch +from torch.nn import functional as F + +from diffusers import PriorTransformer, UNet2DConditionModel, UNet2DModel +from diffusers.pipeline_utils import DiffusionPipeline, ImagePipelineOutput +from diffusers.schedulers import UnCLIPScheduler +from transformers import CLIPTextModelWithProjection, CLIPTokenizer +from transformers.models.clip.modeling_clip import CLIPTextModelOutput + +from ...utils import is_accelerate_available, logging +from .text_proj import UnCLIPTextProjModel + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +class UnCLIPPipeline(DiffusionPipeline): + """ + Pipeline for text-to-image generation using unCLIP + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + + Args: + text_encoder ([`CLIPTextModelWithProjection`]): + Frozen text-encoder. + tokenizer (`CLIPTokenizer`): + Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). + prior ([`PriorTransformer`]): + The canonincal unCLIP prior to approximate the image embedding from the text embedding. + text_proj ([`UnCLIPTextProjModel`]): + Utility class to prepare and combine the embeddings before they are passed to the decoder. + decoder ([`UNet2DConditionModel`]): + The decoder to invert the image embedding into an image. + super_res_first ([`UNet2DModel`]): + Super resolution unet. Used in all but the last step of the super resolution diffusion process. + super_res_last ([`UNet2DModel`]): + Super resolution unet. Used in the last step of the super resolution diffusion process. + prior_scheduler ([`UnCLIPScheduler`]): + Scheduler used in the prior denoising process. Just a modified DDPMScheduler. + decoder_scheduler ([`UnCLIPScheduler`]): + Scheduler used in the decoder denoising process. Just a modified DDPMScheduler. + super_res_scheduler ([`UnCLIPScheduler`]): + Scheduler used in the super resolution denoising process. Just a modified DDPMScheduler. + + """ + + prior: PriorTransformer + decoder: UNet2DConditionModel + text_proj: UnCLIPTextProjModel + text_encoder: CLIPTextModelWithProjection + tokenizer: CLIPTokenizer + super_res_first: UNet2DModel + super_res_last: UNet2DModel + + prior_scheduler: UnCLIPScheduler + decoder_scheduler: UnCLIPScheduler + super_res_scheduler: UnCLIPScheduler + + def __init__( + self, + prior: PriorTransformer, + decoder: UNet2DConditionModel, + text_encoder: CLIPTextModelWithProjection, + tokenizer: CLIPTokenizer, + text_proj: UnCLIPTextProjModel, + super_res_first: UNet2DModel, + super_res_last: UNet2DModel, + prior_scheduler: UnCLIPScheduler, + decoder_scheduler: UnCLIPScheduler, + super_res_scheduler: UnCLIPScheduler, + ): + super().__init__() + + self.register_modules( + prior=prior, + decoder=decoder, + text_encoder=text_encoder, + tokenizer=tokenizer, + text_proj=text_proj, + super_res_first=super_res_first, + super_res_last=super_res_last, + prior_scheduler=prior_scheduler, + decoder_scheduler=decoder_scheduler, + super_res_scheduler=super_res_scheduler, + ) + + def prepare_latents(self, shape, dtype, device, generator, latents, scheduler): + if latents is None: + if device.type == "mps": + # randn does not work reproducibly on mps + latents = torch.randn(shape, generator=generator, device="cpu", dtype=dtype).to(device) + else: + latents = torch.randn(shape, generator=generator, device=device, dtype=dtype) + else: + if latents.shape != shape: + raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}") + latents = latents.to(device) + + latents = latents * scheduler.init_noise_sigma + return latents + + def _encode_prompt(self, prompt, device, num_images_per_prompt, do_classifier_free_guidance, text_model_output: Optional[Union[CLIPTextModelOutput, Tuple]] = None): + batch_size = len(prompt) if isinstance(prompt, list) else 1 + + if text_model_output is None: + # get prompt text embeddings + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer.model_max_length, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + text_mask = text_inputs.attention_mask.bool().to(device) + + if text_input_ids.shape[-1] > self.tokenizer.model_max_length: + removed_text = self.tokenizer.batch_decode(text_input_ids[:, self.tokenizer.model_max_length :]) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {self.tokenizer.model_max_length} tokens: {removed_text}" + ) + text_input_ids = text_input_ids[:, : self.tokenizer.model_max_length] + + text_encoder_output = self.text_encoder(text_input_ids.to(device)) + + text_embeddings = text_encoder_output.text_embeds + text_encoder_hidden_states = text_encoder_output.last_hidden_state + else: + text_embeddings, text_encoder_hidden_states = text_encoder_output[0], text_encoder_output[1] + + text_embeddings = text_embeddings.repeat_interleave(num_images_per_prompt, dim=0) + text_encoder_hidden_states = text_encoder_hidden_states.repeat_interleave(num_images_per_prompt, dim=0) + text_mask = text_mask.repeat_interleave(num_images_per_prompt, dim=0) + + if do_classifier_free_guidance: + uncond_tokens = [""] * batch_size + + max_length = text_input_ids.shape[-1] + uncond_input = self.tokenizer( + uncond_tokens, + padding="max_length", + max_length=max_length, + truncation=True, + return_tensors="pt", + ) + uncond_text_mask = uncond_input.attention_mask.bool().to(device) + uncond_embeddings_text_encoder_output = self.text_encoder(uncond_input.input_ids.to(device)) + + uncond_embeddings = uncond_embeddings_text_encoder_output.text_embeds + uncond_text_encoder_hidden_states = uncond_embeddings_text_encoder_output.last_hidden_state + + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + + seq_len = uncond_embeddings.shape[1] + uncond_embeddings = uncond_embeddings.repeat(1, num_images_per_prompt) + uncond_embeddings = uncond_embeddings.view(batch_size * num_images_per_prompt, seq_len) + + seq_len = uncond_text_encoder_hidden_states.shape[1] + uncond_text_encoder_hidden_states = uncond_text_encoder_hidden_states.repeat(1, num_images_per_prompt, 1) + uncond_text_encoder_hidden_states = uncond_text_encoder_hidden_states.view( + batch_size * num_images_per_prompt, seq_len, -1 + ) + uncond_text_mask = uncond_text_mask.repeat_interleave(num_images_per_prompt, dim=0) + + # done duplicates + + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + text_embeddings = torch.cat([uncond_embeddings, text_embeddings]) + text_encoder_hidden_states = torch.cat([uncond_text_encoder_hidden_states, text_encoder_hidden_states]) + + text_mask = torch.cat([uncond_text_mask, text_mask]) + + return text_embeddings, text_encoder_hidden_states, text_mask + + def enable_sequential_cpu_offload(self, gpu_id=0): + r""" + Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, the pipeline's + models have their state dicts saved to CPU and then are moved to a `torch.device('meta') and loaded to GPU only + when their specific submodule has its `forward` method called. + """ + if is_accelerate_available(): + from accelerate import cpu_offload + else: + raise ImportError("Please install accelerate via `pip install accelerate`") + + device = torch.device(f"cuda:{gpu_id}") + + # TODO: self.prior.post_process_latents is not covered by the offload hooks, so it fails if added to the list + models = [ + self.decoder, + self.text_proj, + self.text_encoder, + self.super_res_first, + self.super_res_last, + ] + for cpu_offloaded_model in models: + if cpu_offloaded_model is not None: + cpu_offload(cpu_offloaded_model, device) + + @property + def _execution_device(self): + r""" + Returns the device on which the pipeline's models will be executed. After calling + `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module + hooks. + """ + if self.device != torch.device("meta") or not hasattr(self.decoder, "_hf_hook"): + return self.device + for module in self.decoder.modules(): + if ( + hasattr(module, "_hf_hook") + and hasattr(module._hf_hook, "execution_device") + and module._hf_hook.execution_device is not None + ): + return torch.device(module._hf_hook.execution_device) + return self.device + + @torch.no_grad() + def __call__( + self, + prompt: Union[str, List[str]], + num_images_per_prompt: int = 1, + prior_num_inference_steps: int = 25, + decoder_num_inference_steps: int = 25, + super_res_num_inference_steps: int = 7, + generator: Optional[torch.Generator] = None, + prior_latents: Optional[torch.FloatTensor] = None, + decoder_latents: Optional[torch.FloatTensor] = None, + super_res_latents: Optional[torch.FloatTensor] = None, + text_model_output: Optional[Union[CLIPTextModelOutput, Tuple]] = None, + prior_guidance_scale: float = 4.0, + decoder_guidance_scale: float = 8.0, + output_type: Optional[str] = "pil", + return_dict: bool = True, + ): + """ + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `List[str]`): + The prompt or prompts to guide the image generation. + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + prior_num_inference_steps (`int`, *optional*, defaults to 25): + The number of denoising steps for the prior. More denoising steps usually lead to a higher quality + image at the expense of slower inference. + decoder_num_inference_steps (`int`, *optional*, defaults to 25): + The number of denoising steps for the decoder. More denoising steps usually lead to a higher quality + image at the expense of slower inference. + super_res_num_inference_steps (`int`, *optional*, defaults to 7): + The number of denoising steps for super resolution. More denoising steps usually lead to a higher + quality image at the expense of slower inference. + generator (`torch.Generator`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + prior_latents (`torch.FloatTensor` of shape (batch size, embeddings dimension), *optional*): + Pre-generated noisy latents to be used as inputs for the prior. + decoder_latents (`torch.FloatTensor` of shape (batch size, channels, height, width), *optional*): + Pre-generated noisy latents to be used as inputs for the decoder. + super_res_latents (`torch.FloatTensor` of shape (batch size, channels, super res height, super res width), *optional*): + Pre-generated noisy latents to be used as inputs for the decoder. + prior_guidance_scale (`float`, *optional*, defaults to 4.0): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + decoder_guidance_scale (`float`, *optional*, defaults to 4.0): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generated image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipeline_utils.ImagePipelineOutput`] instead of a plain tuple. + """ + if isinstance(prompt, str): + batch_size = 1 + elif isinstance(prompt, list): + batch_size = len(prompt) + else: + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + device = self._execution_device + + batch_size = batch_size * num_images_per_prompt + + do_classifier_free_guidance = prior_guidance_scale > 1.0 or decoder_guidance_scale > 1.0 + + text_embeddings, text_encoder_hidden_states, text_mask = self._encode_prompt( + prompt, device, num_images_per_prompt, do_classifier_free_guidance + ) + + # prior + + self.prior_scheduler.set_timesteps(prior_num_inference_steps, device=device) + prior_timesteps_tensor = self.prior_scheduler.timesteps + + embedding_dim = self.prior.config.embedding_dim + + if prior_latents is None: + prior_latents = self.prepare_latents( + (batch_size, embedding_dim), + text_embeddings.dtype, + device, + generator, + prior_latents, + self.prior_scheduler, + ) + + for i, t in enumerate(self.progress_bar(prior_timesteps_tensor)): + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([prior_latents] * 2) if do_classifier_free_guidance else prior_latents + + predicted_image_embedding = self.prior( + latent_model_input, + timestep=t, + proj_embedding=text_embeddings, + encoder_hidden_states=text_encoder_hidden_states, + attention_mask=text_mask, + ).predicted_image_embedding + + if do_classifier_free_guidance: + predicted_image_embedding_uncond, predicted_image_embedding_text = predicted_image_embedding.chunk(2) + predicted_image_embedding = predicted_image_embedding_uncond + prior_guidance_scale * ( + predicted_image_embedding_text - predicted_image_embedding_uncond + ) + + if i + 1 == prior_timesteps_tensor.shape[0]: + prev_timestep = None + else: + prev_timestep = prior_timesteps_tensor[i + 1] + + prior_latents = self.prior_scheduler.step( + predicted_image_embedding, + timestep=t, + sample=prior_latents, + generator=generator, + prev_timestep=prev_timestep, + ).prev_sample + + prior_latents = self.prior.post_process_latents(prior_latents) + + image_embeddings = prior_latents + + # done prior + + # decoder + + text_encoder_hidden_states, additive_clip_time_embeddings = self.text_proj( + image_embeddings=image_embeddings, + text_embeddings=text_embeddings, + text_encoder_hidden_states=text_encoder_hidden_states, + do_classifier_free_guidance=do_classifier_free_guidance, + ) + + decoder_text_mask = F.pad(text_mask, (self.text_proj.clip_extra_context_tokens, 0), value=1) + + self.decoder_scheduler.set_timesteps(decoder_num_inference_steps, device=device) + decoder_timesteps_tensor = self.decoder_scheduler.timesteps + + num_channels_latents = self.decoder.in_channels + height = self.decoder.sample_size + width = self.decoder.sample_size + + if decoder_latents is None: + decoder_latents = self.prepare_latents( + (batch_size, num_channels_latents, height, width), + text_encoder_hidden_states.dtype, + device, + generator, + decoder_latents, + self.decoder_scheduler, + ) + + for i, t in enumerate(self.progress_bar(decoder_timesteps_tensor)): + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([decoder_latents] * 2) if do_classifier_free_guidance else decoder_latents + + noise_pred = self.decoder( + sample=latent_model_input, + timestep=t, + encoder_hidden_states=text_encoder_hidden_states, + class_labels=additive_clip_time_embeddings, + attention_mask=decoder_text_mask, + ).sample + + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred_uncond, _ = noise_pred_uncond.split(latent_model_input.shape[1], dim=1) + noise_pred_text, predicted_variance = noise_pred_text.split(latent_model_input.shape[1], dim=1) + noise_pred = noise_pred_uncond + decoder_guidance_scale * (noise_pred_text - noise_pred_uncond) + noise_pred = torch.cat([noise_pred, predicted_variance], dim=1) + + if i + 1 == decoder_timesteps_tensor.shape[0]: + prev_timestep = None + else: + prev_timestep = decoder_timesteps_tensor[i + 1] + + # compute the previous noisy sample x_t -> x_t-1 + decoder_latents = self.decoder_scheduler.step( + noise_pred, t, decoder_latents, prev_timestep=prev_timestep, generator=generator + ).prev_sample + + decoder_latents = decoder_latents.clamp(-1, 1) + + image_small = decoder_latents + + # done decoder + + # super res + + self.super_res_scheduler.set_timesteps(super_res_num_inference_steps, device=device) + super_res_timesteps_tensor = self.super_res_scheduler.timesteps + + channels = self.super_res_first.in_channels // 2 + height = self.super_res_first.sample_size + width = self.super_res_first.sample_size + + if super_res_latents is None: + super_res_latents = self.prepare_latents( + (batch_size, channels, height, width), + image_small.dtype, + device, + generator, + super_res_latents, + self.super_res_scheduler, + ) + + interpolate_antialias = {} + if "antialias" in inspect.signature(F.interpolate).parameters: + interpolate_antialias["antialias"] = True + + image_upscaled = F.interpolate( + image_small, size=[height, width], mode="bicubic", align_corners=False, **interpolate_antialias + ) + + for i, t in enumerate(self.progress_bar(super_res_timesteps_tensor)): + # no classifier free guidance + + if i == super_res_timesteps_tensor.shape[0] - 1: + unet = self.super_res_last + else: + unet = self.super_res_first + + latent_model_input = torch.cat([super_res_latents, image_upscaled], dim=1) + + noise_pred = unet( + sample=latent_model_input, + timestep=t, + ).sample + + if i + 1 == super_res_timesteps_tensor.shape[0]: + prev_timestep = None + else: + prev_timestep = super_res_timesteps_tensor[i + 1] + + # compute the previous noisy sample x_t -> x_t-1 + super_res_latents = self.super_res_scheduler.step( + noise_pred, t, super_res_latents, prev_timestep=prev_timestep, generator=generator + ).prev_sample + + image = super_res_latents + + # done super res + + # post processing + + image = image * 0.5 + 0.5 + image = image.clamp(0, 1) + image = image.cpu().permute(0, 2, 3, 1).float().numpy() + + if output_type == "pil": + image = self.numpy_to_pil(image) + + if not return_dict: + return (image,) + + return ImagePipelineOutput(images=image) diff --git a/src/diffusers/pipelines/unclip/pipeline_unclip.py b/src/diffusers/pipelines/unclip/pipeline_unclip.py index 3f16c55b3cda..1da5b227bed1 100644 --- a/src/diffusers/pipelines/unclip/pipeline_unclip.py +++ b/src/diffusers/pipelines/unclip/pipeline_unclip.py @@ -13,7 +13,7 @@ # limitations under the License. import inspect -from typing import List, Optional, Union +from typing import List, Optional, Union, Tuple import torch from torch.nn import functional as F @@ -22,6 +22,7 @@ from diffusers.pipeline_utils import DiffusionPipeline, ImagePipelineOutput from diffusers.schedulers import UnCLIPScheduler from transformers import CLIPTextModelWithProjection, CLIPTokenizer +from transformers.models.clip.modeling_clip import CLIPTextModelOutput from ...utils import is_accelerate_available, logging from .text_proj import UnCLIPTextProjModel @@ -117,35 +118,41 @@ def prepare_latents(self, shape, dtype, device, generator, latents, scheduler): latents = latents * scheduler.init_noise_sigma return latents - def _encode_prompt(self, prompt, device, num_images_per_prompt, do_classifier_free_guidance): - batch_size = len(prompt) if isinstance(prompt, list) else 1 + def _encode_prompt(self, prompt, device, num_images_per_prompt, do_classifier_free_guidance, text_model_output: Optional[Union[CLIPTextModelOutput, Tuple]] = None, text_attention_mask: Optional[torch.Tensor] = None): - # get prompt text embeddings - text_inputs = self.tokenizer( - prompt, - padding="max_length", - max_length=self.tokenizer.model_max_length, - return_tensors="pt", - ) - text_input_ids = text_inputs.input_ids - text_mask = text_inputs.attention_mask.bool().to(device) - - if text_input_ids.shape[-1] > self.tokenizer.model_max_length: - removed_text = self.tokenizer.batch_decode(text_input_ids[:, self.tokenizer.model_max_length :]) - logger.warning( - "The following part of your input was truncated because CLIP can only handle sequences up to" - f" {self.tokenizer.model_max_length} tokens: {removed_text}" + if text_model_output is None: + batch_size = len(prompt) if isinstance(prompt, list) else 1 + # get prompt text embeddings + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer.model_max_length, + return_tensors="pt", ) - text_input_ids = text_input_ids[:, : self.tokenizer.model_max_length] + text_input_ids = text_inputs.input_ids + text_mask = text_inputs.attention_mask.bool().to(device) + + if text_input_ids.shape[-1] > self.tokenizer.model_max_length: + removed_text = self.tokenizer.batch_decode(text_input_ids[:, self.tokenizer.model_max_length :]) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {self.tokenizer.model_max_length} tokens: {removed_text}" + ) + text_input_ids = text_input_ids[:, : self.tokenizer.model_max_length] - text_encoder_output = self.text_encoder(text_input_ids.to(device)) + text_encoder_output = self.text_encoder(text_input_ids.to(device)) - text_embeddings = text_encoder_output.text_embeds - text_encoder_hidden_states = text_encoder_output.last_hidden_state + text_embeddings = text_encoder_output.text_embeds + text_encoder_hidden_states = text_encoder_output.last_hidden_state - text_embeddings = text_embeddings.repeat_interleave(num_images_per_prompt, dim=0) - text_encoder_hidden_states = text_encoder_hidden_states.repeat_interleave(num_images_per_prompt, dim=0) - text_mask = text_mask.repeat_interleave(num_images_per_prompt, dim=0) + else: + batch_size = text_model_output[0].shape[0] + text_embeddings, text_encoder_hidden_states = text_encoder_output[0], text_encoder_output[1] + text_mask = text_attention_mask + + text_embeddings = text_embeddings.repeat_interleave(num_images_per_prompt, dim=0) + text_encoder_hidden_states = text_encoder_hidden_states.repeat_interleave(num_images_per_prompt, dim=0) + text_mask = text_mask.repeat_interleave(num_images_per_prompt, dim=0) if do_classifier_free_guidance: uncond_tokens = [""] * batch_size @@ -244,6 +251,8 @@ def __call__( prior_latents: Optional[torch.FloatTensor] = None, decoder_latents: Optional[torch.FloatTensor] = None, super_res_latents: Optional[torch.FloatTensor] = None, + text_model_output: Optional[Union[CLIPTextModelOutput, Tuple]] = None, + text_attention_mask: Optional[torch.Tensor] = None, prior_guidance_scale: float = 4.0, decoder_guidance_scale: float = 8.0, output_type: Optional[str] = "pil", @@ -293,12 +302,16 @@ def __call__( return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~pipeline_utils.ImagePipelineOutput`] instead of a plain tuple. """ - if isinstance(prompt, str): - batch_size = 1 - elif isinstance(prompt, list): - batch_size = len(prompt) + if prompt is not None: + if isinstance(prompt, str): + batch_size = 1 + elif isinstance(prompt, list): + batch_size = len(prompt) + else: + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") else: - raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + batch_size = text_model_output[0].shape[0] + device = self._execution_device batch_size = batch_size * num_images_per_prompt @@ -306,7 +319,7 @@ def __call__( do_classifier_free_guidance = prior_guidance_scale > 1.0 or decoder_guidance_scale > 1.0 text_embeddings, text_encoder_hidden_states, text_mask = self._encode_prompt( - prompt, device, num_images_per_prompt, do_classifier_free_guidance + prompt, device, num_images_per_prompt, do_classifier_free_guidance, text_model_output, text_attention_mask ) # prior diff --git a/src/diffusers/pipelines/unclip/pipeline_unclip_image_variation.py b/src/diffusers/pipelines/unclip/pipeline_unclip_image_variation.py index 33a77fbf5d69..d656ea815262 100644 --- a/src/diffusers/pipelines/unclip/pipeline_unclip_image_variation.py +++ b/src/diffusers/pipelines/unclip/pipeline_unclip_image_variation.py @@ -199,14 +199,15 @@ def _encode_prompt(self, prompt, device, num_images_per_prompt, do_classifier_fr return text_embeddings, text_encoder_hidden_states, text_mask - def _encode_image(self, image, device, num_images_per_prompt): + def _encode_image(self, image, device, num_images_per_prompt, image_embeddings: Optional[torch.Tensor] = None): dtype = next(self.image_encoder.parameters()).dtype - if not isinstance(image, torch.Tensor): - image = self.feature_extractor(images=image, return_tensors="pt").pixel_values + if image_embeddings is None: + if not isinstance(image, torch.Tensor): + image = self.feature_extractor(images=image, return_tensors="pt").pixel_values - image = image.to(device=device, dtype=dtype) - image_embeddings = self.image_encoder(image).image_embeds + image = image.to(device=device, dtype=dtype) + image_embeddings = self.image_encoder(image).image_embeds image_embeddings = image_embeddings.repeat_interleave(num_images_per_prompt, dim=0) @@ -265,6 +266,7 @@ def __call__( generator: Optional[torch.Generator] = None, decoder_latents: Optional[torch.FloatTensor] = None, super_res_latents: Optional[torch.FloatTensor] = None, + image_embeddings: Optional[torch.Tensor] = None, decoder_guidance_scale: float = 8.0, output_type: Optional[str] = "pil", return_dict: bool = True, @@ -324,7 +326,7 @@ def __call__( prompt, device, num_images_per_prompt, do_classifier_free_guidance ) - image_embeddings = self._encode_image(image, device, num_images_per_prompt) + image_embeddings = self._encode_image(image, device, num_images_per_prompt, image_embeddings) # decoder text_encoder_hidden_states, additive_clip_time_embeddings = self.text_proj( From 8dbe8cce5233424a0b6f584ab68f261e8824f037 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Thu, 29 Dec 2022 20:59:59 +0100 Subject: [PATCH 3/7] up --- 1 | 502 -------------------------------------------------------------- 1 file changed, 502 deletions(-) delete mode 100644 1 diff --git a/1 b/1 deleted file mode 100644 index 31edf7091714..000000000000 --- a/1 +++ /dev/null @@ -1,502 +0,0 @@ -# Copyright 2022 Kakao Brain and The HuggingFace Team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import inspect -from typing import List, Optional, Union, Tuple - -import torch -from torch.nn import functional as F - -from diffusers import PriorTransformer, UNet2DConditionModel, UNet2DModel -from diffusers.pipeline_utils import DiffusionPipeline, ImagePipelineOutput -from diffusers.schedulers import UnCLIPScheduler -from transformers import CLIPTextModelWithProjection, CLIPTokenizer -from transformers.models.clip.modeling_clip import CLIPTextModelOutput - -from ...utils import is_accelerate_available, logging -from .text_proj import UnCLIPTextProjModel - - -logger = logging.get_logger(__name__) # pylint: disable=invalid-name - - -class UnCLIPPipeline(DiffusionPipeline): - """ - Pipeline for text-to-image generation using unCLIP - - This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the - library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) - - Args: - text_encoder ([`CLIPTextModelWithProjection`]): - Frozen text-encoder. - tokenizer (`CLIPTokenizer`): - Tokenizer of class - [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). - prior ([`PriorTransformer`]): - The canonincal unCLIP prior to approximate the image embedding from the text embedding. - text_proj ([`UnCLIPTextProjModel`]): - Utility class to prepare and combine the embeddings before they are passed to the decoder. - decoder ([`UNet2DConditionModel`]): - The decoder to invert the image embedding into an image. - super_res_first ([`UNet2DModel`]): - Super resolution unet. Used in all but the last step of the super resolution diffusion process. - super_res_last ([`UNet2DModel`]): - Super resolution unet. Used in the last step of the super resolution diffusion process. - prior_scheduler ([`UnCLIPScheduler`]): - Scheduler used in the prior denoising process. Just a modified DDPMScheduler. - decoder_scheduler ([`UnCLIPScheduler`]): - Scheduler used in the decoder denoising process. Just a modified DDPMScheduler. - super_res_scheduler ([`UnCLIPScheduler`]): - Scheduler used in the super resolution denoising process. Just a modified DDPMScheduler. - - """ - - prior: PriorTransformer - decoder: UNet2DConditionModel - text_proj: UnCLIPTextProjModel - text_encoder: CLIPTextModelWithProjection - tokenizer: CLIPTokenizer - super_res_first: UNet2DModel - super_res_last: UNet2DModel - - prior_scheduler: UnCLIPScheduler - decoder_scheduler: UnCLIPScheduler - super_res_scheduler: UnCLIPScheduler - - def __init__( - self, - prior: PriorTransformer, - decoder: UNet2DConditionModel, - text_encoder: CLIPTextModelWithProjection, - tokenizer: CLIPTokenizer, - text_proj: UnCLIPTextProjModel, - super_res_first: UNet2DModel, - super_res_last: UNet2DModel, - prior_scheduler: UnCLIPScheduler, - decoder_scheduler: UnCLIPScheduler, - super_res_scheduler: UnCLIPScheduler, - ): - super().__init__() - - self.register_modules( - prior=prior, - decoder=decoder, - text_encoder=text_encoder, - tokenizer=tokenizer, - text_proj=text_proj, - super_res_first=super_res_first, - super_res_last=super_res_last, - prior_scheduler=prior_scheduler, - decoder_scheduler=decoder_scheduler, - super_res_scheduler=super_res_scheduler, - ) - - def prepare_latents(self, shape, dtype, device, generator, latents, scheduler): - if latents is None: - if device.type == "mps": - # randn does not work reproducibly on mps - latents = torch.randn(shape, generator=generator, device="cpu", dtype=dtype).to(device) - else: - latents = torch.randn(shape, generator=generator, device=device, dtype=dtype) - else: - if latents.shape != shape: - raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}") - latents = latents.to(device) - - latents = latents * scheduler.init_noise_sigma - return latents - - def _encode_prompt(self, prompt, device, num_images_per_prompt, do_classifier_free_guidance, text_model_output: Optional[Union[CLIPTextModelOutput, Tuple]] = None): - batch_size = len(prompt) if isinstance(prompt, list) else 1 - - if text_model_output is None: - # get prompt text embeddings - text_inputs = self.tokenizer( - prompt, - padding="max_length", - max_length=self.tokenizer.model_max_length, - return_tensors="pt", - ) - text_input_ids = text_inputs.input_ids - text_mask = text_inputs.attention_mask.bool().to(device) - - if text_input_ids.shape[-1] > self.tokenizer.model_max_length: - removed_text = self.tokenizer.batch_decode(text_input_ids[:, self.tokenizer.model_max_length :]) - logger.warning( - "The following part of your input was truncated because CLIP can only handle sequences up to" - f" {self.tokenizer.model_max_length} tokens: {removed_text}" - ) - text_input_ids = text_input_ids[:, : self.tokenizer.model_max_length] - - text_encoder_output = self.text_encoder(text_input_ids.to(device)) - - text_embeddings = text_encoder_output.text_embeds - text_encoder_hidden_states = text_encoder_output.last_hidden_state - else: - text_embeddings, text_encoder_hidden_states = text_encoder_output[0], text_encoder_output[1] - - text_embeddings = text_embeddings.repeat_interleave(num_images_per_prompt, dim=0) - text_encoder_hidden_states = text_encoder_hidden_states.repeat_interleave(num_images_per_prompt, dim=0) - text_mask = text_mask.repeat_interleave(num_images_per_prompt, dim=0) - - if do_classifier_free_guidance: - uncond_tokens = [""] * batch_size - - max_length = text_input_ids.shape[-1] - uncond_input = self.tokenizer( - uncond_tokens, - padding="max_length", - max_length=max_length, - truncation=True, - return_tensors="pt", - ) - uncond_text_mask = uncond_input.attention_mask.bool().to(device) - uncond_embeddings_text_encoder_output = self.text_encoder(uncond_input.input_ids.to(device)) - - uncond_embeddings = uncond_embeddings_text_encoder_output.text_embeds - uncond_text_encoder_hidden_states = uncond_embeddings_text_encoder_output.last_hidden_state - - # duplicate unconditional embeddings for each generation per prompt, using mps friendly method - - seq_len = uncond_embeddings.shape[1] - uncond_embeddings = uncond_embeddings.repeat(1, num_images_per_prompt) - uncond_embeddings = uncond_embeddings.view(batch_size * num_images_per_prompt, seq_len) - - seq_len = uncond_text_encoder_hidden_states.shape[1] - uncond_text_encoder_hidden_states = uncond_text_encoder_hidden_states.repeat(1, num_images_per_prompt, 1) - uncond_text_encoder_hidden_states = uncond_text_encoder_hidden_states.view( - batch_size * num_images_per_prompt, seq_len, -1 - ) - uncond_text_mask = uncond_text_mask.repeat_interleave(num_images_per_prompt, dim=0) - - # done duplicates - - # For classifier free guidance, we need to do two forward passes. - # Here we concatenate the unconditional and text embeddings into a single batch - # to avoid doing two forward passes - text_embeddings = torch.cat([uncond_embeddings, text_embeddings]) - text_encoder_hidden_states = torch.cat([uncond_text_encoder_hidden_states, text_encoder_hidden_states]) - - text_mask = torch.cat([uncond_text_mask, text_mask]) - - return text_embeddings, text_encoder_hidden_states, text_mask - - def enable_sequential_cpu_offload(self, gpu_id=0): - r""" - Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, the pipeline's - models have their state dicts saved to CPU and then are moved to a `torch.device('meta') and loaded to GPU only - when their specific submodule has its `forward` method called. - """ - if is_accelerate_available(): - from accelerate import cpu_offload - else: - raise ImportError("Please install accelerate via `pip install accelerate`") - - device = torch.device(f"cuda:{gpu_id}") - - # TODO: self.prior.post_process_latents is not covered by the offload hooks, so it fails if added to the list - models = [ - self.decoder, - self.text_proj, - self.text_encoder, - self.super_res_first, - self.super_res_last, - ] - for cpu_offloaded_model in models: - if cpu_offloaded_model is not None: - cpu_offload(cpu_offloaded_model, device) - - @property - def _execution_device(self): - r""" - Returns the device on which the pipeline's models will be executed. After calling - `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module - hooks. - """ - if self.device != torch.device("meta") or not hasattr(self.decoder, "_hf_hook"): - return self.device - for module in self.decoder.modules(): - if ( - hasattr(module, "_hf_hook") - and hasattr(module._hf_hook, "execution_device") - and module._hf_hook.execution_device is not None - ): - return torch.device(module._hf_hook.execution_device) - return self.device - - @torch.no_grad() - def __call__( - self, - prompt: Union[str, List[str]], - num_images_per_prompt: int = 1, - prior_num_inference_steps: int = 25, - decoder_num_inference_steps: int = 25, - super_res_num_inference_steps: int = 7, - generator: Optional[torch.Generator] = None, - prior_latents: Optional[torch.FloatTensor] = None, - decoder_latents: Optional[torch.FloatTensor] = None, - super_res_latents: Optional[torch.FloatTensor] = None, - text_model_output: Optional[Union[CLIPTextModelOutput, Tuple]] = None, - prior_guidance_scale: float = 4.0, - decoder_guidance_scale: float = 8.0, - output_type: Optional[str] = "pil", - return_dict: bool = True, - ): - """ - Function invoked when calling the pipeline for generation. - - Args: - prompt (`str` or `List[str]`): - The prompt or prompts to guide the image generation. - num_images_per_prompt (`int`, *optional*, defaults to 1): - The number of images to generate per prompt. - prior_num_inference_steps (`int`, *optional*, defaults to 25): - The number of denoising steps for the prior. More denoising steps usually lead to a higher quality - image at the expense of slower inference. - decoder_num_inference_steps (`int`, *optional*, defaults to 25): - The number of denoising steps for the decoder. More denoising steps usually lead to a higher quality - image at the expense of slower inference. - super_res_num_inference_steps (`int`, *optional*, defaults to 7): - The number of denoising steps for super resolution. More denoising steps usually lead to a higher - quality image at the expense of slower inference. - generator (`torch.Generator`, *optional*): - One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) - to make generation deterministic. - prior_latents (`torch.FloatTensor` of shape (batch size, embeddings dimension), *optional*): - Pre-generated noisy latents to be used as inputs for the prior. - decoder_latents (`torch.FloatTensor` of shape (batch size, channels, height, width), *optional*): - Pre-generated noisy latents to be used as inputs for the decoder. - super_res_latents (`torch.FloatTensor` of shape (batch size, channels, super res height, super res width), *optional*): - Pre-generated noisy latents to be used as inputs for the decoder. - prior_guidance_scale (`float`, *optional*, defaults to 4.0): - Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). - `guidance_scale` is defined as `w` of equation 2. of [Imagen - Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > - 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, - usually at the expense of lower image quality. - decoder_guidance_scale (`float`, *optional*, defaults to 4.0): - Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). - `guidance_scale` is defined as `w` of equation 2. of [Imagen - Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > - 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, - usually at the expense of lower image quality. - output_type (`str`, *optional*, defaults to `"pil"`): - The output format of the generated image. Choose between - [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. - return_dict (`bool`, *optional*, defaults to `True`): - Whether or not to return a [`~pipeline_utils.ImagePipelineOutput`] instead of a plain tuple. - """ - if isinstance(prompt, str): - batch_size = 1 - elif isinstance(prompt, list): - batch_size = len(prompt) - else: - raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") - device = self._execution_device - - batch_size = batch_size * num_images_per_prompt - - do_classifier_free_guidance = prior_guidance_scale > 1.0 or decoder_guidance_scale > 1.0 - - text_embeddings, text_encoder_hidden_states, text_mask = self._encode_prompt( - prompt, device, num_images_per_prompt, do_classifier_free_guidance - ) - - # prior - - self.prior_scheduler.set_timesteps(prior_num_inference_steps, device=device) - prior_timesteps_tensor = self.prior_scheduler.timesteps - - embedding_dim = self.prior.config.embedding_dim - - if prior_latents is None: - prior_latents = self.prepare_latents( - (batch_size, embedding_dim), - text_embeddings.dtype, - device, - generator, - prior_latents, - self.prior_scheduler, - ) - - for i, t in enumerate(self.progress_bar(prior_timesteps_tensor)): - # expand the latents if we are doing classifier free guidance - latent_model_input = torch.cat([prior_latents] * 2) if do_classifier_free_guidance else prior_latents - - predicted_image_embedding = self.prior( - latent_model_input, - timestep=t, - proj_embedding=text_embeddings, - encoder_hidden_states=text_encoder_hidden_states, - attention_mask=text_mask, - ).predicted_image_embedding - - if do_classifier_free_guidance: - predicted_image_embedding_uncond, predicted_image_embedding_text = predicted_image_embedding.chunk(2) - predicted_image_embedding = predicted_image_embedding_uncond + prior_guidance_scale * ( - predicted_image_embedding_text - predicted_image_embedding_uncond - ) - - if i + 1 == prior_timesteps_tensor.shape[0]: - prev_timestep = None - else: - prev_timestep = prior_timesteps_tensor[i + 1] - - prior_latents = self.prior_scheduler.step( - predicted_image_embedding, - timestep=t, - sample=prior_latents, - generator=generator, - prev_timestep=prev_timestep, - ).prev_sample - - prior_latents = self.prior.post_process_latents(prior_latents) - - image_embeddings = prior_latents - - # done prior - - # decoder - - text_encoder_hidden_states, additive_clip_time_embeddings = self.text_proj( - image_embeddings=image_embeddings, - text_embeddings=text_embeddings, - text_encoder_hidden_states=text_encoder_hidden_states, - do_classifier_free_guidance=do_classifier_free_guidance, - ) - - decoder_text_mask = F.pad(text_mask, (self.text_proj.clip_extra_context_tokens, 0), value=1) - - self.decoder_scheduler.set_timesteps(decoder_num_inference_steps, device=device) - decoder_timesteps_tensor = self.decoder_scheduler.timesteps - - num_channels_latents = self.decoder.in_channels - height = self.decoder.sample_size - width = self.decoder.sample_size - - if decoder_latents is None: - decoder_latents = self.prepare_latents( - (batch_size, num_channels_latents, height, width), - text_encoder_hidden_states.dtype, - device, - generator, - decoder_latents, - self.decoder_scheduler, - ) - - for i, t in enumerate(self.progress_bar(decoder_timesteps_tensor)): - # expand the latents if we are doing classifier free guidance - latent_model_input = torch.cat([decoder_latents] * 2) if do_classifier_free_guidance else decoder_latents - - noise_pred = self.decoder( - sample=latent_model_input, - timestep=t, - encoder_hidden_states=text_encoder_hidden_states, - class_labels=additive_clip_time_embeddings, - attention_mask=decoder_text_mask, - ).sample - - if do_classifier_free_guidance: - noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) - noise_pred_uncond, _ = noise_pred_uncond.split(latent_model_input.shape[1], dim=1) - noise_pred_text, predicted_variance = noise_pred_text.split(latent_model_input.shape[1], dim=1) - noise_pred = noise_pred_uncond + decoder_guidance_scale * (noise_pred_text - noise_pred_uncond) - noise_pred = torch.cat([noise_pred, predicted_variance], dim=1) - - if i + 1 == decoder_timesteps_tensor.shape[0]: - prev_timestep = None - else: - prev_timestep = decoder_timesteps_tensor[i + 1] - - # compute the previous noisy sample x_t -> x_t-1 - decoder_latents = self.decoder_scheduler.step( - noise_pred, t, decoder_latents, prev_timestep=prev_timestep, generator=generator - ).prev_sample - - decoder_latents = decoder_latents.clamp(-1, 1) - - image_small = decoder_latents - - # done decoder - - # super res - - self.super_res_scheduler.set_timesteps(super_res_num_inference_steps, device=device) - super_res_timesteps_tensor = self.super_res_scheduler.timesteps - - channels = self.super_res_first.in_channels // 2 - height = self.super_res_first.sample_size - width = self.super_res_first.sample_size - - if super_res_latents is None: - super_res_latents = self.prepare_latents( - (batch_size, channels, height, width), - image_small.dtype, - device, - generator, - super_res_latents, - self.super_res_scheduler, - ) - - interpolate_antialias = {} - if "antialias" in inspect.signature(F.interpolate).parameters: - interpolate_antialias["antialias"] = True - - image_upscaled = F.interpolate( - image_small, size=[height, width], mode="bicubic", align_corners=False, **interpolate_antialias - ) - - for i, t in enumerate(self.progress_bar(super_res_timesteps_tensor)): - # no classifier free guidance - - if i == super_res_timesteps_tensor.shape[0] - 1: - unet = self.super_res_last - else: - unet = self.super_res_first - - latent_model_input = torch.cat([super_res_latents, image_upscaled], dim=1) - - noise_pred = unet( - sample=latent_model_input, - timestep=t, - ).sample - - if i + 1 == super_res_timesteps_tensor.shape[0]: - prev_timestep = None - else: - prev_timestep = super_res_timesteps_tensor[i + 1] - - # compute the previous noisy sample x_t -> x_t-1 - super_res_latents = self.super_res_scheduler.step( - noise_pred, t, super_res_latents, prev_timestep=prev_timestep, generator=generator - ).prev_sample - - image = super_res_latents - - # done super res - - # post processing - - image = image * 0.5 + 0.5 - image = image.clamp(0, 1) - image = image.cpu().permute(0, 2, 3, 1).float().numpy() - - if output_type == "pil": - image = self.numpy_to_pil(image) - - if not return_dict: - return (image,) - - return ImagePipelineOutput(images=image) From 6b029b4eed69c32b1662031e57165254c7addbc5 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Thu, 29 Dec 2022 22:51:33 +0100 Subject: [PATCH 4/7] make unclip for text work --- .../pipelines/unclip/pipeline_unclip.py | 74 ++++++------- tests/pipelines/unclip/test_unclip.py | 104 ++++++++++++++++++ .../unclip/test_unclip_image_variation.py | 5 +- 3 files changed, 141 insertions(+), 42 deletions(-) diff --git a/src/diffusers/pipelines/unclip/pipeline_unclip.py b/src/diffusers/pipelines/unclip/pipeline_unclip.py index 1da5b227bed1..097cc41def03 100644 --- a/src/diffusers/pipelines/unclip/pipeline_unclip.py +++ b/src/diffusers/pipelines/unclip/pipeline_unclip.py @@ -140,28 +140,27 @@ def _encode_prompt(self, prompt, device, num_images_per_prompt, do_classifier_fr ) text_input_ids = text_input_ids[:, : self.tokenizer.model_max_length] - text_encoder_output = self.text_encoder(text_input_ids.to(device)) + text_encoder_output = self.text_encoder(text_input_ids.to(device)) - text_embeddings = text_encoder_output.text_embeds - text_encoder_hidden_states = text_encoder_output.last_hidden_state + text_embeddings = text_encoder_output.text_embeds + text_encoder_hidden_states = text_encoder_output.last_hidden_state - else: - batch_size = text_model_output[0].shape[0] - text_embeddings, text_encoder_hidden_states = text_encoder_output[0], text_encoder_output[1] - text_mask = text_attention_mask + else: + batch_size = text_model_output[0].shape[0] + text_embeddings, text_encoder_hidden_states = text_model_output[0], text_model_output[1] + text_mask = text_attention_mask - text_embeddings = text_embeddings.repeat_interleave(num_images_per_prompt, dim=0) - text_encoder_hidden_states = text_encoder_hidden_states.repeat_interleave(num_images_per_prompt, dim=0) - text_mask = text_mask.repeat_interleave(num_images_per_prompt, dim=0) + text_embeddings = text_embeddings.repeat_interleave(num_images_per_prompt, dim=0) + text_encoder_hidden_states = text_encoder_hidden_states.repeat_interleave(num_images_per_prompt, dim=0) + text_mask = text_mask.repeat_interleave(num_images_per_prompt, dim=0) if do_classifier_free_guidance: uncond_tokens = [""] * batch_size - max_length = text_input_ids.shape[-1] uncond_input = self.tokenizer( uncond_tokens, padding="max_length", - max_length=max_length, + max_length=self.tokenizer.model_max_length, truncation=True, return_tensors="pt", ) @@ -329,15 +328,14 @@ def __call__( embedding_dim = self.prior.config.embedding_dim - if prior_latents is None: - prior_latents = self.prepare_latents( - (batch_size, embedding_dim), - text_embeddings.dtype, - device, - generator, - prior_latents, - self.prior_scheduler, - ) + prior_latents = self.prepare_latents( + (batch_size, embedding_dim), + text_embeddings.dtype, + device, + generator, + prior_latents, + self.prior_scheduler, + ) for i, t in enumerate(self.progress_bar(prior_timesteps_tensor)): # expand the latents if we are doing classifier free guidance @@ -394,15 +392,14 @@ def __call__( height = self.decoder.sample_size width = self.decoder.sample_size - if decoder_latents is None: - decoder_latents = self.prepare_latents( - (batch_size, num_channels_latents, height, width), - text_encoder_hidden_states.dtype, - device, - generator, - decoder_latents, - self.decoder_scheduler, - ) + decoder_latents = self.prepare_latents( + (batch_size, num_channels_latents, height, width), + text_encoder_hidden_states.dtype, + device, + generator, + decoder_latents, + self.decoder_scheduler, + ) for i, t in enumerate(self.progress_bar(decoder_timesteps_tensor)): # expand the latents if we are doing classifier free guidance @@ -448,15 +445,14 @@ def __call__( height = self.super_res_first.sample_size width = self.super_res_first.sample_size - if super_res_latents is None: - super_res_latents = self.prepare_latents( - (batch_size, channels, height, width), - image_small.dtype, - device, - generator, - super_res_latents, - self.super_res_scheduler, - ) + super_res_latents = self.prepare_latents( + (batch_size, channels, height, width), + image_small.dtype, + device, + generator, + super_res_latents, + self.super_res_scheduler, + ) interpolate_antialias = {} if "antialias" in inspect.signature(F.interpolate).parameters: diff --git a/tests/pipelines/unclip/test_unclip.py b/tests/pipelines/unclip/test_unclip.py index fb0cb75ea703..2ffc46e8d974 100644 --- a/tests/pipelines/unclip/test_unclip.py +++ b/tests/pipelines/unclip/test_unclip.py @@ -248,6 +248,110 @@ def test_unclip(self): assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2 + def test_unclip_custom_text(self): + device = torch.device("cpu") + + class DummyScheduler: + init_noise_sigma = 1 + + prior = self.dummy_prior + decoder = self.dummy_decoder + text_proj = self.dummy_text_proj + text_encoder = self.dummy_text_encoder + tokenizer = self.dummy_tokenizer + super_res_first = self.dummy_super_res_first + super_res_last = self.dummy_super_res_last + + prior_scheduler = UnCLIPScheduler( + variance_type="fixed_small_log", + prediction_type="sample", + num_train_timesteps=1000, + clip_sample_range=5.0, + ) + + decoder_scheduler = UnCLIPScheduler( + variance_type="learned_range", + prediction_type="epsilon", + num_train_timesteps=1000, + ) + + super_res_scheduler = UnCLIPScheduler( + variance_type="fixed_small_log", + prediction_type="epsilon", + num_train_timesteps=1000, + ) + + pipe = UnCLIPPipeline( + prior=prior, + decoder=decoder, + text_proj=text_proj, + text_encoder=text_encoder, + tokenizer=tokenizer, + super_res_first=super_res_first, + super_res_last=super_res_last, + prior_scheduler=prior_scheduler, + decoder_scheduler=decoder_scheduler, + super_res_scheduler=super_res_scheduler, + ) + pipe = pipe.to(device) + + generator = torch.Generator(device=device).manual_seed(0) + dtype = prior.dtype + batch_size = 1 + + shape = (batch_size, prior.config.embedding_dim) + prior_latents = pipe.prepare_latents(shape, dtype=dtype, device=device, generator=generator, latents=None, scheduler=DummyScheduler()) + shape = (batch_size, decoder.in_channels, decoder.sample_size, decoder.sample_size) + decoder_latents = pipe.prepare_latents(shape, dtype=dtype, device=device, generator=generator, latents=None, scheduler=DummyScheduler()) + + shape = (batch_size, super_res_first.in_channels // 2, super_res_first.sample_size, super_res_first.sample_size) + super_res_latents = pipe.prepare_latents(shape, dtype=dtype, device=device, generator=generator, latents=None, scheduler=DummyScheduler()) + + pipe.set_progress_bar_config(disable=None) + + prompt = "this is a prompt example" + + generator = torch.Generator(device=device).manual_seed(0) + output = pipe( + [prompt], + generator=generator, + prior_num_inference_steps=2, + decoder_num_inference_steps=2, + super_res_num_inference_steps=2, + prior_latents=prior_latents, + decoder_latents=decoder_latents, + super_res_latents=super_res_latents, + output_type="np", + ) + image = output.images + + text_inputs = tokenizer( + prompt, + padding="max_length", + max_length=tokenizer.model_max_length, + return_tensors="pt", + ) + text_model_output = text_encoder(text_inputs.input_ids) + text_attention_mask = text_inputs.attention_mask + + generator = torch.Generator(device=device).manual_seed(0) + image_from_text = pipe( + [prompt], + generator=generator, + prior_num_inference_steps=2, + decoder_num_inference_steps=2, + super_res_num_inference_steps=2, + prior_latents=prior_latents, + decoder_latents=decoder_latents, + super_res_latents=super_res_latents, + text_model_output=text_model_output, + text_attention_mask=text_attention_mask, + output_type="np", + )[0] + + # make sure passing text embeddings manually is identical + assert np.abs(image - image_from_text).max() < 1e-4 + @slow @require_torch_gpu diff --git a/tests/pipelines/unclip/test_unclip_image_variation.py b/tests/pipelines/unclip/test_unclip_image_variation.py index 9b790e2059e4..791b46d9d41a 100644 --- a/tests/pipelines/unclip/test_unclip_image_variation.py +++ b/tests/pipelines/unclip/test_unclip_image_variation.py @@ -421,11 +421,10 @@ def test_unclip_image_variation_karlo(self): "/unclip/karlo_v1_alpha_cat_variation_fp16.npy" ) - pipeline = UnCLIPImageVariationPipeline.from_pretrained( - "fusing/karlo-image-variations-diffusers", torch_dtype=torch.float16 - ) + pipeline = UnCLIPImageVariationPipeline.from_pretrained("fusing/karlo-image-variations-diffusers") pipeline = pipeline.to(torch_device) pipeline.set_progress_bar_config(disable=None) + pipeline.enable_sequential_cpu_offload() generator = torch.Generator(device=torch_device).manual_seed(0) output = pipeline( From a1904886f52d7f401bcbc18df36dcedcfa1c3c40 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Thu, 29 Dec 2022 23:19:58 +0100 Subject: [PATCH 5/7] finish allowing to pass embeddings --- .../pipelines/unclip/pipeline_unclip.py | 24 +++++++-- .../unclip/pipeline_unclip_image_variation.py | 20 +++++--- tests/pipelines/unclip/test_unclip.py | 22 ++++++--- .../unclip/test_unclip_image_variation.py | 49 +++++++++++++++++++ 4 files changed, 98 insertions(+), 17 deletions(-) diff --git a/src/diffusers/pipelines/unclip/pipeline_unclip.py b/src/diffusers/pipelines/unclip/pipeline_unclip.py index 097cc41def03..aa70b6613d43 100644 --- a/src/diffusers/pipelines/unclip/pipeline_unclip.py +++ b/src/diffusers/pipelines/unclip/pipeline_unclip.py @@ -13,7 +13,7 @@ # limitations under the License. import inspect -from typing import List, Optional, Union, Tuple +from typing import List, Optional, Tuple, Union import torch from torch.nn import functional as F @@ -118,7 +118,15 @@ def prepare_latents(self, shape, dtype, device, generator, latents, scheduler): latents = latents * scheduler.init_noise_sigma return latents - def _encode_prompt(self, prompt, device, num_images_per_prompt, do_classifier_free_guidance, text_model_output: Optional[Union[CLIPTextModelOutput, Tuple]] = None, text_attention_mask: Optional[torch.Tensor] = None): + def _encode_prompt( + self, + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + text_model_output: Optional[Union[CLIPTextModelOutput, Tuple]] = None, + text_attention_mask: Optional[torch.Tensor] = None, + ): if text_model_output is None: batch_size = len(prompt) if isinstance(prompt, list) else 1 @@ -241,7 +249,7 @@ def _execution_device(self): @torch.no_grad() def __call__( self, - prompt: Union[str, List[str]], + prompt: Optional[Union[str, List[str]]] = None, num_images_per_prompt: int = 1, prior_num_inference_steps: int = 25, decoder_num_inference_steps: int = 25, @@ -262,7 +270,8 @@ def __call__( Args: prompt (`str` or `List[str]`): - The prompt or prompts to guide the image generation. + The prompt or prompts to guide the image generation. This can only be left undefined if + `text_model_output` and `text_attention_mask` is passed. num_images_per_prompt (`int`, *optional*, defaults to 1): The number of images to generate per prompt. prior_num_inference_steps (`int`, *optional*, defaults to 25): @@ -295,6 +304,13 @@ def __call__( Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, usually at the expense of lower image quality. + text_model_output (`CLIPTextModelOutput`, *optional*): + Pre-defined CLIPTextModel outputs that can be derived from the text encoder. Pre-defined text outputs + can be passed for tasks like text embedding interpolations. Make sure to also pass + `text_attention_mask` in this case. `prompt` can the be left to `None`. + text_attention_mask (`torch.Tensor`, *optional*): + Pre-defined CLIP text attention mask that can be derived from the tokenizer. Pre-defined text attention + masks are necessary when passing `text_model_output`. output_type (`str`, *optional*, defaults to `"pil"`): The output format of the generated image. Choose between [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. diff --git a/src/diffusers/pipelines/unclip/pipeline_unclip_image_variation.py b/src/diffusers/pipelines/unclip/pipeline_unclip_image_variation.py index d656ea815262..cdb9ccc16e6f 100644 --- a/src/diffusers/pipelines/unclip/pipeline_unclip_image_variation.py +++ b/src/diffusers/pipelines/unclip/pipeline_unclip_image_variation.py @@ -259,7 +259,7 @@ def _execution_device(self): @torch.no_grad() def __call__( self, - image: Union[PIL.Image.Image, List[PIL.Image.Image], torch.FloatTensor], + image: Optional[Union[PIL.Image.Image, List[PIL.Image.Image], torch.FloatTensor]] = None, num_images_per_prompt: int = 1, decoder_num_inference_steps: int = 25, super_res_num_inference_steps: int = 7, @@ -279,7 +279,7 @@ def __call__( The image or images to guide the image generation. If you provide a tensor, it needs to comply with the configuration of [this](https://huggingface.co/fusing/karlo-image-variations-diffusers/blob/main/feature_extractor/preprocessor_config.json) - `CLIPFeatureExtractor`. + `CLIPFeatureExtractor`. Can be left to `None` only when `image_embeddings` are passed. num_images_per_prompt (`int`, *optional*, defaults to 1): The number of images to generate per prompt. decoder_num_inference_steps (`int`, *optional*, defaults to 25): @@ -301,18 +301,24 @@ def __call__( Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, usually at the expense of lower image quality. + image_embeddings (`torch.Tensor`, *optional*): + Pre-defined image embeddings that can be derived from the image encoder. Pre-defined image embeddings + can be passed for tasks like image interpolations. `image` can the be left to `None`. output_type (`str`, *optional*, defaults to `"pil"`): The output format of the generated image. Choose between [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~pipeline_utils.ImagePipelineOutput`] instead of a plain tuple. """ - if isinstance(image, PIL.Image.Image): - batch_size = 1 - elif isinstance(image, list): - batch_size = len(image) + if image is not None: + if isinstance(image, PIL.Image.Image): + batch_size = 1 + elif isinstance(image, list): + batch_size = len(image) + else: + batch_size = image.shape[0] else: - batch_size = image.shape[0] + batch_size = image_embeddings.shape[0] prompt = [""] * batch_size diff --git a/tests/pipelines/unclip/test_unclip.py b/tests/pipelines/unclip/test_unclip.py index 2ffc46e8d974..670082c20c24 100644 --- a/tests/pipelines/unclip/test_unclip.py +++ b/tests/pipelines/unclip/test_unclip.py @@ -248,7 +248,7 @@ def test_unclip(self): assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2 - def test_unclip_custom_text(self): + def test_unclip_passed_text_embed(self): device = torch.device("cpu") class DummyScheduler: @@ -300,12 +300,23 @@ class DummyScheduler: batch_size = 1 shape = (batch_size, prior.config.embedding_dim) - prior_latents = pipe.prepare_latents(shape, dtype=dtype, device=device, generator=generator, latents=None, scheduler=DummyScheduler()) + prior_latents = pipe.prepare_latents( + shape, dtype=dtype, device=device, generator=generator, latents=None, scheduler=DummyScheduler() + ) shape = (batch_size, decoder.in_channels, decoder.sample_size, decoder.sample_size) - decoder_latents = pipe.prepare_latents(shape, dtype=dtype, device=device, generator=generator, latents=None, scheduler=DummyScheduler()) + decoder_latents = pipe.prepare_latents( + shape, dtype=dtype, device=device, generator=generator, latents=None, scheduler=DummyScheduler() + ) - shape = (batch_size, super_res_first.in_channels // 2, super_res_first.sample_size, super_res_first.sample_size) - super_res_latents = pipe.prepare_latents(shape, dtype=dtype, device=device, generator=generator, latents=None, scheduler=DummyScheduler()) + shape = ( + batch_size, + super_res_first.in_channels // 2, + super_res_first.sample_size, + super_res_first.sample_size, + ) + super_res_latents = pipe.prepare_latents( + shape, dtype=dtype, device=device, generator=generator, latents=None, scheduler=DummyScheduler() + ) pipe.set_progress_bar_config(disable=None) @@ -336,7 +347,6 @@ class DummyScheduler: generator = torch.Generator(device=device).manual_seed(0) image_from_text = pipe( - [prompt], generator=generator, prior_num_inference_steps=2, decoder_num_inference_steps=2, diff --git a/tests/pipelines/unclip/test_unclip_image_variation.py b/tests/pipelines/unclip/test_unclip_image_variation.py index 791b46d9d41a..86448e0d9375 100644 --- a/tests/pipelines/unclip/test_unclip_image_variation.py +++ b/tests/pipelines/unclip/test_unclip_image_variation.py @@ -402,6 +402,55 @@ def test_unclip_image_variation_input_num_images_per_prompt(self): assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2 + def test_unclip_passed_image_embed(self): + device = torch.device("cpu") + seed = 0 + + class DummyScheduler: + init_noise_sigma = 1 + + pipe = self.get_pipeline(device) + + generator = torch.Generator(device=device).manual_seed(0) + dtype = pipe.decoder.dtype + batch_size = 1 + + shape = (batch_size, pipe.decoder.in_channels, pipe.decoder.sample_size, pipe.decoder.sample_size) + decoder_latents = pipe.prepare_latents( + shape, dtype=dtype, device=device, generator=generator, latents=None, scheduler=DummyScheduler() + ) + + shape = ( + batch_size, + pipe.super_res_first.in_channels // 2, + pipe.super_res_first.sample_size, + pipe.super_res_first.sample_size, + ) + super_res_latents = pipe.prepare_latents( + shape, dtype=dtype, device=device, generator=generator, latents=None, scheduler=DummyScheduler() + ) + + pipeline_inputs = self.get_pipeline_inputs(device, seed) + + img_out_1 = pipe( + **pipeline_inputs, decoder_latents=decoder_latents, super_res_latents=super_res_latents + ).images + + pipeline_inputs = self.get_pipeline_inputs(device, seed) + # Don't pass image, instead pass embedding + image = pipeline_inputs.pop("image") + image_embeddings = pipe.image_encoder(image).image_embeds + + img_out_2 = pipe( + **pipeline_inputs, + decoder_latents=decoder_latents, + super_res_latents=super_res_latents, + image_embeddings=image_embeddings, + ).images + + # make sure passing text embeddings manually is identical + assert np.abs(img_out_1 - img_out_2).max() < 1e-4 + @slow @require_torch_gpu From 9e449c12151ca9ee50950c8f63c055f1515b49d3 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Fri, 30 Dec 2022 11:39:46 +0100 Subject: [PATCH 6/7] correct more --- .../unclip/pipeline_unclip_image_variation.py | 10 ---------- 1 file changed, 10 deletions(-) diff --git a/src/diffusers/pipelines/unclip/pipeline_unclip_image_variation.py b/src/diffusers/pipelines/unclip/pipeline_unclip_image_variation.py index cdb9ccc16e6f..80b33aebafd7 100644 --- a/src/diffusers/pipelines/unclip/pipeline_unclip_image_variation.py +++ b/src/diffusers/pipelines/unclip/pipeline_unclip_image_variation.py @@ -126,7 +126,6 @@ def prepare_latents(self, shape, dtype, device, generator, latents, scheduler): latents = latents * scheduler.init_noise_sigma return latents - # Copied from diffusers.pipelines.unclip.pipeline_unclip.UnCLIPPipeline._encode_prompt def _encode_prompt(self, prompt, device, num_images_per_prompt, do_classifier_free_guidance): batch_size = len(prompt) if isinstance(prompt, list) else 1 @@ -139,15 +138,6 @@ def _encode_prompt(self, prompt, device, num_images_per_prompt, do_classifier_fr ) text_input_ids = text_inputs.input_ids text_mask = text_inputs.attention_mask.bool().to(device) - - if text_input_ids.shape[-1] > self.tokenizer.model_max_length: - removed_text = self.tokenizer.batch_decode(text_input_ids[:, self.tokenizer.model_max_length :]) - logger.warning( - "The following part of your input was truncated because CLIP can only handle sequences up to" - f" {self.tokenizer.model_max_length} tokens: {removed_text}" - ) - text_input_ids = text_input_ids[:, : self.tokenizer.model_max_length] - text_encoder_output = self.text_encoder(text_input_ids.to(device)) text_embeddings = text_encoder_output.text_embeds From b930e02916fcce6f88bd0e63ebf85502aad3be47 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Fri, 30 Dec 2022 11:42:41 +0100 Subject: [PATCH 7/7] make style --- src/diffusers/pipelines/unclip/pipeline_unclip.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/diffusers/pipelines/unclip/pipeline_unclip.py b/src/diffusers/pipelines/unclip/pipeline_unclip.py index aa70b6613d43..24fae5e45b93 100644 --- a/src/diffusers/pipelines/unclip/pipeline_unclip.py +++ b/src/diffusers/pipelines/unclip/pipeline_unclip.py @@ -127,7 +127,6 @@ def _encode_prompt( text_model_output: Optional[Union[CLIPTextModelOutput, Tuple]] = None, text_attention_mask: Optional[torch.Tensor] = None, ): - if text_model_output is None: batch_size = len(prompt) if isinstance(prompt, list) else 1 # get prompt text embeddings