diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml index cfbdac08a3fb..40ee6c7a4960 100644 --- a/docs/source/en/_toctree.yml +++ b/docs/source/en/_toctree.yml @@ -165,6 +165,8 @@ title: Self-Attention Guidance - local: api/pipelines/stable_diffusion/panorama title: MultiDiffusion Panorama + - local: api/pipelines/stable_diffusion/controlnet + title: Text-to-Image Generation with ControlNet Conditioning title: Stable Diffusion - local: api/pipelines/stable_diffusion_2 title: Stable Diffusion 2 diff --git a/docs/source/en/api/models.mdx b/docs/source/en/api/models.mdx index 6db2eba32220..dc425e98628c 100644 --- a/docs/source/en/api/models.mdx +++ b/docs/source/en/api/models.mdx @@ -64,6 +64,12 @@ The models are built on the base class ['ModelMixin'] that is a `torch.nn.module ## PriorTransformerOutput [[autodoc]] models.prior_transformer.PriorTransformerOutput +## ControlNetOutput +[[autodoc]] models.controlnet.ControlNetOutput + +## ControlNetModel +[[autodoc]] ControlNetModel + ## FlaxModelMixin [[autodoc]] FlaxModelMixin diff --git a/docs/source/en/api/pipelines/overview.mdx b/docs/source/en/api/pipelines/overview.mdx index bd24a170957b..411f56ce5871 100644 --- a/docs/source/en/api/pipelines/overview.mdx +++ b/docs/source/en/api/pipelines/overview.mdx @@ -46,6 +46,7 @@ available a colab notebook to directly try them out. |---|---|:---:|:---:| | [alt_diffusion](./alt_diffusion) | [**AltDiffusion**](https://arxiv.org/abs/2211.06679) | Image-to-Image Text-Guided Generation | - | [audio_diffusion](./audio_diffusion) | [**Audio Diffusion**](https://github.com/teticio/audio_diffusion.git) | Unconditional Audio Generation | +| [controlnet](./api/pipelines/stable_diffusion/controlnet) | [**ControlNet with Stable Diffusion**](https://arxiv.org/abs/2302.05543) | Image-to-Image Text-Guided Generation | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1AiR7Q-sBqO88NCyswpfiuwXZc7DfMyKA?usp=sharing) | [cycle_diffusion](./cycle_diffusion) | [**Cycle Diffusion**](https://arxiv.org/abs/2210.05559) | Image-to-Image Text-Guided Generation | | [dance_diffusion](./dance_diffusion) | [**Dance Diffusion**](https://github.com/williamberman/diffusers.git) | Unconditional Audio Generation | | [ddpm](./ddpm) | [**Denoising Diffusion Probabilistic Models**](https://arxiv.org/abs/2006.11239) | Unconditional Image Generation | diff --git a/docs/source/en/api/pipelines/stable_diffusion/controlnet.mdx b/docs/source/en/api/pipelines/stable_diffusion/controlnet.mdx new file mode 100644 index 000000000000..f636d79ba50c --- /dev/null +++ b/docs/source/en/api/pipelines/stable_diffusion/controlnet.mdx @@ -0,0 +1,160 @@ + + +# Text-to-Image Generation with ControlNet Conditioning + +## Overview + +[Adding Conditional Control to Text-to-Image Diffusion Models](https://arxiv.org/abs/2302.05543) by Lvmin Zhang and Maneesh Agrawala. + +Using the pretrained models we can provide control images (for example, a depth map) to control Stable Diffusion text-to-image generation so that it follows the structure of the depth image and fills in the details. + +The abstract of the paper is the following: + +*We present a neural network structure, ControlNet, to control pretrained large diffusion models to support additional input conditions. The ControlNet learns task-specific conditions in an end-to-end way, and the learning is robust even when the training dataset is small (< 50k). Moreover, training a ControlNet is as fast as fine-tuning a diffusion model, and the model can be trained on a personal devices. Alternatively, if powerful computation clusters are available, the model can scale to large amounts (millions to billions) of data. We report that large diffusion models like Stable Diffusion can be augmented with ControlNets to enable conditional inputs like edge maps, segmentation maps, keypoints, etc. This may enrich the methods to control large diffusion models and further facilitate related applications.* + +This model was contributed by the amazing community contributor [takuma104](https://huggingface.co/takuma104) ❤️ . + +Resources: + +* [Paper](https://arxiv.org/abs/2302.05543) +* [Original Code](https://github.com/lllyasviel/ControlNet) + +## Available Pipelines: + +| Pipeline | Tasks | Demo +|---|---|:---:| +| [StableDiffusionControlNetPipeline](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_controlnet.py) | *Text-to-Image Generation with ControlNet Conditioning* | [Colab Example](https://colab.research.google.com/drive/1AiR7Q-sBqO88NCyswpfiuwXZc7DfMyKA?usp=sharing) | + +## Usage example + +In the following we give a simple example of how to use a *ControlNet* checkpoint with Diffusers for inference. +The inference pipeline is the same for all pipelines: + +* 1. Take an image and run it through a pre-conditioning processor. +* 2. Run the pre-processed image through the [`StableDiffusionControlNetPipeline`]. + +Let's have a look at a simple example using the [Canny Edge ControlNet](https://huggingface.co/lllyasviel/sd-controlnet-canny). + +```python +from diffusers import StableDiffusionControlNetPipeline +from diffusers.utils import load_image + +# Let's load the popular vermeer image +image = load_image( + "https://hf.co/datasets/huggingface/documentation-images/resolve/main/diffusers/input_image_vermeer.png" +) +``` + +![img](https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/input_image_vermeer.png) + +Next, we process the image to get the canny image. This is step *1.* - running the pre-conditioning processor. The pre-conditioning processor is different for every ControlNet. Please see the model cards of the [official checkpoints](#controlnet-with-stable-diffusion-1.5) for more information about other models. + +First, we need to install opencv: + +``` +pip install opencv-contrib-python +``` + +Then we can retrieve the canny edges of the image. + +```python +import cv2 +from PIL import Image +import numpy as np + +image = np.array(image) + +low_threshold = 100 +high_threshold = 200 + +image = cv2.Canny(image, low_threshold, high_threshold) +image = image[:, :, None] +image = np.concatenate([image, image, image], axis=2) +canny_image = Image.fromarray(image) +``` + +Let's take a look at the processed image. + +![img](https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/vermeer_canny_edged.png) + +Now, we load the official [Stable Diffusion 1.5 Model](runwayml/stable-diffusion-v1-5) as well as the ControlNet for canny edges. + +```py +from diffusers import StableDiffusionControlNetPipeline, ControlNetModel +import torch + +controlnet = ControlNetModel.from_pretrained("lllyasviel/sd-controlnet-canny", torch_dtype=torch.float16) +pipe = StableDiffusionControlNetPipeline.from_pretrained( + "runwayml/stable-diffusion-v1-5", controlnet=controlnet, torch_dtype=torch.float16 +) +``` + +To speed-up things and reduce memory, let's enable model offloading and use the fast [`UniPCMultistepScheduler`]. + +```py +from diffusers import UniPCMultistepScheduler + +pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config) + +# this command loads the individual model components on GPU on-demand. +pipe.enable_model_cpu_offload() +``` + +Finally, we can run the pipeline: + +```py +generator = torch.manual_seed(0) + +out_image = pipe( + "disco dancer with colorful lights", num_inference_steps=20, generator=generator, image=canny_image +).images[0] +``` + +This should take only around 3-4 seconds on GPU (depending on hardware). The output image then looks as follows: + +![img](https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/vermeer_disco_dancing.png) + + +**Note**: To see how to run all other ControlNet checkpoints, please have a look at [ControlNet with Stable Diffusion 1.5](#controlnet-with-stable-diffusion-1.5) + + + +## Available checkpoints + +ControlNet requires a *control image* in addition to the text-to-image *prompt*. +Each pretrained model is trained using a different conditioning method that requires different images for conditioning the generated outputs. For example, Canny edge conditioning requires the control image to be the output of a Canny filter, while depth conditioning requires the control image to be a depth map. See the overview and image examples below to know more. + +All checkpoints can be found under the authors' namespace [lllyasviel](https://huggingface.co/lllyasviel). + +### ControlNet with Stable Diffusion 1.5 + +| Model Name | Control Image Overview| Control Image Example | Generated Image Example | +|---|---|---|---| +|[lllyasviel/sd-controlnet-canny](https://huggingface.co/lllyasviel/sd-controlnet-canny)
*Trained with canny edge detection* | A monochrome image with white edges on a black background.||| +|[lllyasviel/sd-controlnet-depth](https://huggingface.co/lllyasviel/sd-controlnet-depth)
*Trained with Midas depth estimation* |A grayscale image with black representing deep areas and white representing shallow areas.||| +|[lllyasviel/sd-controlnet-hed](https://huggingface.co/lllyasviel/sd-controlnet-hed)
*Trained with HED edge detection (soft edge)* |A monochrome image with white soft edges on a black background.|| | +|[lllyasviel/sd-controlnet-mlsd](https://huggingface.co/lllyasviel/sd-controlnet-mlsd)
*Trained with M-LSD line detection* |A monochrome image composed only of white straight lines on a black background.||| +|[lllyasviel/sd-controlnet-normal](https://huggingface.co/lllyasviel/sd-controlnet-normal)
*Trained with normal map* |A [normal mapped](https://en.wikipedia.org/wiki/Normal_mapping) image.||| +|[lllyasviel/sd-controlnet_openpose](https://huggingface.co/lllyasviel/sd-controlnet_openpose)
*Trained with OpenPose bone image* |A [OpenPose bone](https://github.com/CMU-Perceptual-Computing-Lab/openpose) image.||| +|[lllyasviel/sd-controlnet_scribble](https://huggingface.co/lllyasviel/sd-controlnet_scribble)
*Trained with human scribbles* |A hand-drawn monochrome image with white outlines on a black background.|| | +|[lllyasviel/sd-controlnet_seg](https://huggingface.co/lllyasviel/sd-controlnet_seg)
*Trained with semantic segmentation* |An [ADE20K](https://groups.csail.mit.edu/vision/datasets/ADE20K/)'s segmentation protocol image.|| | + +[[autodoc]] StableDiffusionControlNetPipeline + - all + - __call__ + - enable_attention_slicing + - disable_attention_slicing + - enable_vae_slicing + - disable_vae_slicing + - enable_xformers_memory_efficient_attention + - disable_xformers_memory_efficient_attention diff --git a/docs/source/en/index.mdx b/docs/source/en/index.mdx index 675278bf8924..5471d9235c7e 100644 --- a/docs/source/en/index.mdx +++ b/docs/source/en/index.mdx @@ -36,6 +36,7 @@ available a colab notebook to directly try them out. |---|---|:---:|:---:| | [alt_diffusion](./api/pipelines/alt_diffusion) | [**AltDiffusion**](https://arxiv.org/abs/2211.06679) | Image-to-Image Text-Guided Generation | | [audio_diffusion](./api/pipelines/audio_diffusion) | [**Audio Diffusion**](https://github.com/teticio/audio-diffusion.git) | Unconditional Audio Generation | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/teticio/audio-diffusion/blob/master/notebooks/audio_diffusion_pipeline.ipynb) +| [controlnet](./api/pipelines/stable_diffusion/controlnet) | [**ControlNet with Stable Diffusion**](https://arxiv.org/abs/2302.05543) | Image-to-Image Text-Guided Generation | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1AiR7Q-sBqO88NCyswpfiuwXZc7DfMyKA?usp=sharing) | [cycle_diffusion](./api/pipelines/cycle_diffusion) | [**Cycle Diffusion**](https://arxiv.org/abs/2210.05559) | Image-to-Image Text-Guided Generation | | [dance_diffusion](./api/pipelines/dance_diffusion) | [**Dance Diffusion**](https://github.com/williamberman/diffusers.git) | Unconditional Audio Generation | | [ddpm](./api/pipelines/ddpm) | [**Denoising Diffusion Probabilistic Models**](https://arxiv.org/abs/2006.11239) | Unconditional Image Generation | diff --git a/docs/source/en/using-diffusers/controlling_generation.mdx b/docs/source/en/using-diffusers/controlling_generation.mdx index bce101a715f7..74da955abcf2 100644 --- a/docs/source/en/using-diffusers/controlling_generation.mdx +++ b/docs/source/en/using-diffusers/controlling_generation.mdx @@ -35,6 +35,7 @@ Unless otherwise mentioned, these are techniques that work with existing models 7. [MultiDiffusion Panorama](#multidiffusion-panorama) 8. [DreamBooth](#dreambooth) 9. [Textual Inversion](#textual-inversion) +10. [ControlNet](#controlnet) ## Instruct Pix2Pix @@ -146,3 +147,14 @@ See [here](../training/dreambooth) for more information on how to use it. [Textual Inversion](../training/text_inversion) fine-tunes a model to teach it about a new concept. I.e. a few pictures of a style of artwork can be used to generate images in that style. See [here](../training/text_inversion) for more information on how to use it. + +## ControlNet + +[Paper](https://arxiv.org/abs/2302.05543) + +[ControlNet](../api/pipelines/stable_diffusion/controlnet) is an auxiliary network which adds an extra condition. +There are 8 canonical pre-trained ControlNets trained on different conditionings such as edge detection, scribbles, +depth maps, and semantic segmentations. + +See [here](../api/pipelines/stable_diffusion/controlnet) for more information on how to use it. + diff --git a/scripts/convert_original_stable_diffusion_to_diffusers.py b/scripts/convert_original_stable_diffusion_to_diffusers.py index 492fd1a8fa29..15afbccb900e 100644 --- a/scripts/convert_original_stable_diffusion_to_diffusers.py +++ b/scripts/convert_original_stable_diffusion_to_diffusers.py @@ -120,6 +120,9 @@ help="Path to the clip stats file. Only required if the stable unclip model's config specifies `model.params.noise_aug_config.params.clip_stats_path`.", required=False, ) + parser.add_argument( + "--controlnet", action="store_true", default=None, help="Set flag if this is a controlnet checkpoint." + ) args = parser.parse_args() pipe = load_pipeline_from_original_stable_diffusion_ckpt( @@ -137,5 +140,11 @@ stable_unclip=args.stable_unclip, stable_unclip_prior=args.stable_unclip_prior, clip_stats_path=args.clip_stats_path, + controlnet=args.controlnet, ) - pipe.save_pretrained(args.dump_path, safe_serialization=args.to_safetensors) + + if args.controlnet: + # only save the controlnet model + pipe.controlnet.save_pretrained(args.dump_path, safe_serialization=args.to_safetensors) + else: + pipe.save_pretrained(args.dump_path, safe_serialization=args.to_safetensors) diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index 3767a0b410f6..59b0a5ec206c 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -34,6 +34,7 @@ else: from .models import ( AutoencoderKL, + ControlNetModel, ModelMixin, PriorTransformer, Transformer2DModel, @@ -113,6 +114,7 @@ PaintByExamplePipeline, SemanticStableDiffusionPipeline, StableDiffusionAttendAndExcitePipeline, + StableDiffusionControlNetPipeline, StableDiffusionDepth2ImgPipeline, StableDiffusionImageVariationPipeline, StableDiffusionImg2ImgPipeline, diff --git a/src/diffusers/models/__init__.py b/src/diffusers/models/__init__.py index 127a85f263d6..e0b2cddd4bf9 100644 --- a/src/diffusers/models/__init__.py +++ b/src/diffusers/models/__init__.py @@ -17,6 +17,7 @@ if is_torch_available(): from .autoencoder_kl import AutoencoderKL + from .controlnet import ControlNetModel from .dual_transformer_2d import DualTransformer2DModel from .modeling_utils import ModelMixin from .prior_transformer import PriorTransformer diff --git a/src/diffusers/models/controlnet.py b/src/diffusers/models/controlnet.py new file mode 100644 index 000000000000..5b55da5af372 --- /dev/null +++ b/src/diffusers/models/controlnet.py @@ -0,0 +1,506 @@ +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from dataclasses import dataclass +from typing import Any, Dict, List, Optional, Tuple, Union + +import torch +from torch import nn +from torch.nn import functional as F + +from ..configuration_utils import ConfigMixin, register_to_config +from ..utils import BaseOutput, logging +from .cross_attention import AttnProcessor +from .embeddings import TimestepEmbedding, Timesteps +from .modeling_utils import ModelMixin +from .unet_2d_blocks import ( + CrossAttnDownBlock2D, + DownBlock2D, + UNetMidBlock2DCrossAttn, + get_down_block, +) + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +@dataclass +class ControlNetOutput(BaseOutput): + down_block_res_samples: Tuple[torch.Tensor] + mid_block_res_sample: torch.Tensor + + +class ControlNetConditioningEmbedding(nn.Module): + """ + Quoting from https://arxiv.org/abs/2302.05543: "Stable Diffusion uses a pre-processing method similar to VQ-GAN + [11] to convert the entire dataset of 512 × 512 images into smaller 64 × 64 “latent images” for stabilized + training. This requires ControlNets to convert image-based conditions to 64 × 64 feature space to match the + convolution size. We use a tiny network E(·) of four convolution layers with 4 × 4 kernels and 2 × 2 strides + (activated by ReLU, channels are 16, 32, 64, 128, initialized with Gaussian weights, trained jointly with the full + model) to encode image-space conditions ... into feature maps ..." + """ + + def __init__( + self, + conditioning_embedding_channels: int, + conditioning_channels: int = 3, + block_out_channels: Tuple[int] = (16, 32, 96, 256), + ): + super().__init__() + + self.conv_in = nn.Conv2d(conditioning_channels, block_out_channels[0], kernel_size=3, padding=1) + + self.blocks = nn.ModuleList([]) + + for i in range(len(block_out_channels) - 1): + channel_in = block_out_channels[i] + channel_out = block_out_channels[i + 1] + self.blocks.append(nn.Conv2d(channel_in, channel_in, kernel_size=3, padding=1)) + self.blocks.append(nn.Conv2d(channel_in, channel_out, kernel_size=3, padding=1, stride=2)) + + self.conv_out = zero_module( + nn.Conv2d(block_out_channels[-1], conditioning_embedding_channels, kernel_size=3, padding=1) + ) + + def forward(self, conditioning): + embedding = self.conv_in(conditioning) + embedding = F.silu(embedding) + + for block in self.blocks: + embedding = block(embedding) + embedding = F.silu(embedding) + + embedding = self.conv_out(embedding) + + return embedding + + +class ControlNetModel(ModelMixin, ConfigMixin): + _supports_gradient_checkpointing = True + + @register_to_config + def __init__( + self, + in_channels: int = 4, + flip_sin_to_cos: bool = True, + freq_shift: int = 0, + down_block_types: Tuple[str] = ( + "CrossAttnDownBlock2D", + "CrossAttnDownBlock2D", + "CrossAttnDownBlock2D", + "DownBlock2D", + ), + only_cross_attention: Union[bool, Tuple[bool]] = False, + block_out_channels: Tuple[int] = (320, 640, 1280, 1280), + layers_per_block: int = 2, + downsample_padding: int = 1, + mid_block_scale_factor: float = 1, + act_fn: str = "silu", + norm_num_groups: Optional[int] = 32, + norm_eps: float = 1e-5, + cross_attention_dim: int = 1280, + attention_head_dim: Union[int, Tuple[int]] = 8, + use_linear_projection: bool = False, + class_embed_type: Optional[str] = None, + num_class_embeds: Optional[int] = None, + upcast_attention: bool = False, + resnet_time_scale_shift: str = "default", + projection_class_embeddings_input_dim: Optional[int] = None, + controlnet_conditioning_channel_order: str = "rgb", + conditioning_embedding_out_channels: Optional[Tuple[int]] = (16, 32, 96, 256), + ): + super().__init__() + + # Check inputs + if len(block_out_channels) != len(down_block_types): + raise ValueError( + f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}." + ) + + if not isinstance(only_cross_attention, bool) and len(only_cross_attention) != len(down_block_types): + raise ValueError( + f"Must provide the same number of `only_cross_attention` as `down_block_types`. `only_cross_attention`: {only_cross_attention}. `down_block_types`: {down_block_types}." + ) + + if not isinstance(attention_head_dim, int) and len(attention_head_dim) != len(down_block_types): + raise ValueError( + f"Must provide the same number of `attention_head_dim` as `down_block_types`. `attention_head_dim`: {attention_head_dim}. `down_block_types`: {down_block_types}." + ) + + # input + conv_in_kernel = 3 + conv_in_padding = (conv_in_kernel - 1) // 2 + self.conv_in = nn.Conv2d( + in_channels, block_out_channels[0], kernel_size=conv_in_kernel, padding=conv_in_padding + ) + + # time + time_embed_dim = block_out_channels[0] * 4 + + self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift) + timestep_input_dim = block_out_channels[0] + + self.time_embedding = TimestepEmbedding( + timestep_input_dim, + time_embed_dim, + act_fn=act_fn, + ) + + # class embedding + if class_embed_type is None and num_class_embeds is not None: + self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim) + elif class_embed_type == "timestep": + self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim) + elif class_embed_type == "identity": + self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim) + elif class_embed_type == "projection": + if projection_class_embeddings_input_dim is None: + raise ValueError( + "`class_embed_type`: 'projection' requires `projection_class_embeddings_input_dim` be set" + ) + # The projection `class_embed_type` is the same as the timestep `class_embed_type` except + # 1. the `class_labels` inputs are not first converted to sinusoidal embeddings + # 2. it projects from an arbitrary input dimension. + # + # Note that `TimestepEmbedding` is quite general, being mainly linear layers and activations. + # When used for embedding actual timesteps, the timesteps are first converted to sinusoidal embeddings. + # As a result, `TimestepEmbedding` can be passed arbitrary vectors. + self.class_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim) + else: + self.class_embedding = None + + # control net conditioning embedding + self.controlnet_cond_embedding = ControlNetConditioningEmbedding( + conditioning_embedding_channels=block_out_channels[0], + block_out_channels=conditioning_embedding_out_channels, + ) + + self.down_blocks = nn.ModuleList([]) + self.controlnet_down_blocks = nn.ModuleList([]) + + if isinstance(only_cross_attention, bool): + only_cross_attention = [only_cross_attention] * len(down_block_types) + + if isinstance(attention_head_dim, int): + attention_head_dim = (attention_head_dim,) * len(down_block_types) + + # down + output_channel = block_out_channels[0] + + controlnet_block = nn.Conv2d(output_channel, output_channel, kernel_size=1) + controlnet_block = zero_module(controlnet_block) + self.controlnet_down_blocks.append(controlnet_block) + + for i, down_block_type in enumerate(down_block_types): + input_channel = output_channel + output_channel = block_out_channels[i] + is_final_block = i == len(block_out_channels) - 1 + + down_block = get_down_block( + down_block_type, + num_layers=layers_per_block, + in_channels=input_channel, + out_channels=output_channel, + temb_channels=time_embed_dim, + add_downsample=not is_final_block, + resnet_eps=norm_eps, + resnet_act_fn=act_fn, + resnet_groups=norm_num_groups, + cross_attention_dim=cross_attention_dim, + attn_num_head_channels=attention_head_dim[i], + downsample_padding=downsample_padding, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention[i], + upcast_attention=upcast_attention, + resnet_time_scale_shift=resnet_time_scale_shift, + ) + self.down_blocks.append(down_block) + + for _ in range(layers_per_block): + controlnet_block = nn.Conv2d(output_channel, output_channel, kernel_size=1) + controlnet_block = zero_module(controlnet_block) + self.controlnet_down_blocks.append(controlnet_block) + + if not is_final_block: + controlnet_block = nn.Conv2d(output_channel, output_channel, kernel_size=1) + controlnet_block = zero_module(controlnet_block) + self.controlnet_down_blocks.append(controlnet_block) + + # mid + mid_block_channel = block_out_channels[-1] + + controlnet_block = nn.Conv2d(mid_block_channel, mid_block_channel, kernel_size=1) + controlnet_block = zero_module(controlnet_block) + self.controlnet_mid_block = controlnet_block + + self.mid_block = UNetMidBlock2DCrossAttn( + in_channels=mid_block_channel, + temb_channels=time_embed_dim, + resnet_eps=norm_eps, + resnet_act_fn=act_fn, + output_scale_factor=mid_block_scale_factor, + resnet_time_scale_shift=resnet_time_scale_shift, + cross_attention_dim=cross_attention_dim, + attn_num_head_channels=attention_head_dim[-1], + resnet_groups=norm_num_groups, + use_linear_projection=use_linear_projection, + upcast_attention=upcast_attention, + ) + + @property + # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.attn_processors + def attn_processors(self) -> Dict[str, AttnProcessor]: + r""" + Returns: + `dict` of attention processors: A dictionary containing all attention processors used in the model with + indexed by its weight name. + """ + # set recursively + processors = {} + + def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttnProcessor]): + if hasattr(module, "set_processor"): + processors[f"{name}.processor"] = module.processor + + for sub_name, child in module.named_children(): + fn_recursive_add_processors(f"{name}.{sub_name}", child, processors) + + return processors + + for name, module in self.named_children(): + fn_recursive_add_processors(name, module, processors) + + return processors + + # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attn_processor + def set_attn_processor(self, processor: Union[AttnProcessor, Dict[str, AttnProcessor]]): + r""" + Parameters: + `processor (`dict` of `AttnProcessor` or `AttnProcessor`): + The instantiated processor class or a dictionary of processor classes that will be set as the processor + of **all** `CrossAttention` layers. + In case `processor` is a dict, the key needs to define the path to the corresponding cross attention processor. This is strongly recommended when setting trainablae attention processors.: + + """ + count = len(self.attn_processors.keys()) + + if isinstance(processor, dict) and len(processor) != count: + raise ValueError( + f"A dict of processors was passed, but the number of processors {len(processor)} does not match the" + f" number of attention layers: {count}. Please make sure to pass {count} processor classes." + ) + + def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): + if hasattr(module, "set_processor"): + if not isinstance(processor, dict): + module.set_processor(processor) + else: + module.set_processor(processor.pop(f"{name}.processor")) + + for sub_name, child in module.named_children(): + fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor) + + for name, module in self.named_children(): + fn_recursive_attn_processor(name, module, processor) + + # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attention_slice + def set_attention_slice(self, slice_size): + r""" + Enable sliced attention computation. + + When this option is enabled, the attention module will split the input tensor in slices, to compute attention + in several steps. This is useful to save some memory in exchange for a small speed decrease. + + Args: + slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`): + When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If + `"max"`, maxium amount of memory will be saved by running only one slice at a time. If a number is + provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim` + must be a multiple of `slice_size`. + """ + sliceable_head_dims = [] + + def fn_recursive_retrieve_slicable_dims(module: torch.nn.Module): + if hasattr(module, "set_attention_slice"): + sliceable_head_dims.append(module.sliceable_head_dim) + + for child in module.children(): + fn_recursive_retrieve_slicable_dims(child) + + # retrieve number of attention layers + for module in self.children(): + fn_recursive_retrieve_slicable_dims(module) + + num_slicable_layers = len(sliceable_head_dims) + + if slice_size == "auto": + # half the attention head size is usually a good trade-off between + # speed and memory + slice_size = [dim // 2 for dim in sliceable_head_dims] + elif slice_size == "max": + # make smallest slice possible + slice_size = num_slicable_layers * [1] + + slice_size = num_slicable_layers * [slice_size] if not isinstance(slice_size, list) else slice_size + + if len(slice_size) != len(sliceable_head_dims): + raise ValueError( + f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different" + f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}." + ) + + for i in range(len(slice_size)): + size = slice_size[i] + dim = sliceable_head_dims[i] + if size is not None and size > dim: + raise ValueError(f"size {size} has to be smaller or equal to {dim}.") + + # Recursively walk through all the children. + # Any children which exposes the set_attention_slice method + # gets the message + def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[int]): + if hasattr(module, "set_attention_slice"): + module.set_attention_slice(slice_size.pop()) + + for child in module.children(): + fn_recursive_set_attention_slice(child, slice_size) + + reversed_slice_size = list(reversed(slice_size)) + for module in self.children(): + fn_recursive_set_attention_slice(module, reversed_slice_size) + + def _set_gradient_checkpointing(self, module, value=False): + if isinstance(module, (CrossAttnDownBlock2D, DownBlock2D)): + module.gradient_checkpointing = value + + def forward( + self, + sample: torch.FloatTensor, + timestep: Union[torch.Tensor, float, int], + encoder_hidden_states: torch.Tensor, + controlnet_cond: torch.FloatTensor, + class_labels: Optional[torch.Tensor] = None, + timestep_cond: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + return_dict: bool = True, + ) -> Union[ControlNetOutput, Tuple]: + # check channel order + channel_order = self.config.controlnet_conditioning_channel_order + + if channel_order == "rgb": + # in rgb order by default + ... + elif channel_order == "bgr": + controlnet_cond = torch.flip(controlnet_cond, dims=[1]) + else: + raise ValueError(f"unknown `controlnet_conditioning_channel_order`: {channel_order}") + + # prepare attention_mask + if attention_mask is not None: + attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0 + attention_mask = attention_mask.unsqueeze(1) + + # 1. time + timesteps = timestep + if not torch.is_tensor(timesteps): + # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can + # This would be a good case for the `match` statement (Python 3.10+) + is_mps = sample.device.type == "mps" + if isinstance(timestep, float): + dtype = torch.float32 if is_mps else torch.float64 + else: + dtype = torch.int32 if is_mps else torch.int64 + timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device) + elif len(timesteps.shape) == 0: + timesteps = timesteps[None].to(sample.device) + + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timesteps = timesteps.expand(sample.shape[0]) + + t_emb = self.time_proj(timesteps) + + # timesteps does not contain any weights and will always return f32 tensors + # but time_embedding might actually be running in fp16. so we need to cast here. + # there might be better ways to encapsulate this. + t_emb = t_emb.to(dtype=self.dtype) + + emb = self.time_embedding(t_emb, timestep_cond) + + if self.class_embedding is not None: + if class_labels is None: + raise ValueError("class_labels should be provided when num_class_embeds > 0") + + if self.config.class_embed_type == "timestep": + class_labels = self.time_proj(class_labels) + + class_emb = self.class_embedding(class_labels).to(dtype=self.dtype) + emb = emb + class_emb + + # 2. pre-process + sample = self.conv_in(sample) + + controlnet_cond = self.controlnet_cond_embedding(controlnet_cond) + + sample += controlnet_cond + + # 3. down + down_block_res_samples = (sample,) + for downsample_block in self.down_blocks: + if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention: + sample, res_samples = downsample_block( + hidden_states=sample, + temb=emb, + encoder_hidden_states=encoder_hidden_states, + attention_mask=attention_mask, + cross_attention_kwargs=cross_attention_kwargs, + ) + else: + sample, res_samples = downsample_block(hidden_states=sample, temb=emb) + + down_block_res_samples += res_samples + + # 4. mid + if self.mid_block is not None: + sample = self.mid_block( + sample, + emb, + encoder_hidden_states=encoder_hidden_states, + attention_mask=attention_mask, + cross_attention_kwargs=cross_attention_kwargs, + ) + + # 5. Control net blocks + + controlnet_down_block_res_samples = () + + for down_block_res_sample, controlnet_block in zip(down_block_res_samples, self.controlnet_down_blocks): + down_block_res_sample = controlnet_block(down_block_res_sample) + controlnet_down_block_res_samples += (down_block_res_sample,) + + down_block_res_samples = controlnet_down_block_res_samples + + mid_block_res_sample = self.controlnet_mid_block(sample) + + if not return_dict: + return (down_block_res_samples, mid_block_res_sample) + + return ControlNetOutput( + down_block_res_samples=down_block_res_samples, mid_block_res_sample=mid_block_res_sample + ) + + +def zero_module(module): + for p in module.parameters(): + nn.init.zeros_(p) + return module diff --git a/src/diffusers/models/unet_2d_condition.py b/src/diffusers/models/unet_2d_condition.py index d6e2e4a234cc..a5f997e9b400 100644 --- a/src/diffusers/models/unet_2d_condition.py +++ b/src/diffusers/models/unet_2d_condition.py @@ -142,7 +142,7 @@ def __init__( num_class_embeds: Optional[int] = None, upcast_attention: bool = False, resnet_time_scale_shift: str = "default", - time_embedding_type: str = "positional", # fourier, positional + time_embedding_type: str = "positional", timestep_post_act: Optional[str] = None, time_cond_proj_dim: Optional[int] = None, conv_in_kernel: int = 3, @@ -492,6 +492,8 @@ def forward( timestep_cond: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, cross_attention_kwargs: Optional[Dict[str, Any]] = None, + down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None, + mid_block_additional_residual: Optional[torch.Tensor] = None, return_dict: bool = True, ) -> Union[UNet2DConditionOutput, Tuple]: r""" @@ -589,6 +591,17 @@ def forward( down_block_res_samples += res_samples + if down_block_additional_residuals is not None: + new_down_block_res_samples = () + + for down_block_res_sample, down_block_additional_residual in zip( + down_block_res_samples, down_block_additional_residuals + ): + down_block_res_sample += down_block_additional_residual + new_down_block_res_samples += (down_block_res_sample,) + + down_block_res_samples = new_down_block_res_samples + # 4. mid if self.mid_block is not None: sample = self.mid_block( @@ -599,6 +612,9 @@ def forward( cross_attention_kwargs=cross_attention_kwargs, ) + if mid_block_additional_residual is not None: + sample += mid_block_additional_residual + # 5. up for i, upsample_block in enumerate(self.up_blocks): is_final_block = i == len(self.up_blocks) - 1 @@ -625,6 +641,7 @@ def forward( sample = upsample_block( hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, upsample_size=upsample_size ) + # 6. post-process if self.conv_norm_out: sample = self.conv_norm_out(sample) diff --git a/src/diffusers/pipelines/__init__.py b/src/diffusers/pipelines/__init__.py index 2888eb6ced0d..f5b7026fdc85 100644 --- a/src/diffusers/pipelines/__init__.py +++ b/src/diffusers/pipelines/__init__.py @@ -48,6 +48,7 @@ from .stable_diffusion import ( CycleDiffusionPipeline, StableDiffusionAttendAndExcitePipeline, + StableDiffusionControlNetPipeline, StableDiffusionDepth2ImgPipeline, StableDiffusionImageVariationPipeline, StableDiffusionImg2ImgPipeline, diff --git a/src/diffusers/pipelines/stable_diffusion/__init__.py b/src/diffusers/pipelines/stable_diffusion/__init__.py index 0059ba91954b..c20394845e86 100644 --- a/src/diffusers/pipelines/stable_diffusion/__init__.py +++ b/src/diffusers/pipelines/stable_diffusion/__init__.py @@ -45,6 +45,7 @@ class StableDiffusionPipelineOutput(BaseOutput): from .pipeline_cycle_diffusion import CycleDiffusionPipeline from .pipeline_stable_diffusion import StableDiffusionPipeline from .pipeline_stable_diffusion_attend_and_excite import StableDiffusionAttendAndExcitePipeline + from .pipeline_stable_diffusion_controlnet import StableDiffusionControlNetPipeline from .pipeline_stable_diffusion_img2img import StableDiffusionImg2ImgPipeline from .pipeline_stable_diffusion_inpaint import StableDiffusionInpaintPipeline from .pipeline_stable_diffusion_inpaint_legacy import StableDiffusionInpaintPipelineLegacy diff --git a/src/diffusers/pipelines/stable_diffusion/convert_from_ckpt.py b/src/diffusers/pipelines/stable_diffusion/convert_from_ckpt.py index 72acbe5462a9..a46121f0c9cd 100644 --- a/src/diffusers/pipelines/stable_diffusion/convert_from_ckpt.py +++ b/src/diffusers/pipelines/stable_diffusion/convert_from_ckpt.py @@ -34,6 +34,7 @@ from diffusers import ( AutoencoderKL, + ControlNetModel, DDIMScheduler, DDPMScheduler, DPMSolverMultistepScheduler, @@ -44,6 +45,7 @@ LMSDiscreteScheduler, PNDMScheduler, PriorTransformer, + StableDiffusionControlNetPipeline, StableDiffusionPipeline, StableUnCLIPImg2ImgPipeline, StableUnCLIPPipeline, @@ -224,11 +226,15 @@ def conv_attn_to_linear(checkpoint): checkpoint[key] = checkpoint[key][:, :, 0] -def create_unet_diffusers_config(original_config, image_size: int): +def create_unet_diffusers_config(original_config, image_size: int, controlnet=False): """ Creates a config for the diffusers based on the config of the LDM model. """ - unet_params = original_config.model.params.unet_config.params + if controlnet: + unet_params = original_config.model.params.control_stage_config.params + else: + unet_params = original_config.model.params.unet_config.params + vae_params = original_config.model.params.first_stage_config.params.ddconfig block_out_channels = [unet_params.model_channels * mult for mult in unet_params.channel_mult] @@ -272,9 +278,7 @@ def create_unet_diffusers_config(original_config, image_size: int): config = dict( sample_size=image_size // vae_scale_factor, in_channels=unet_params.in_channels, - out_channels=unet_params.out_channels, down_block_types=tuple(down_block_types), - up_block_types=tuple(up_block_types), block_out_channels=tuple(block_out_channels), layers_per_block=unet_params.num_res_blocks, cross_attention_dim=unet_params.context_dim, @@ -284,6 +288,10 @@ def create_unet_diffusers_config(original_config, image_size: int): projection_class_embeddings_input_dim=projection_class_embeddings_input_dim, ) + if not controlnet: + config["out_channels"] = unet_params.out_channels + config["up_block_types"] = tuple(up_block_types) + return config @@ -331,7 +339,7 @@ def create_ldm_bert_config(original_config): return config -def convert_ldm_unet_checkpoint(checkpoint, config, path=None, extract_ema=False): +def convert_ldm_unet_checkpoint(checkpoint, config, path=None, extract_ema=False, controlnet=False): """ Takes a state dict and a config, and returns a converted checkpoint. """ @@ -340,7 +348,11 @@ def convert_ldm_unet_checkpoint(checkpoint, config, path=None, extract_ema=False unet_state_dict = {} keys = list(checkpoint.keys()) - unet_key = "model.diffusion_model." + if controlnet: + unet_key = "control_model." + else: + unet_key = "model.diffusion_model." + # at least a 100 parameters have to start with `model_ema` in order for the checkpoint to be EMA if sum(k.startswith("model_ema") for k in keys) > 100 and extract_ema: print(f"Checkpoint {path} has both EMA and non-EMA weights.") @@ -384,10 +396,11 @@ def convert_ldm_unet_checkpoint(checkpoint, config, path=None, extract_ema=False new_checkpoint["conv_in.weight"] = unet_state_dict["input_blocks.0.0.weight"] new_checkpoint["conv_in.bias"] = unet_state_dict["input_blocks.0.0.bias"] - new_checkpoint["conv_norm_out.weight"] = unet_state_dict["out.0.weight"] - new_checkpoint["conv_norm_out.bias"] = unet_state_dict["out.0.bias"] - new_checkpoint["conv_out.weight"] = unet_state_dict["out.2.weight"] - new_checkpoint["conv_out.bias"] = unet_state_dict["out.2.bias"] + if not controlnet: + new_checkpoint["conv_norm_out.weight"] = unet_state_dict["out.0.weight"] + new_checkpoint["conv_norm_out.bias"] = unet_state_dict["out.0.bias"] + new_checkpoint["conv_out.weight"] = unet_state_dict["out.2.weight"] + new_checkpoint["conv_out.bias"] = unet_state_dict["out.2.bias"] # Retrieves the keys for the input blocks only num_input_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "input_blocks" in layer}) @@ -512,6 +525,48 @@ def convert_ldm_unet_checkpoint(checkpoint, config, path=None, extract_ema=False new_checkpoint[new_path] = unet_state_dict[old_path] + if controlnet: + # conditioning embedding + + orig_index = 0 + + new_checkpoint["controlnet_cond_embedding.conv_in.weight"] = unet_state_dict.pop( + f"input_hint_block.{orig_index}.weight" + ) + new_checkpoint["controlnet_cond_embedding.conv_in.bias"] = unet_state_dict.pop( + f"input_hint_block.{orig_index}.bias" + ) + + orig_index += 2 + + diffusers_index = 0 + + while diffusers_index < 6: + new_checkpoint[f"controlnet_cond_embedding.blocks.{diffusers_index}.weight"] = unet_state_dict.pop( + f"input_hint_block.{orig_index}.weight" + ) + new_checkpoint[f"controlnet_cond_embedding.blocks.{diffusers_index}.bias"] = unet_state_dict.pop( + f"input_hint_block.{orig_index}.bias" + ) + diffusers_index += 1 + orig_index += 2 + + new_checkpoint["controlnet_cond_embedding.conv_out.weight"] = unet_state_dict.pop( + f"input_hint_block.{orig_index}.weight" + ) + new_checkpoint["controlnet_cond_embedding.conv_out.bias"] = unet_state_dict.pop( + f"input_hint_block.{orig_index}.bias" + ) + + # down blocks + for i in range(num_input_blocks): + new_checkpoint[f"controlnet_down_blocks.{i}.weight"] = unet_state_dict.pop(f"zero_convs.{i}.0.weight") + new_checkpoint[f"controlnet_down_blocks.{i}.bias"] = unet_state_dict.pop(f"zero_convs.{i}.0.bias") + + # mid block + new_checkpoint["controlnet_mid_block.weight"] = unet_state_dict.pop("middle_block_out.0.weight") + new_checkpoint["controlnet_mid_block.bias"] = unet_state_dict.pop("middle_block_out.0.bias") + return new_checkpoint @@ -912,6 +967,7 @@ def load_pipeline_from_original_stable_diffusion_ckpt( stable_unclip: Optional[str] = None, stable_unclip_prior: Optional[str] = None, clip_stats_path: Optional[str] = None, + controlnet: Optional[bool] = None, ) -> StableDiffusionPipeline: """ Load a Stable Diffusion pipeline object from a CompVis-style `.ckpt`/`.safetensors` file and (ideally) a `.yaml` @@ -1093,6 +1149,12 @@ def load_pipeline_from_original_stable_diffusion_ckpt( model_type = original_config.model.params.cond_stage_config.target.split(".")[-1] logger.debug(f"no `model_type` given, `model_type` inferred as: {model_type}") + if controlnet is None: + controlnet = "control_stage_config" in original_config.model.params + + if controlnet and model_type != "FrozenCLIPEmbedder": + raise ValueError("`controlnet`=True only supports `model_type`='FrozenCLIPEmbedder'") + if model_type == "FrozenOpenCLIPEmbedder": text_model = convert_open_clip_checkpoint(checkpoint) tokenizer = CLIPTokenizer.from_pretrained("stabilityai/stable-diffusion-2", subfolder="tokenizer") @@ -1180,15 +1242,41 @@ def load_pipeline_from_original_stable_diffusion_ckpt( tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14") safety_checker = StableDiffusionSafetyChecker.from_pretrained("CompVis/stable-diffusion-safety-checker") feature_extractor = AutoFeatureExtractor.from_pretrained("CompVis/stable-diffusion-safety-checker") - pipe = StableDiffusionPipeline( - vae=vae, - text_encoder=text_model, - tokenizer=tokenizer, - unet=unet, - scheduler=scheduler, - safety_checker=safety_checker, - feature_extractor=feature_extractor, - ) + + if controlnet: + # Convert the ControlNetModel model. + ctrlnet_config = create_unet_diffusers_config(original_config, image_size=image_size, controlnet=True) + ctrlnet_config["upcast_attention"] = upcast_attention + + ctrlnet_config.pop("sample_size") + + controlnet_model = ControlNetModel(**ctrlnet_config) + + converted_ctrl_checkpoint = convert_ldm_unet_checkpoint( + checkpoint, ctrlnet_config, path=checkpoint_path, extract_ema=extract_ema, controlnet=True + ) + controlnet_model.load_state_dict(converted_ctrl_checkpoint) + + pipe = StableDiffusionControlNetPipeline( + vae=vae, + text_encoder=text_model, + tokenizer=tokenizer, + unet=unet, + controlnet=controlnet_model, + scheduler=scheduler, + safety_checker=safety_checker, + feature_extractor=feature_extractor, + ) + else: + pipe = StableDiffusionPipeline( + vae=vae, + text_encoder=text_model, + tokenizer=tokenizer, + unet=unet, + scheduler=scheduler, + safety_checker=safety_checker, + feature_extractor=feature_extractor, + ) else: text_config = create_ldm_bert_config(original_config) text_model = convert_ldm_bert_checkpoint(checkpoint, text_config) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_controlnet.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_controlnet.py new file mode 100644 index 000000000000..00f35ee4d021 --- /dev/null +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_controlnet.py @@ -0,0 +1,820 @@ +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import inspect +from typing import Any, Callable, Dict, List, Optional, Union + +import numpy as np +import PIL.Image +import torch +from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer + +from ...models import AutoencoderKL, ControlNetModel, UNet2DConditionModel +from ...schedulers import KarrasDiffusionSchedulers +from ...utils import ( + PIL_INTERPOLATION, + is_accelerate_available, + is_accelerate_version, + logging, + randn_tensor, + replace_example_docstring, +) +from ..pipeline_utils import DiffusionPipeline +from . import StableDiffusionPipelineOutput +from .safety_checker import StableDiffusionSafetyChecker + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> # !pip install opencv-python transformers accelerate + >>> from diffusers import StableDiffusionControlNetPipeline, ControlNetModel, UniPCMultistepScheduler + >>> from diffusers.utils import load_image + >>> import numpy as np + >>> import torch + + >>> import cv2 + >>> from PIL import Image + + >>> # download an image + >>> image = load_image( + ... "https://hf.co/datasets/huggingface/documentation-images/resolve/main/diffusers/input_image_vermeer.png" + ... ) + >>> image = np.array(image) + + >>> # get canny image + >>> image = cv2.Canny(image, 100, 200) + >>> image = image[:, :, None] + >>> image = np.concatenate([image, image, image], axis=2) + >>> canny_image = Image.fromarray(image) + + >>> # load control net and stable diffusion v1-5 + >>> controlnet = ControlNetModel.from_pretrained("lllyasviel/sd-controlnet-canny", torch_dtype=torch.float16) + >>> pipe = StableDiffusionControlNetPipeline.from_pretrained( + ... "runwayml/stable-diffusion-v1-5", controlnet=controlnet, torch_dtype=torch.float16 + ... ) + + >>> # speed up diffusion process with faster scheduler and memory optimization + >>> pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config) + >>> # remove following line if xformers is not installed + >>> pipe.enable_xformers_memory_efficient_attention() + + >>> pipe.enable_model_cpu_offload() + + >>> # generate image + >>> generator = torch.manual_seed(0) + >>> image = pipe( + ... "futuristic-looking woman", num_inference_steps=20, generator=generator, image=canny_image + ... ).images[0] + ``` +""" + + +class StableDiffusionControlNetPipeline(DiffusionPipeline): + r""" + Pipeline for text-to-image generation using Stable Diffusion with ControlNet guidance. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + + Args: + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. + text_encoder ([`CLIPTextModel`]): + Frozen text-encoder. Stable Diffusion uses the text portion of + [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically + the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. + tokenizer (`CLIPTokenizer`): + Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). + unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents. + controlnet ([`ControlNetModel`]): + Provides additional conditioning to the unet during the denoising process + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of + [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. + safety_checker ([`StableDiffusionSafetyChecker`]): + Classification module that estimates whether generated images could be considered offensive or harmful. + Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details. + feature_extractor ([`CLIPFeatureExtractor`]): + Model that extracts features from generated images to be used as inputs for the `safety_checker`. + """ + _optional_components = ["safety_checker", "feature_extractor"] + + def __init__( + self, + vae: AutoencoderKL, + text_encoder: CLIPTextModel, + tokenizer: CLIPTokenizer, + unet: UNet2DConditionModel, + controlnet: ControlNetModel, + scheduler: KarrasDiffusionSchedulers, + safety_checker: StableDiffusionSafetyChecker, + feature_extractor: CLIPFeatureExtractor, + requires_safety_checker: bool = True, + ): + super().__init__() + + if safety_checker is None and requires_safety_checker: + logger.warning( + f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" + " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered" + " results in services or applications open to the public. Both the diffusers team and Hugging Face" + " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling" + " it only for use-cases that involve analyzing network behavior or auditing its results. For more" + " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ." + ) + + if safety_checker is not None and feature_extractor is None: + raise ValueError( + "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety" + " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." + ) + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + unet=unet, + controlnet=controlnet, + scheduler=scheduler, + safety_checker=safety_checker, + feature_extractor=feature_extractor, + ) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + self.register_to_config(requires_safety_checker=requires_safety_checker) + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_slicing + def enable_vae_slicing(self): + r""" + Enable sliced VAE decoding. + + When this option is enabled, the VAE will split the input tensor in slices to compute decoding in several + steps. This is useful to save some memory and allow larger batch sizes. + """ + self.vae.enable_slicing() + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_vae_slicing + def disable_vae_slicing(self): + r""" + Disable sliced VAE decoding. If `enable_vae_slicing` was previously invoked, this method will go back to + computing decoding in one step. + """ + self.vae.disable_slicing() + + def enable_sequential_cpu_offload(self, gpu_id=0): + r""" + Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet, + text_encoder, vae, controlnet, and safety checker have their state dicts saved to CPU and then are moved to a + `torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called. + Note that offloading happens on a submodule basis. Memory savings are higher than with + `enable_model_cpu_offload`, but performance is lower. + """ + if is_accelerate_available(): + from accelerate import cpu_offload + else: + raise ImportError("Please install accelerate via `pip install accelerate`") + + device = torch.device(f"cuda:{gpu_id}") + + for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae, self.controlnet]: + cpu_offload(cpu_offloaded_model, device) + + if self.safety_checker is not None: + cpu_offload(self.safety_checker, execution_device=device, offload_buffers=True) + + def enable_model_cpu_offload(self, gpu_id=0): + r""" + Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared + to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward` + method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with + `enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`. + """ + if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"): + from accelerate import cpu_offload_with_hook + else: + raise ImportError("`enable_model_offload` requires `accelerate v0.17.0` or higher.") + + device = torch.device(f"cuda:{gpu_id}") + + hook = None + for cpu_offloaded_model in [self.text_encoder, self.unet, self.vae]: + _, hook = cpu_offload_with_hook(cpu_offloaded_model, device, prev_module_hook=hook) + + if self.safety_checker is not None: + # the safety checker can offload the vae again + _, hook = cpu_offload_with_hook(self.safety_checker, device, prev_module_hook=hook) + + # control net hook has be manually offloaded as it alternates with unet + cpu_offload_with_hook(self.controlnet, device) + + # We'll offload the last model manually. + self.final_offload_hook = hook + + @property + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._execution_device + def _execution_device(self): + r""" + Returns the device on which the pipeline's models will be executed. After calling + `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module + hooks. + """ + if not hasattr(self.unet, "_hf_hook"): + return self.device + for module in self.unet.modules(): + if ( + hasattr(module, "_hf_hook") + and hasattr(module._hf_hook, "execution_device") + and module._hf_hook.execution_device is not None + ): + return torch.device(module._hf_hook.execution_device) + return self.device + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt + def _encode_prompt( + self, + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt=None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + device: (`torch.device`): + torch device + num_images_per_prompt (`int`): + number of images that should be generated per prompt + do_classifier_free_guidance (`bool`): + whether to use classifier free guidance or not + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds`. instead. If not defined, one has to pass `negative_prompt_embeds`. instead. + Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + """ + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if prompt_embeds is None: + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( + text_input_ids, untruncated_ids + ): + removed_text = self.tokenizer.batch_decode( + untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] + ) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {self.tokenizer.model_max_length} tokens: {removed_text}" + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = text_inputs.attention_mask.to(device) + else: + attention_mask = None + + prompt_embeds = self.text_encoder( + text_input_ids.to(device), + attention_mask=attention_mask, + ) + prompt_embeds = prompt_embeds[0] + + prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) + + bs_embed, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) + + # get unconditional embeddings for classifier free guidance + if do_classifier_free_guidance and negative_prompt_embeds is None: + uncond_tokens: List[str] + if negative_prompt is None: + uncond_tokens = [""] * batch_size + elif type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif isinstance(negative_prompt, str): + uncond_tokens = [negative_prompt] + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = negative_prompt + + max_length = prompt_embeds.shape[1] + uncond_input = self.tokenizer( + uncond_tokens, + padding="max_length", + max_length=max_length, + truncation=True, + return_tensors="pt", + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = uncond_input.attention_mask.to(device) + else: + attention_mask = None + + negative_prompt_embeds = self.text_encoder( + uncond_input.input_ids.to(device), + attention_mask=attention_mask, + ) + negative_prompt_embeds = negative_prompt_embeds[0] + + if do_classifier_free_guidance: + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = negative_prompt_embeds.shape[1] + + negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) + + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) + + return prompt_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker + def run_safety_checker(self, image, device, dtype): + if self.safety_checker is not None: + safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(device) + image, has_nsfw_concept = self.safety_checker( + images=image, clip_input=safety_checker_input.pixel_values.to(dtype) + ) + else: + has_nsfw_concept = None + return image, has_nsfw_concept + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents + def decode_latents(self, latents): + latents = 1 / self.vae.config.scaling_factor * latents + image = self.vae.decode(latents).sample + image = (image / 2 + 0.5).clamp(0, 1) + # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 + image = image.cpu().permute(0, 2, 3, 1).float().numpy() + return image + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + def check_inputs( + self, + prompt, + image, + height, + width, + callback_steps, + negative_prompt=None, + prompt_embeds=None, + negative_prompt_embeds=None, + ): + if height % 8 != 0 or width % 8 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + + if (callback_steps is None) or ( + callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) + ): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" + f" {type(callback_steps)}." + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + image_is_pil = isinstance(image, PIL.Image.Image) + image_is_tensor = isinstance(image, torch.Tensor) + image_is_pil_list = isinstance(image, list) and isinstance(image[0], PIL.Image.Image) + image_is_tensor_list = isinstance(image, list) and isinstance(image[0], torch.Tensor) + + if not image_is_pil and not image_is_tensor and not image_is_pil_list and not image_is_tensor_list: + raise TypeError( + "image must be passed and be one of PIL image, torch tensor, list of PIL images, or list of torch tensors" + ) + + if image_is_pil: + image_batch_size = 1 + elif image_is_tensor: + image_batch_size = image.shape[0] + elif image_is_pil_list: + image_batch_size = len(image) + elif image_is_tensor_list: + image_batch_size = len(image) + + if prompt is not None and isinstance(prompt, str): + prompt_batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + prompt_batch_size = len(prompt) + elif prompt_embeds is not None: + prompt_batch_size = prompt_embeds.shape[0] + + if image_batch_size != 1 and image_batch_size != prompt_batch_size: + raise ValueError( + f"If image batch size is not 1, image batch size must be same as prompt batch size. image batch size: {image_batch_size}, prompt batch size: {prompt_batch_size}" + ) + + def prepare_image(self, image, width, height, batch_size, num_images_per_prompt, device, dtype): + if not isinstance(image, torch.Tensor): + if isinstance(image, PIL.Image.Image): + image = [image] + + if isinstance(image[0], PIL.Image.Image): + image = [ + np.array(i.resize((width, height), resample=PIL_INTERPOLATION["lanczos"]))[None, :] for i in image + ] + image = np.concatenate(image, axis=0) + image = np.array(image).astype(np.float32) / 255.0 + image = image.transpose(0, 3, 1, 2) + image = torch.from_numpy(image) + elif isinstance(image[0], torch.Tensor): + image = torch.cat(image, dim=0) + + image_batch_size = image.shape[0] + + if image_batch_size == 1: + repeat_by = batch_size + else: + # image batch size is the same as prompt batch size + repeat_by = num_images_per_prompt + + image = image.repeat_interleave(repeat_by, dim=0) + + image = image.to(device=device, dtype=dtype) + + return image + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents + def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): + shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + latents = latents.to(device) + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.scheduler.init_noise_sigma + return latents + + def _default_height_width(self, height, width, image): + if isinstance(image, list): + image = image[0] + + if height is None: + if isinstance(image, PIL.Image.Image): + height = image.height + elif isinstance(image, torch.Tensor): + height = image.shape[3] + + height = (height // 8) * 8 # round down to nearest multiple of 8 + + if width is None: + if isinstance(image, PIL.Image.Image): + width = image.width + elif isinstance(image, torch.Tensor): + width = image.shape[2] + + width = (width // 8) * 8 # round down to nearest multiple of 8 + + return height, width + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: Union[str, List[str]] = None, + image: Union[torch.FloatTensor, PIL.Image.Image, List[torch.FloatTensor], List[PIL.Image.Image]] = None, + height: Optional[int] = None, + width: Optional[int] = None, + num_inference_steps: int = 50, + guidance_scale: float = 7.5, + negative_prompt: Optional[Union[str, List[str]]] = None, + num_images_per_prompt: Optional[int] = 1, + eta: float = 0.0, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.FloatTensor] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, + callback_steps: int = 1, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + controlnet_conditioning_scale: float = 1.0, + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + image (`torch.FloatTensor`, `PIL.Image.Image`, `List[torch.FloatTensor]` or `List[PIL.Image.Image]`): + The ControlNet input condition. ControlNet uses this input condition to generate guidance to Unet. If + the type is specified as `Torch.FloatTensor`, it is passed to ControlNet as is. PIL.Image.Image` can + also be accepted as an image. The control image is automatically resized to fit the output image. + height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The width in pixels of the generated image. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + guidance_scale (`float`, *optional*, defaults to 7.5): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds`. instead. If not defined, one has to pass `negative_prompt_embeds`. instead. + Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to + [`schedulers.DDIMScheduler`], will be ignored for others. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will ge generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a + plain tuple. + callback (`Callable`, *optional*): + A function that will be called every `callback_steps` steps during inference. The function will be + called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function will be called. If not specified, the callback will be + called at every step. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttnProcessor` as defined under + `self.processor` in + [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py). + controlnet_conditioning_scale (`float`, *optional*, defaults to 1.0): + The outputs of the controlnet are multiplied by `controlnet_conditioning_scale` before they are added + to the residual in the original unet. + + Examples: + + Returns: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple. + When returning a tuple, the first element is a list with the generated images, and the second element is a + list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" + (nsfw) content, according to the `safety_checker`. + """ + # 0. Default height and width to unet + height, width = self._default_height_width(height, width, image) + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, image, height, width, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds + ) + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + + # 3. Encode input prompt + prompt_embeds = self._encode_prompt( + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + ) + + # 4. Prepare image + image = self.prepare_image( + image, + width, + height, + batch_size * num_images_per_prompt, + num_images_per_prompt, + device, + self.controlnet.dtype, + ) + + if do_classifier_free_guidance: + image = torch.cat([image] * 2) + + # 5. Prepare timesteps + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps = self.scheduler.timesteps + + # 6. Prepare latent variables + num_channels_latents = self.unet.in_channels + latents = self.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + ) + + # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # 8. Denoising loop + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + down_block_res_samples, mid_block_res_sample = self.controlnet( + latent_model_input, + t, + encoder_hidden_states=prompt_embeds, + controlnet_cond=image, + return_dict=False, + ) + + down_block_res_samples = [ + down_block_res_sample * controlnet_conditioning_scale + for down_block_res_sample in down_block_res_samples + ] + mid_block_res_sample *= controlnet_conditioning_scale + + # predict the noise residual + noise_pred = self.unet( + latent_model_input, + t, + encoder_hidden_states=prompt_embeds, + cross_attention_kwargs=cross_attention_kwargs, + down_block_additional_residuals=down_block_res_samples, + mid_block_additional_residual=mid_block_res_sample, + ).sample + + # perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + if callback is not None and i % callback_steps == 0: + callback(i, t, latents) + + # If we do sequential model offloading, let's offload unet and controlnet + # manually for max memory savings + if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: + self.unet.to("cpu") + self.controlnet.to("cpu") + torch.cuda.empty_cache() + + if output_type == "latent": + image = latents + has_nsfw_concept = None + elif output_type == "pil": + # 8. Post-processing + image = self.decode_latents(latents) + + # 9. Run safety checker + image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) + + # 10. Convert to PIL + image = self.numpy_to_pil(image) + else: + # 8. Post-processing + image = self.decode_latents(latents) + + # 9. Run safety checker + image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) + + # Offload last model to CPU + if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: + self.final_offload_hook.offload() + + if not return_dict: + return (image, has_nsfw_concept) + + return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) diff --git a/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py b/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py index 4adf9eed0e29..261135d33318 100644 --- a/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py +++ b/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py @@ -228,7 +228,7 @@ def __init__( num_class_embeds: Optional[int] = None, upcast_attention: bool = False, resnet_time_scale_shift: str = "default", - time_embedding_type: str = "positional", # fourier, positional + time_embedding_type: str = "positional", timestep_post_act: Optional[str] = None, time_cond_proj_dim: Optional[int] = None, conv_in_kernel: int = 3, @@ -582,6 +582,8 @@ def forward( timestep_cond: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, cross_attention_kwargs: Optional[Dict[str, Any]] = None, + down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None, + mid_block_additional_residual: Optional[torch.Tensor] = None, return_dict: bool = True, ) -> Union[UNet2DConditionOutput, Tuple]: r""" @@ -679,6 +681,17 @@ def forward( down_block_res_samples += res_samples + if down_block_additional_residuals is not None: + new_down_block_res_samples = () + + for down_block_res_sample, down_block_additional_residual in zip( + down_block_res_samples, down_block_additional_residuals + ): + down_block_res_sample += down_block_additional_residual + new_down_block_res_samples += (down_block_res_sample,) + + down_block_res_samples = new_down_block_res_samples + # 4. mid if self.mid_block is not None: sample = self.mid_block( @@ -689,6 +702,9 @@ def forward( cross_attention_kwargs=cross_attention_kwargs, ) + if mid_block_additional_residual is not None: + sample += mid_block_additional_residual + # 5. up for i, upsample_block in enumerate(self.up_blocks): is_final_block = i == len(self.up_blocks) - 1 @@ -715,6 +731,7 @@ def forward( sample = upsample_block( hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, upsample_size=upsample_size ) + # 6. post-process if self.conv_norm_out: sample = self.conv_norm_out(sample) diff --git a/src/diffusers/utils/dummy_pt_objects.py b/src/diffusers/utils/dummy_pt_objects.py index 67e537a63596..c731a1f1ddf3 100644 --- a/src/diffusers/utils/dummy_pt_objects.py +++ b/src/diffusers/utils/dummy_pt_objects.py @@ -17,6 +17,21 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) +class ControlNetModel(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + class ModelMixin(metaclass=DummyObject): _backends = ["torch"] diff --git a/src/diffusers/utils/dummy_torch_and_transformers_objects.py b/src/diffusers/utils/dummy_torch_and_transformers_objects.py index 24731a00cabe..1b0f812ad16c 100644 --- a/src/diffusers/utils/dummy_torch_and_transformers_objects.py +++ b/src/diffusers/utils/dummy_torch_and_transformers_objects.py @@ -107,6 +107,21 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) +class StableDiffusionControlNetPipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + class StableDiffusionDepth2ImgPipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] diff --git a/tests/pipelines/stable_diffusion/test_stable_diffusion_controlnet.py b/tests/pipelines/stable_diffusion/test_stable_diffusion_controlnet.py new file mode 100644 index 000000000000..406cbc9ad089 --- /dev/null +++ b/tests/pipelines/stable_diffusion/test_stable_diffusion_controlnet.py @@ -0,0 +1,398 @@ +# coding=utf-8 +# Copyright 2023 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import gc +import unittest + +import numpy as np +import torch +from transformers import CLIPTextConfig, CLIPTextModel, CLIPTokenizer + +from diffusers import ( + AutoencoderKL, + ControlNetModel, + DDIMScheduler, + StableDiffusionControlNetPipeline, + UNet2DConditionModel, +) +from diffusers.utils import load_image, load_numpy, randn_tensor, slow, torch_device +from diffusers.utils.import_utils import is_xformers_available +from diffusers.utils.testing_utils import require_torch_gpu + +from ...pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_PARAMS +from ...test_pipelines_common import PipelineTesterMixin + + +class StableDiffusionControlNetPipelineFastTests(PipelineTesterMixin, unittest.TestCase): + pipeline_class = StableDiffusionControlNetPipeline + params = TEXT_TO_IMAGE_PARAMS + batch_params = TEXT_TO_IMAGE_BATCH_PARAMS + + def get_dummy_components(self): + torch.manual_seed(0) + unet = UNet2DConditionModel( + block_out_channels=(32, 64), + layers_per_block=2, + sample_size=32, + in_channels=4, + out_channels=4, + down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"), + up_block_types=("CrossAttnUpBlock2D", "UpBlock2D"), + cross_attention_dim=32, + ) + torch.manual_seed(0) + controlnet = ControlNetModel( + block_out_channels=(32, 64), + layers_per_block=2, + in_channels=4, + down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"), + cross_attention_dim=32, + conditioning_embedding_out_channels=(16, 32), + ) + torch.manual_seed(0) + scheduler = DDIMScheduler( + beta_start=0.00085, + beta_end=0.012, + beta_schedule="scaled_linear", + clip_sample=False, + set_alpha_to_one=False, + ) + torch.manual_seed(0) + vae = AutoencoderKL( + block_out_channels=[32, 64], + in_channels=3, + out_channels=3, + down_block_types=["DownEncoderBlock2D", "DownEncoderBlock2D"], + up_block_types=["UpDecoderBlock2D", "UpDecoderBlock2D"], + latent_channels=4, + ) + torch.manual_seed(0) + text_encoder_config = CLIPTextConfig( + bos_token_id=0, + eos_token_id=2, + hidden_size=32, + intermediate_size=37, + layer_norm_eps=1e-05, + num_attention_heads=4, + num_hidden_layers=5, + pad_token_id=1, + vocab_size=1000, + ) + text_encoder = CLIPTextModel(text_encoder_config) + tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip") + + components = { + "unet": unet, + "controlnet": controlnet, + "scheduler": scheduler, + "vae": vae, + "text_encoder": text_encoder, + "tokenizer": tokenizer, + "safety_checker": None, + "feature_extractor": None, + } + return components + + def get_dummy_inputs(self, device, seed=0): + if str(device).startswith("mps"): + generator = torch.manual_seed(seed) + else: + generator = torch.Generator(device=device).manual_seed(seed) + + controlnet_embedder_scale_factor = 2 + image = randn_tensor( + (1, 3, 32 * controlnet_embedder_scale_factor, 32 * controlnet_embedder_scale_factor), + generator=generator, + device=torch.device(device), + ) + + inputs = { + "prompt": "A painting of a squirrel eating a burger", + "generator": generator, + "num_inference_steps": 2, + "guidance_scale": 6.0, + "output_type": "numpy", + "image": image, + } + + return inputs + + def test_attention_slicing_forward_pass(self): + return self._test_attention_slicing_forward_pass(expected_max_diff=2e-3) + + @unittest.skipIf( + torch_device != "cuda" or not is_xformers_available(), + reason="XFormers attention is only available with CUDA and `xformers` installed", + ) + def test_xformers_attention_forwardGenerator_pass(self): + self._test_xformers_attention_forwardGenerator_pass(expected_max_diff=2e-3) + + def test_inference_batch_single_identical(self): + self._test_inference_batch_single_identical(expected_max_diff=2e-3) + + +@slow +@require_torch_gpu +class StableDiffusionControlNetPipelineSlowTests(unittest.TestCase): + def tearDown(self): + super().tearDown() + gc.collect() + torch.cuda.empty_cache() + + def test_canny(self): + controlnet = ControlNetModel.from_pretrained("lllyasviel/sd-controlnet-canny") + + pipe = StableDiffusionControlNetPipeline.from_pretrained( + "runwayml/stable-diffusion-v1-5", safety_checker=None, controlnet=controlnet + ) + pipe.enable_model_cpu_offload() + pipe.set_progress_bar_config(disable=None) + + generator = torch.Generator(device="cpu").manual_seed(0) + prompt = "bird" + image = load_image( + "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd_controlnet/bird_canny.png" + ) + + output = pipe(prompt, image, generator=generator, output_type="np") + + image = output.images[0] + + assert image.shape == (768, 512, 3) + + expected_image = load_numpy( + "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd_controlnet/bird_canny_out.npy" + ) + + assert np.abs(expected_image - image).max() < 5e-3 + + def test_depth(self): + controlnet = ControlNetModel.from_pretrained("lllyasviel/sd-controlnet-depth") + + pipe = StableDiffusionControlNetPipeline.from_pretrained( + "runwayml/stable-diffusion-v1-5", safety_checker=None, controlnet=controlnet + ) + pipe.enable_model_cpu_offload() + pipe.set_progress_bar_config(disable=None) + + generator = torch.Generator(device="cpu").manual_seed(0) + prompt = "Stormtrooper's lecture" + image = load_image( + "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd_controlnet/stormtrooper_depth.png" + ) + + output = pipe(prompt, image, generator=generator, output_type="np") + + image = output.images[0] + + assert image.shape == (512, 512, 3) + + expected_image = load_numpy( + "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd_controlnet/stormtrooper_depth_out.npy" + ) + + assert np.abs(expected_image - image).max() < 5e-3 + + def test_hed(self): + controlnet = ControlNetModel.from_pretrained("lllyasviel/sd-controlnet-hed") + + pipe = StableDiffusionControlNetPipeline.from_pretrained( + "runwayml/stable-diffusion-v1-5", safety_checker=None, controlnet=controlnet + ) + pipe.enable_model_cpu_offload() + pipe.set_progress_bar_config(disable=None) + + generator = torch.Generator(device="cpu").manual_seed(0) + prompt = "oil painting of handsome old man, masterpiece" + image = load_image( + "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd_controlnet/man_hed.png" + ) + + output = pipe(prompt, image, generator=generator, output_type="np") + + image = output.images[0] + + assert image.shape == (704, 512, 3) + + expected_image = load_numpy( + "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd_controlnet/man_hed_out.npy" + ) + + assert np.abs(expected_image - image).max() < 5e-3 + + def test_mlsd(self): + controlnet = ControlNetModel.from_pretrained("lllyasviel/sd-controlnet-mlsd") + + pipe = StableDiffusionControlNetPipeline.from_pretrained( + "runwayml/stable-diffusion-v1-5", safety_checker=None, controlnet=controlnet + ) + pipe.enable_model_cpu_offload() + pipe.set_progress_bar_config(disable=None) + + generator = torch.Generator(device="cpu").manual_seed(0) + prompt = "room" + image = load_image( + "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd_controlnet/room_mlsd.png" + ) + + output = pipe(prompt, image, generator=generator, output_type="np") + + image = output.images[0] + + assert image.shape == (704, 512, 3) + + expected_image = load_numpy( + "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd_controlnet/room_mlsd_out.npy" + ) + + assert np.abs(expected_image - image).max() < 5e-3 + + def test_normal(self): + controlnet = ControlNetModel.from_pretrained("lllyasviel/sd-controlnet-normal") + + pipe = StableDiffusionControlNetPipeline.from_pretrained( + "runwayml/stable-diffusion-v1-5", safety_checker=None, controlnet=controlnet + ) + pipe.enable_model_cpu_offload() + pipe.set_progress_bar_config(disable=None) + + generator = torch.Generator(device="cpu").manual_seed(0) + prompt = "cute toy" + image = load_image( + "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd_controlnet/cute_toy_normal.png" + ) + + output = pipe(prompt, image, generator=generator, output_type="np") + + image = output.images[0] + + assert image.shape == (512, 512, 3) + + expected_image = load_numpy( + "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd_controlnet/cute_toy_normal_out.npy" + ) + + assert np.abs(expected_image - image).max() < 5e-3 + + def test_openpose(self): + controlnet = ControlNetModel.from_pretrained("lllyasviel/sd-controlnet-openpose") + + pipe = StableDiffusionControlNetPipeline.from_pretrained( + "runwayml/stable-diffusion-v1-5", safety_checker=None, controlnet=controlnet + ) + pipe.enable_model_cpu_offload() + pipe.set_progress_bar_config(disable=None) + + generator = torch.Generator(device="cpu").manual_seed(0) + prompt = "Chef in the kitchen" + image = load_image( + "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd_controlnet/pose.png" + ) + + output = pipe(prompt, image, generator=generator, output_type="np") + + image = output.images[0] + + assert image.shape == (768, 512, 3) + + expected_image = load_numpy( + "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd_controlnet/chef_pose_out.npy" + ) + + assert np.abs(expected_image - image).max() < 5e-3 + + def test_scribble(self): + controlnet = ControlNetModel.from_pretrained("lllyasviel/sd-controlnet-scribble") + + pipe = StableDiffusionControlNetPipeline.from_pretrained( + "runwayml/stable-diffusion-v1-5", safety_checker=None, controlnet=controlnet + ) + pipe.enable_model_cpu_offload() + pipe.set_progress_bar_config(disable=None) + + generator = torch.Generator(device="cpu").manual_seed(5) + prompt = "bag" + image = load_image( + "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd_controlnet/bag_scribble.png" + ) + + output = pipe(prompt, image, generator=generator, output_type="np") + + image = output.images[0] + + assert image.shape == (640, 512, 3) + + expected_image = load_numpy( + "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd_controlnet/bag_scribble_out.npy" + ) + + assert np.abs(expected_image - image).max() < 5e-3 + + def test_seg(self): + controlnet = ControlNetModel.from_pretrained("lllyasviel/sd-controlnet-seg") + + pipe = StableDiffusionControlNetPipeline.from_pretrained( + "runwayml/stable-diffusion-v1-5", safety_checker=None, controlnet=controlnet + ) + pipe.enable_model_cpu_offload() + pipe.set_progress_bar_config(disable=None) + + generator = torch.Generator(device="cpu").manual_seed(5) + prompt = "house" + image = load_image( + "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd_controlnet/house_seg.png" + ) + + output = pipe(prompt, image, generator=generator, output_type="np") + + image = output.images[0] + + assert image.shape == (512, 512, 3) + + expected_image = load_numpy( + "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd_controlnet/house_seg_out.npy" + ) + + assert np.abs(expected_image - image).max() < 5e-3 + + def test_sequential_cpu_offloading(self): + torch.cuda.empty_cache() + torch.cuda.reset_max_memory_allocated() + torch.cuda.reset_peak_memory_stats() + + controlnet = ControlNetModel.from_pretrained("lllyasviel/sd-controlnet-seg") + + pipe = StableDiffusionControlNetPipeline.from_pretrained( + "runwayml/stable-diffusion-v1-5", safety_checker=None, controlnet=controlnet + ) + pipe.set_progress_bar_config(disable=None) + pipe.enable_attention_slicing() + pipe.enable_sequential_cpu_offload() + + prompt = "house" + image = load_image( + "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd_controlnet/house_seg.png" + ) + + _ = pipe( + prompt, + image, + num_inference_steps=2, + output_type="np", + ) + + mem_bytes = torch.cuda.max_memory_allocated() + # make sure that less than 7 GB is allocated + assert mem_bytes < 4 * 10**9 diff --git a/tests/test_pipelines_common.py b/tests/test_pipelines_common.py index 0756340c1f3b..bdcf0edd43fd 100644 --- a/tests/test_pipelines_common.py +++ b/tests/test_pipelines_common.py @@ -240,6 +240,7 @@ def _test_inference_batch_single_identical( test_max_difference=None, test_mean_pixel_difference=None, relax_max_difference=False, + expected_max_diff=1e-4, additional_params_copy_to_batched_inputs=["num_inference_steps"], ): if test_max_difference is None: @@ -308,7 +309,7 @@ def _test_inference_batch_single_identical( max_diff = np.median(diff[-5:]) else: max_diff = np.abs(output_batch[0][0] - output[0][0]).max() - assert max_diff < 1e-4 + assert max_diff < expected_max_diff if test_mean_pixel_difference: assert_mean_pixel_difference(output_batch[0][0], output[0][0]) @@ -449,7 +450,7 @@ def test_to_device(self): def test_attention_slicing_forward_pass(self): self._test_attention_slicing_forward_pass() - def _test_attention_slicing_forward_pass(self, test_max_difference=True): + def _test_attention_slicing_forward_pass(self, test_max_difference=True, expected_max_diff=1e-3): if not self.test_attention_slicing: return @@ -471,7 +472,7 @@ def _test_attention_slicing_forward_pass(self, test_max_difference=True): if test_max_difference: max_diff = np.abs(output_with_slicing - output_without_slicing).max() - self.assertLess(max_diff, 1e-3, "Attention slicing should not affect the inference results") + self.assertLess(max_diff, expected_max_diff, "Attention slicing should not affect the inference results") assert_mean_pixel_difference(output_with_slicing[0], output_without_slicing[0]) @@ -505,7 +506,7 @@ def test_cpu_offload_forward_pass(self): def test_xformers_attention_forwardGenerator_pass(self): self._test_xformers_attention_forwardGenerator_pass() - def _test_xformers_attention_forwardGenerator_pass(self, test_max_difference=True): + def _test_xformers_attention_forwardGenerator_pass(self, test_max_difference=True, expected_max_diff=1e-4): if not self.test_xformers_attention: return @@ -523,7 +524,7 @@ def _test_xformers_attention_forwardGenerator_pass(self, test_max_difference=Tru if test_max_difference: max_diff = np.abs(output_with_offload - output_without_offload).max() - self.assertLess(max_diff, 1e-4, "XFormers attention should not affect the inference results") + self.assertLess(max_diff, expected_max_diff, "XFormers attention should not affect the inference results") assert_mean_pixel_difference(output_with_offload[0], output_without_offload[0])