diff --git a/docs/source/en/api/models/controlnet.md b/docs/source/en/api/models/controlnet.md
index 966a0e53b496..5d4cac6658cc 100644
--- a/docs/source/en/api/models/controlnet.md
+++ b/docs/source/en/api/models/controlnet.md
@@ -39,7 +39,7 @@ pipe = StableDiffusionControlNetPipeline.from_single_file(url, controlnet=contro
## ControlNetOutput
-[[autodoc]] models.controlnet.ControlNetOutput
+[[autodoc]] models.controlnets.controlnet.ControlNetOutput
## FlaxControlNetModel
@@ -47,4 +47,4 @@ pipe = StableDiffusionControlNetPipeline.from_single_file(url, controlnet=contro
## FlaxControlNetOutput
-[[autodoc]] models.controlnet_flax.FlaxControlNetOutput
+[[autodoc]] models.controlnets.controlnet_flax.FlaxControlNetOutput
diff --git a/docs/source/en/api/models/controlnet_sd3.md b/docs/source/en/api/models/controlnet_sd3.md
index 59db64546fa2..78564d238eea 100644
--- a/docs/source/en/api/models/controlnet_sd3.md
+++ b/docs/source/en/api/models/controlnet_sd3.md
@@ -38,5 +38,5 @@ pipe = StableDiffusion3ControlNetPipeline.from_pretrained("stabilityai/stable-di
## SD3ControlNetOutput
-[[autodoc]] models.controlnet_sd3.SD3ControlNetOutput
+[[autodoc]] models.controlnets.controlnet_sd3.SD3ControlNetOutput
diff --git a/examples/research_projects/promptdiffusion/promptdiffusioncontrolnet.py b/examples/research_projects/promptdiffusion/promptdiffusioncontrolnet.py
index 46cabd863dfa..6b1826a1c92d 100644
--- a/examples/research_projects/promptdiffusion/promptdiffusioncontrolnet.py
+++ b/examples/research_projects/promptdiffusion/promptdiffusioncontrolnet.py
@@ -229,11 +229,11 @@ def forward(
In this mode, the ControlNet encoder tries its best to recognize the input content of the input even if
you remove all prompts. A `guidance_scale` between 3.0 and 5.0 is recommended.
return_dict (`bool`, defaults to `True`):
- Whether or not to return a [`~models.controlnet.ControlNetOutput`] instead of a plain tuple.
+ Whether or not to return a [`~models.controlnets.controlnet.ControlNetOutput`] instead of a plain tuple.
Returns:
- [`~models.controlnet.ControlNetOutput`] **or** `tuple`:
- If `return_dict` is `True`, a [`~models.controlnet.ControlNetOutput`] is returned, otherwise a tuple is
+ [`~models.controlnets.controlnet.ControlNetOutput`] **or** `tuple`:
+ If `return_dict` is `True`, a [`~models.controlnets.controlnet.ControlNetOutput`] is returned, otherwise a tuple is
returned where the first element is the sample tensor.
"""
# check channel order
diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py
index fb6d22084bd6..533aa5de1e87 100644
--- a/src/diffusers/__init__.py
+++ b/src/diffusers/__init__.py
@@ -487,7 +487,7 @@
else:
- _import_structure["models.controlnet_flax"] = ["FlaxControlNetModel"]
+ _import_structure["models.controlnets.controlnet_flax"] = ["FlaxControlNetModel"]
_import_structure["models.modeling_flax_utils"] = ["FlaxModelMixin"]
_import_structure["models.unets.unet_2d_condition_flax"] = ["FlaxUNet2DConditionModel"]
_import_structure["models.vae_flax"] = ["FlaxAutoencoderKL"]
@@ -914,7 +914,7 @@
except OptionalDependencyNotAvailable:
from .utils.dummy_flax_objects import * # noqa F403
else:
- from .models.controlnet_flax import FlaxControlNetModel
+ from .models.controlnets.controlnet_flax import FlaxControlNetModel
from .models.modeling_flax_utils import FlaxModelMixin
from .models.unets.unet_2d_condition_flax import FlaxUNet2DConditionModel
from .models.vae_flax import FlaxAutoencoderKL
diff --git a/src/diffusers/models/__init__.py b/src/diffusers/models/__init__.py
index 518ab6df65c4..65e2418ac794 100644
--- a/src/diffusers/models/__init__.py
+++ b/src/diffusers/models/__init__.py
@@ -36,12 +36,16 @@
_import_structure["autoencoders.autoencoder_tiny"] = ["AutoencoderTiny"]
_import_structure["autoencoders.consistency_decoder_vae"] = ["ConsistencyDecoderVAE"]
_import_structure["autoencoders.vq_model"] = ["VQModel"]
- _import_structure["controlnet"] = ["ControlNetModel"]
- _import_structure["controlnet_flux"] = ["FluxControlNetModel", "FluxMultiControlNetModel"]
- _import_structure["controlnet_hunyuan"] = ["HunyuanDiT2DControlNetModel", "HunyuanDiT2DMultiControlNetModel"]
- _import_structure["controlnet_sd3"] = ["SD3ControlNetModel", "SD3MultiControlNetModel"]
- _import_structure["controlnet_sparsectrl"] = ["SparseControlNetModel"]
- _import_structure["controlnet_xs"] = ["ControlNetXSAdapter", "UNetControlNetXSModel"]
+ _import_structure["controlnets.controlnet"] = ["ControlNetModel"]
+ _import_structure["controlnets.controlnet_flux"] = ["FluxControlNetModel", "FluxMultiControlNetModel"]
+ _import_structure["controlnets.controlnet_hunyuan"] = [
+ "HunyuanDiT2DControlNetModel",
+ "HunyuanDiT2DMultiControlNetModel",
+ ]
+ _import_structure["controlnets.controlnet_sd3"] = ["SD3ControlNetModel", "SD3MultiControlNetModel"]
+ _import_structure["controlnets.controlnet_sparsectrl"] = ["SparseControlNetModel"]
+ _import_structure["controlnets.controlnet_xs"] = ["ControlNetXSAdapter", "UNetControlNetXSModel"]
+ _import_structure["controlnets.multicontrolnet"] = ["MultiControlNetModel"]
_import_structure["embeddings"] = ["ImageProjection"]
_import_structure["modeling_utils"] = ["ModelMixin"]
_import_structure["transformers.auraflow_transformer_2d"] = ["AuraFlowTransformer2DModel"]
@@ -74,7 +78,7 @@
_import_structure["unets.uvit_2d"] = ["UVit2DModel"]
if is_flax_available():
- _import_structure["controlnet_flax"] = ["FlaxControlNetModel"]
+ _import_structure["controlnets.controlnet_flax"] = ["FlaxControlNetModel"]
_import_structure["unets.unet_2d_condition_flax"] = ["FlaxUNet2DConditionModel"]
_import_structure["vae_flax"] = ["FlaxAutoencoderKL"]
@@ -94,12 +98,19 @@
ConsistencyDecoderVAE,
VQModel,
)
- from .controlnet import ControlNetModel
- from .controlnet_flux import FluxControlNetModel, FluxMultiControlNetModel
- from .controlnet_hunyuan import HunyuanDiT2DControlNetModel, HunyuanDiT2DMultiControlNetModel
- from .controlnet_sd3 import SD3ControlNetModel, SD3MultiControlNetModel
- from .controlnet_sparsectrl import SparseControlNetModel
- from .controlnet_xs import ControlNetXSAdapter, UNetControlNetXSModel
+ from .controlnets import (
+ ControlNetModel,
+ ControlNetXSAdapter,
+ FluxControlNetModel,
+ FluxMultiControlNetModel,
+ HunyuanDiT2DControlNetModel,
+ HunyuanDiT2DMultiControlNetModel,
+ MultiControlNetModel,
+ SD3ControlNetModel,
+ SD3MultiControlNetModel,
+ SparseControlNetModel,
+ UNetControlNetXSModel,
+ )
from .embeddings import ImageProjection
from .modeling_utils import ModelMixin
from .transformers import (
@@ -137,7 +148,7 @@
)
if is_flax_available():
- from .controlnet_flax import FlaxControlNetModel
+ from .controlnets import FlaxControlNetModel
from .unets import FlaxUNet2DConditionModel
from .vae_flax import FlaxAutoencoderKL
diff --git a/src/diffusers/models/controlnet.py b/src/diffusers/models/controlnet.py
index d3ae96605077..174f2b9ada96 100644
--- a/src/diffusers/models/controlnet.py
+++ b/src/diffusers/models/controlnet.py
@@ -11,860 +11,32 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-from dataclasses import dataclass
-from typing import Any, Dict, List, Optional, Tuple, Union
-
-import torch
-from torch import nn
-from torch.nn import functional as F
-
-from ..configuration_utils import ConfigMixin, register_to_config
-from ..loaders.single_file_model import FromOriginalModelMixin
-from ..utils import BaseOutput, logging
-from .attention_processor import (
- ADDED_KV_ATTENTION_PROCESSORS,
- CROSS_ATTENTION_PROCESSORS,
- AttentionProcessor,
- AttnAddedKVProcessor,
- AttnProcessor,
+from ..utils import deprecate
+from .controlnets.controlnet import ( # noqa
+ BaseOutput,
+ ControlNetConditioningEmbedding,
+ ControlNetModel,
+ ControlNetOutput,
+ zero_module,
)
-from .embeddings import TextImageProjection, TextImageTimeEmbedding, TextTimeEmbedding, TimestepEmbedding, Timesteps
-from .modeling_utils import ModelMixin
-from .unets.unet_2d_blocks import (
- CrossAttnDownBlock2D,
- DownBlock2D,
- UNetMidBlock2D,
- UNetMidBlock2DCrossAttn,
- get_down_block,
-)
-from .unets.unet_2d_condition import UNet2DConditionModel
-
-
-logger = logging.get_logger(__name__) # pylint: disable=invalid-name
-
-
-@dataclass
-class ControlNetOutput(BaseOutput):
- """
- The output of [`ControlNetModel`].
-
- Args:
- down_block_res_samples (`tuple[torch.Tensor]`):
- A tuple of downsample activations at different resolutions for each downsampling block. Each tensor should
- be of shape `(batch_size, channel * resolution, height //resolution, width // resolution)`. Output can be
- used to condition the original UNet's downsampling activations.
- mid_down_block_re_sample (`torch.Tensor`):
- The activation of the middle block (the lowest sample resolution). Each tensor should be of shape
- `(batch_size, channel * lowest_resolution, height // lowest_resolution, width // lowest_resolution)`.
- Output can be used to condition the original UNet's middle block activation.
- """
-
- down_block_res_samples: Tuple[torch.Tensor]
- mid_block_res_sample: torch.Tensor
-
-
-class ControlNetConditioningEmbedding(nn.Module):
- """
- Quoting from https://arxiv.org/abs/2302.05543: "Stable Diffusion uses a pre-processing method similar to VQ-GAN
- [11] to convert the entire dataset of 512 × 512 images into smaller 64 × 64 “latent images” for stabilized
- training. This requires ControlNets to convert image-based conditions to 64 × 64 feature space to match the
- convolution size. We use a tiny network E(·) of four convolution layers with 4 × 4 kernels and 2 × 2 strides
- (activated by ReLU, channels are 16, 32, 64, 128, initialized with Gaussian weights, trained jointly with the full
- model) to encode image-space conditions ... into feature maps ..."
- """
-
- def __init__(
- self,
- conditioning_embedding_channels: int,
- conditioning_channels: int = 3,
- block_out_channels: Tuple[int, ...] = (16, 32, 96, 256),
- ):
- super().__init__()
-
- self.conv_in = nn.Conv2d(conditioning_channels, block_out_channels[0], kernel_size=3, padding=1)
-
- self.blocks = nn.ModuleList([])
-
- for i in range(len(block_out_channels) - 1):
- channel_in = block_out_channels[i]
- channel_out = block_out_channels[i + 1]
- self.blocks.append(nn.Conv2d(channel_in, channel_in, kernel_size=3, padding=1))
- self.blocks.append(nn.Conv2d(channel_in, channel_out, kernel_size=3, padding=1, stride=2))
-
- self.conv_out = zero_module(
- nn.Conv2d(block_out_channels[-1], conditioning_embedding_channels, kernel_size=3, padding=1)
- )
-
- def forward(self, conditioning):
- embedding = self.conv_in(conditioning)
- embedding = F.silu(embedding)
-
- for block in self.blocks:
- embedding = block(embedding)
- embedding = F.silu(embedding)
-
- embedding = self.conv_out(embedding)
-
- return embedding
-
-
-class ControlNetModel(ModelMixin, ConfigMixin, FromOriginalModelMixin):
- """
- A ControlNet model.
-
- Args:
- in_channels (`int`, defaults to 4):
- The number of channels in the input sample.
- flip_sin_to_cos (`bool`, defaults to `True`):
- Whether to flip the sin to cos in the time embedding.
- freq_shift (`int`, defaults to 0):
- The frequency shift to apply to the time embedding.
- down_block_types (`tuple[str]`, defaults to `("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")`):
- The tuple of downsample blocks to use.
- only_cross_attention (`Union[bool, Tuple[bool]]`, defaults to `False`):
- block_out_channels (`tuple[int]`, defaults to `(320, 640, 1280, 1280)`):
- The tuple of output channels for each block.
- layers_per_block (`int`, defaults to 2):
- The number of layers per block.
- downsample_padding (`int`, defaults to 1):
- The padding to use for the downsampling convolution.
- mid_block_scale_factor (`float`, defaults to 1):
- The scale factor to use for the mid block.
- act_fn (`str`, defaults to "silu"):
- The activation function to use.
- norm_num_groups (`int`, *optional*, defaults to 32):
- The number of groups to use for the normalization. If None, normalization and activation layers is skipped
- in post-processing.
- norm_eps (`float`, defaults to 1e-5):
- The epsilon to use for the normalization.
- cross_attention_dim (`int`, defaults to 1280):
- The dimension of the cross attention features.
- transformer_layers_per_block (`int` or `Tuple[int]`, *optional*, defaults to 1):
- The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`]. Only relevant for
- [`~models.unet_2d_blocks.CrossAttnDownBlock2D`], [`~models.unet_2d_blocks.CrossAttnUpBlock2D`],
- [`~models.unet_2d_blocks.UNetMidBlock2DCrossAttn`].
- encoder_hid_dim (`int`, *optional*, defaults to None):
- If `encoder_hid_dim_type` is defined, `encoder_hidden_states` will be projected from `encoder_hid_dim`
- dimension to `cross_attention_dim`.
- encoder_hid_dim_type (`str`, *optional*, defaults to `None`):
- If given, the `encoder_hidden_states` and potentially other embeddings are down-projected to text
- embeddings of dimension `cross_attention` according to `encoder_hid_dim_type`.
- attention_head_dim (`Union[int, Tuple[int]]`, defaults to 8):
- The dimension of the attention heads.
- use_linear_projection (`bool`, defaults to `False`):
- class_embed_type (`str`, *optional*, defaults to `None`):
- The type of class embedding to use which is ultimately summed with the time embeddings. Choose from None,
- `"timestep"`, `"identity"`, `"projection"`, or `"simple_projection"`.
- addition_embed_type (`str`, *optional*, defaults to `None`):
- Configures an optional embedding which will be summed with the time embeddings. Choose from `None` or
- "text". "text" will use the `TextTimeEmbedding` layer.
- num_class_embeds (`int`, *optional*, defaults to 0):
- Input dimension of the learnable embedding matrix to be projected to `time_embed_dim`, when performing
- class conditioning with `class_embed_type` equal to `None`.
- upcast_attention (`bool`, defaults to `False`):
- resnet_time_scale_shift (`str`, defaults to `"default"`):
- Time scale shift config for ResNet blocks (see `ResnetBlock2D`). Choose from `default` or `scale_shift`.
- projection_class_embeddings_input_dim (`int`, *optional*, defaults to `None`):
- The dimension of the `class_labels` input when `class_embed_type="projection"`. Required when
- `class_embed_type="projection"`.
- controlnet_conditioning_channel_order (`str`, defaults to `"rgb"`):
- The channel order of conditional image. Will convert to `rgb` if it's `bgr`.
- conditioning_embedding_out_channels (`tuple[int]`, *optional*, defaults to `(16, 32, 96, 256)`):
- The tuple of output channel for each block in the `conditioning_embedding` layer.
- global_pool_conditions (`bool`, defaults to `False`):
- TODO(Patrick) - unused parameter.
- addition_embed_type_num_heads (`int`, defaults to 64):
- The number of heads to use for the `TextTimeEmbedding` layer.
- """
-
- _supports_gradient_checkpointing = True
-
- @register_to_config
- def __init__(
- self,
- in_channels: int = 4,
- conditioning_channels: int = 3,
- flip_sin_to_cos: bool = True,
- freq_shift: int = 0,
- down_block_types: Tuple[str, ...] = (
- "CrossAttnDownBlock2D",
- "CrossAttnDownBlock2D",
- "CrossAttnDownBlock2D",
- "DownBlock2D",
- ),
- mid_block_type: Optional[str] = "UNetMidBlock2DCrossAttn",
- only_cross_attention: Union[bool, Tuple[bool]] = False,
- block_out_channels: Tuple[int, ...] = (320, 640, 1280, 1280),
- layers_per_block: int = 2,
- downsample_padding: int = 1,
- mid_block_scale_factor: float = 1,
- act_fn: str = "silu",
- norm_num_groups: Optional[int] = 32,
- norm_eps: float = 1e-5,
- cross_attention_dim: int = 1280,
- transformer_layers_per_block: Union[int, Tuple[int, ...]] = 1,
- encoder_hid_dim: Optional[int] = None,
- encoder_hid_dim_type: Optional[str] = None,
- attention_head_dim: Union[int, Tuple[int, ...]] = 8,
- num_attention_heads: Optional[Union[int, Tuple[int, ...]]] = None,
- use_linear_projection: bool = False,
- class_embed_type: Optional[str] = None,
- addition_embed_type: Optional[str] = None,
- addition_time_embed_dim: Optional[int] = None,
- num_class_embeds: Optional[int] = None,
- upcast_attention: bool = False,
- resnet_time_scale_shift: str = "default",
- projection_class_embeddings_input_dim: Optional[int] = None,
- controlnet_conditioning_channel_order: str = "rgb",
- conditioning_embedding_out_channels: Optional[Tuple[int, ...]] = (16, 32, 96, 256),
- global_pool_conditions: bool = False,
- addition_embed_type_num_heads: int = 64,
- ):
- super().__init__()
-
- # If `num_attention_heads` is not defined (which is the case for most models)
- # it will default to `attention_head_dim`. This looks weird upon first reading it and it is.
- # The reason for this behavior is to correct for incorrectly named variables that were introduced
- # when this library was created. The incorrect naming was only discovered much later in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131
- # Changing `attention_head_dim` to `num_attention_heads` for 40,000+ configurations is too backwards breaking
- # which is why we correct for the naming here.
- num_attention_heads = num_attention_heads or attention_head_dim
-
- # Check inputs
- if len(block_out_channels) != len(down_block_types):
- raise ValueError(
- f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}."
- )
-
- if not isinstance(only_cross_attention, bool) and len(only_cross_attention) != len(down_block_types):
- raise ValueError(
- f"Must provide the same number of `only_cross_attention` as `down_block_types`. `only_cross_attention`: {only_cross_attention}. `down_block_types`: {down_block_types}."
- )
-
- if not isinstance(num_attention_heads, int) and len(num_attention_heads) != len(down_block_types):
- raise ValueError(
- f"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}."
- )
-
- if isinstance(transformer_layers_per_block, int):
- transformer_layers_per_block = [transformer_layers_per_block] * len(down_block_types)
-
- # input
- conv_in_kernel = 3
- conv_in_padding = (conv_in_kernel - 1) // 2
- self.conv_in = nn.Conv2d(
- in_channels, block_out_channels[0], kernel_size=conv_in_kernel, padding=conv_in_padding
- )
-
- # time
- time_embed_dim = block_out_channels[0] * 4
- self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift)
- timestep_input_dim = block_out_channels[0]
- self.time_embedding = TimestepEmbedding(
- timestep_input_dim,
- time_embed_dim,
- act_fn=act_fn,
- )
-
- if encoder_hid_dim_type is None and encoder_hid_dim is not None:
- encoder_hid_dim_type = "text_proj"
- self.register_to_config(encoder_hid_dim_type=encoder_hid_dim_type)
- logger.info("encoder_hid_dim_type defaults to 'text_proj' as `encoder_hid_dim` is defined.")
-
- if encoder_hid_dim is None and encoder_hid_dim_type is not None:
- raise ValueError(
- f"`encoder_hid_dim` has to be defined when `encoder_hid_dim_type` is set to {encoder_hid_dim_type}."
- )
-
- if encoder_hid_dim_type == "text_proj":
- self.encoder_hid_proj = nn.Linear(encoder_hid_dim, cross_attention_dim)
- elif encoder_hid_dim_type == "text_image_proj":
- # image_embed_dim DOESN'T have to be `cross_attention_dim`. To not clutter the __init__ too much
- # they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use
- # case when `addition_embed_type == "text_image_proj"` (Kandinsky 2.1)`
- self.encoder_hid_proj = TextImageProjection(
- text_embed_dim=encoder_hid_dim,
- image_embed_dim=cross_attention_dim,
- cross_attention_dim=cross_attention_dim,
- )
-
- elif encoder_hid_dim_type is not None:
- raise ValueError(
- f"encoder_hid_dim_type: {encoder_hid_dim_type} must be None, 'text_proj' or 'text_image_proj'."
- )
- else:
- self.encoder_hid_proj = None
-
- # class embedding
- if class_embed_type is None and num_class_embeds is not None:
- self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim)
- elif class_embed_type == "timestep":
- self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim)
- elif class_embed_type == "identity":
- self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim)
- elif class_embed_type == "projection":
- if projection_class_embeddings_input_dim is None:
- raise ValueError(
- "`class_embed_type`: 'projection' requires `projection_class_embeddings_input_dim` be set"
- )
- # The projection `class_embed_type` is the same as the timestep `class_embed_type` except
- # 1. the `class_labels` inputs are not first converted to sinusoidal embeddings
- # 2. it projects from an arbitrary input dimension.
- #
- # Note that `TimestepEmbedding` is quite general, being mainly linear layers and activations.
- # When used for embedding actual timesteps, the timesteps are first converted to sinusoidal embeddings.
- # As a result, `TimestepEmbedding` can be passed arbitrary vectors.
- self.class_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim)
- else:
- self.class_embedding = None
-
- if addition_embed_type == "text":
- if encoder_hid_dim is not None:
- text_time_embedding_from_dim = encoder_hid_dim
- else:
- text_time_embedding_from_dim = cross_attention_dim
-
- self.add_embedding = TextTimeEmbedding(
- text_time_embedding_from_dim, time_embed_dim, num_heads=addition_embed_type_num_heads
- )
- elif addition_embed_type == "text_image":
- # text_embed_dim and image_embed_dim DON'T have to be `cross_attention_dim`. To not clutter the __init__ too much
- # they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use
- # case when `addition_embed_type == "text_image"` (Kandinsky 2.1)`
- self.add_embedding = TextImageTimeEmbedding(
- text_embed_dim=cross_attention_dim, image_embed_dim=cross_attention_dim, time_embed_dim=time_embed_dim
- )
- elif addition_embed_type == "text_time":
- self.add_time_proj = Timesteps(addition_time_embed_dim, flip_sin_to_cos, freq_shift)
- self.add_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim)
-
- elif addition_embed_type is not None:
- raise ValueError(f"addition_embed_type: {addition_embed_type} must be None, 'text' or 'text_image'.")
-
- # control net conditioning embedding
- self.controlnet_cond_embedding = ControlNetConditioningEmbedding(
- conditioning_embedding_channels=block_out_channels[0],
- block_out_channels=conditioning_embedding_out_channels,
- conditioning_channels=conditioning_channels,
- )
-
- self.down_blocks = nn.ModuleList([])
- self.controlnet_down_blocks = nn.ModuleList([])
-
- if isinstance(only_cross_attention, bool):
- only_cross_attention = [only_cross_attention] * len(down_block_types)
-
- if isinstance(attention_head_dim, int):
- attention_head_dim = (attention_head_dim,) * len(down_block_types)
-
- if isinstance(num_attention_heads, int):
- num_attention_heads = (num_attention_heads,) * len(down_block_types)
-
- # down
- output_channel = block_out_channels[0]
-
- controlnet_block = nn.Conv2d(output_channel, output_channel, kernel_size=1)
- controlnet_block = zero_module(controlnet_block)
- self.controlnet_down_blocks.append(controlnet_block)
-
- for i, down_block_type in enumerate(down_block_types):
- input_channel = output_channel
- output_channel = block_out_channels[i]
- is_final_block = i == len(block_out_channels) - 1
-
- down_block = get_down_block(
- down_block_type,
- num_layers=layers_per_block,
- transformer_layers_per_block=transformer_layers_per_block[i],
- in_channels=input_channel,
- out_channels=output_channel,
- temb_channels=time_embed_dim,
- add_downsample=not is_final_block,
- resnet_eps=norm_eps,
- resnet_act_fn=act_fn,
- resnet_groups=norm_num_groups,
- cross_attention_dim=cross_attention_dim,
- num_attention_heads=num_attention_heads[i],
- attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel,
- downsample_padding=downsample_padding,
- use_linear_projection=use_linear_projection,
- only_cross_attention=only_cross_attention[i],
- upcast_attention=upcast_attention,
- resnet_time_scale_shift=resnet_time_scale_shift,
- )
- self.down_blocks.append(down_block)
-
- for _ in range(layers_per_block):
- controlnet_block = nn.Conv2d(output_channel, output_channel, kernel_size=1)
- controlnet_block = zero_module(controlnet_block)
- self.controlnet_down_blocks.append(controlnet_block)
-
- if not is_final_block:
- controlnet_block = nn.Conv2d(output_channel, output_channel, kernel_size=1)
- controlnet_block = zero_module(controlnet_block)
- self.controlnet_down_blocks.append(controlnet_block)
-
- # mid
- mid_block_channel = block_out_channels[-1]
-
- controlnet_block = nn.Conv2d(mid_block_channel, mid_block_channel, kernel_size=1)
- controlnet_block = zero_module(controlnet_block)
- self.controlnet_mid_block = controlnet_block
-
- if mid_block_type == "UNetMidBlock2DCrossAttn":
- self.mid_block = UNetMidBlock2DCrossAttn(
- transformer_layers_per_block=transformer_layers_per_block[-1],
- in_channels=mid_block_channel,
- temb_channels=time_embed_dim,
- resnet_eps=norm_eps,
- resnet_act_fn=act_fn,
- output_scale_factor=mid_block_scale_factor,
- resnet_time_scale_shift=resnet_time_scale_shift,
- cross_attention_dim=cross_attention_dim,
- num_attention_heads=num_attention_heads[-1],
- resnet_groups=norm_num_groups,
- use_linear_projection=use_linear_projection,
- upcast_attention=upcast_attention,
- )
- elif mid_block_type == "UNetMidBlock2D":
- self.mid_block = UNetMidBlock2D(
- in_channels=block_out_channels[-1],
- temb_channels=time_embed_dim,
- num_layers=0,
- resnet_eps=norm_eps,
- resnet_act_fn=act_fn,
- output_scale_factor=mid_block_scale_factor,
- resnet_groups=norm_num_groups,
- resnet_time_scale_shift=resnet_time_scale_shift,
- add_attention=False,
- )
- else:
- raise ValueError(f"unknown mid_block_type : {mid_block_type}")
-
- @classmethod
- def from_unet(
- cls,
- unet: UNet2DConditionModel,
- controlnet_conditioning_channel_order: str = "rgb",
- conditioning_embedding_out_channels: Optional[Tuple[int, ...]] = (16, 32, 96, 256),
- load_weights_from_unet: bool = True,
- conditioning_channels: int = 3,
- ):
- r"""
- Instantiate a [`ControlNetModel`] from [`UNet2DConditionModel`].
-
- Parameters:
- unet (`UNet2DConditionModel`):
- The UNet model weights to copy to the [`ControlNetModel`]. All configuration options are also copied
- where applicable.
- """
- transformer_layers_per_block = (
- unet.config.transformer_layers_per_block if "transformer_layers_per_block" in unet.config else 1
- )
- encoder_hid_dim = unet.config.encoder_hid_dim if "encoder_hid_dim" in unet.config else None
- encoder_hid_dim_type = unet.config.encoder_hid_dim_type if "encoder_hid_dim_type" in unet.config else None
- addition_embed_type = unet.config.addition_embed_type if "addition_embed_type" in unet.config else None
- addition_time_embed_dim = (
- unet.config.addition_time_embed_dim if "addition_time_embed_dim" in unet.config else None
- )
-
- controlnet = cls(
- encoder_hid_dim=encoder_hid_dim,
- encoder_hid_dim_type=encoder_hid_dim_type,
- addition_embed_type=addition_embed_type,
- addition_time_embed_dim=addition_time_embed_dim,
- transformer_layers_per_block=transformer_layers_per_block,
- in_channels=unet.config.in_channels,
- flip_sin_to_cos=unet.config.flip_sin_to_cos,
- freq_shift=unet.config.freq_shift,
- down_block_types=unet.config.down_block_types,
- only_cross_attention=unet.config.only_cross_attention,
- block_out_channels=unet.config.block_out_channels,
- layers_per_block=unet.config.layers_per_block,
- downsample_padding=unet.config.downsample_padding,
- mid_block_scale_factor=unet.config.mid_block_scale_factor,
- act_fn=unet.config.act_fn,
- norm_num_groups=unet.config.norm_num_groups,
- norm_eps=unet.config.norm_eps,
- cross_attention_dim=unet.config.cross_attention_dim,
- attention_head_dim=unet.config.attention_head_dim,
- num_attention_heads=unet.config.num_attention_heads,
- use_linear_projection=unet.config.use_linear_projection,
- class_embed_type=unet.config.class_embed_type,
- num_class_embeds=unet.config.num_class_embeds,
- upcast_attention=unet.config.upcast_attention,
- resnet_time_scale_shift=unet.config.resnet_time_scale_shift,
- projection_class_embeddings_input_dim=unet.config.projection_class_embeddings_input_dim,
- mid_block_type=unet.config.mid_block_type,
- controlnet_conditioning_channel_order=controlnet_conditioning_channel_order,
- conditioning_embedding_out_channels=conditioning_embedding_out_channels,
- conditioning_channels=conditioning_channels,
- )
-
- if load_weights_from_unet:
- controlnet.conv_in.load_state_dict(unet.conv_in.state_dict())
- controlnet.time_proj.load_state_dict(unet.time_proj.state_dict())
- controlnet.time_embedding.load_state_dict(unet.time_embedding.state_dict())
-
- if controlnet.class_embedding:
- controlnet.class_embedding.load_state_dict(unet.class_embedding.state_dict())
-
- if hasattr(controlnet, "add_embedding"):
- controlnet.add_embedding.load_state_dict(unet.add_embedding.state_dict())
-
- controlnet.down_blocks.load_state_dict(unet.down_blocks.state_dict())
- controlnet.mid_block.load_state_dict(unet.mid_block.state_dict())
-
- return controlnet
-
- @property
- # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
- def attn_processors(self) -> Dict[str, AttentionProcessor]:
- r"""
- Returns:
- `dict` of attention processors: A dictionary containing all attention processors used in the model with
- indexed by its weight name.
- """
- # set recursively
- processors = {}
-
- def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
- if hasattr(module, "get_processor"):
- processors[f"{name}.processor"] = module.get_processor()
-
- for sub_name, child in module.named_children():
- fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
-
- return processors
-
- for name, module in self.named_children():
- fn_recursive_add_processors(name, module, processors)
-
- return processors
-
- # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
- def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
- r"""
- Sets the attention processor to use to compute attention.
-
- Parameters:
- processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
- The instantiated processor class or a dictionary of processor classes that will be set as the processor
- for **all** `Attention` layers.
-
- If `processor` is a dict, the key needs to define the path to the corresponding cross attention
- processor. This is strongly recommended when setting trainable attention processors.
-
- """
- count = len(self.attn_processors.keys())
-
- if isinstance(processor, dict) and len(processor) != count:
- raise ValueError(
- f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
- f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
- )
-
- def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
- if hasattr(module, "set_processor"):
- if not isinstance(processor, dict):
- module.set_processor(processor)
- else:
- module.set_processor(processor.pop(f"{name}.processor"))
-
- for sub_name, child in module.named_children():
- fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
-
- for name, module in self.named_children():
- fn_recursive_attn_processor(name, module, processor)
-
- # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor
- def set_default_attn_processor(self):
- """
- Disables custom attention processors and sets the default attention implementation.
- """
- if all(proc.__class__ in ADDED_KV_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
- processor = AttnAddedKVProcessor()
- elif all(proc.__class__ in CROSS_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
- processor = AttnProcessor()
- else:
- raise ValueError(
- f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}"
- )
-
- self.set_attn_processor(processor)
-
- # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attention_slice
- def set_attention_slice(self, slice_size: Union[str, int, List[int]]) -> None:
- r"""
- Enable sliced attention computation.
-
- When this option is enabled, the attention module splits the input tensor in slices to compute attention in
- several steps. This is useful for saving some memory in exchange for a small decrease in speed.
-
- Args:
- slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`):
- When `"auto"`, input to the attention heads is halved, so attention is computed in two steps. If
- `"max"`, maximum amount of memory is saved by running only one slice at a time. If a number is
- provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim`
- must be a multiple of `slice_size`.
- """
- sliceable_head_dims = []
-
- def fn_recursive_retrieve_sliceable_dims(module: torch.nn.Module):
- if hasattr(module, "set_attention_slice"):
- sliceable_head_dims.append(module.sliceable_head_dim)
-
- for child in module.children():
- fn_recursive_retrieve_sliceable_dims(child)
-
- # retrieve number of attention layers
- for module in self.children():
- fn_recursive_retrieve_sliceable_dims(module)
-
- num_sliceable_layers = len(sliceable_head_dims)
-
- if slice_size == "auto":
- # half the attention head size is usually a good trade-off between
- # speed and memory
- slice_size = [dim // 2 for dim in sliceable_head_dims]
- elif slice_size == "max":
- # make smallest slice possible
- slice_size = num_sliceable_layers * [1]
-
- slice_size = num_sliceable_layers * [slice_size] if not isinstance(slice_size, list) else slice_size
-
- if len(slice_size) != len(sliceable_head_dims):
- raise ValueError(
- f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different"
- f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}."
- )
-
- for i in range(len(slice_size)):
- size = slice_size[i]
- dim = sliceable_head_dims[i]
- if size is not None and size > dim:
- raise ValueError(f"size {size} has to be smaller or equal to {dim}.")
-
- # Recursively walk through all the children.
- # Any children which exposes the set_attention_slice method
- # gets the message
- def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[int]):
- if hasattr(module, "set_attention_slice"):
- module.set_attention_slice(slice_size.pop())
-
- for child in module.children():
- fn_recursive_set_attention_slice(child, slice_size)
-
- reversed_slice_size = list(reversed(slice_size))
- for module in self.children():
- fn_recursive_set_attention_slice(module, reversed_slice_size)
-
- def _set_gradient_checkpointing(self, module, value: bool = False) -> None:
- if isinstance(module, (CrossAttnDownBlock2D, DownBlock2D)):
- module.gradient_checkpointing = value
-
- def forward(
- self,
- sample: torch.Tensor,
- timestep: Union[torch.Tensor, float, int],
- encoder_hidden_states: torch.Tensor,
- controlnet_cond: torch.Tensor,
- conditioning_scale: float = 1.0,
- class_labels: Optional[torch.Tensor] = None,
- timestep_cond: Optional[torch.Tensor] = None,
- attention_mask: Optional[torch.Tensor] = None,
- added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,
- cross_attention_kwargs: Optional[Dict[str, Any]] = None,
- guess_mode: bool = False,
- return_dict: bool = True,
- ) -> Union[ControlNetOutput, Tuple[Tuple[torch.Tensor, ...], torch.Tensor]]:
- """
- The [`ControlNetModel`] forward method.
-
- Args:
- sample (`torch.Tensor`):
- The noisy input tensor.
- timestep (`Union[torch.Tensor, float, int]`):
- The number of timesteps to denoise an input.
- encoder_hidden_states (`torch.Tensor`):
- The encoder hidden states.
- controlnet_cond (`torch.Tensor`):
- The conditional input tensor of shape `(batch_size, sequence_length, hidden_size)`.
- conditioning_scale (`float`, defaults to `1.0`):
- The scale factor for ControlNet outputs.
- class_labels (`torch.Tensor`, *optional*, defaults to `None`):
- Optional class labels for conditioning. Their embeddings will be summed with the timestep embeddings.
- timestep_cond (`torch.Tensor`, *optional*, defaults to `None`):
- Additional conditional embeddings for timestep. If provided, the embeddings will be summed with the
- timestep_embedding passed through the `self.time_embedding` layer to obtain the final timestep
- embeddings.
- attention_mask (`torch.Tensor`, *optional*, defaults to `None`):
- An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. If `1` the mask
- is kept, otherwise if `0` it is discarded. Mask will be converted into a bias, which adds large
- negative values to the attention scores corresponding to "discard" tokens.
- added_cond_kwargs (`dict`):
- Additional conditions for the Stable Diffusion XL UNet.
- cross_attention_kwargs (`dict[str]`, *optional*, defaults to `None`):
- A kwargs dictionary that if specified is passed along to the `AttnProcessor`.
- guess_mode (`bool`, defaults to `False`):
- In this mode, the ControlNet encoder tries its best to recognize the input content of the input even if
- you remove all prompts. A `guidance_scale` between 3.0 and 5.0 is recommended.
- return_dict (`bool`, defaults to `True`):
- Whether or not to return a [`~models.controlnet.ControlNetOutput`] instead of a plain tuple.
-
- Returns:
- [`~models.controlnet.ControlNetOutput`] **or** `tuple`:
- If `return_dict` is `True`, a [`~models.controlnet.ControlNetOutput`] is returned, otherwise a tuple is
- returned where the first element is the sample tensor.
- """
- # check channel order
- channel_order = self.config.controlnet_conditioning_channel_order
-
- if channel_order == "rgb":
- # in rgb order by default
- ...
- elif channel_order == "bgr":
- controlnet_cond = torch.flip(controlnet_cond, dims=[1])
- else:
- raise ValueError(f"unknown `controlnet_conditioning_channel_order`: {channel_order}")
-
- # prepare attention_mask
- if attention_mask is not None:
- attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0
- attention_mask = attention_mask.unsqueeze(1)
-
- # 1. time
- timesteps = timestep
- if not torch.is_tensor(timesteps):
- # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
- # This would be a good case for the `match` statement (Python 3.10+)
- is_mps = sample.device.type == "mps"
- if isinstance(timestep, float):
- dtype = torch.float32 if is_mps else torch.float64
- else:
- dtype = torch.int32 if is_mps else torch.int64
- timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
- elif len(timesteps.shape) == 0:
- timesteps = timesteps[None].to(sample.device)
-
- # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
- timesteps = timesteps.expand(sample.shape[0])
-
- t_emb = self.time_proj(timesteps)
-
- # timesteps does not contain any weights and will always return f32 tensors
- # but time_embedding might actually be running in fp16. so we need to cast here.
- # there might be better ways to encapsulate this.
- t_emb = t_emb.to(dtype=sample.dtype)
-
- emb = self.time_embedding(t_emb, timestep_cond)
- aug_emb = None
-
- if self.class_embedding is not None:
- if class_labels is None:
- raise ValueError("class_labels should be provided when num_class_embeds > 0")
-
- if self.config.class_embed_type == "timestep":
- class_labels = self.time_proj(class_labels)
-
- class_emb = self.class_embedding(class_labels).to(dtype=self.dtype)
- emb = emb + class_emb
-
- if self.config.addition_embed_type is not None:
- if self.config.addition_embed_type == "text":
- aug_emb = self.add_embedding(encoder_hidden_states)
-
- elif self.config.addition_embed_type == "text_time":
- if "text_embeds" not in added_cond_kwargs:
- raise ValueError(
- f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `text_embeds` to be passed in `added_cond_kwargs`"
- )
- text_embeds = added_cond_kwargs.get("text_embeds")
- if "time_ids" not in added_cond_kwargs:
- raise ValueError(
- f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `time_ids` to be passed in `added_cond_kwargs`"
- )
- time_ids = added_cond_kwargs.get("time_ids")
- time_embeds = self.add_time_proj(time_ids.flatten())
- time_embeds = time_embeds.reshape((text_embeds.shape[0], -1))
-
- add_embeds = torch.concat([text_embeds, time_embeds], dim=-1)
- add_embeds = add_embeds.to(emb.dtype)
- aug_emb = self.add_embedding(add_embeds)
-
- emb = emb + aug_emb if aug_emb is not None else emb
-
- # 2. pre-process
- sample = self.conv_in(sample)
-
- controlnet_cond = self.controlnet_cond_embedding(controlnet_cond)
- sample = sample + controlnet_cond
-
- # 3. down
- down_block_res_samples = (sample,)
- for downsample_block in self.down_blocks:
- if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
- sample, res_samples = downsample_block(
- hidden_states=sample,
- temb=emb,
- encoder_hidden_states=encoder_hidden_states,
- attention_mask=attention_mask,
- cross_attention_kwargs=cross_attention_kwargs,
- )
- else:
- sample, res_samples = downsample_block(hidden_states=sample, temb=emb)
-
- down_block_res_samples += res_samples
-
- # 4. mid
- if self.mid_block is not None:
- if hasattr(self.mid_block, "has_cross_attention") and self.mid_block.has_cross_attention:
- sample = self.mid_block(
- sample,
- emb,
- encoder_hidden_states=encoder_hidden_states,
- attention_mask=attention_mask,
- cross_attention_kwargs=cross_attention_kwargs,
- )
- else:
- sample = self.mid_block(sample, emb)
-
- # 5. Control net blocks
- controlnet_down_block_res_samples = ()
-
- for down_block_res_sample, controlnet_block in zip(down_block_res_samples, self.controlnet_down_blocks):
- down_block_res_sample = controlnet_block(down_block_res_sample)
- controlnet_down_block_res_samples = controlnet_down_block_res_samples + (down_block_res_sample,)
-
- down_block_res_samples = controlnet_down_block_res_samples
-
- mid_block_res_sample = self.controlnet_mid_block(sample)
- # 6. scaling
- if guess_mode and not self.config.global_pool_conditions:
- scales = torch.logspace(-1, 0, len(down_block_res_samples) + 1, device=sample.device) # 0.1 to 1.0
- scales = scales * conditioning_scale
- down_block_res_samples = [sample * scale for sample, scale in zip(down_block_res_samples, scales)]
- mid_block_res_sample = mid_block_res_sample * scales[-1] # last one
- else:
- down_block_res_samples = [sample * conditioning_scale for sample in down_block_res_samples]
- mid_block_res_sample = mid_block_res_sample * conditioning_scale
- if self.config.global_pool_conditions:
- down_block_res_samples = [
- torch.mean(sample, dim=(2, 3), keepdim=True) for sample in down_block_res_samples
- ]
- mid_block_res_sample = torch.mean(mid_block_res_sample, dim=(2, 3), keepdim=True)
+class ControlNetOutput(ControlNetOutput):
+ def __init__(self, *args, **kwargs):
+ deprecation_message = "Importing `ControlNetOutput` from `diffusers.models.controlnet` is deprecated and this will be removed in a future version. Please use `from diffusers.models.controlnets.controlnet import ControlNetOutput`, instead."
+ deprecate("ControlNetOutput", "0.34", deprecation_message)
+ super().__init__(*args, **kwargs)
- if not return_dict:
- return (down_block_res_samples, mid_block_res_sample)
- return ControlNetOutput(
- down_block_res_samples=down_block_res_samples, mid_block_res_sample=mid_block_res_sample
- )
+class ControlNetModel(ControlNetModel):
+ def __init__(self, *args, **kwargs):
+ deprecation_message = "Importing `ControlNetModel` from `diffusers.models.controlnet` is deprecated and this will be removed in a future version. Please use `from diffusers.models.controlnets.controlnet import ControlNetModel`, instead."
+ deprecate("ControlNetModel", "0.34", deprecation_message)
+ super().__init__(*args, **kwargs)
-def zero_module(module):
- for p in module.parameters():
- nn.init.zeros_(p)
- return module
+class ControlNetConditioningEmbedding(ControlNetConditioningEmbedding):
+ def __init__(self, *args, **kwargs):
+ deprecation_message = "Importing `ControlNetConditioningEmbedding` from `diffusers.models.controlnet` is deprecated and this will be removed in a future version. Please use `from diffusers.models.controlnets.controlnet import ControlNetConditioningEmbedding`, instead."
+ deprecate("ControlNetConditioningEmbedding", "0.34", deprecation_message)
+ super().__init__(*args, **kwargs)
diff --git a/src/diffusers/models/controlnet_flux.py b/src/diffusers/models/controlnet_flux.py
index 961e30155a3d..9b256239d712 100644
--- a/src/diffusers/models/controlnet_flux.py
+++ b/src/diffusers/models/controlnet_flux.py
@@ -12,525 +12,30 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from dataclasses import dataclass
-from typing import Any, Dict, List, Optional, Tuple, Union
-import torch
-import torch.nn as nn
-
-from ..configuration_utils import ConfigMixin, register_to_config
-from ..loaders import PeftAdapterMixin
-from ..models.attention_processor import AttentionProcessor
-from ..models.modeling_utils import ModelMixin
-from ..utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers
-from .controlnet import BaseOutput, ControlNetConditioningEmbedding, zero_module
-from .embeddings import CombinedTimestepGuidanceTextProjEmbeddings, CombinedTimestepTextProjEmbeddings, FluxPosEmbed
-from .modeling_outputs import Transformer2DModelOutput
-from .transformers.transformer_flux import FluxSingleTransformerBlock, FluxTransformerBlock
+from ..utils import deprecate, logging
+from .controlnets.controlnet_flux import FluxControlNetModel, FluxControlNetOutput, FluxMultiControlNetModel
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
-@dataclass
-class FluxControlNetOutput(BaseOutput):
- controlnet_block_samples: Tuple[torch.Tensor]
- controlnet_single_block_samples: Tuple[torch.Tensor]
-
-
-class FluxControlNetModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
- _supports_gradient_checkpointing = True
-
- @register_to_config
- def __init__(
- self,
- patch_size: int = 1,
- in_channels: int = 64,
- num_layers: int = 19,
- num_single_layers: int = 38,
- attention_head_dim: int = 128,
- num_attention_heads: int = 24,
- joint_attention_dim: int = 4096,
- pooled_projection_dim: int = 768,
- guidance_embeds: bool = False,
- axes_dims_rope: List[int] = [16, 56, 56],
- num_mode: int = None,
- conditioning_embedding_channels: int = None,
- ):
- super().__init__()
- self.out_channels = in_channels
- self.inner_dim = num_attention_heads * attention_head_dim
-
- self.pos_embed = FluxPosEmbed(theta=10000, axes_dim=axes_dims_rope)
- text_time_guidance_cls = (
- CombinedTimestepGuidanceTextProjEmbeddings if guidance_embeds else CombinedTimestepTextProjEmbeddings
- )
- self.time_text_embed = text_time_guidance_cls(
- embedding_dim=self.inner_dim, pooled_projection_dim=pooled_projection_dim
- )
-
- self.context_embedder = nn.Linear(joint_attention_dim, self.inner_dim)
- self.x_embedder = torch.nn.Linear(in_channels, self.inner_dim)
-
- self.transformer_blocks = nn.ModuleList(
- [
- FluxTransformerBlock(
- dim=self.inner_dim,
- num_attention_heads=num_attention_heads,
- attention_head_dim=attention_head_dim,
- )
- for i in range(num_layers)
- ]
- )
-
- self.single_transformer_blocks = nn.ModuleList(
- [
- FluxSingleTransformerBlock(
- dim=self.inner_dim,
- num_attention_heads=num_attention_heads,
- attention_head_dim=attention_head_dim,
- )
- for i in range(num_single_layers)
- ]
- )
-
- # controlnet_blocks
- self.controlnet_blocks = nn.ModuleList([])
- for _ in range(len(self.transformer_blocks)):
- self.controlnet_blocks.append(zero_module(nn.Linear(self.inner_dim, self.inner_dim)))
-
- self.controlnet_single_blocks = nn.ModuleList([])
- for _ in range(len(self.single_transformer_blocks)):
- self.controlnet_single_blocks.append(zero_module(nn.Linear(self.inner_dim, self.inner_dim)))
-
- self.union = num_mode is not None
- if self.union:
- self.controlnet_mode_embedder = nn.Embedding(num_mode, self.inner_dim)
-
- if conditioning_embedding_channels is not None:
- self.input_hint_block = ControlNetConditioningEmbedding(
- conditioning_embedding_channels=conditioning_embedding_channels, block_out_channels=(16, 16, 16, 16)
- )
- self.controlnet_x_embedder = torch.nn.Linear(in_channels, self.inner_dim)
- else:
- self.input_hint_block = None
- self.controlnet_x_embedder = zero_module(torch.nn.Linear(in_channels, self.inner_dim))
-
- self.gradient_checkpointing = False
-
- @property
- # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
- def attn_processors(self):
- r"""
- Returns:
- `dict` of attention processors: A dictionary containing all attention processors used in the model with
- indexed by its weight name.
- """
- # set recursively
- processors = {}
-
- def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
- if hasattr(module, "get_processor"):
- processors[f"{name}.processor"] = module.get_processor()
-
- for sub_name, child in module.named_children():
- fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
-
- return processors
-
- for name, module in self.named_children():
- fn_recursive_add_processors(name, module, processors)
-
- return processors
-
- # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
- def set_attn_processor(self, processor):
- r"""
- Sets the attention processor to use to compute attention.
-
- Parameters:
- processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
- The instantiated processor class or a dictionary of processor classes that will be set as the processor
- for **all** `Attention` layers.
-
- If `processor` is a dict, the key needs to define the path to the corresponding cross attention
- processor. This is strongly recommended when setting trainable attention processors.
-
- """
- count = len(self.attn_processors.keys())
-
- if isinstance(processor, dict) and len(processor) != count:
- raise ValueError(
- f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
- f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
- )
-
- def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
- if hasattr(module, "set_processor"):
- if not isinstance(processor, dict):
- module.set_processor(processor)
- else:
- module.set_processor(processor.pop(f"{name}.processor"))
-
- for sub_name, child in module.named_children():
- fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
-
- for name, module in self.named_children():
- fn_recursive_attn_processor(name, module, processor)
-
- def _set_gradient_checkpointing(self, module, value=False):
- if hasattr(module, "gradient_checkpointing"):
- module.gradient_checkpointing = value
-
- @classmethod
- def from_transformer(
- cls,
- transformer,
- num_layers: int = 4,
- num_single_layers: int = 10,
- attention_head_dim: int = 128,
- num_attention_heads: int = 24,
- load_weights_from_transformer=True,
- ):
- config = transformer.config
- config["num_layers"] = num_layers
- config["num_single_layers"] = num_single_layers
- config["attention_head_dim"] = attention_head_dim
- config["num_attention_heads"] = num_attention_heads
-
- controlnet = cls(**config)
-
- if load_weights_from_transformer:
- controlnet.pos_embed.load_state_dict(transformer.pos_embed.state_dict())
- controlnet.time_text_embed.load_state_dict(transformer.time_text_embed.state_dict())
- controlnet.context_embedder.load_state_dict(transformer.context_embedder.state_dict())
- controlnet.x_embedder.load_state_dict(transformer.x_embedder.state_dict())
- controlnet.transformer_blocks.load_state_dict(transformer.transformer_blocks.state_dict(), strict=False)
- controlnet.single_transformer_blocks.load_state_dict(
- transformer.single_transformer_blocks.state_dict(), strict=False
- )
-
- controlnet.controlnet_x_embedder = zero_module(controlnet.controlnet_x_embedder)
-
- return controlnet
-
- def forward(
- self,
- hidden_states: torch.Tensor,
- controlnet_cond: torch.Tensor,
- controlnet_mode: torch.Tensor = None,
- conditioning_scale: float = 1.0,
- encoder_hidden_states: torch.Tensor = None,
- pooled_projections: torch.Tensor = None,
- timestep: torch.LongTensor = None,
- img_ids: torch.Tensor = None,
- txt_ids: torch.Tensor = None,
- guidance: torch.Tensor = None,
- joint_attention_kwargs: Optional[Dict[str, Any]] = None,
- return_dict: bool = True,
- ) -> Union[torch.FloatTensor, Transformer2DModelOutput]:
- """
- The [`FluxTransformer2DModel`] forward method.
-
- Args:
- hidden_states (`torch.FloatTensor` of shape `(batch size, channel, height, width)`):
- Input `hidden_states`.
- controlnet_cond (`torch.Tensor`):
- The conditional input tensor of shape `(batch_size, sequence_length, hidden_size)`.
- controlnet_mode (`torch.Tensor`):
- The mode tensor of shape `(batch_size, 1)`.
- conditioning_scale (`float`, defaults to `1.0`):
- The scale factor for ControlNet outputs.
- encoder_hidden_states (`torch.FloatTensor` of shape `(batch size, sequence_len, embed_dims)`):
- Conditional embeddings (embeddings computed from the input conditions such as prompts) to use.
- pooled_projections (`torch.FloatTensor` of shape `(batch_size, projection_dim)`): Embeddings projected
- from the embeddings of input conditions.
- timestep ( `torch.LongTensor`):
- Used to indicate denoising step.
- block_controlnet_hidden_states: (`list` of `torch.Tensor`):
- A list of tensors that if specified are added to the residuals of transformer blocks.
- joint_attention_kwargs (`dict`, *optional*):
- A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
- `self.processor` in
- [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
- return_dict (`bool`, *optional*, defaults to `True`):
- Whether or not to return a [`~models.transformer_2d.Transformer2DModelOutput`] instead of a plain
- tuple.
-
- Returns:
- If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
- `tuple` where the first element is the sample tensor.
- """
- if joint_attention_kwargs is not None:
- joint_attention_kwargs = joint_attention_kwargs.copy()
- lora_scale = joint_attention_kwargs.pop("scale", 1.0)
- else:
- lora_scale = 1.0
-
- if USE_PEFT_BACKEND:
- # weight the lora layers by setting `lora_scale` for each PEFT layer
- scale_lora_layers(self, lora_scale)
- else:
- if joint_attention_kwargs is not None and joint_attention_kwargs.get("scale", None) is not None:
- logger.warning(
- "Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective."
- )
- hidden_states = self.x_embedder(hidden_states)
-
- if self.input_hint_block is not None:
- controlnet_cond = self.input_hint_block(controlnet_cond)
- batch_size, channels, height_pw, width_pw = controlnet_cond.shape
- height = height_pw // self.config.patch_size
- width = width_pw // self.config.patch_size
- controlnet_cond = controlnet_cond.reshape(
- batch_size, channels, height, self.config.patch_size, width, self.config.patch_size
- )
- controlnet_cond = controlnet_cond.permute(0, 2, 4, 1, 3, 5)
- controlnet_cond = controlnet_cond.reshape(batch_size, height * width, -1)
- # add
- hidden_states = hidden_states + self.controlnet_x_embedder(controlnet_cond)
-
- timestep = timestep.to(hidden_states.dtype) * 1000
- if guidance is not None:
- guidance = guidance.to(hidden_states.dtype) * 1000
- else:
- guidance = None
- temb = (
- self.time_text_embed(timestep, pooled_projections)
- if guidance is None
- else self.time_text_embed(timestep, guidance, pooled_projections)
- )
- encoder_hidden_states = self.context_embedder(encoder_hidden_states)
-
- if self.union:
- # union mode
- if controlnet_mode is None:
- raise ValueError("`controlnet_mode` cannot be `None` when applying ControlNet-Union")
- # union mode emb
- controlnet_mode_emb = self.controlnet_mode_embedder(controlnet_mode)
- encoder_hidden_states = torch.cat([controlnet_mode_emb, encoder_hidden_states], dim=1)
- txt_ids = torch.cat([txt_ids[:1], txt_ids], dim=0)
-
- if txt_ids.ndim == 3:
- logger.warning(
- "Passing `txt_ids` 3d torch.Tensor is deprecated."
- "Please remove the batch dimension and pass it as a 2d torch Tensor"
- )
- txt_ids = txt_ids[0]
- if img_ids.ndim == 3:
- logger.warning(
- "Passing `img_ids` 3d torch.Tensor is deprecated."
- "Please remove the batch dimension and pass it as a 2d torch Tensor"
- )
- img_ids = img_ids[0]
-
- ids = torch.cat((txt_ids, img_ids), dim=0)
- image_rotary_emb = self.pos_embed(ids)
-
- block_samples = ()
- for index_block, block in enumerate(self.transformer_blocks):
- if self.training and self.gradient_checkpointing:
-
- def create_custom_forward(module, return_dict=None):
- def custom_forward(*inputs):
- if return_dict is not None:
- return module(*inputs, return_dict=return_dict)
- else:
- return module(*inputs)
-
- return custom_forward
-
- ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
- encoder_hidden_states, hidden_states = torch.utils.checkpoint.checkpoint(
- create_custom_forward(block),
- hidden_states,
- encoder_hidden_states,
- temb,
- image_rotary_emb,
- **ckpt_kwargs,
- )
-
- else:
- encoder_hidden_states, hidden_states = block(
- hidden_states=hidden_states,
- encoder_hidden_states=encoder_hidden_states,
- temb=temb,
- image_rotary_emb=image_rotary_emb,
- )
- block_samples = block_samples + (hidden_states,)
-
- hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
-
- single_block_samples = ()
- for index_block, block in enumerate(self.single_transformer_blocks):
- if self.training and self.gradient_checkpointing:
-
- def create_custom_forward(module, return_dict=None):
- def custom_forward(*inputs):
- if return_dict is not None:
- return module(*inputs, return_dict=return_dict)
- else:
- return module(*inputs)
-
- return custom_forward
-
- ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
- hidden_states = torch.utils.checkpoint.checkpoint(
- create_custom_forward(block),
- hidden_states,
- temb,
- image_rotary_emb,
- **ckpt_kwargs,
- )
-
- else:
- hidden_states = block(
- hidden_states=hidden_states,
- temb=temb,
- image_rotary_emb=image_rotary_emb,
- )
- single_block_samples = single_block_samples + (hidden_states[:, encoder_hidden_states.shape[1] :],)
-
- # controlnet block
- controlnet_block_samples = ()
- for block_sample, controlnet_block in zip(block_samples, self.controlnet_blocks):
- block_sample = controlnet_block(block_sample)
- controlnet_block_samples = controlnet_block_samples + (block_sample,)
-
- controlnet_single_block_samples = ()
- for single_block_sample, controlnet_block in zip(single_block_samples, self.controlnet_single_blocks):
- single_block_sample = controlnet_block(single_block_sample)
- controlnet_single_block_samples = controlnet_single_block_samples + (single_block_sample,)
-
- # scaling
- controlnet_block_samples = [sample * conditioning_scale for sample in controlnet_block_samples]
- controlnet_single_block_samples = [sample * conditioning_scale for sample in controlnet_single_block_samples]
-
- controlnet_block_samples = None if len(controlnet_block_samples) == 0 else controlnet_block_samples
- controlnet_single_block_samples = (
- None if len(controlnet_single_block_samples) == 0 else controlnet_single_block_samples
- )
-
- if USE_PEFT_BACKEND:
- # remove `lora_scale` from each PEFT layer
- unscale_lora_layers(self, lora_scale)
-
- if not return_dict:
- return (controlnet_block_samples, controlnet_single_block_samples)
-
- return FluxControlNetOutput(
- controlnet_block_samples=controlnet_block_samples,
- controlnet_single_block_samples=controlnet_single_block_samples,
- )
-
-
-class FluxMultiControlNetModel(ModelMixin):
- r"""
- `FluxMultiControlNetModel` wrapper class for Multi-FluxControlNetModel
-
- This module is a wrapper for multiple instances of the `FluxControlNetModel`. The `forward()` API is designed to be
- compatible with `FluxControlNetModel`.
-
- Args:
- controlnets (`List[FluxControlNetModel]`):
- Provides additional conditioning to the unet during the denoising process. You must set multiple
- `FluxControlNetModel` as a list.
- """
-
- def __init__(self, controlnets):
- super().__init__()
- self.nets = nn.ModuleList(controlnets)
-
- def forward(
- self,
- hidden_states: torch.FloatTensor,
- controlnet_cond: List[torch.tensor],
- controlnet_mode: List[torch.tensor],
- conditioning_scale: List[float],
- encoder_hidden_states: torch.Tensor = None,
- pooled_projections: torch.Tensor = None,
- timestep: torch.LongTensor = None,
- img_ids: torch.Tensor = None,
- txt_ids: torch.Tensor = None,
- guidance: torch.Tensor = None,
- joint_attention_kwargs: Optional[Dict[str, Any]] = None,
- return_dict: bool = True,
- ) -> Union[FluxControlNetOutput, Tuple]:
- # ControlNet-Union with multiple conditions
- # only load one ControlNet for saving memories
- if len(self.nets) == 1 and self.nets[0].union:
- controlnet = self.nets[0]
-
- for i, (image, mode, scale) in enumerate(zip(controlnet_cond, controlnet_mode, conditioning_scale)):
- block_samples, single_block_samples = controlnet(
- hidden_states=hidden_states,
- controlnet_cond=image,
- controlnet_mode=mode[:, None],
- conditioning_scale=scale,
- timestep=timestep,
- guidance=guidance,
- pooled_projections=pooled_projections,
- encoder_hidden_states=encoder_hidden_states,
- txt_ids=txt_ids,
- img_ids=img_ids,
- joint_attention_kwargs=joint_attention_kwargs,
- return_dict=return_dict,
- )
-
- # merge samples
- if i == 0:
- control_block_samples = block_samples
- control_single_block_samples = single_block_samples
- else:
- control_block_samples = [
- control_block_sample + block_sample
- for control_block_sample, block_sample in zip(control_block_samples, block_samples)
- ]
+class FluxControlNetOutput(FluxControlNetOutput):
+ def __init__(self, *args, **kwargs):
+ deprecation_message = "Importing `FluxControlNetOutput` from `diffusers.models.controlnet_flux` is deprecated and this will be removed in a future version. Please use `from diffusers.models.controlnets.controlnet_flux import FluxControlNetOutput`, instead."
+ deprecate("FluxControlNetOutput", "0.34", deprecation_message)
+ super().__init__(*args, **kwargs)
- control_single_block_samples = [
- control_single_block_sample + block_sample
- for control_single_block_sample, block_sample in zip(
- control_single_block_samples, single_block_samples
- )
- ]
- # Regular Multi-ControlNets
- # load all ControlNets into memories
- else:
- for i, (image, mode, scale, controlnet) in enumerate(
- zip(controlnet_cond, controlnet_mode, conditioning_scale, self.nets)
- ):
- block_samples, single_block_samples = controlnet(
- hidden_states=hidden_states,
- controlnet_cond=image,
- controlnet_mode=mode[:, None],
- conditioning_scale=scale,
- timestep=timestep,
- guidance=guidance,
- pooled_projections=pooled_projections,
- encoder_hidden_states=encoder_hidden_states,
- txt_ids=txt_ids,
- img_ids=img_ids,
- joint_attention_kwargs=joint_attention_kwargs,
- return_dict=return_dict,
- )
+class FluxControlNetModel(FluxControlNetModel):
+ def __init__(self, *args, **kwargs):
+ deprecation_message = "Importing `FluxControlNetModel` from `diffusers.models.controlnet_flux` is deprecated and this will be removed in a future version. Please use `from diffusers.models.controlnets.controlnet_flux import FluxControlNetModel`, instead."
+ deprecate("FluxControlNetModel", "0.34", deprecation_message)
+ super().__init__(*args, **kwargs)
- # merge samples
- if i == 0:
- control_block_samples = block_samples
- control_single_block_samples = single_block_samples
- else:
- if block_samples is not None and control_block_samples is not None:
- control_block_samples = [
- control_block_sample + block_sample
- for control_block_sample, block_sample in zip(control_block_samples, block_samples)
- ]
- if single_block_samples is not None and control_single_block_samples is not None:
- control_single_block_samples = [
- control_single_block_sample + block_sample
- for control_single_block_sample, block_sample in zip(
- control_single_block_samples, single_block_samples
- )
- ]
- return control_block_samples, control_single_block_samples
+class FluxMultiControlNetModel(FluxMultiControlNetModel):
+ def __init__(self, *args, **kwargs):
+ deprecation_message = "Importing `FluxMultiControlNetModel` from `diffusers.models.controlnet_flux` is deprecated and this will be removed in a future version. Please use `from diffusers.models.controlnets.controlnet_flux import FluxMultiControlNetModel`, instead."
+ deprecate("FluxMultiControlNetModel", "0.34", deprecation_message)
+ super().__init__(*args, **kwargs)
diff --git a/src/diffusers/models/controlnet_sd3.py b/src/diffusers/models/controlnet_sd3.py
index 43b52a645a0d..5e70559e9ac4 100644
--- a/src/diffusers/models/controlnet_sd3.py
+++ b/src/diffusers/models/controlnet_sd3.py
@@ -13,410 +13,29 @@
# limitations under the License.
-from dataclasses import dataclass
-from typing import Any, Dict, List, Optional, Tuple, Union
-
-import torch
-import torch.nn as nn
-
-from ..configuration_utils import ConfigMixin, register_to_config
-from ..loaders import FromOriginalModelMixin, PeftAdapterMixin
-from ..models.attention import JointTransformerBlock
-from ..models.attention_processor import Attention, AttentionProcessor, FusedJointAttnProcessor2_0
-from ..models.modeling_outputs import Transformer2DModelOutput
-from ..models.modeling_utils import ModelMixin
-from ..utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers
-from .controlnet import BaseOutput, zero_module
-from .embeddings import CombinedTimestepTextProjEmbeddings, PatchEmbed
+from ..utils import deprecate, logging
+from .controlnets.controlnet_sd3 import SD3ControlNetModel, SD3ControlNetOutput, SD3MultiControlNetModel
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
-@dataclass
-class SD3ControlNetOutput(BaseOutput):
- controlnet_block_samples: Tuple[torch.Tensor]
-
-
-class SD3ControlNetModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin):
- _supports_gradient_checkpointing = True
-
- @register_to_config
- def __init__(
- self,
- sample_size: int = 128,
- patch_size: int = 2,
- in_channels: int = 16,
- num_layers: int = 18,
- attention_head_dim: int = 64,
- num_attention_heads: int = 18,
- joint_attention_dim: int = 4096,
- caption_projection_dim: int = 1152,
- pooled_projection_dim: int = 2048,
- out_channels: int = 16,
- pos_embed_max_size: int = 96,
- extra_conditioning_channels: int = 0,
- ):
- super().__init__()
- default_out_channels = in_channels
- self.out_channels = out_channels if out_channels is not None else default_out_channels
- self.inner_dim = num_attention_heads * attention_head_dim
-
- self.pos_embed = PatchEmbed(
- height=sample_size,
- width=sample_size,
- patch_size=patch_size,
- in_channels=in_channels,
- embed_dim=self.inner_dim,
- pos_embed_max_size=pos_embed_max_size,
- )
- self.time_text_embed = CombinedTimestepTextProjEmbeddings(
- embedding_dim=self.inner_dim, pooled_projection_dim=pooled_projection_dim
- )
- self.context_embedder = nn.Linear(joint_attention_dim, caption_projection_dim)
-
- # `attention_head_dim` is doubled to account for the mixing.
- # It needs to crafted when we get the actual checkpoints.
- self.transformer_blocks = nn.ModuleList(
- [
- JointTransformerBlock(
- dim=self.inner_dim,
- num_attention_heads=num_attention_heads,
- attention_head_dim=self.config.attention_head_dim,
- context_pre_only=False,
- )
- for i in range(num_layers)
- ]
- )
-
- # controlnet_blocks
- self.controlnet_blocks = nn.ModuleList([])
- for _ in range(len(self.transformer_blocks)):
- controlnet_block = nn.Linear(self.inner_dim, self.inner_dim)
- controlnet_block = zero_module(controlnet_block)
- self.controlnet_blocks.append(controlnet_block)
- pos_embed_input = PatchEmbed(
- height=sample_size,
- width=sample_size,
- patch_size=patch_size,
- in_channels=in_channels + extra_conditioning_channels,
- embed_dim=self.inner_dim,
- pos_embed_type=None,
- )
- self.pos_embed_input = zero_module(pos_embed_input)
-
- self.gradient_checkpointing = False
-
- # Copied from diffusers.models.unets.unet_3d_condition.UNet3DConditionModel.enable_forward_chunking
- def enable_forward_chunking(self, chunk_size: Optional[int] = None, dim: int = 0) -> None:
- """
- Sets the attention processor to use [feed forward
- chunking](https://huggingface.co/blog/reformer#2-chunked-feed-forward-layers).
-
- Parameters:
- chunk_size (`int`, *optional*):
- The chunk size of the feed-forward layers. If not specified, will run feed-forward layer individually
- over each tensor of dim=`dim`.
- dim (`int`, *optional*, defaults to `0`):
- The dimension over which the feed-forward computation should be chunked. Choose between dim=0 (batch)
- or dim=1 (sequence length).
- """
- if dim not in [0, 1]:
- raise ValueError(f"Make sure to set `dim` to either 0 or 1, not {dim}")
-
- # By default chunk size is 1
- chunk_size = chunk_size or 1
-
- def fn_recursive_feed_forward(module: torch.nn.Module, chunk_size: int, dim: int):
- if hasattr(module, "set_chunk_feed_forward"):
- module.set_chunk_feed_forward(chunk_size=chunk_size, dim=dim)
-
- for child in module.children():
- fn_recursive_feed_forward(child, chunk_size, dim)
-
- for module in self.children():
- fn_recursive_feed_forward(module, chunk_size, dim)
-
- @property
- # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
- def attn_processors(self) -> Dict[str, AttentionProcessor]:
- r"""
- Returns:
- `dict` of attention processors: A dictionary containing all attention processors used in the model with
- indexed by its weight name.
- """
- # set recursively
- processors = {}
-
- def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
- if hasattr(module, "get_processor"):
- processors[f"{name}.processor"] = module.get_processor()
-
- for sub_name, child in module.named_children():
- fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
-
- return processors
-
- for name, module in self.named_children():
- fn_recursive_add_processors(name, module, processors)
-
- return processors
-
- # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
- def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
- r"""
- Sets the attention processor to use to compute attention.
-
- Parameters:
- processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
- The instantiated processor class or a dictionary of processor classes that will be set as the processor
- for **all** `Attention` layers.
-
- If `processor` is a dict, the key needs to define the path to the corresponding cross attention
- processor. This is strongly recommended when setting trainable attention processors.
-
- """
- count = len(self.attn_processors.keys())
-
- if isinstance(processor, dict) and len(processor) != count:
- raise ValueError(
- f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
- f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
- )
-
- def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
- if hasattr(module, "set_processor"):
- if not isinstance(processor, dict):
- module.set_processor(processor)
- else:
- module.set_processor(processor.pop(f"{name}.processor"))
-
- for sub_name, child in module.named_children():
- fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
-
- for name, module in self.named_children():
- fn_recursive_attn_processor(name, module, processor)
-
- # Copied from diffusers.models.transformers.transformer_sd3.SD3Transformer2DModel.fuse_qkv_projections
- def fuse_qkv_projections(self):
- """
- Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value)
- are fused. For cross-attention modules, key and value projection matrices are fused.
-
-
-
- This API is 🧪 experimental.
-
-
- """
- self.original_attn_processors = None
-
- for _, attn_processor in self.attn_processors.items():
- if "Added" in str(attn_processor.__class__.__name__):
- raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.")
-
- self.original_attn_processors = self.attn_processors
-
- for module in self.modules():
- if isinstance(module, Attention):
- module.fuse_projections(fuse=True)
-
- self.set_attn_processor(FusedJointAttnProcessor2_0())
-
- # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections
- def unfuse_qkv_projections(self):
- """Disables the fused QKV projection if enabled.
-
-
-
- This API is 🧪 experimental.
-
-
-
- """
- if self.original_attn_processors is not None:
- self.set_attn_processor(self.original_attn_processors)
-
- def _set_gradient_checkpointing(self, module, value=False):
- if hasattr(module, "gradient_checkpointing"):
- module.gradient_checkpointing = value
-
- @classmethod
- def from_transformer(
- cls, transformer, num_layers=12, num_extra_conditioning_channels=1, load_weights_from_transformer=True
- ):
- config = transformer.config
- config["num_layers"] = num_layers or config.num_layers
- config["extra_conditioning_channels"] = num_extra_conditioning_channels
- controlnet = cls(**config)
-
- if load_weights_from_transformer:
- controlnet.pos_embed.load_state_dict(transformer.pos_embed.state_dict())
- controlnet.time_text_embed.load_state_dict(transformer.time_text_embed.state_dict())
- controlnet.context_embedder.load_state_dict(transformer.context_embedder.state_dict())
- controlnet.transformer_blocks.load_state_dict(transformer.transformer_blocks.state_dict(), strict=False)
-
- controlnet.pos_embed_input = zero_module(controlnet.pos_embed_input)
-
- return controlnet
-
- def forward(
- self,
- hidden_states: torch.FloatTensor,
- controlnet_cond: torch.Tensor,
- conditioning_scale: float = 1.0,
- encoder_hidden_states: torch.FloatTensor = None,
- pooled_projections: torch.FloatTensor = None,
- timestep: torch.LongTensor = None,
- joint_attention_kwargs: Optional[Dict[str, Any]] = None,
- return_dict: bool = True,
- ) -> Union[torch.FloatTensor, Transformer2DModelOutput]:
- """
- The [`SD3Transformer2DModel`] forward method.
-
- Args:
- hidden_states (`torch.FloatTensor` of shape `(batch size, channel, height, width)`):
- Input `hidden_states`.
- controlnet_cond (`torch.Tensor`):
- The conditional input tensor of shape `(batch_size, sequence_length, hidden_size)`.
- conditioning_scale (`float`, defaults to `1.0`):
- The scale factor for ControlNet outputs.
- encoder_hidden_states (`torch.FloatTensor` of shape `(batch size, sequence_len, embed_dims)`):
- Conditional embeddings (embeddings computed from the input conditions such as prompts) to use.
- pooled_projections (`torch.FloatTensor` of shape `(batch_size, projection_dim)`): Embeddings projected
- from the embeddings of input conditions.
- timestep ( `torch.LongTensor`):
- Used to indicate denoising step.
- joint_attention_kwargs (`dict`, *optional*):
- A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
- `self.processor` in
- [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
- return_dict (`bool`, *optional*, defaults to `True`):
- Whether or not to return a [`~models.transformer_2d.Transformer2DModelOutput`] instead of a plain
- tuple.
-
- Returns:
- If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
- `tuple` where the first element is the sample tensor.
- """
- if joint_attention_kwargs is not None:
- joint_attention_kwargs = joint_attention_kwargs.copy()
- lora_scale = joint_attention_kwargs.pop("scale", 1.0)
- else:
- lora_scale = 1.0
-
- if USE_PEFT_BACKEND:
- # weight the lora layers by setting `lora_scale` for each PEFT layer
- scale_lora_layers(self, lora_scale)
- else:
- if joint_attention_kwargs is not None and joint_attention_kwargs.get("scale", None) is not None:
- logger.warning(
- "Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective."
- )
-
- hidden_states = self.pos_embed(hidden_states) # takes care of adding positional embeddings too.
- temb = self.time_text_embed(timestep, pooled_projections)
- encoder_hidden_states = self.context_embedder(encoder_hidden_states)
-
- # add
- hidden_states = hidden_states + self.pos_embed_input(controlnet_cond)
-
- block_res_samples = ()
-
- for block in self.transformer_blocks:
- if self.training and self.gradient_checkpointing:
-
- def create_custom_forward(module, return_dict=None):
- def custom_forward(*inputs):
- if return_dict is not None:
- return module(*inputs, return_dict=return_dict)
- else:
- return module(*inputs)
-
- return custom_forward
-
- ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
- encoder_hidden_states, hidden_states = torch.utils.checkpoint.checkpoint(
- create_custom_forward(block),
- hidden_states,
- encoder_hidden_states,
- temb,
- **ckpt_kwargs,
- )
-
- else:
- encoder_hidden_states, hidden_states = block(
- hidden_states=hidden_states, encoder_hidden_states=encoder_hidden_states, temb=temb
- )
-
- block_res_samples = block_res_samples + (hidden_states,)
-
- controlnet_block_res_samples = ()
- for block_res_sample, controlnet_block in zip(block_res_samples, self.controlnet_blocks):
- block_res_sample = controlnet_block(block_res_sample)
- controlnet_block_res_samples = controlnet_block_res_samples + (block_res_sample,)
-
- # 6. scaling
- controlnet_block_res_samples = [sample * conditioning_scale for sample in controlnet_block_res_samples]
-
- if USE_PEFT_BACKEND:
- # remove `lora_scale` from each PEFT layer
- unscale_lora_layers(self, lora_scale)
-
- if not return_dict:
- return (controlnet_block_res_samples,)
-
- return SD3ControlNetOutput(controlnet_block_samples=controlnet_block_res_samples)
-
-
-class SD3MultiControlNetModel(ModelMixin):
- r"""
- `SD3ControlNetModel` wrapper class for Multi-SD3ControlNet
-
- This module is a wrapper for multiple instances of the `SD3ControlNetModel`. The `forward()` API is designed to be
- compatible with `SD3ControlNetModel`.
-
- Args:
- controlnets (`List[SD3ControlNetModel]`):
- Provides additional conditioning to the unet during the denoising process. You must set multiple
- `SD3ControlNetModel` as a list.
- """
+class SD3ControlNetOutput(SD3ControlNetOutput):
+ def __init__(self, *args, **kwargs):
+ deprecation_message = "Importing `SD3ControlNetOutput` from `diffusers.models.controlnet_sd3` is deprecated and this will be removed in a future version. Please use `from diffusers.models.controlnets.controlnet_sd3 import SD3ControlNetOutput`, instead."
+ deprecate("SD3ControlNetOutput", "0.34", deprecation_message)
+ super().__init__(*args, **kwargs)
- def __init__(self, controlnets):
- super().__init__()
- self.nets = nn.ModuleList(controlnets)
- def forward(
- self,
- hidden_states: torch.FloatTensor,
- controlnet_cond: List[torch.tensor],
- conditioning_scale: List[float],
- pooled_projections: torch.FloatTensor,
- encoder_hidden_states: torch.FloatTensor = None,
- timestep: torch.LongTensor = None,
- joint_attention_kwargs: Optional[Dict[str, Any]] = None,
- return_dict: bool = True,
- ) -> Union[SD3ControlNetOutput, Tuple]:
- for i, (image, scale, controlnet) in enumerate(zip(controlnet_cond, conditioning_scale, self.nets)):
- block_samples = controlnet(
- hidden_states=hidden_states,
- timestep=timestep,
- encoder_hidden_states=encoder_hidden_states,
- pooled_projections=pooled_projections,
- controlnet_cond=image,
- conditioning_scale=scale,
- joint_attention_kwargs=joint_attention_kwargs,
- return_dict=return_dict,
- )
+class SD3ControlNetModel(SD3ControlNetModel):
+ def __init__(self, *args, **kwargs):
+ deprecation_message = "Importing `SD3ControlNetModel` from `diffusers.models.controlnet_sd3` is deprecated and this will be removed in a future version. Please use `from diffusers.models.controlnets.controlnet_sd3 import SD3ControlNetModel`, instead."
+ deprecate("SD3ControlNetModel", "0.34", deprecation_message)
+ super().__init__(*args, **kwargs)
- # merge samples
- if i == 0:
- control_block_samples = block_samples
- else:
- control_block_samples = [
- control_block_sample + block_sample
- for control_block_sample, block_sample in zip(control_block_samples[0], block_samples[0])
- ]
- control_block_samples = (tuple(control_block_samples),)
- return control_block_samples
+class SD3MultiControlNetModel(SD3MultiControlNetModel):
+ def __init__(self, *args, **kwargs):
+ deprecation_message = "Importing `SD3MultiControlNetModel` from `diffusers.models.controlnet_sd3` is deprecated and this will be removed in a future version. Please use `from diffusers.models.controlnets.controlnet_sd3 import SD3MultiControlNetModel`, instead."
+ deprecate("SD3MultiControlNetModel", "0.34", deprecation_message)
+ super().__init__(*args, **kwargs)
diff --git a/src/diffusers/models/controlnet_sparsectrl.py b/src/diffusers/models/controlnet_sparsectrl.py
index fa37e1f9e393..1ccbd385b9a6 100644
--- a/src/diffusers/models/controlnet_sparsectrl.py
+++ b/src/diffusers/models/controlnet_sparsectrl.py
@@ -12,777 +12,35 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from dataclasses import dataclass
-from typing import Any, Dict, List, Optional, Tuple, Union
-import torch
-from torch import nn
-from torch.nn import functional as F
-
-from ..configuration_utils import ConfigMixin, register_to_config
-from ..loaders import FromOriginalModelMixin
-from ..utils import BaseOutput, logging
-from .attention_processor import (
- ADDED_KV_ATTENTION_PROCESSORS,
- CROSS_ATTENTION_PROCESSORS,
- AttentionProcessor,
- AttnAddedKVProcessor,
- AttnProcessor,
+from ..utils import deprecate, logging
+from .controlnets.controlnet_sparsectrl import ( # noqa
+ SparseControlNetConditioningEmbedding,
+ SparseControlNetModel,
+ SparseControlNetOutput,
+ zero_module,
)
-from .embeddings import TimestepEmbedding, Timesteps
-from .modeling_utils import ModelMixin
-from .unets.unet_2d_blocks import UNetMidBlock2DCrossAttn
-from .unets.unet_2d_condition import UNet2DConditionModel
-from .unets.unet_motion_model import CrossAttnDownBlockMotion, DownBlockMotion
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
-@dataclass
-class SparseControlNetOutput(BaseOutput):
- """
- The output of [`SparseControlNetModel`].
-
- Args:
- down_block_res_samples (`tuple[torch.Tensor]`):
- A tuple of downsample activations at different resolutions for each downsampling block. Each tensor should
- be of shape `(batch_size, channel * resolution, height //resolution, width // resolution)`. Output can be
- used to condition the original UNet's downsampling activations.
- mid_down_block_re_sample (`torch.Tensor`):
- The activation of the middle block (the lowest sample resolution). Each tensor should be of shape
- `(batch_size, channel * lowest_resolution, height // lowest_resolution, width // lowest_resolution)`.
- Output can be used to condition the original UNet's middle block activation.
- """
-
- down_block_res_samples: Tuple[torch.Tensor]
- mid_block_res_sample: torch.Tensor
-
-
-class SparseControlNetConditioningEmbedding(nn.Module):
- def __init__(
- self,
- conditioning_embedding_channels: int,
- conditioning_channels: int = 3,
- block_out_channels: Tuple[int, ...] = (16, 32, 96, 256),
- ):
- super().__init__()
-
- self.conv_in = nn.Conv2d(conditioning_channels, block_out_channels[0], kernel_size=3, padding=1)
- self.blocks = nn.ModuleList([])
-
- for i in range(len(block_out_channels) - 1):
- channel_in = block_out_channels[i]
- channel_out = block_out_channels[i + 1]
- self.blocks.append(nn.Conv2d(channel_in, channel_in, kernel_size=3, padding=1))
- self.blocks.append(nn.Conv2d(channel_in, channel_out, kernel_size=3, padding=1, stride=2))
-
- self.conv_out = zero_module(
- nn.Conv2d(block_out_channels[-1], conditioning_embedding_channels, kernel_size=3, padding=1)
- )
-
- def forward(self, conditioning: torch.Tensor) -> torch.Tensor:
- embedding = self.conv_in(conditioning)
- embedding = F.silu(embedding)
-
- for block in self.blocks:
- embedding = block(embedding)
- embedding = F.silu(embedding)
-
- embedding = self.conv_out(embedding)
- return embedding
-
-
-class SparseControlNetModel(ModelMixin, ConfigMixin, FromOriginalModelMixin):
- """
- A SparseControlNet model as described in [SparseCtrl: Adding Sparse Controls to Text-to-Video Diffusion
- Models](https://arxiv.org/abs/2311.16933).
-
- Args:
- in_channels (`int`, defaults to 4):
- The number of channels in the input sample.
- conditioning_channels (`int`, defaults to 4):
- The number of input channels in the controlnet conditional embedding module. If
- `concat_condition_embedding` is True, the value provided here is incremented by 1.
- flip_sin_to_cos (`bool`, defaults to `True`):
- Whether to flip the sin to cos in the time embedding.
- freq_shift (`int`, defaults to 0):
- The frequency shift to apply to the time embedding.
- down_block_types (`tuple[str]`, defaults to `("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")`):
- The tuple of downsample blocks to use.
- only_cross_attention (`Union[bool, Tuple[bool]]`, defaults to `False`):
- block_out_channels (`tuple[int]`, defaults to `(320, 640, 1280, 1280)`):
- The tuple of output channels for each block.
- layers_per_block (`int`, defaults to 2):
- The number of layers per block.
- downsample_padding (`int`, defaults to 1):
- The padding to use for the downsampling convolution.
- mid_block_scale_factor (`float`, defaults to 1):
- The scale factor to use for the mid block.
- act_fn (`str`, defaults to "silu"):
- The activation function to use.
- norm_num_groups (`int`, *optional*, defaults to 32):
- The number of groups to use for the normalization. If None, normalization and activation layers is skipped
- in post-processing.
- norm_eps (`float`, defaults to 1e-5):
- The epsilon to use for the normalization.
- cross_attention_dim (`int`, defaults to 1280):
- The dimension of the cross attention features.
- transformer_layers_per_block (`int` or `Tuple[int]`, *optional*, defaults to 1):
- The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`]. Only relevant for
- [`~models.unet_2d_blocks.CrossAttnDownBlock2D`], [`~models.unet_2d_blocks.CrossAttnUpBlock2D`],
- [`~models.unet_2d_blocks.UNetMidBlock2DCrossAttn`].
- transformer_layers_per_mid_block (`int` or `Tuple[int]`, *optional*, defaults to 1):
- The number of transformer layers to use in each layer in the middle block.
- attention_head_dim (`int` or `Tuple[int]`, defaults to 8):
- The dimension of the attention heads.
- num_attention_heads (`int` or `Tuple[int]`, *optional*):
- The number of heads to use for multi-head attention.
- use_linear_projection (`bool`, defaults to `False`):
- upcast_attention (`bool`, defaults to `False`):
- resnet_time_scale_shift (`str`, defaults to `"default"`):
- Time scale shift config for ResNet blocks (see `ResnetBlock2D`). Choose from `default` or `scale_shift`.
- conditioning_embedding_out_channels (`Tuple[int]`, defaults to `(16, 32, 96, 256)`):
- The tuple of output channel for each block in the `conditioning_embedding` layer.
- global_pool_conditions (`bool`, defaults to `False`):
- TODO(Patrick) - unused parameter
- controlnet_conditioning_channel_order (`str`, defaults to `rgb`):
- motion_max_seq_length (`int`, defaults to `32`):
- The maximum sequence length to use in the motion module.
- motion_num_attention_heads (`int` or `Tuple[int]`, defaults to `8`):
- The number of heads to use in each attention layer of the motion module.
- concat_conditioning_mask (`bool`, defaults to `True`):
- use_simplified_condition_embedding (`bool`, defaults to `True`):
- """
-
- _supports_gradient_checkpointing = True
-
- @register_to_config
- def __init__(
- self,
- in_channels: int = 4,
- conditioning_channels: int = 4,
- flip_sin_to_cos: bool = True,
- freq_shift: int = 0,
- down_block_types: Tuple[str, ...] = (
- "CrossAttnDownBlockMotion",
- "CrossAttnDownBlockMotion",
- "CrossAttnDownBlockMotion",
- "DownBlockMotion",
- ),
- only_cross_attention: Union[bool, Tuple[bool]] = False,
- block_out_channels: Tuple[int, ...] = (320, 640, 1280, 1280),
- layers_per_block: int = 2,
- downsample_padding: int = 1,
- mid_block_scale_factor: float = 1,
- act_fn: str = "silu",
- norm_num_groups: Optional[int] = 32,
- norm_eps: float = 1e-5,
- cross_attention_dim: int = 768,
- transformer_layers_per_block: Union[int, Tuple[int, ...]] = 1,
- transformer_layers_per_mid_block: Optional[Union[int, Tuple[int]]] = None,
- temporal_transformer_layers_per_block: Union[int, Tuple[int, ...]] = 1,
- attention_head_dim: Union[int, Tuple[int, ...]] = 8,
- num_attention_heads: Optional[Union[int, Tuple[int, ...]]] = None,
- use_linear_projection: bool = False,
- upcast_attention: bool = False,
- resnet_time_scale_shift: str = "default",
- conditioning_embedding_out_channels: Optional[Tuple[int, ...]] = (16, 32, 96, 256),
- global_pool_conditions: bool = False,
- controlnet_conditioning_channel_order: str = "rgb",
- motion_max_seq_length: int = 32,
- motion_num_attention_heads: int = 8,
- concat_conditioning_mask: bool = True,
- use_simplified_condition_embedding: bool = True,
- ):
- super().__init__()
- self.use_simplified_condition_embedding = use_simplified_condition_embedding
-
- # If `num_attention_heads` is not defined (which is the case for most models)
- # it will default to `attention_head_dim`. This looks weird upon first reading it and it is.
- # The reason for this behavior is to correct for incorrectly named variables that were introduced
- # when this library was created. The incorrect naming was only discovered much later in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131
- # Changing `attention_head_dim` to `num_attention_heads` for 40,000+ configurations is too backwards breaking
- # which is why we correct for the naming here.
- num_attention_heads = num_attention_heads or attention_head_dim
-
- # Check inputs
- if len(block_out_channels) != len(down_block_types):
- raise ValueError(
- f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}."
- )
-
- if not isinstance(only_cross_attention, bool) and len(only_cross_attention) != len(down_block_types):
- raise ValueError(
- f"Must provide the same number of `only_cross_attention` as `down_block_types`. `only_cross_attention`: {only_cross_attention}. `down_block_types`: {down_block_types}."
- )
-
- if not isinstance(num_attention_heads, int) and len(num_attention_heads) != len(down_block_types):
- raise ValueError(
- f"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}."
- )
-
- if isinstance(transformer_layers_per_block, int):
- transformer_layers_per_block = [transformer_layers_per_block] * len(down_block_types)
- if isinstance(temporal_transformer_layers_per_block, int):
- temporal_transformer_layers_per_block = [temporal_transformer_layers_per_block] * len(down_block_types)
-
- # input
- conv_in_kernel = 3
- conv_in_padding = (conv_in_kernel - 1) // 2
- self.conv_in = nn.Conv2d(
- in_channels, block_out_channels[0], kernel_size=conv_in_kernel, padding=conv_in_padding
- )
-
- if concat_conditioning_mask:
- conditioning_channels = conditioning_channels + 1
-
- self.concat_conditioning_mask = concat_conditioning_mask
-
- # control net conditioning embedding
- if use_simplified_condition_embedding:
- self.controlnet_cond_embedding = zero_module(
- nn.Conv2d(conditioning_channels, block_out_channels[0], kernel_size=3, padding=1)
- )
- else:
- self.controlnet_cond_embedding = SparseControlNetConditioningEmbedding(
- conditioning_embedding_channels=block_out_channels[0],
- block_out_channels=conditioning_embedding_out_channels,
- conditioning_channels=conditioning_channels,
- )
-
- # time
- time_embed_dim = block_out_channels[0] * 4
- self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift)
- timestep_input_dim = block_out_channels[0]
-
- self.time_embedding = TimestepEmbedding(
- timestep_input_dim,
- time_embed_dim,
- act_fn=act_fn,
- )
-
- self.down_blocks = nn.ModuleList([])
- self.controlnet_down_blocks = nn.ModuleList([])
-
- if isinstance(cross_attention_dim, int):
- cross_attention_dim = (cross_attention_dim,) * len(down_block_types)
-
- if isinstance(only_cross_attention, bool):
- only_cross_attention = [only_cross_attention] * len(down_block_types)
-
- if isinstance(attention_head_dim, int):
- attention_head_dim = (attention_head_dim,) * len(down_block_types)
-
- if isinstance(num_attention_heads, int):
- num_attention_heads = (num_attention_heads,) * len(down_block_types)
-
- if isinstance(motion_num_attention_heads, int):
- motion_num_attention_heads = (motion_num_attention_heads,) * len(down_block_types)
-
- # down
- output_channel = block_out_channels[0]
-
- controlnet_block = nn.Conv2d(output_channel, output_channel, kernel_size=1)
- controlnet_block = zero_module(controlnet_block)
- self.controlnet_down_blocks.append(controlnet_block)
-
- for i, down_block_type in enumerate(down_block_types):
- input_channel = output_channel
- output_channel = block_out_channels[i]
- is_final_block = i == len(block_out_channels) - 1
-
- if down_block_type == "CrossAttnDownBlockMotion":
- down_block = CrossAttnDownBlockMotion(
- in_channels=input_channel,
- out_channels=output_channel,
- temb_channels=time_embed_dim,
- dropout=0,
- num_layers=layers_per_block,
- transformer_layers_per_block=transformer_layers_per_block[i],
- resnet_eps=norm_eps,
- resnet_time_scale_shift=resnet_time_scale_shift,
- resnet_act_fn=act_fn,
- resnet_groups=norm_num_groups,
- resnet_pre_norm=True,
- num_attention_heads=num_attention_heads[i],
- cross_attention_dim=cross_attention_dim[i],
- add_downsample=not is_final_block,
- dual_cross_attention=False,
- use_linear_projection=use_linear_projection,
- only_cross_attention=only_cross_attention[i],
- upcast_attention=upcast_attention,
- temporal_num_attention_heads=motion_num_attention_heads[i],
- temporal_max_seq_length=motion_max_seq_length,
- temporal_transformer_layers_per_block=temporal_transformer_layers_per_block[i],
- temporal_double_self_attention=False,
- )
- elif down_block_type == "DownBlockMotion":
- down_block = DownBlockMotion(
- in_channels=input_channel,
- out_channels=output_channel,
- temb_channels=time_embed_dim,
- dropout=0,
- num_layers=layers_per_block,
- resnet_eps=norm_eps,
- resnet_time_scale_shift=resnet_time_scale_shift,
- resnet_act_fn=act_fn,
- resnet_groups=norm_num_groups,
- resnet_pre_norm=True,
- add_downsample=not is_final_block,
- temporal_num_attention_heads=motion_num_attention_heads[i],
- temporal_max_seq_length=motion_max_seq_length,
- temporal_transformer_layers_per_block=temporal_transformer_layers_per_block[i],
- temporal_double_self_attention=False,
- )
- else:
- raise ValueError(
- "Invalid `block_type` encountered. Must be one of `CrossAttnDownBlockMotion` or `DownBlockMotion`"
- )
-
- self.down_blocks.append(down_block)
-
- for _ in range(layers_per_block):
- controlnet_block = nn.Conv2d(output_channel, output_channel, kernel_size=1)
- controlnet_block = zero_module(controlnet_block)
- self.controlnet_down_blocks.append(controlnet_block)
-
- if not is_final_block:
- controlnet_block = nn.Conv2d(output_channel, output_channel, kernel_size=1)
- controlnet_block = zero_module(controlnet_block)
- self.controlnet_down_blocks.append(controlnet_block)
-
- # mid
- mid_block_channels = block_out_channels[-1]
-
- controlnet_block = nn.Conv2d(mid_block_channels, mid_block_channels, kernel_size=1)
- controlnet_block = zero_module(controlnet_block)
- self.controlnet_mid_block = controlnet_block
-
- if transformer_layers_per_mid_block is None:
- transformer_layers_per_mid_block = (
- transformer_layers_per_block[-1] if isinstance(transformer_layers_per_block[-1], int) else 1
- )
-
- self.mid_block = UNetMidBlock2DCrossAttn(
- in_channels=mid_block_channels,
- temb_channels=time_embed_dim,
- dropout=0,
- num_layers=1,
- transformer_layers_per_block=transformer_layers_per_mid_block,
- resnet_eps=norm_eps,
- resnet_time_scale_shift=resnet_time_scale_shift,
- resnet_act_fn=act_fn,
- resnet_groups=norm_num_groups,
- resnet_pre_norm=True,
- num_attention_heads=num_attention_heads[-1],
- output_scale_factor=mid_block_scale_factor,
- cross_attention_dim=cross_attention_dim[-1],
- dual_cross_attention=False,
- use_linear_projection=use_linear_projection,
- upcast_attention=upcast_attention,
- attention_type="default",
- )
-
- @classmethod
- def from_unet(
- cls,
- unet: UNet2DConditionModel,
- controlnet_conditioning_channel_order: str = "rgb",
- conditioning_embedding_out_channels: Optional[Tuple[int, ...]] = (16, 32, 96, 256),
- load_weights_from_unet: bool = True,
- conditioning_channels: int = 3,
- ) -> "SparseControlNetModel":
- r"""
- Instantiate a [`SparseControlNetModel`] from [`UNet2DConditionModel`].
-
- Parameters:
- unet (`UNet2DConditionModel`):
- The UNet model weights to copy to the [`SparseControlNetModel`]. All configuration options are also
- copied where applicable.
- """
- transformer_layers_per_block = (
- unet.config.transformer_layers_per_block if "transformer_layers_per_block" in unet.config else 1
- )
- down_block_types = unet.config.down_block_types
-
- for i in range(len(down_block_types)):
- if "CrossAttn" in down_block_types[i]:
- down_block_types[i] = "CrossAttnDownBlockMotion"
- elif "Down" in down_block_types[i]:
- down_block_types[i] = "DownBlockMotion"
- else:
- raise ValueError("Invalid `block_type` encountered. Must be a cross-attention or down block")
-
- controlnet = cls(
- in_channels=unet.config.in_channels,
- conditioning_channels=conditioning_channels,
- flip_sin_to_cos=unet.config.flip_sin_to_cos,
- freq_shift=unet.config.freq_shift,
- down_block_types=unet.config.down_block_types,
- only_cross_attention=unet.config.only_cross_attention,
- block_out_channels=unet.config.block_out_channels,
- layers_per_block=unet.config.layers_per_block,
- downsample_padding=unet.config.downsample_padding,
- mid_block_scale_factor=unet.config.mid_block_scale_factor,
- act_fn=unet.config.act_fn,
- norm_num_groups=unet.config.norm_num_groups,
- norm_eps=unet.config.norm_eps,
- cross_attention_dim=unet.config.cross_attention_dim,
- transformer_layers_per_block=transformer_layers_per_block,
- attention_head_dim=unet.config.attention_head_dim,
- num_attention_heads=unet.config.num_attention_heads,
- use_linear_projection=unet.config.use_linear_projection,
- upcast_attention=unet.config.upcast_attention,
- resnet_time_scale_shift=unet.config.resnet_time_scale_shift,
- conditioning_embedding_out_channels=conditioning_embedding_out_channels,
- controlnet_conditioning_channel_order=controlnet_conditioning_channel_order,
- )
-
- if load_weights_from_unet:
- controlnet.conv_in.load_state_dict(unet.conv_in.state_dict(), strict=False)
- controlnet.time_proj.load_state_dict(unet.time_proj.state_dict(), strict=False)
- controlnet.time_embedding.load_state_dict(unet.time_embedding.state_dict(), strict=False)
- controlnet.down_blocks.load_state_dict(unet.down_blocks.state_dict(), strict=False)
- controlnet.mid_block.load_state_dict(unet.mid_block.state_dict(), strict=False)
-
- return controlnet
-
- @property
- # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
- def attn_processors(self) -> Dict[str, AttentionProcessor]:
- r"""
- Returns:
- `dict` of attention processors: A dictionary containing all attention processors used in the model with
- indexed by its weight name.
- """
- # set recursively
- processors = {}
-
- def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
- if hasattr(module, "get_processor"):
- processors[f"{name}.processor"] = module.get_processor()
-
- for sub_name, child in module.named_children():
- fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
-
- return processors
-
- for name, module in self.named_children():
- fn_recursive_add_processors(name, module, processors)
-
- return processors
-
- # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
- def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
- r"""
- Sets the attention processor to use to compute attention.
-
- Parameters:
- processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
- The instantiated processor class or a dictionary of processor classes that will be set as the processor
- for **all** `Attention` layers.
-
- If `processor` is a dict, the key needs to define the path to the corresponding cross attention
- processor. This is strongly recommended when setting trainable attention processors.
-
- """
- count = len(self.attn_processors.keys())
-
- if isinstance(processor, dict) and len(processor) != count:
- raise ValueError(
- f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
- f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
- )
-
- def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
- if hasattr(module, "set_processor"):
- if not isinstance(processor, dict):
- module.set_processor(processor)
- else:
- module.set_processor(processor.pop(f"{name}.processor"))
-
- for sub_name, child in module.named_children():
- fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
-
- for name, module in self.named_children():
- fn_recursive_attn_processor(name, module, processor)
-
- # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor
- def set_default_attn_processor(self):
- """
- Disables custom attention processors and sets the default attention implementation.
- """
- if all(proc.__class__ in ADDED_KV_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
- processor = AttnAddedKVProcessor()
- elif all(proc.__class__ in CROSS_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
- processor = AttnProcessor()
- else:
- raise ValueError(
- f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}"
- )
-
- self.set_attn_processor(processor)
-
- # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attention_slice
- def set_attention_slice(self, slice_size: Union[str, int, List[int]]) -> None:
- r"""
- Enable sliced attention computation.
-
- When this option is enabled, the attention module splits the input tensor in slices to compute attention in
- several steps. This is useful for saving some memory in exchange for a small decrease in speed.
-
- Args:
- slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`):
- When `"auto"`, input to the attention heads is halved, so attention is computed in two steps. If
- `"max"`, maximum amount of memory is saved by running only one slice at a time. If a number is
- provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim`
- must be a multiple of `slice_size`.
- """
- sliceable_head_dims = []
-
- def fn_recursive_retrieve_sliceable_dims(module: torch.nn.Module):
- if hasattr(module, "set_attention_slice"):
- sliceable_head_dims.append(module.sliceable_head_dim)
-
- for child in module.children():
- fn_recursive_retrieve_sliceable_dims(child)
-
- # retrieve number of attention layers
- for module in self.children():
- fn_recursive_retrieve_sliceable_dims(module)
-
- num_sliceable_layers = len(sliceable_head_dims)
-
- if slice_size == "auto":
- # half the attention head size is usually a good trade-off between
- # speed and memory
- slice_size = [dim // 2 for dim in sliceable_head_dims]
- elif slice_size == "max":
- # make smallest slice possible
- slice_size = num_sliceable_layers * [1]
-
- slice_size = num_sliceable_layers * [slice_size] if not isinstance(slice_size, list) else slice_size
-
- if len(slice_size) != len(sliceable_head_dims):
- raise ValueError(
- f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different"
- f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}."
- )
-
- for i in range(len(slice_size)):
- size = slice_size[i]
- dim = sliceable_head_dims[i]
- if size is not None and size > dim:
- raise ValueError(f"size {size} has to be smaller or equal to {dim}.")
-
- # Recursively walk through all the children.
- # Any children which exposes the set_attention_slice method
- # gets the message
- def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[int]):
- if hasattr(module, "set_attention_slice"):
- module.set_attention_slice(slice_size.pop())
-
- for child in module.children():
- fn_recursive_set_attention_slice(child, slice_size)
-
- reversed_slice_size = list(reversed(slice_size))
- for module in self.children():
- fn_recursive_set_attention_slice(module, reversed_slice_size)
-
- def _set_gradient_checkpointing(self, module, value: bool = False) -> None:
- if isinstance(module, (CrossAttnDownBlockMotion, DownBlockMotion, UNetMidBlock2DCrossAttn)):
- module.gradient_checkpointing = value
-
- def forward(
- self,
- sample: torch.Tensor,
- timestep: Union[torch.Tensor, float, int],
- encoder_hidden_states: torch.Tensor,
- controlnet_cond: torch.Tensor,
- conditioning_scale: float = 1.0,
- timestep_cond: Optional[torch.Tensor] = None,
- attention_mask: Optional[torch.Tensor] = None,
- cross_attention_kwargs: Optional[Dict[str, Any]] = None,
- conditioning_mask: Optional[torch.Tensor] = None,
- guess_mode: bool = False,
- return_dict: bool = True,
- ) -> Union[SparseControlNetOutput, Tuple[Tuple[torch.Tensor, ...], torch.Tensor]]:
- """
- The [`SparseControlNetModel`] forward method.
-
- Args:
- sample (`torch.Tensor`):
- The noisy input tensor.
- timestep (`Union[torch.Tensor, float, int]`):
- The number of timesteps to denoise an input.
- encoder_hidden_states (`torch.Tensor`):
- The encoder hidden states.
- controlnet_cond (`torch.Tensor`):
- The conditional input tensor of shape `(batch_size, sequence_length, hidden_size)`.
- conditioning_scale (`float`, defaults to `1.0`):
- The scale factor for ControlNet outputs.
- class_labels (`torch.Tensor`, *optional*, defaults to `None`):
- Optional class labels for conditioning. Their embeddings will be summed with the timestep embeddings.
- timestep_cond (`torch.Tensor`, *optional*, defaults to `None`):
- Additional conditional embeddings for timestep. If provided, the embeddings will be summed with the
- timestep_embedding passed through the `self.time_embedding` layer to obtain the final timestep
- embeddings.
- attention_mask (`torch.Tensor`, *optional*, defaults to `None`):
- An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. If `1` the mask
- is kept, otherwise if `0` it is discarded. Mask will be converted into a bias, which adds large
- negative values to the attention scores corresponding to "discard" tokens.
- added_cond_kwargs (`dict`):
- Additional conditions for the Stable Diffusion XL UNet.
- cross_attention_kwargs (`dict[str]`, *optional*, defaults to `None`):
- A kwargs dictionary that if specified is passed along to the `AttnProcessor`.
- guess_mode (`bool`, defaults to `False`):
- In this mode, the ControlNet encoder tries its best to recognize the input content of the input even if
- you remove all prompts. A `guidance_scale` between 3.0 and 5.0 is recommended.
- return_dict (`bool`, defaults to `True`):
- Whether or not to return a [`~models.controlnet.ControlNetOutput`] instead of a plain tuple.
- Returns:
- [`~models.controlnet.ControlNetOutput`] **or** `tuple`:
- If `return_dict` is `True`, a [`~models.controlnet.ControlNetOutput`] is returned, otherwise a tuple is
- returned where the first element is the sample tensor.
- """
- sample_batch_size, sample_channels, sample_num_frames, sample_height, sample_width = sample.shape
- sample = torch.zeros_like(sample)
-
- # check channel order
- channel_order = self.config.controlnet_conditioning_channel_order
-
- if channel_order == "rgb":
- # in rgb order by default
- ...
- elif channel_order == "bgr":
- controlnet_cond = torch.flip(controlnet_cond, dims=[1])
- else:
- raise ValueError(f"unknown `controlnet_conditioning_channel_order`: {channel_order}")
-
- # prepare attention_mask
- if attention_mask is not None:
- attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0
- attention_mask = attention_mask.unsqueeze(1)
-
- # 1. time
- timesteps = timestep
- if not torch.is_tensor(timesteps):
- # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
- # This would be a good case for the `match` statement (Python 3.10+)
- is_mps = sample.device.type == "mps"
- if isinstance(timestep, float):
- dtype = torch.float32 if is_mps else torch.float64
- else:
- dtype = torch.int32 if is_mps else torch.int64
- timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
- elif len(timesteps.shape) == 0:
- timesteps = timesteps[None].to(sample.device)
-
- # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
- timesteps = timesteps.expand(sample.shape[0])
-
- t_emb = self.time_proj(timesteps)
-
- # timesteps does not contain any weights and will always return f32 tensors
- # but time_embedding might actually be running in fp16. so we need to cast here.
- # there might be better ways to encapsulate this.
- t_emb = t_emb.to(dtype=sample.dtype)
-
- emb = self.time_embedding(t_emb, timestep_cond)
- emb = emb.repeat_interleave(sample_num_frames, dim=0)
-
- # 2. pre-process
- batch_size, channels, num_frames, height, width = sample.shape
-
- sample = sample.permute(0, 2, 1, 3, 4).reshape(batch_size * num_frames, channels, height, width)
- sample = self.conv_in(sample)
-
- batch_frames, channels, height, width = sample.shape
- sample = sample[:, None].reshape(sample_batch_size, sample_num_frames, channels, height, width)
-
- if self.concat_conditioning_mask:
- controlnet_cond = torch.cat([controlnet_cond, conditioning_mask], dim=1)
-
- batch_size, channels, num_frames, height, width = controlnet_cond.shape
- controlnet_cond = controlnet_cond.permute(0, 2, 1, 3, 4).reshape(
- batch_size * num_frames, channels, height, width
- )
- controlnet_cond = self.controlnet_cond_embedding(controlnet_cond)
- batch_frames, channels, height, width = controlnet_cond.shape
- controlnet_cond = controlnet_cond[:, None].reshape(batch_size, num_frames, channels, height, width)
-
- sample = sample + controlnet_cond
-
- batch_size, num_frames, channels, height, width = sample.shape
- sample = sample.reshape(sample_batch_size * sample_num_frames, channels, height, width)
-
- # 3. down
- down_block_res_samples = (sample,)
- for downsample_block in self.down_blocks:
- if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
- sample, res_samples = downsample_block(
- hidden_states=sample,
- temb=emb,
- encoder_hidden_states=encoder_hidden_states,
- attention_mask=attention_mask,
- num_frames=num_frames,
- cross_attention_kwargs=cross_attention_kwargs,
- )
- else:
- sample, res_samples = downsample_block(hidden_states=sample, temb=emb, num_frames=num_frames)
-
- down_block_res_samples += res_samples
-
- # 4. mid
- if self.mid_block is not None:
- if hasattr(self.mid_block, "has_cross_attention") and self.mid_block.has_cross_attention:
- sample = self.mid_block(
- sample,
- emb,
- encoder_hidden_states=encoder_hidden_states,
- attention_mask=attention_mask,
- cross_attention_kwargs=cross_attention_kwargs,
- )
- else:
- sample = self.mid_block(sample, emb)
-
- # 5. Control net blocks
- controlnet_down_block_res_samples = ()
-
- for down_block_res_sample, controlnet_block in zip(down_block_res_samples, self.controlnet_down_blocks):
- down_block_res_sample = controlnet_block(down_block_res_sample)
- controlnet_down_block_res_samples = controlnet_down_block_res_samples + (down_block_res_sample,)
-
- down_block_res_samples = controlnet_down_block_res_samples
- mid_block_res_sample = self.controlnet_mid_block(sample)
-
- # 6. scaling
- if guess_mode and not self.config.global_pool_conditions:
- scales = torch.logspace(-1, 0, len(down_block_res_samples) + 1, device=sample.device) # 0.1 to 1.0
- scales = scales * conditioning_scale
- down_block_res_samples = [sample * scale for sample, scale in zip(down_block_res_samples, scales)]
- mid_block_res_sample = mid_block_res_sample * scales[-1] # last one
- else:
- down_block_res_samples = [sample * conditioning_scale for sample in down_block_res_samples]
- mid_block_res_sample = mid_block_res_sample * conditioning_scale
-
- if self.config.global_pool_conditions:
- down_block_res_samples = [
- torch.mean(sample, dim=(2, 3), keepdim=True) for sample in down_block_res_samples
- ]
- mid_block_res_sample = torch.mean(mid_block_res_sample, dim=(2, 3), keepdim=True)
+class SparseControlNetOutput(SparseControlNetOutput):
+ def __init__(self, *args, **kwargs):
+ deprecation_message = "Importing `SparseControlNetOutput` from `diffusers.models.controlnet_sparsectrl` is deprecated and this will be removed in a future version. Please use `from diffusers.models.controlnets.controlnet_sparsectrl import SparseControlNetOutput`, instead."
+ deprecate("SparseControlNetOutput", "0.34", deprecation_message)
+ super().__init__(*args, **kwargs)
- if not return_dict:
- return (down_block_res_samples, mid_block_res_sample)
- return SparseControlNetOutput(
- down_block_res_samples=down_block_res_samples, mid_block_res_sample=mid_block_res_sample
- )
+class SparseControlNetConditioningEmbedding(SparseControlNetConditioningEmbedding):
+ def __init__(self, *args, **kwargs):
+ deprecation_message = "Importing `SparseControlNetConditioningEmbedding` from `diffusers.models.controlnet_sparsectrl` is deprecated and this will be removed in a future version. Please use `from diffusers.models.controlnets.controlnet_sparsectrl import SparseControlNetConditioningEmbedding`, instead."
+ deprecate("SparseControlNetConditioningEmbedding", "0.34", deprecation_message)
+ super().__init__(*args, **kwargs)
-# Copied from diffusers.models.controlnet.zero_module
-def zero_module(module: nn.Module) -> nn.Module:
- for p in module.parameters():
- nn.init.zeros_(p)
- return module
+class SparseControlNetModel(SparseControlNetModel):
+ def __init__(self, *args, **kwargs):
+ deprecation_message = "Importing `SparseControlNetModel` from `diffusers.models.controlnet_sparsectrl` is deprecated and this will be removed in a future version. Please use `from diffusers.models.controlnets.controlnet_sparsectrl import SparseControlNetModel`, instead."
+ deprecate("SparseControlNetModel", "0.34", deprecation_message)
+ super().__init__(*args, **kwargs)
diff --git a/src/diffusers/models/controlnets/__init__.py b/src/diffusers/models/controlnets/__init__.py
new file mode 100644
index 000000000000..3e4b3561e839
--- /dev/null
+++ b/src/diffusers/models/controlnets/__init__.py
@@ -0,0 +1,22 @@
+from ...utils import is_flax_available, is_torch_available
+
+
+if is_torch_available():
+ from .controlnet import ControlNetModel, ControlNetOutput
+ from .controlnet_flux import FluxControlNetModel, FluxControlNetOutput, FluxMultiControlNetModel
+ from .controlnet_hunyuan import (
+ HunyuanControlNetOutput,
+ HunyuanDiT2DControlNetModel,
+ HunyuanDiT2DMultiControlNetModel,
+ )
+ from .controlnet_sd3 import SD3ControlNetModel, SD3ControlNetOutput, SD3MultiControlNetModel
+ from .controlnet_sparsectrl import (
+ SparseControlNetConditioningEmbedding,
+ SparseControlNetModel,
+ SparseControlNetOutput,
+ )
+ from .controlnet_xs import ControlNetXSAdapter, ControlNetXSOutput, UNetControlNetXSModel
+ from .multicontrolnet import MultiControlNetModel
+
+if is_flax_available():
+ from .controlnet_flax import FlaxControlNetModel
diff --git a/src/diffusers/models/controlnets/controlnet.py b/src/diffusers/models/controlnets/controlnet.py
new file mode 100644
index 000000000000..bd00f6dd1906
--- /dev/null
+++ b/src/diffusers/models/controlnets/controlnet.py
@@ -0,0 +1,872 @@
+# Copyright 2024 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from dataclasses import dataclass
+from typing import Any, Dict, List, Optional, Tuple, Union
+
+import torch
+from torch import nn
+from torch.nn import functional as F
+
+from ...configuration_utils import ConfigMixin, register_to_config
+from ...loaders.single_file_model import FromOriginalModelMixin
+from ...utils import BaseOutput, logging
+from ..attention_processor import (
+ ADDED_KV_ATTENTION_PROCESSORS,
+ CROSS_ATTENTION_PROCESSORS,
+ AttentionProcessor,
+ AttnAddedKVProcessor,
+ AttnProcessor,
+)
+from ..embeddings import TextImageProjection, TextImageTimeEmbedding, TextTimeEmbedding, TimestepEmbedding, Timesteps
+from ..modeling_utils import ModelMixin
+from ..unets.unet_2d_blocks import (
+ CrossAttnDownBlock2D,
+ DownBlock2D,
+ UNetMidBlock2D,
+ UNetMidBlock2DCrossAttn,
+ get_down_block,
+)
+from ..unets.unet_2d_condition import UNet2DConditionModel
+
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+
+@dataclass
+class ControlNetOutput(BaseOutput):
+ """
+ The output of [`ControlNetModel`].
+
+ Args:
+ down_block_res_samples (`tuple[torch.Tensor]`):
+ A tuple of downsample activations at different resolutions for each downsampling block. Each tensor should
+ be of shape `(batch_size, channel * resolution, height //resolution, width // resolution)`. Output can be
+ used to condition the original UNet's downsampling activations.
+ mid_down_block_re_sample (`torch.Tensor`):
+ The activation of the middle block (the lowest sample resolution). Each tensor should be of shape
+ `(batch_size, channel * lowest_resolution, height // lowest_resolution, width // lowest_resolution)`.
+ Output can be used to condition the original UNet's middle block activation.
+ """
+
+ down_block_res_samples: Tuple[torch.Tensor]
+ mid_block_res_sample: torch.Tensor
+
+
+class ControlNetConditioningEmbedding(nn.Module):
+ """
+ Quoting from https://arxiv.org/abs/2302.05543: "Stable Diffusion uses a pre-processing method similar to VQ-GAN
+ [11] to convert the entire dataset of 512 × 512 images into smaller 64 × 64 “latent images” for stabilized
+ training. This requires ControlNets to convert image-based conditions to 64 × 64 feature space to match the
+ convolution size. We use a tiny network E(·) of four convolution layers with 4 × 4 kernels and 2 × 2 strides
+ (activated by ReLU, channels are 16, 32, 64, 128, initialized with Gaussian weights, trained jointly with the full
+ model) to encode image-space conditions ... into feature maps ..."
+ """
+
+ def __init__(
+ self,
+ conditioning_embedding_channels: int,
+ conditioning_channels: int = 3,
+ block_out_channels: Tuple[int, ...] = (16, 32, 96, 256),
+ ):
+ super().__init__()
+
+ self.conv_in = nn.Conv2d(conditioning_channels, block_out_channels[0], kernel_size=3, padding=1)
+
+ self.blocks = nn.ModuleList([])
+
+ for i in range(len(block_out_channels) - 1):
+ channel_in = block_out_channels[i]
+ channel_out = block_out_channels[i + 1]
+ self.blocks.append(nn.Conv2d(channel_in, channel_in, kernel_size=3, padding=1))
+ self.blocks.append(nn.Conv2d(channel_in, channel_out, kernel_size=3, padding=1, stride=2))
+
+ self.conv_out = zero_module(
+ nn.Conv2d(block_out_channels[-1], conditioning_embedding_channels, kernel_size=3, padding=1)
+ )
+
+ def forward(self, conditioning):
+ embedding = self.conv_in(conditioning)
+ embedding = F.silu(embedding)
+
+ for block in self.blocks:
+ embedding = block(embedding)
+ embedding = F.silu(embedding)
+
+ embedding = self.conv_out(embedding)
+
+ return embedding
+
+
+class ControlNetModel(ModelMixin, ConfigMixin, FromOriginalModelMixin):
+ """
+ A ControlNet model.
+
+ Args:
+ in_channels (`int`, defaults to 4):
+ The number of channels in the input sample.
+ flip_sin_to_cos (`bool`, defaults to `True`):
+ Whether to flip the sin to cos in the time embedding.
+ freq_shift (`int`, defaults to 0):
+ The frequency shift to apply to the time embedding.
+ down_block_types (`tuple[str]`, defaults to `("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")`):
+ The tuple of downsample blocks to use.
+ only_cross_attention (`Union[bool, Tuple[bool]]`, defaults to `False`):
+ block_out_channels (`tuple[int]`, defaults to `(320, 640, 1280, 1280)`):
+ The tuple of output channels for each block.
+ layers_per_block (`int`, defaults to 2):
+ The number of layers per block.
+ downsample_padding (`int`, defaults to 1):
+ The padding to use for the downsampling convolution.
+ mid_block_scale_factor (`float`, defaults to 1):
+ The scale factor to use for the mid block.
+ act_fn (`str`, defaults to "silu"):
+ The activation function to use.
+ norm_num_groups (`int`, *optional*, defaults to 32):
+ The number of groups to use for the normalization. If None, normalization and activation layers is skipped
+ in post-processing.
+ norm_eps (`float`, defaults to 1e-5):
+ The epsilon to use for the normalization.
+ cross_attention_dim (`int`, defaults to 1280):
+ The dimension of the cross attention features.
+ transformer_layers_per_block (`int` or `Tuple[int]`, *optional*, defaults to 1):
+ The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`]. Only relevant for
+ [`~models.unet_2d_blocks.CrossAttnDownBlock2D`], [`~models.unet_2d_blocks.CrossAttnUpBlock2D`],
+ [`~models.unet_2d_blocks.UNetMidBlock2DCrossAttn`].
+ encoder_hid_dim (`int`, *optional*, defaults to None):
+ If `encoder_hid_dim_type` is defined, `encoder_hidden_states` will be projected from `encoder_hid_dim`
+ dimension to `cross_attention_dim`.
+ encoder_hid_dim_type (`str`, *optional*, defaults to `None`):
+ If given, the `encoder_hidden_states` and potentially other embeddings are down-projected to text
+ embeddings of dimension `cross_attention` according to `encoder_hid_dim_type`.
+ attention_head_dim (`Union[int, Tuple[int]]`, defaults to 8):
+ The dimension of the attention heads.
+ use_linear_projection (`bool`, defaults to `False`):
+ class_embed_type (`str`, *optional*, defaults to `None`):
+ The type of class embedding to use which is ultimately summed with the time embeddings. Choose from None,
+ `"timestep"`, `"identity"`, `"projection"`, or `"simple_projection"`.
+ addition_embed_type (`str`, *optional*, defaults to `None`):
+ Configures an optional embedding which will be summed with the time embeddings. Choose from `None` or
+ "text". "text" will use the `TextTimeEmbedding` layer.
+ num_class_embeds (`int`, *optional*, defaults to 0):
+ Input dimension of the learnable embedding matrix to be projected to `time_embed_dim`, when performing
+ class conditioning with `class_embed_type` equal to `None`.
+ upcast_attention (`bool`, defaults to `False`):
+ resnet_time_scale_shift (`str`, defaults to `"default"`):
+ Time scale shift config for ResNet blocks (see `ResnetBlock2D`). Choose from `default` or `scale_shift`.
+ projection_class_embeddings_input_dim (`int`, *optional*, defaults to `None`):
+ The dimension of the `class_labels` input when `class_embed_type="projection"`. Required when
+ `class_embed_type="projection"`.
+ controlnet_conditioning_channel_order (`str`, defaults to `"rgb"`):
+ The channel order of conditional image. Will convert to `rgb` if it's `bgr`.
+ conditioning_embedding_out_channels (`tuple[int]`, *optional*, defaults to `(16, 32, 96, 256)`):
+ The tuple of output channel for each block in the `conditioning_embedding` layer.
+ global_pool_conditions (`bool`, defaults to `False`):
+ TODO(Patrick) - unused parameter.
+ addition_embed_type_num_heads (`int`, defaults to 64):
+ The number of heads to use for the `TextTimeEmbedding` layer.
+ """
+
+ _supports_gradient_checkpointing = True
+
+ @register_to_config
+ def __init__(
+ self,
+ in_channels: int = 4,
+ conditioning_channels: int = 3,
+ flip_sin_to_cos: bool = True,
+ freq_shift: int = 0,
+ down_block_types: Tuple[str, ...] = (
+ "CrossAttnDownBlock2D",
+ "CrossAttnDownBlock2D",
+ "CrossAttnDownBlock2D",
+ "DownBlock2D",
+ ),
+ mid_block_type: Optional[str] = "UNetMidBlock2DCrossAttn",
+ only_cross_attention: Union[bool, Tuple[bool]] = False,
+ block_out_channels: Tuple[int, ...] = (320, 640, 1280, 1280),
+ layers_per_block: int = 2,
+ downsample_padding: int = 1,
+ mid_block_scale_factor: float = 1,
+ act_fn: str = "silu",
+ norm_num_groups: Optional[int] = 32,
+ norm_eps: float = 1e-5,
+ cross_attention_dim: int = 1280,
+ transformer_layers_per_block: Union[int, Tuple[int, ...]] = 1,
+ encoder_hid_dim: Optional[int] = None,
+ encoder_hid_dim_type: Optional[str] = None,
+ attention_head_dim: Union[int, Tuple[int, ...]] = 8,
+ num_attention_heads: Optional[Union[int, Tuple[int, ...]]] = None,
+ use_linear_projection: bool = False,
+ class_embed_type: Optional[str] = None,
+ addition_embed_type: Optional[str] = None,
+ addition_time_embed_dim: Optional[int] = None,
+ num_class_embeds: Optional[int] = None,
+ upcast_attention: bool = False,
+ resnet_time_scale_shift: str = "default",
+ projection_class_embeddings_input_dim: Optional[int] = None,
+ controlnet_conditioning_channel_order: str = "rgb",
+ conditioning_embedding_out_channels: Optional[Tuple[int, ...]] = (16, 32, 96, 256),
+ global_pool_conditions: bool = False,
+ addition_embed_type_num_heads: int = 64,
+ ):
+ super().__init__()
+
+ # If `num_attention_heads` is not defined (which is the case for most models)
+ # it will default to `attention_head_dim`. This looks weird upon first reading it and it is.
+ # The reason for this behavior is to correct for incorrectly named variables that were introduced
+ # when this library was created. The incorrect naming was only discovered much later in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131
+ # Changing `attention_head_dim` to `num_attention_heads` for 40,000+ configurations is too backwards breaking
+ # which is why we correct for the naming here.
+ num_attention_heads = num_attention_heads or attention_head_dim
+
+ # Check inputs
+ if len(block_out_channels) != len(down_block_types):
+ raise ValueError(
+ f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}."
+ )
+
+ if not isinstance(only_cross_attention, bool) and len(only_cross_attention) != len(down_block_types):
+ raise ValueError(
+ f"Must provide the same number of `only_cross_attention` as `down_block_types`. `only_cross_attention`: {only_cross_attention}. `down_block_types`: {down_block_types}."
+ )
+
+ if not isinstance(num_attention_heads, int) and len(num_attention_heads) != len(down_block_types):
+ raise ValueError(
+ f"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}."
+ )
+
+ if isinstance(transformer_layers_per_block, int):
+ transformer_layers_per_block = [transformer_layers_per_block] * len(down_block_types)
+
+ # input
+ conv_in_kernel = 3
+ conv_in_padding = (conv_in_kernel - 1) // 2
+ self.conv_in = nn.Conv2d(
+ in_channels, block_out_channels[0], kernel_size=conv_in_kernel, padding=conv_in_padding
+ )
+
+ # time
+ time_embed_dim = block_out_channels[0] * 4
+ self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift)
+ timestep_input_dim = block_out_channels[0]
+ self.time_embedding = TimestepEmbedding(
+ timestep_input_dim,
+ time_embed_dim,
+ act_fn=act_fn,
+ )
+
+ if encoder_hid_dim_type is None and encoder_hid_dim is not None:
+ encoder_hid_dim_type = "text_proj"
+ self.register_to_config(encoder_hid_dim_type=encoder_hid_dim_type)
+ logger.info("encoder_hid_dim_type defaults to 'text_proj' as `encoder_hid_dim` is defined.")
+
+ if encoder_hid_dim is None and encoder_hid_dim_type is not None:
+ raise ValueError(
+ f"`encoder_hid_dim` has to be defined when `encoder_hid_dim_type` is set to {encoder_hid_dim_type}."
+ )
+
+ if encoder_hid_dim_type == "text_proj":
+ self.encoder_hid_proj = nn.Linear(encoder_hid_dim, cross_attention_dim)
+ elif encoder_hid_dim_type == "text_image_proj":
+ # image_embed_dim DOESN'T have to be `cross_attention_dim`. To not clutter the __init__ too much
+ # they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use
+ # case when `addition_embed_type == "text_image_proj"` (Kandinsky 2.1)`
+ self.encoder_hid_proj = TextImageProjection(
+ text_embed_dim=encoder_hid_dim,
+ image_embed_dim=cross_attention_dim,
+ cross_attention_dim=cross_attention_dim,
+ )
+
+ elif encoder_hid_dim_type is not None:
+ raise ValueError(
+ f"encoder_hid_dim_type: {encoder_hid_dim_type} must be None, 'text_proj' or 'text_image_proj'."
+ )
+ else:
+ self.encoder_hid_proj = None
+
+ # class embedding
+ if class_embed_type is None and num_class_embeds is not None:
+ self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim)
+ elif class_embed_type == "timestep":
+ self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim)
+ elif class_embed_type == "identity":
+ self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim)
+ elif class_embed_type == "projection":
+ if projection_class_embeddings_input_dim is None:
+ raise ValueError(
+ "`class_embed_type`: 'projection' requires `projection_class_embeddings_input_dim` be set"
+ )
+ # The projection `class_embed_type` is the same as the timestep `class_embed_type` except
+ # 1. the `class_labels` inputs are not first converted to sinusoidal embeddings
+ # 2. it projects from an arbitrary input dimension.
+ #
+ # Note that `TimestepEmbedding` is quite general, being mainly linear layers and activations.
+ # When used for embedding actual timesteps, the timesteps are first converted to sinusoidal embeddings.
+ # As a result, `TimestepEmbedding` can be passed arbitrary vectors.
+ self.class_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim)
+ else:
+ self.class_embedding = None
+
+ if addition_embed_type == "text":
+ if encoder_hid_dim is not None:
+ text_time_embedding_from_dim = encoder_hid_dim
+ else:
+ text_time_embedding_from_dim = cross_attention_dim
+
+ self.add_embedding = TextTimeEmbedding(
+ text_time_embedding_from_dim, time_embed_dim, num_heads=addition_embed_type_num_heads
+ )
+ elif addition_embed_type == "text_image":
+ # text_embed_dim and image_embed_dim DON'T have to be `cross_attention_dim`. To not clutter the __init__ too much
+ # they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use
+ # case when `addition_embed_type == "text_image"` (Kandinsky 2.1)`
+ self.add_embedding = TextImageTimeEmbedding(
+ text_embed_dim=cross_attention_dim, image_embed_dim=cross_attention_dim, time_embed_dim=time_embed_dim
+ )
+ elif addition_embed_type == "text_time":
+ self.add_time_proj = Timesteps(addition_time_embed_dim, flip_sin_to_cos, freq_shift)
+ self.add_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim)
+
+ elif addition_embed_type is not None:
+ raise ValueError(f"addition_embed_type: {addition_embed_type} must be None, 'text' or 'text_image'.")
+
+ # control net conditioning embedding
+ self.controlnet_cond_embedding = ControlNetConditioningEmbedding(
+ conditioning_embedding_channels=block_out_channels[0],
+ block_out_channels=conditioning_embedding_out_channels,
+ conditioning_channels=conditioning_channels,
+ )
+
+ self.down_blocks = nn.ModuleList([])
+ self.controlnet_down_blocks = nn.ModuleList([])
+
+ if isinstance(only_cross_attention, bool):
+ only_cross_attention = [only_cross_attention] * len(down_block_types)
+
+ if isinstance(attention_head_dim, int):
+ attention_head_dim = (attention_head_dim,) * len(down_block_types)
+
+ if isinstance(num_attention_heads, int):
+ num_attention_heads = (num_attention_heads,) * len(down_block_types)
+
+ # down
+ output_channel = block_out_channels[0]
+
+ controlnet_block = nn.Conv2d(output_channel, output_channel, kernel_size=1)
+ controlnet_block = zero_module(controlnet_block)
+ self.controlnet_down_blocks.append(controlnet_block)
+
+ for i, down_block_type in enumerate(down_block_types):
+ input_channel = output_channel
+ output_channel = block_out_channels[i]
+ is_final_block = i == len(block_out_channels) - 1
+
+ down_block = get_down_block(
+ down_block_type,
+ num_layers=layers_per_block,
+ transformer_layers_per_block=transformer_layers_per_block[i],
+ in_channels=input_channel,
+ out_channels=output_channel,
+ temb_channels=time_embed_dim,
+ add_downsample=not is_final_block,
+ resnet_eps=norm_eps,
+ resnet_act_fn=act_fn,
+ resnet_groups=norm_num_groups,
+ cross_attention_dim=cross_attention_dim,
+ num_attention_heads=num_attention_heads[i],
+ attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel,
+ downsample_padding=downsample_padding,
+ use_linear_projection=use_linear_projection,
+ only_cross_attention=only_cross_attention[i],
+ upcast_attention=upcast_attention,
+ resnet_time_scale_shift=resnet_time_scale_shift,
+ )
+ self.down_blocks.append(down_block)
+
+ for _ in range(layers_per_block):
+ controlnet_block = nn.Conv2d(output_channel, output_channel, kernel_size=1)
+ controlnet_block = zero_module(controlnet_block)
+ self.controlnet_down_blocks.append(controlnet_block)
+
+ if not is_final_block:
+ controlnet_block = nn.Conv2d(output_channel, output_channel, kernel_size=1)
+ controlnet_block = zero_module(controlnet_block)
+ self.controlnet_down_blocks.append(controlnet_block)
+
+ # mid
+ mid_block_channel = block_out_channels[-1]
+
+ controlnet_block = nn.Conv2d(mid_block_channel, mid_block_channel, kernel_size=1)
+ controlnet_block = zero_module(controlnet_block)
+ self.controlnet_mid_block = controlnet_block
+
+ if mid_block_type == "UNetMidBlock2DCrossAttn":
+ self.mid_block = UNetMidBlock2DCrossAttn(
+ transformer_layers_per_block=transformer_layers_per_block[-1],
+ in_channels=mid_block_channel,
+ temb_channels=time_embed_dim,
+ resnet_eps=norm_eps,
+ resnet_act_fn=act_fn,
+ output_scale_factor=mid_block_scale_factor,
+ resnet_time_scale_shift=resnet_time_scale_shift,
+ cross_attention_dim=cross_attention_dim,
+ num_attention_heads=num_attention_heads[-1],
+ resnet_groups=norm_num_groups,
+ use_linear_projection=use_linear_projection,
+ upcast_attention=upcast_attention,
+ )
+ elif mid_block_type == "UNetMidBlock2D":
+ self.mid_block = UNetMidBlock2D(
+ in_channels=block_out_channels[-1],
+ temb_channels=time_embed_dim,
+ num_layers=0,
+ resnet_eps=norm_eps,
+ resnet_act_fn=act_fn,
+ output_scale_factor=mid_block_scale_factor,
+ resnet_groups=norm_num_groups,
+ resnet_time_scale_shift=resnet_time_scale_shift,
+ add_attention=False,
+ )
+ else:
+ raise ValueError(f"unknown mid_block_type : {mid_block_type}")
+
+ @classmethod
+ def from_unet(
+ cls,
+ unet: UNet2DConditionModel,
+ controlnet_conditioning_channel_order: str = "rgb",
+ conditioning_embedding_out_channels: Optional[Tuple[int, ...]] = (16, 32, 96, 256),
+ load_weights_from_unet: bool = True,
+ conditioning_channels: int = 3,
+ ):
+ r"""
+ Instantiate a [`ControlNetModel`] from [`UNet2DConditionModel`].
+
+ Parameters:
+ unet (`UNet2DConditionModel`):
+ The UNet model weights to copy to the [`ControlNetModel`]. All configuration options are also copied
+ where applicable.
+ """
+ transformer_layers_per_block = (
+ unet.config.transformer_layers_per_block if "transformer_layers_per_block" in unet.config else 1
+ )
+ encoder_hid_dim = unet.config.encoder_hid_dim if "encoder_hid_dim" in unet.config else None
+ encoder_hid_dim_type = unet.config.encoder_hid_dim_type if "encoder_hid_dim_type" in unet.config else None
+ addition_embed_type = unet.config.addition_embed_type if "addition_embed_type" in unet.config else None
+ addition_time_embed_dim = (
+ unet.config.addition_time_embed_dim if "addition_time_embed_dim" in unet.config else None
+ )
+
+ controlnet = cls(
+ encoder_hid_dim=encoder_hid_dim,
+ encoder_hid_dim_type=encoder_hid_dim_type,
+ addition_embed_type=addition_embed_type,
+ addition_time_embed_dim=addition_time_embed_dim,
+ transformer_layers_per_block=transformer_layers_per_block,
+ in_channels=unet.config.in_channels,
+ flip_sin_to_cos=unet.config.flip_sin_to_cos,
+ freq_shift=unet.config.freq_shift,
+ down_block_types=unet.config.down_block_types,
+ only_cross_attention=unet.config.only_cross_attention,
+ block_out_channels=unet.config.block_out_channels,
+ layers_per_block=unet.config.layers_per_block,
+ downsample_padding=unet.config.downsample_padding,
+ mid_block_scale_factor=unet.config.mid_block_scale_factor,
+ act_fn=unet.config.act_fn,
+ norm_num_groups=unet.config.norm_num_groups,
+ norm_eps=unet.config.norm_eps,
+ cross_attention_dim=unet.config.cross_attention_dim,
+ attention_head_dim=unet.config.attention_head_dim,
+ num_attention_heads=unet.config.num_attention_heads,
+ use_linear_projection=unet.config.use_linear_projection,
+ class_embed_type=unet.config.class_embed_type,
+ num_class_embeds=unet.config.num_class_embeds,
+ upcast_attention=unet.config.upcast_attention,
+ resnet_time_scale_shift=unet.config.resnet_time_scale_shift,
+ projection_class_embeddings_input_dim=unet.config.projection_class_embeddings_input_dim,
+ mid_block_type=unet.config.mid_block_type,
+ controlnet_conditioning_channel_order=controlnet_conditioning_channel_order,
+ conditioning_embedding_out_channels=conditioning_embedding_out_channels,
+ conditioning_channels=conditioning_channels,
+ )
+
+ if load_weights_from_unet:
+ controlnet.conv_in.load_state_dict(unet.conv_in.state_dict())
+ controlnet.time_proj.load_state_dict(unet.time_proj.state_dict())
+ controlnet.time_embedding.load_state_dict(unet.time_embedding.state_dict())
+
+ if controlnet.class_embedding:
+ controlnet.class_embedding.load_state_dict(unet.class_embedding.state_dict())
+
+ if hasattr(controlnet, "add_embedding"):
+ controlnet.add_embedding.load_state_dict(unet.add_embedding.state_dict())
+
+ controlnet.down_blocks.load_state_dict(unet.down_blocks.state_dict())
+ controlnet.mid_block.load_state_dict(unet.mid_block.state_dict())
+
+ return controlnet
+
+ @property
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
+ def attn_processors(self) -> Dict[str, AttentionProcessor]:
+ r"""
+ Returns:
+ `dict` of attention processors: A dictionary containing all attention processors used in the model with
+ indexed by its weight name.
+ """
+ # set recursively
+ processors = {}
+
+ def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
+ if hasattr(module, "get_processor"):
+ processors[f"{name}.processor"] = module.get_processor()
+
+ for sub_name, child in module.named_children():
+ fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
+
+ return processors
+
+ for name, module in self.named_children():
+ fn_recursive_add_processors(name, module, processors)
+
+ return processors
+
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
+ def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
+ r"""
+ Sets the attention processor to use to compute attention.
+
+ Parameters:
+ processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
+ The instantiated processor class or a dictionary of processor classes that will be set as the processor
+ for **all** `Attention` layers.
+
+ If `processor` is a dict, the key needs to define the path to the corresponding cross attention
+ processor. This is strongly recommended when setting trainable attention processors.
+
+ """
+ count = len(self.attn_processors.keys())
+
+ if isinstance(processor, dict) and len(processor) != count:
+ raise ValueError(
+ f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
+ f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
+ )
+
+ def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
+ if hasattr(module, "set_processor"):
+ if not isinstance(processor, dict):
+ module.set_processor(processor)
+ else:
+ module.set_processor(processor.pop(f"{name}.processor"))
+
+ for sub_name, child in module.named_children():
+ fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
+
+ for name, module in self.named_children():
+ fn_recursive_attn_processor(name, module, processor)
+
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor
+ def set_default_attn_processor(self):
+ """
+ Disables custom attention processors and sets the default attention implementation.
+ """
+ if all(proc.__class__ in ADDED_KV_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
+ processor = AttnAddedKVProcessor()
+ elif all(proc.__class__ in CROSS_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
+ processor = AttnProcessor()
+ else:
+ raise ValueError(
+ f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}"
+ )
+
+ self.set_attn_processor(processor)
+
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attention_slice
+ def set_attention_slice(self, slice_size: Union[str, int, List[int]]) -> None:
+ r"""
+ Enable sliced attention computation.
+
+ When this option is enabled, the attention module splits the input tensor in slices to compute attention in
+ several steps. This is useful for saving some memory in exchange for a small decrease in speed.
+
+ Args:
+ slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`):
+ When `"auto"`, input to the attention heads is halved, so attention is computed in two steps. If
+ `"max"`, maximum amount of memory is saved by running only one slice at a time. If a number is
+ provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim`
+ must be a multiple of `slice_size`.
+ """
+ sliceable_head_dims = []
+
+ def fn_recursive_retrieve_sliceable_dims(module: torch.nn.Module):
+ if hasattr(module, "set_attention_slice"):
+ sliceable_head_dims.append(module.sliceable_head_dim)
+
+ for child in module.children():
+ fn_recursive_retrieve_sliceable_dims(child)
+
+ # retrieve number of attention layers
+ for module in self.children():
+ fn_recursive_retrieve_sliceable_dims(module)
+
+ num_sliceable_layers = len(sliceable_head_dims)
+
+ if slice_size == "auto":
+ # half the attention head size is usually a good trade-off between
+ # speed and memory
+ slice_size = [dim // 2 for dim in sliceable_head_dims]
+ elif slice_size == "max":
+ # make smallest slice possible
+ slice_size = num_sliceable_layers * [1]
+
+ slice_size = num_sliceable_layers * [slice_size] if not isinstance(slice_size, list) else slice_size
+
+ if len(slice_size) != len(sliceable_head_dims):
+ raise ValueError(
+ f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different"
+ f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}."
+ )
+
+ for i in range(len(slice_size)):
+ size = slice_size[i]
+ dim = sliceable_head_dims[i]
+ if size is not None and size > dim:
+ raise ValueError(f"size {size} has to be smaller or equal to {dim}.")
+
+ # Recursively walk through all the children.
+ # Any children which exposes the set_attention_slice method
+ # gets the message
+ def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[int]):
+ if hasattr(module, "set_attention_slice"):
+ module.set_attention_slice(slice_size.pop())
+
+ for child in module.children():
+ fn_recursive_set_attention_slice(child, slice_size)
+
+ reversed_slice_size = list(reversed(slice_size))
+ for module in self.children():
+ fn_recursive_set_attention_slice(module, reversed_slice_size)
+
+ def _set_gradient_checkpointing(self, module, value: bool = False) -> None:
+ if isinstance(module, (CrossAttnDownBlock2D, DownBlock2D)):
+ module.gradient_checkpointing = value
+
+ def forward(
+ self,
+ sample: torch.Tensor,
+ timestep: Union[torch.Tensor, float, int],
+ encoder_hidden_states: torch.Tensor,
+ controlnet_cond: torch.Tensor,
+ conditioning_scale: float = 1.0,
+ class_labels: Optional[torch.Tensor] = None,
+ timestep_cond: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
+ guess_mode: bool = False,
+ return_dict: bool = True,
+ ) -> Union[ControlNetOutput, Tuple[Tuple[torch.Tensor, ...], torch.Tensor]]:
+ """
+ The [`ControlNetModel`] forward method.
+
+ Args:
+ sample (`torch.Tensor`):
+ The noisy input tensor.
+ timestep (`Union[torch.Tensor, float, int]`):
+ The number of timesteps to denoise an input.
+ encoder_hidden_states (`torch.Tensor`):
+ The encoder hidden states.
+ controlnet_cond (`torch.Tensor`):
+ The conditional input tensor of shape `(batch_size, sequence_length, hidden_size)`.
+ conditioning_scale (`float`, defaults to `1.0`):
+ The scale factor for ControlNet outputs.
+ class_labels (`torch.Tensor`, *optional*, defaults to `None`):
+ Optional class labels for conditioning. Their embeddings will be summed with the timestep embeddings.
+ timestep_cond (`torch.Tensor`, *optional*, defaults to `None`):
+ Additional conditional embeddings for timestep. If provided, the embeddings will be summed with the
+ timestep_embedding passed through the `self.time_embedding` layer to obtain the final timestep
+ embeddings.
+ attention_mask (`torch.Tensor`, *optional*, defaults to `None`):
+ An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. If `1` the mask
+ is kept, otherwise if `0` it is discarded. Mask will be converted into a bias, which adds large
+ negative values to the attention scores corresponding to "discard" tokens.
+ added_cond_kwargs (`dict`):
+ Additional conditions for the Stable Diffusion XL UNet.
+ cross_attention_kwargs (`dict[str]`, *optional*, defaults to `None`):
+ A kwargs dictionary that if specified is passed along to the `AttnProcessor`.
+ guess_mode (`bool`, defaults to `False`):
+ In this mode, the ControlNet encoder tries its best to recognize the input content of the input even if
+ you remove all prompts. A `guidance_scale` between 3.0 and 5.0 is recommended.
+ return_dict (`bool`, defaults to `True`):
+ Whether or not to return a [`~models.controlnets.controlnet.ControlNetOutput`] instead of a plain
+ tuple.
+
+ Returns:
+ [`~models.controlnets.controlnet.ControlNetOutput`] **or** `tuple`:
+ If `return_dict` is `True`, a [`~models.controlnets.controlnet.ControlNetOutput`] is returned,
+ otherwise a tuple is returned where the first element is the sample tensor.
+ """
+ # check channel order
+ channel_order = self.config.controlnet_conditioning_channel_order
+
+ if channel_order == "rgb":
+ # in rgb order by default
+ ...
+ elif channel_order == "bgr":
+ controlnet_cond = torch.flip(controlnet_cond, dims=[1])
+ else:
+ raise ValueError(f"unknown `controlnet_conditioning_channel_order`: {channel_order}")
+
+ # prepare attention_mask
+ if attention_mask is not None:
+ attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0
+ attention_mask = attention_mask.unsqueeze(1)
+
+ # 1. time
+ timesteps = timestep
+ if not torch.is_tensor(timesteps):
+ # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
+ # This would be a good case for the `match` statement (Python 3.10+)
+ is_mps = sample.device.type == "mps"
+ if isinstance(timestep, float):
+ dtype = torch.float32 if is_mps else torch.float64
+ else:
+ dtype = torch.int32 if is_mps else torch.int64
+ timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
+ elif len(timesteps.shape) == 0:
+ timesteps = timesteps[None].to(sample.device)
+
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
+ timesteps = timesteps.expand(sample.shape[0])
+
+ t_emb = self.time_proj(timesteps)
+
+ # timesteps does not contain any weights and will always return f32 tensors
+ # but time_embedding might actually be running in fp16. so we need to cast here.
+ # there might be better ways to encapsulate this.
+ t_emb = t_emb.to(dtype=sample.dtype)
+
+ emb = self.time_embedding(t_emb, timestep_cond)
+ aug_emb = None
+
+ if self.class_embedding is not None:
+ if class_labels is None:
+ raise ValueError("class_labels should be provided when num_class_embeds > 0")
+
+ if self.config.class_embed_type == "timestep":
+ class_labels = self.time_proj(class_labels)
+
+ class_emb = self.class_embedding(class_labels).to(dtype=self.dtype)
+ emb = emb + class_emb
+
+ if self.config.addition_embed_type is not None:
+ if self.config.addition_embed_type == "text":
+ aug_emb = self.add_embedding(encoder_hidden_states)
+
+ elif self.config.addition_embed_type == "text_time":
+ if "text_embeds" not in added_cond_kwargs:
+ raise ValueError(
+ f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `text_embeds` to be passed in `added_cond_kwargs`"
+ )
+ text_embeds = added_cond_kwargs.get("text_embeds")
+ if "time_ids" not in added_cond_kwargs:
+ raise ValueError(
+ f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `time_ids` to be passed in `added_cond_kwargs`"
+ )
+ time_ids = added_cond_kwargs.get("time_ids")
+ time_embeds = self.add_time_proj(time_ids.flatten())
+ time_embeds = time_embeds.reshape((text_embeds.shape[0], -1))
+
+ add_embeds = torch.concat([text_embeds, time_embeds], dim=-1)
+ add_embeds = add_embeds.to(emb.dtype)
+ aug_emb = self.add_embedding(add_embeds)
+
+ emb = emb + aug_emb if aug_emb is not None else emb
+
+ # 2. pre-process
+ sample = self.conv_in(sample)
+
+ controlnet_cond = self.controlnet_cond_embedding(controlnet_cond)
+ sample = sample + controlnet_cond
+
+ # 3. down
+ down_block_res_samples = (sample,)
+ for downsample_block in self.down_blocks:
+ if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
+ sample, res_samples = downsample_block(
+ hidden_states=sample,
+ temb=emb,
+ encoder_hidden_states=encoder_hidden_states,
+ attention_mask=attention_mask,
+ cross_attention_kwargs=cross_attention_kwargs,
+ )
+ else:
+ sample, res_samples = downsample_block(hidden_states=sample, temb=emb)
+
+ down_block_res_samples += res_samples
+
+ # 4. mid
+ if self.mid_block is not None:
+ if hasattr(self.mid_block, "has_cross_attention") and self.mid_block.has_cross_attention:
+ sample = self.mid_block(
+ sample,
+ emb,
+ encoder_hidden_states=encoder_hidden_states,
+ attention_mask=attention_mask,
+ cross_attention_kwargs=cross_attention_kwargs,
+ )
+ else:
+ sample = self.mid_block(sample, emb)
+
+ # 5. Control net blocks
+
+ controlnet_down_block_res_samples = ()
+
+ for down_block_res_sample, controlnet_block in zip(down_block_res_samples, self.controlnet_down_blocks):
+ down_block_res_sample = controlnet_block(down_block_res_sample)
+ controlnet_down_block_res_samples = controlnet_down_block_res_samples + (down_block_res_sample,)
+
+ down_block_res_samples = controlnet_down_block_res_samples
+
+ mid_block_res_sample = self.controlnet_mid_block(sample)
+
+ # 6. scaling
+ if guess_mode and not self.config.global_pool_conditions:
+ scales = torch.logspace(-1, 0, len(down_block_res_samples) + 1, device=sample.device) # 0.1 to 1.0
+ scales = scales * conditioning_scale
+ down_block_res_samples = [sample * scale for sample, scale in zip(down_block_res_samples, scales)]
+ mid_block_res_sample = mid_block_res_sample * scales[-1] # last one
+ else:
+ down_block_res_samples = [sample * conditioning_scale for sample in down_block_res_samples]
+ mid_block_res_sample = mid_block_res_sample * conditioning_scale
+
+ if self.config.global_pool_conditions:
+ down_block_res_samples = [
+ torch.mean(sample, dim=(2, 3), keepdim=True) for sample in down_block_res_samples
+ ]
+ mid_block_res_sample = torch.mean(mid_block_res_sample, dim=(2, 3), keepdim=True)
+
+ if not return_dict:
+ return (down_block_res_samples, mid_block_res_sample)
+
+ return ControlNetOutput(
+ down_block_res_samples=down_block_res_samples, mid_block_res_sample=mid_block_res_sample
+ )
+
+
+def zero_module(module):
+ for p in module.parameters():
+ nn.init.zeros_(p)
+ return module
diff --git a/src/diffusers/models/controlnet_flax.py b/src/diffusers/models/controlnets/controlnet_flax.py
similarity index 98%
rename from src/diffusers/models/controlnet_flax.py
rename to src/diffusers/models/controlnets/controlnet_flax.py
index 0540850a9e61..ab8d9b5f8cbb 100644
--- a/src/diffusers/models/controlnet_flax.py
+++ b/src/diffusers/models/controlnets/controlnet_flax.py
@@ -19,11 +19,11 @@
import jax.numpy as jnp
from flax.core.frozen_dict import FrozenDict
-from ..configuration_utils import ConfigMixin, flax_register_to_config
-from ..utils import BaseOutput
-from .embeddings_flax import FlaxTimestepEmbedding, FlaxTimesteps
-from .modeling_flax_utils import FlaxModelMixin
-from .unets.unet_2d_blocks_flax import (
+from ...configuration_utils import ConfigMixin, flax_register_to_config
+from ...utils import BaseOutput
+from ..embeddings_flax import FlaxTimestepEmbedding, FlaxTimesteps
+from ..modeling_flax_utils import FlaxModelMixin
+from ..unets.unet_2d_blocks_flax import (
FlaxCrossAttnDownBlock2D,
FlaxDownBlock2D,
FlaxUNetMidBlock2DCrossAttn,
diff --git a/src/diffusers/models/controlnets/controlnet_flux.py b/src/diffusers/models/controlnets/controlnet_flux.py
new file mode 100644
index 000000000000..e6a3eceed9b4
--- /dev/null
+++ b/src/diffusers/models/controlnets/controlnet_flux.py
@@ -0,0 +1,536 @@
+# Copyright 2024 Black Forest Labs, The HuggingFace Team and The InstantX Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from dataclasses import dataclass
+from typing import Any, Dict, List, Optional, Tuple, Union
+
+import torch
+import torch.nn as nn
+
+from ...configuration_utils import ConfigMixin, register_to_config
+from ...loaders import PeftAdapterMixin
+from ...models.attention_processor import AttentionProcessor
+from ...models.modeling_utils import ModelMixin
+from ...utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers
+from ..controlnet import BaseOutput, ControlNetConditioningEmbedding, zero_module
+from ..embeddings import CombinedTimestepGuidanceTextProjEmbeddings, CombinedTimestepTextProjEmbeddings, FluxPosEmbed
+from ..modeling_outputs import Transformer2DModelOutput
+from ..transformers.transformer_flux import FluxSingleTransformerBlock, FluxTransformerBlock
+
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+
+@dataclass
+class FluxControlNetOutput(BaseOutput):
+ controlnet_block_samples: Tuple[torch.Tensor]
+ controlnet_single_block_samples: Tuple[torch.Tensor]
+
+
+class FluxControlNetModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
+ _supports_gradient_checkpointing = True
+
+ @register_to_config
+ def __init__(
+ self,
+ patch_size: int = 1,
+ in_channels: int = 64,
+ num_layers: int = 19,
+ num_single_layers: int = 38,
+ attention_head_dim: int = 128,
+ num_attention_heads: int = 24,
+ joint_attention_dim: int = 4096,
+ pooled_projection_dim: int = 768,
+ guidance_embeds: bool = False,
+ axes_dims_rope: List[int] = [16, 56, 56],
+ num_mode: int = None,
+ conditioning_embedding_channels: int = None,
+ ):
+ super().__init__()
+ self.out_channels = in_channels
+ self.inner_dim = num_attention_heads * attention_head_dim
+
+ self.pos_embed = FluxPosEmbed(theta=10000, axes_dim=axes_dims_rope)
+ text_time_guidance_cls = (
+ CombinedTimestepGuidanceTextProjEmbeddings if guidance_embeds else CombinedTimestepTextProjEmbeddings
+ )
+ self.time_text_embed = text_time_guidance_cls(
+ embedding_dim=self.inner_dim, pooled_projection_dim=pooled_projection_dim
+ )
+
+ self.context_embedder = nn.Linear(joint_attention_dim, self.inner_dim)
+ self.x_embedder = torch.nn.Linear(in_channels, self.inner_dim)
+
+ self.transformer_blocks = nn.ModuleList(
+ [
+ FluxTransformerBlock(
+ dim=self.inner_dim,
+ num_attention_heads=num_attention_heads,
+ attention_head_dim=attention_head_dim,
+ )
+ for i in range(num_layers)
+ ]
+ )
+
+ self.single_transformer_blocks = nn.ModuleList(
+ [
+ FluxSingleTransformerBlock(
+ dim=self.inner_dim,
+ num_attention_heads=num_attention_heads,
+ attention_head_dim=attention_head_dim,
+ )
+ for i in range(num_single_layers)
+ ]
+ )
+
+ # controlnet_blocks
+ self.controlnet_blocks = nn.ModuleList([])
+ for _ in range(len(self.transformer_blocks)):
+ self.controlnet_blocks.append(zero_module(nn.Linear(self.inner_dim, self.inner_dim)))
+
+ self.controlnet_single_blocks = nn.ModuleList([])
+ for _ in range(len(self.single_transformer_blocks)):
+ self.controlnet_single_blocks.append(zero_module(nn.Linear(self.inner_dim, self.inner_dim)))
+
+ self.union = num_mode is not None
+ if self.union:
+ self.controlnet_mode_embedder = nn.Embedding(num_mode, self.inner_dim)
+
+ if conditioning_embedding_channels is not None:
+ self.input_hint_block = ControlNetConditioningEmbedding(
+ conditioning_embedding_channels=conditioning_embedding_channels, block_out_channels=(16, 16, 16, 16)
+ )
+ self.controlnet_x_embedder = torch.nn.Linear(in_channels, self.inner_dim)
+ else:
+ self.input_hint_block = None
+ self.controlnet_x_embedder = zero_module(torch.nn.Linear(in_channels, self.inner_dim))
+
+ self.gradient_checkpointing = False
+
+ @property
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
+ def attn_processors(self):
+ r"""
+ Returns:
+ `dict` of attention processors: A dictionary containing all attention processors used in the model with
+ indexed by its weight name.
+ """
+ # set recursively
+ processors = {}
+
+ def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
+ if hasattr(module, "get_processor"):
+ processors[f"{name}.processor"] = module.get_processor()
+
+ for sub_name, child in module.named_children():
+ fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
+
+ return processors
+
+ for name, module in self.named_children():
+ fn_recursive_add_processors(name, module, processors)
+
+ return processors
+
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
+ def set_attn_processor(self, processor):
+ r"""
+ Sets the attention processor to use to compute attention.
+
+ Parameters:
+ processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
+ The instantiated processor class or a dictionary of processor classes that will be set as the processor
+ for **all** `Attention` layers.
+
+ If `processor` is a dict, the key needs to define the path to the corresponding cross attention
+ processor. This is strongly recommended when setting trainable attention processors.
+
+ """
+ count = len(self.attn_processors.keys())
+
+ if isinstance(processor, dict) and len(processor) != count:
+ raise ValueError(
+ f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
+ f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
+ )
+
+ def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
+ if hasattr(module, "set_processor"):
+ if not isinstance(processor, dict):
+ module.set_processor(processor)
+ else:
+ module.set_processor(processor.pop(f"{name}.processor"))
+
+ for sub_name, child in module.named_children():
+ fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
+
+ for name, module in self.named_children():
+ fn_recursive_attn_processor(name, module, processor)
+
+ def _set_gradient_checkpointing(self, module, value=False):
+ if hasattr(module, "gradient_checkpointing"):
+ module.gradient_checkpointing = value
+
+ @classmethod
+ def from_transformer(
+ cls,
+ transformer,
+ num_layers: int = 4,
+ num_single_layers: int = 10,
+ attention_head_dim: int = 128,
+ num_attention_heads: int = 24,
+ load_weights_from_transformer=True,
+ ):
+ config = transformer.config
+ config["num_layers"] = num_layers
+ config["num_single_layers"] = num_single_layers
+ config["attention_head_dim"] = attention_head_dim
+ config["num_attention_heads"] = num_attention_heads
+
+ controlnet = cls(**config)
+
+ if load_weights_from_transformer:
+ controlnet.pos_embed.load_state_dict(transformer.pos_embed.state_dict())
+ controlnet.time_text_embed.load_state_dict(transformer.time_text_embed.state_dict())
+ controlnet.context_embedder.load_state_dict(transformer.context_embedder.state_dict())
+ controlnet.x_embedder.load_state_dict(transformer.x_embedder.state_dict())
+ controlnet.transformer_blocks.load_state_dict(transformer.transformer_blocks.state_dict(), strict=False)
+ controlnet.single_transformer_blocks.load_state_dict(
+ transformer.single_transformer_blocks.state_dict(), strict=False
+ )
+
+ controlnet.controlnet_x_embedder = zero_module(controlnet.controlnet_x_embedder)
+
+ return controlnet
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ controlnet_cond: torch.Tensor,
+ controlnet_mode: torch.Tensor = None,
+ conditioning_scale: float = 1.0,
+ encoder_hidden_states: torch.Tensor = None,
+ pooled_projections: torch.Tensor = None,
+ timestep: torch.LongTensor = None,
+ img_ids: torch.Tensor = None,
+ txt_ids: torch.Tensor = None,
+ guidance: torch.Tensor = None,
+ joint_attention_kwargs: Optional[Dict[str, Any]] = None,
+ return_dict: bool = True,
+ ) -> Union[torch.FloatTensor, Transformer2DModelOutput]:
+ """
+ The [`FluxTransformer2DModel`] forward method.
+
+ Args:
+ hidden_states (`torch.FloatTensor` of shape `(batch size, channel, height, width)`):
+ Input `hidden_states`.
+ controlnet_cond (`torch.Tensor`):
+ The conditional input tensor of shape `(batch_size, sequence_length, hidden_size)`.
+ controlnet_mode (`torch.Tensor`):
+ The mode tensor of shape `(batch_size, 1)`.
+ conditioning_scale (`float`, defaults to `1.0`):
+ The scale factor for ControlNet outputs.
+ encoder_hidden_states (`torch.FloatTensor` of shape `(batch size, sequence_len, embed_dims)`):
+ Conditional embeddings (embeddings computed from the input conditions such as prompts) to use.
+ pooled_projections (`torch.FloatTensor` of shape `(batch_size, projection_dim)`): Embeddings projected
+ from the embeddings of input conditions.
+ timestep ( `torch.LongTensor`):
+ Used to indicate denoising step.
+ block_controlnet_hidden_states: (`list` of `torch.Tensor`):
+ A list of tensors that if specified are added to the residuals of transformer blocks.
+ joint_attention_kwargs (`dict`, *optional*):
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
+ `self.processor` in
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~models.transformer_2d.Transformer2DModelOutput`] instead of a plain
+ tuple.
+
+ Returns:
+ If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
+ `tuple` where the first element is the sample tensor.
+ """
+ if joint_attention_kwargs is not None:
+ joint_attention_kwargs = joint_attention_kwargs.copy()
+ lora_scale = joint_attention_kwargs.pop("scale", 1.0)
+ else:
+ lora_scale = 1.0
+
+ if USE_PEFT_BACKEND:
+ # weight the lora layers by setting `lora_scale` for each PEFT layer
+ scale_lora_layers(self, lora_scale)
+ else:
+ if joint_attention_kwargs is not None and joint_attention_kwargs.get("scale", None) is not None:
+ logger.warning(
+ "Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective."
+ )
+ hidden_states = self.x_embedder(hidden_states)
+
+ if self.input_hint_block is not None:
+ controlnet_cond = self.input_hint_block(controlnet_cond)
+ batch_size, channels, height_pw, width_pw = controlnet_cond.shape
+ height = height_pw // self.config.patch_size
+ width = width_pw // self.config.patch_size
+ controlnet_cond = controlnet_cond.reshape(
+ batch_size, channels, height, self.config.patch_size, width, self.config.patch_size
+ )
+ controlnet_cond = controlnet_cond.permute(0, 2, 4, 1, 3, 5)
+ controlnet_cond = controlnet_cond.reshape(batch_size, height * width, -1)
+ # add
+ hidden_states = hidden_states + self.controlnet_x_embedder(controlnet_cond)
+
+ timestep = timestep.to(hidden_states.dtype) * 1000
+ if guidance is not None:
+ guidance = guidance.to(hidden_states.dtype) * 1000
+ else:
+ guidance = None
+ temb = (
+ self.time_text_embed(timestep, pooled_projections)
+ if guidance is None
+ else self.time_text_embed(timestep, guidance, pooled_projections)
+ )
+ encoder_hidden_states = self.context_embedder(encoder_hidden_states)
+
+ if self.union:
+ # union mode
+ if controlnet_mode is None:
+ raise ValueError("`controlnet_mode` cannot be `None` when applying ControlNet-Union")
+ # union mode emb
+ controlnet_mode_emb = self.controlnet_mode_embedder(controlnet_mode)
+ encoder_hidden_states = torch.cat([controlnet_mode_emb, encoder_hidden_states], dim=1)
+ txt_ids = torch.cat([txt_ids[:1], txt_ids], dim=0)
+
+ if txt_ids.ndim == 3:
+ logger.warning(
+ "Passing `txt_ids` 3d torch.Tensor is deprecated."
+ "Please remove the batch dimension and pass it as a 2d torch Tensor"
+ )
+ txt_ids = txt_ids[0]
+ if img_ids.ndim == 3:
+ logger.warning(
+ "Passing `img_ids` 3d torch.Tensor is deprecated."
+ "Please remove the batch dimension and pass it as a 2d torch Tensor"
+ )
+ img_ids = img_ids[0]
+
+ ids = torch.cat((txt_ids, img_ids), dim=0)
+ image_rotary_emb = self.pos_embed(ids)
+
+ block_samples = ()
+ for index_block, block in enumerate(self.transformer_blocks):
+ if self.training and self.gradient_checkpointing:
+
+ def create_custom_forward(module, return_dict=None):
+ def custom_forward(*inputs):
+ if return_dict is not None:
+ return module(*inputs, return_dict=return_dict)
+ else:
+ return module(*inputs)
+
+ return custom_forward
+
+ ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
+ encoder_hidden_states, hidden_states = torch.utils.checkpoint.checkpoint(
+ create_custom_forward(block),
+ hidden_states,
+ encoder_hidden_states,
+ temb,
+ image_rotary_emb,
+ **ckpt_kwargs,
+ )
+
+ else:
+ encoder_hidden_states, hidden_states = block(
+ hidden_states=hidden_states,
+ encoder_hidden_states=encoder_hidden_states,
+ temb=temb,
+ image_rotary_emb=image_rotary_emb,
+ )
+ block_samples = block_samples + (hidden_states,)
+
+ hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
+
+ single_block_samples = ()
+ for index_block, block in enumerate(self.single_transformer_blocks):
+ if self.training and self.gradient_checkpointing:
+
+ def create_custom_forward(module, return_dict=None):
+ def custom_forward(*inputs):
+ if return_dict is not None:
+ return module(*inputs, return_dict=return_dict)
+ else:
+ return module(*inputs)
+
+ return custom_forward
+
+ ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
+ hidden_states = torch.utils.checkpoint.checkpoint(
+ create_custom_forward(block),
+ hidden_states,
+ temb,
+ image_rotary_emb,
+ **ckpt_kwargs,
+ )
+
+ else:
+ hidden_states = block(
+ hidden_states=hidden_states,
+ temb=temb,
+ image_rotary_emb=image_rotary_emb,
+ )
+ single_block_samples = single_block_samples + (hidden_states[:, encoder_hidden_states.shape[1] :],)
+
+ # controlnet block
+ controlnet_block_samples = ()
+ for block_sample, controlnet_block in zip(block_samples, self.controlnet_blocks):
+ block_sample = controlnet_block(block_sample)
+ controlnet_block_samples = controlnet_block_samples + (block_sample,)
+
+ controlnet_single_block_samples = ()
+ for single_block_sample, controlnet_block in zip(single_block_samples, self.controlnet_single_blocks):
+ single_block_sample = controlnet_block(single_block_sample)
+ controlnet_single_block_samples = controlnet_single_block_samples + (single_block_sample,)
+
+ # scaling
+ controlnet_block_samples = [sample * conditioning_scale for sample in controlnet_block_samples]
+ controlnet_single_block_samples = [sample * conditioning_scale for sample in controlnet_single_block_samples]
+
+ controlnet_block_samples = None if len(controlnet_block_samples) == 0 else controlnet_block_samples
+ controlnet_single_block_samples = (
+ None if len(controlnet_single_block_samples) == 0 else controlnet_single_block_samples
+ )
+
+ if USE_PEFT_BACKEND:
+ # remove `lora_scale` from each PEFT layer
+ unscale_lora_layers(self, lora_scale)
+
+ if not return_dict:
+ return (controlnet_block_samples, controlnet_single_block_samples)
+
+ return FluxControlNetOutput(
+ controlnet_block_samples=controlnet_block_samples,
+ controlnet_single_block_samples=controlnet_single_block_samples,
+ )
+
+
+class FluxMultiControlNetModel(ModelMixin):
+ r"""
+ `FluxMultiControlNetModel` wrapper class for Multi-FluxControlNetModel
+
+ This module is a wrapper for multiple instances of the `FluxControlNetModel`. The `forward()` API is designed to be
+ compatible with `FluxControlNetModel`.
+
+ Args:
+ controlnets (`List[FluxControlNetModel]`):
+ Provides additional conditioning to the unet during the denoising process. You must set multiple
+ `FluxControlNetModel` as a list.
+ """
+
+ def __init__(self, controlnets):
+ super().__init__()
+ self.nets = nn.ModuleList(controlnets)
+
+ def forward(
+ self,
+ hidden_states: torch.FloatTensor,
+ controlnet_cond: List[torch.tensor],
+ controlnet_mode: List[torch.tensor],
+ conditioning_scale: List[float],
+ encoder_hidden_states: torch.Tensor = None,
+ pooled_projections: torch.Tensor = None,
+ timestep: torch.LongTensor = None,
+ img_ids: torch.Tensor = None,
+ txt_ids: torch.Tensor = None,
+ guidance: torch.Tensor = None,
+ joint_attention_kwargs: Optional[Dict[str, Any]] = None,
+ return_dict: bool = True,
+ ) -> Union[FluxControlNetOutput, Tuple]:
+ # ControlNet-Union with multiple conditions
+ # only load one ControlNet for saving memories
+ if len(self.nets) == 1 and self.nets[0].union:
+ controlnet = self.nets[0]
+
+ for i, (image, mode, scale) in enumerate(zip(controlnet_cond, controlnet_mode, conditioning_scale)):
+ block_samples, single_block_samples = controlnet(
+ hidden_states=hidden_states,
+ controlnet_cond=image,
+ controlnet_mode=mode[:, None],
+ conditioning_scale=scale,
+ timestep=timestep,
+ guidance=guidance,
+ pooled_projections=pooled_projections,
+ encoder_hidden_states=encoder_hidden_states,
+ txt_ids=txt_ids,
+ img_ids=img_ids,
+ joint_attention_kwargs=joint_attention_kwargs,
+ return_dict=return_dict,
+ )
+
+ # merge samples
+ if i == 0:
+ control_block_samples = block_samples
+ control_single_block_samples = single_block_samples
+ else:
+ control_block_samples = [
+ control_block_sample + block_sample
+ for control_block_sample, block_sample in zip(control_block_samples, block_samples)
+ ]
+
+ control_single_block_samples = [
+ control_single_block_sample + block_sample
+ for control_single_block_sample, block_sample in zip(
+ control_single_block_samples, single_block_samples
+ )
+ ]
+
+ # Regular Multi-ControlNets
+ # load all ControlNets into memories
+ else:
+ for i, (image, mode, scale, controlnet) in enumerate(
+ zip(controlnet_cond, controlnet_mode, conditioning_scale, self.nets)
+ ):
+ block_samples, single_block_samples = controlnet(
+ hidden_states=hidden_states,
+ controlnet_cond=image,
+ controlnet_mode=mode[:, None],
+ conditioning_scale=scale,
+ timestep=timestep,
+ guidance=guidance,
+ pooled_projections=pooled_projections,
+ encoder_hidden_states=encoder_hidden_states,
+ txt_ids=txt_ids,
+ img_ids=img_ids,
+ joint_attention_kwargs=joint_attention_kwargs,
+ return_dict=return_dict,
+ )
+
+ # merge samples
+ if i == 0:
+ control_block_samples = block_samples
+ control_single_block_samples = single_block_samples
+ else:
+ if block_samples is not None and control_block_samples is not None:
+ control_block_samples = [
+ control_block_sample + block_sample
+ for control_block_sample, block_sample in zip(control_block_samples, block_samples)
+ ]
+ if single_block_samples is not None and control_single_block_samples is not None:
+ control_single_block_samples = [
+ control_single_block_sample + block_sample
+ for control_single_block_sample, block_sample in zip(
+ control_single_block_samples, single_block_samples
+ )
+ ]
+
+ return control_block_samples, control_single_block_samples
diff --git a/src/diffusers/models/controlnet_hunyuan.py b/src/diffusers/models/controlnets/controlnet_hunyuan.py
similarity index 98%
rename from src/diffusers/models/controlnet_hunyuan.py
rename to src/diffusers/models/controlnets/controlnet_hunyuan.py
index 4277d81d1cb9..f2aa34d2d056 100644
--- a/src/diffusers/models/controlnet_hunyuan.py
+++ b/src/diffusers/models/controlnets/controlnet_hunyuan.py
@@ -17,17 +17,17 @@
import torch
from torch import nn
-from ..configuration_utils import ConfigMixin, register_to_config
-from ..utils import logging
-from .attention_processor import AttentionProcessor
-from .controlnet import BaseOutput, Tuple, zero_module
-from .embeddings import (
+from ...configuration_utils import ConfigMixin, register_to_config
+from ...utils import logging
+from ..attention_processor import AttentionProcessor
+from ..embeddings import (
HunyuanCombinedTimestepTextSizeStyleEmbedding,
PatchEmbed,
PixArtAlphaTextProjection,
)
-from .modeling_utils import ModelMixin
-from .transformers.hunyuan_transformer_2d import HunyuanDiTBlock
+from ..modeling_utils import ModelMixin
+from ..transformers.hunyuan_transformer_2d import HunyuanDiTBlock
+from .controlnet import BaseOutput, Tuple, zero_module
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
diff --git a/src/diffusers/models/controlnets/controlnet_sd3.py b/src/diffusers/models/controlnets/controlnet_sd3.py
new file mode 100644
index 000000000000..911d65e03d88
--- /dev/null
+++ b/src/diffusers/models/controlnets/controlnet_sd3.py
@@ -0,0 +1,422 @@
+# Copyright 2024 Stability AI, The HuggingFace Team and The InstantX Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+from dataclasses import dataclass
+from typing import Any, Dict, List, Optional, Tuple, Union
+
+import torch
+import torch.nn as nn
+
+from ...configuration_utils import ConfigMixin, register_to_config
+from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
+from ...utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers
+from ..attention import JointTransformerBlock
+from ..attention_processor import Attention, AttentionProcessor, FusedJointAttnProcessor2_0
+from ..embeddings import CombinedTimestepTextProjEmbeddings, PatchEmbed
+from ..modeling_outputs import Transformer2DModelOutput
+from ..modeling_utils import ModelMixin
+from .controlnet import BaseOutput, zero_module
+
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+
+@dataclass
+class SD3ControlNetOutput(BaseOutput):
+ controlnet_block_samples: Tuple[torch.Tensor]
+
+
+class SD3ControlNetModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin):
+ _supports_gradient_checkpointing = True
+
+ @register_to_config
+ def __init__(
+ self,
+ sample_size: int = 128,
+ patch_size: int = 2,
+ in_channels: int = 16,
+ num_layers: int = 18,
+ attention_head_dim: int = 64,
+ num_attention_heads: int = 18,
+ joint_attention_dim: int = 4096,
+ caption_projection_dim: int = 1152,
+ pooled_projection_dim: int = 2048,
+ out_channels: int = 16,
+ pos_embed_max_size: int = 96,
+ extra_conditioning_channels: int = 0,
+ ):
+ super().__init__()
+ default_out_channels = in_channels
+ self.out_channels = out_channels if out_channels is not None else default_out_channels
+ self.inner_dim = num_attention_heads * attention_head_dim
+
+ self.pos_embed = PatchEmbed(
+ height=sample_size,
+ width=sample_size,
+ patch_size=patch_size,
+ in_channels=in_channels,
+ embed_dim=self.inner_dim,
+ pos_embed_max_size=pos_embed_max_size,
+ )
+ self.time_text_embed = CombinedTimestepTextProjEmbeddings(
+ embedding_dim=self.inner_dim, pooled_projection_dim=pooled_projection_dim
+ )
+ self.context_embedder = nn.Linear(joint_attention_dim, caption_projection_dim)
+
+ # `attention_head_dim` is doubled to account for the mixing.
+ # It needs to crafted when we get the actual checkpoints.
+ self.transformer_blocks = nn.ModuleList(
+ [
+ JointTransformerBlock(
+ dim=self.inner_dim,
+ num_attention_heads=num_attention_heads,
+ attention_head_dim=self.config.attention_head_dim,
+ context_pre_only=False,
+ )
+ for i in range(num_layers)
+ ]
+ )
+
+ # controlnet_blocks
+ self.controlnet_blocks = nn.ModuleList([])
+ for _ in range(len(self.transformer_blocks)):
+ controlnet_block = nn.Linear(self.inner_dim, self.inner_dim)
+ controlnet_block = zero_module(controlnet_block)
+ self.controlnet_blocks.append(controlnet_block)
+ pos_embed_input = PatchEmbed(
+ height=sample_size,
+ width=sample_size,
+ patch_size=patch_size,
+ in_channels=in_channels + extra_conditioning_channels,
+ embed_dim=self.inner_dim,
+ pos_embed_type=None,
+ )
+ self.pos_embed_input = zero_module(pos_embed_input)
+
+ self.gradient_checkpointing = False
+
+ # Copied from diffusers.models.unets.unet_3d_condition.UNet3DConditionModel.enable_forward_chunking
+ def enable_forward_chunking(self, chunk_size: Optional[int] = None, dim: int = 0) -> None:
+ """
+ Sets the attention processor to use [feed forward
+ chunking](https://huggingface.co/blog/reformer#2-chunked-feed-forward-layers).
+
+ Parameters:
+ chunk_size (`int`, *optional*):
+ The chunk size of the feed-forward layers. If not specified, will run feed-forward layer individually
+ over each tensor of dim=`dim`.
+ dim (`int`, *optional*, defaults to `0`):
+ The dimension over which the feed-forward computation should be chunked. Choose between dim=0 (batch)
+ or dim=1 (sequence length).
+ """
+ if dim not in [0, 1]:
+ raise ValueError(f"Make sure to set `dim` to either 0 or 1, not {dim}")
+
+ # By default chunk size is 1
+ chunk_size = chunk_size or 1
+
+ def fn_recursive_feed_forward(module: torch.nn.Module, chunk_size: int, dim: int):
+ if hasattr(module, "set_chunk_feed_forward"):
+ module.set_chunk_feed_forward(chunk_size=chunk_size, dim=dim)
+
+ for child in module.children():
+ fn_recursive_feed_forward(child, chunk_size, dim)
+
+ for module in self.children():
+ fn_recursive_feed_forward(module, chunk_size, dim)
+
+ @property
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
+ def attn_processors(self) -> Dict[str, AttentionProcessor]:
+ r"""
+ Returns:
+ `dict` of attention processors: A dictionary containing all attention processors used in the model with
+ indexed by its weight name.
+ """
+ # set recursively
+ processors = {}
+
+ def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
+ if hasattr(module, "get_processor"):
+ processors[f"{name}.processor"] = module.get_processor()
+
+ for sub_name, child in module.named_children():
+ fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
+
+ return processors
+
+ for name, module in self.named_children():
+ fn_recursive_add_processors(name, module, processors)
+
+ return processors
+
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
+ def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
+ r"""
+ Sets the attention processor to use to compute attention.
+
+ Parameters:
+ processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
+ The instantiated processor class or a dictionary of processor classes that will be set as the processor
+ for **all** `Attention` layers.
+
+ If `processor` is a dict, the key needs to define the path to the corresponding cross attention
+ processor. This is strongly recommended when setting trainable attention processors.
+
+ """
+ count = len(self.attn_processors.keys())
+
+ if isinstance(processor, dict) and len(processor) != count:
+ raise ValueError(
+ f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
+ f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
+ )
+
+ def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
+ if hasattr(module, "set_processor"):
+ if not isinstance(processor, dict):
+ module.set_processor(processor)
+ else:
+ module.set_processor(processor.pop(f"{name}.processor"))
+
+ for sub_name, child in module.named_children():
+ fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
+
+ for name, module in self.named_children():
+ fn_recursive_attn_processor(name, module, processor)
+
+ # Copied from diffusers.models.transformers.transformer_sd3.SD3Transformer2DModel.fuse_qkv_projections
+ def fuse_qkv_projections(self):
+ """
+ Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value)
+ are fused. For cross-attention modules, key and value projection matrices are fused.
+
+
+
+ This API is 🧪 experimental.
+
+
+ """
+ self.original_attn_processors = None
+
+ for _, attn_processor in self.attn_processors.items():
+ if "Added" in str(attn_processor.__class__.__name__):
+ raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.")
+
+ self.original_attn_processors = self.attn_processors
+
+ for module in self.modules():
+ if isinstance(module, Attention):
+ module.fuse_projections(fuse=True)
+
+ self.set_attn_processor(FusedJointAttnProcessor2_0())
+
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections
+ def unfuse_qkv_projections(self):
+ """Disables the fused QKV projection if enabled.
+
+
+
+ This API is 🧪 experimental.
+
+
+
+ """
+ if self.original_attn_processors is not None:
+ self.set_attn_processor(self.original_attn_processors)
+
+ def _set_gradient_checkpointing(self, module, value=False):
+ if hasattr(module, "gradient_checkpointing"):
+ module.gradient_checkpointing = value
+
+ @classmethod
+ def from_transformer(
+ cls, transformer, num_layers=12, num_extra_conditioning_channels=1, load_weights_from_transformer=True
+ ):
+ config = transformer.config
+ config["num_layers"] = num_layers or config.num_layers
+ config["extra_conditioning_channels"] = num_extra_conditioning_channels
+ controlnet = cls(**config)
+
+ if load_weights_from_transformer:
+ controlnet.pos_embed.load_state_dict(transformer.pos_embed.state_dict())
+ controlnet.time_text_embed.load_state_dict(transformer.time_text_embed.state_dict())
+ controlnet.context_embedder.load_state_dict(transformer.context_embedder.state_dict())
+ controlnet.transformer_blocks.load_state_dict(transformer.transformer_blocks.state_dict(), strict=False)
+
+ controlnet.pos_embed_input = zero_module(controlnet.pos_embed_input)
+
+ return controlnet
+
+ def forward(
+ self,
+ hidden_states: torch.FloatTensor,
+ controlnet_cond: torch.Tensor,
+ conditioning_scale: float = 1.0,
+ encoder_hidden_states: torch.FloatTensor = None,
+ pooled_projections: torch.FloatTensor = None,
+ timestep: torch.LongTensor = None,
+ joint_attention_kwargs: Optional[Dict[str, Any]] = None,
+ return_dict: bool = True,
+ ) -> Union[torch.FloatTensor, Transformer2DModelOutput]:
+ """
+ The [`SD3Transformer2DModel`] forward method.
+
+ Args:
+ hidden_states (`torch.FloatTensor` of shape `(batch size, channel, height, width)`):
+ Input `hidden_states`.
+ controlnet_cond (`torch.Tensor`):
+ The conditional input tensor of shape `(batch_size, sequence_length, hidden_size)`.
+ conditioning_scale (`float`, defaults to `1.0`):
+ The scale factor for ControlNet outputs.
+ encoder_hidden_states (`torch.FloatTensor` of shape `(batch size, sequence_len, embed_dims)`):
+ Conditional embeddings (embeddings computed from the input conditions such as prompts) to use.
+ pooled_projections (`torch.FloatTensor` of shape `(batch_size, projection_dim)`): Embeddings projected
+ from the embeddings of input conditions.
+ timestep ( `torch.LongTensor`):
+ Used to indicate denoising step.
+ joint_attention_kwargs (`dict`, *optional*):
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
+ `self.processor` in
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~models.transformer_2d.Transformer2DModelOutput`] instead of a plain
+ tuple.
+
+ Returns:
+ If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
+ `tuple` where the first element is the sample tensor.
+ """
+ if joint_attention_kwargs is not None:
+ joint_attention_kwargs = joint_attention_kwargs.copy()
+ lora_scale = joint_attention_kwargs.pop("scale", 1.0)
+ else:
+ lora_scale = 1.0
+
+ if USE_PEFT_BACKEND:
+ # weight the lora layers by setting `lora_scale` for each PEFT layer
+ scale_lora_layers(self, lora_scale)
+ else:
+ if joint_attention_kwargs is not None and joint_attention_kwargs.get("scale", None) is not None:
+ logger.warning(
+ "Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective."
+ )
+
+ hidden_states = self.pos_embed(hidden_states) # takes care of adding positional embeddings too.
+ temb = self.time_text_embed(timestep, pooled_projections)
+ encoder_hidden_states = self.context_embedder(encoder_hidden_states)
+
+ # add
+ hidden_states = hidden_states + self.pos_embed_input(controlnet_cond)
+
+ block_res_samples = ()
+
+ for block in self.transformer_blocks:
+ if self.training and self.gradient_checkpointing:
+
+ def create_custom_forward(module, return_dict=None):
+ def custom_forward(*inputs):
+ if return_dict is not None:
+ return module(*inputs, return_dict=return_dict)
+ else:
+ return module(*inputs)
+
+ return custom_forward
+
+ ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
+ encoder_hidden_states, hidden_states = torch.utils.checkpoint.checkpoint(
+ create_custom_forward(block),
+ hidden_states,
+ encoder_hidden_states,
+ temb,
+ **ckpt_kwargs,
+ )
+
+ else:
+ encoder_hidden_states, hidden_states = block(
+ hidden_states=hidden_states, encoder_hidden_states=encoder_hidden_states, temb=temb
+ )
+
+ block_res_samples = block_res_samples + (hidden_states,)
+
+ controlnet_block_res_samples = ()
+ for block_res_sample, controlnet_block in zip(block_res_samples, self.controlnet_blocks):
+ block_res_sample = controlnet_block(block_res_sample)
+ controlnet_block_res_samples = controlnet_block_res_samples + (block_res_sample,)
+
+ # 6. scaling
+ controlnet_block_res_samples = [sample * conditioning_scale for sample in controlnet_block_res_samples]
+
+ if USE_PEFT_BACKEND:
+ # remove `lora_scale` from each PEFT layer
+ unscale_lora_layers(self, lora_scale)
+
+ if not return_dict:
+ return (controlnet_block_res_samples,)
+
+ return SD3ControlNetOutput(controlnet_block_samples=controlnet_block_res_samples)
+
+
+class SD3MultiControlNetModel(ModelMixin):
+ r"""
+ `SD3ControlNetModel` wrapper class for Multi-SD3ControlNet
+
+ This module is a wrapper for multiple instances of the `SD3ControlNetModel`. The `forward()` API is designed to be
+ compatible with `SD3ControlNetModel`.
+
+ Args:
+ controlnets (`List[SD3ControlNetModel]`):
+ Provides additional conditioning to the unet during the denoising process. You must set multiple
+ `SD3ControlNetModel` as a list.
+ """
+
+ def __init__(self, controlnets):
+ super().__init__()
+ self.nets = nn.ModuleList(controlnets)
+
+ def forward(
+ self,
+ hidden_states: torch.FloatTensor,
+ controlnet_cond: List[torch.tensor],
+ conditioning_scale: List[float],
+ pooled_projections: torch.FloatTensor,
+ encoder_hidden_states: torch.FloatTensor = None,
+ timestep: torch.LongTensor = None,
+ joint_attention_kwargs: Optional[Dict[str, Any]] = None,
+ return_dict: bool = True,
+ ) -> Union[SD3ControlNetOutput, Tuple]:
+ for i, (image, scale, controlnet) in enumerate(zip(controlnet_cond, conditioning_scale, self.nets)):
+ block_samples = controlnet(
+ hidden_states=hidden_states,
+ timestep=timestep,
+ encoder_hidden_states=encoder_hidden_states,
+ pooled_projections=pooled_projections,
+ controlnet_cond=image,
+ conditioning_scale=scale,
+ joint_attention_kwargs=joint_attention_kwargs,
+ return_dict=return_dict,
+ )
+
+ # merge samples
+ if i == 0:
+ control_block_samples = block_samples
+ else:
+ control_block_samples = [
+ control_block_sample + block_sample
+ for control_block_sample, block_sample in zip(control_block_samples[0], block_samples[0])
+ ]
+ control_block_samples = (tuple(control_block_samples),)
+
+ return control_block_samples
diff --git a/src/diffusers/models/controlnets/controlnet_sparsectrl.py b/src/diffusers/models/controlnets/controlnet_sparsectrl.py
new file mode 100644
index 000000000000..fd599c10b2d7
--- /dev/null
+++ b/src/diffusers/models/controlnets/controlnet_sparsectrl.py
@@ -0,0 +1,788 @@
+# Copyright 2024 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from dataclasses import dataclass
+from typing import Any, Dict, List, Optional, Tuple, Union
+
+import torch
+from torch import nn
+from torch.nn import functional as F
+
+from ...configuration_utils import ConfigMixin, register_to_config
+from ...loaders import FromOriginalModelMixin
+from ...utils import BaseOutput, logging
+from ..attention_processor import (
+ ADDED_KV_ATTENTION_PROCESSORS,
+ CROSS_ATTENTION_PROCESSORS,
+ AttentionProcessor,
+ AttnAddedKVProcessor,
+ AttnProcessor,
+)
+from ..embeddings import TimestepEmbedding, Timesteps
+from ..modeling_utils import ModelMixin
+from ..unets.unet_2d_blocks import UNetMidBlock2DCrossAttn
+from ..unets.unet_2d_condition import UNet2DConditionModel
+from ..unets.unet_motion_model import CrossAttnDownBlockMotion, DownBlockMotion
+
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+
+@dataclass
+class SparseControlNetOutput(BaseOutput):
+ """
+ The output of [`SparseControlNetModel`].
+
+ Args:
+ down_block_res_samples (`tuple[torch.Tensor]`):
+ A tuple of downsample activations at different resolutions for each downsampling block. Each tensor should
+ be of shape `(batch_size, channel * resolution, height //resolution, width // resolution)`. Output can be
+ used to condition the original UNet's downsampling activations.
+ mid_down_block_re_sample (`torch.Tensor`):
+ The activation of the middle block (the lowest sample resolution). Each tensor should be of shape
+ `(batch_size, channel * lowest_resolution, height // lowest_resolution, width // lowest_resolution)`.
+ Output can be used to condition the original UNet's middle block activation.
+ """
+
+ down_block_res_samples: Tuple[torch.Tensor]
+ mid_block_res_sample: torch.Tensor
+
+
+class SparseControlNetConditioningEmbedding(nn.Module):
+ def __init__(
+ self,
+ conditioning_embedding_channels: int,
+ conditioning_channels: int = 3,
+ block_out_channels: Tuple[int, ...] = (16, 32, 96, 256),
+ ):
+ super().__init__()
+
+ self.conv_in = nn.Conv2d(conditioning_channels, block_out_channels[0], kernel_size=3, padding=1)
+ self.blocks = nn.ModuleList([])
+
+ for i in range(len(block_out_channels) - 1):
+ channel_in = block_out_channels[i]
+ channel_out = block_out_channels[i + 1]
+ self.blocks.append(nn.Conv2d(channel_in, channel_in, kernel_size=3, padding=1))
+ self.blocks.append(nn.Conv2d(channel_in, channel_out, kernel_size=3, padding=1, stride=2))
+
+ self.conv_out = zero_module(
+ nn.Conv2d(block_out_channels[-1], conditioning_embedding_channels, kernel_size=3, padding=1)
+ )
+
+ def forward(self, conditioning: torch.Tensor) -> torch.Tensor:
+ embedding = self.conv_in(conditioning)
+ embedding = F.silu(embedding)
+
+ for block in self.blocks:
+ embedding = block(embedding)
+ embedding = F.silu(embedding)
+
+ embedding = self.conv_out(embedding)
+ return embedding
+
+
+class SparseControlNetModel(ModelMixin, ConfigMixin, FromOriginalModelMixin):
+ """
+ A SparseControlNet model as described in [SparseCtrl: Adding Sparse Controls to Text-to-Video Diffusion
+ Models](https://arxiv.org/abs/2311.16933).
+
+ Args:
+ in_channels (`int`, defaults to 4):
+ The number of channels in the input sample.
+ conditioning_channels (`int`, defaults to 4):
+ The number of input channels in the controlnet conditional embedding module. If
+ `concat_condition_embedding` is True, the value provided here is incremented by 1.
+ flip_sin_to_cos (`bool`, defaults to `True`):
+ Whether to flip the sin to cos in the time embedding.
+ freq_shift (`int`, defaults to 0):
+ The frequency shift to apply to the time embedding.
+ down_block_types (`tuple[str]`, defaults to `("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")`):
+ The tuple of downsample blocks to use.
+ only_cross_attention (`Union[bool, Tuple[bool]]`, defaults to `False`):
+ block_out_channels (`tuple[int]`, defaults to `(320, 640, 1280, 1280)`):
+ The tuple of output channels for each block.
+ layers_per_block (`int`, defaults to 2):
+ The number of layers per block.
+ downsample_padding (`int`, defaults to 1):
+ The padding to use for the downsampling convolution.
+ mid_block_scale_factor (`float`, defaults to 1):
+ The scale factor to use for the mid block.
+ act_fn (`str`, defaults to "silu"):
+ The activation function to use.
+ norm_num_groups (`int`, *optional*, defaults to 32):
+ The number of groups to use for the normalization. If None, normalization and activation layers is skipped
+ in post-processing.
+ norm_eps (`float`, defaults to 1e-5):
+ The epsilon to use for the normalization.
+ cross_attention_dim (`int`, defaults to 1280):
+ The dimension of the cross attention features.
+ transformer_layers_per_block (`int` or `Tuple[int]`, *optional*, defaults to 1):
+ The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`]. Only relevant for
+ [`~models.unet_2d_blocks.CrossAttnDownBlock2D`], [`~models.unet_2d_blocks.CrossAttnUpBlock2D`],
+ [`~models.unet_2d_blocks.UNetMidBlock2DCrossAttn`].
+ transformer_layers_per_mid_block (`int` or `Tuple[int]`, *optional*, defaults to 1):
+ The number of transformer layers to use in each layer in the middle block.
+ attention_head_dim (`int` or `Tuple[int]`, defaults to 8):
+ The dimension of the attention heads.
+ num_attention_heads (`int` or `Tuple[int]`, *optional*):
+ The number of heads to use for multi-head attention.
+ use_linear_projection (`bool`, defaults to `False`):
+ upcast_attention (`bool`, defaults to `False`):
+ resnet_time_scale_shift (`str`, defaults to `"default"`):
+ Time scale shift config for ResNet blocks (see `ResnetBlock2D`). Choose from `default` or `scale_shift`.
+ conditioning_embedding_out_channels (`Tuple[int]`, defaults to `(16, 32, 96, 256)`):
+ The tuple of output channel for each block in the `conditioning_embedding` layer.
+ global_pool_conditions (`bool`, defaults to `False`):
+ TODO(Patrick) - unused parameter
+ controlnet_conditioning_channel_order (`str`, defaults to `rgb`):
+ motion_max_seq_length (`int`, defaults to `32`):
+ The maximum sequence length to use in the motion module.
+ motion_num_attention_heads (`int` or `Tuple[int]`, defaults to `8`):
+ The number of heads to use in each attention layer of the motion module.
+ concat_conditioning_mask (`bool`, defaults to `True`):
+ use_simplified_condition_embedding (`bool`, defaults to `True`):
+ """
+
+ _supports_gradient_checkpointing = True
+
+ @register_to_config
+ def __init__(
+ self,
+ in_channels: int = 4,
+ conditioning_channels: int = 4,
+ flip_sin_to_cos: bool = True,
+ freq_shift: int = 0,
+ down_block_types: Tuple[str, ...] = (
+ "CrossAttnDownBlockMotion",
+ "CrossAttnDownBlockMotion",
+ "CrossAttnDownBlockMotion",
+ "DownBlockMotion",
+ ),
+ only_cross_attention: Union[bool, Tuple[bool]] = False,
+ block_out_channels: Tuple[int, ...] = (320, 640, 1280, 1280),
+ layers_per_block: int = 2,
+ downsample_padding: int = 1,
+ mid_block_scale_factor: float = 1,
+ act_fn: str = "silu",
+ norm_num_groups: Optional[int] = 32,
+ norm_eps: float = 1e-5,
+ cross_attention_dim: int = 768,
+ transformer_layers_per_block: Union[int, Tuple[int, ...]] = 1,
+ transformer_layers_per_mid_block: Optional[Union[int, Tuple[int]]] = None,
+ temporal_transformer_layers_per_block: Union[int, Tuple[int, ...]] = 1,
+ attention_head_dim: Union[int, Tuple[int, ...]] = 8,
+ num_attention_heads: Optional[Union[int, Tuple[int, ...]]] = None,
+ use_linear_projection: bool = False,
+ upcast_attention: bool = False,
+ resnet_time_scale_shift: str = "default",
+ conditioning_embedding_out_channels: Optional[Tuple[int, ...]] = (16, 32, 96, 256),
+ global_pool_conditions: bool = False,
+ controlnet_conditioning_channel_order: str = "rgb",
+ motion_max_seq_length: int = 32,
+ motion_num_attention_heads: int = 8,
+ concat_conditioning_mask: bool = True,
+ use_simplified_condition_embedding: bool = True,
+ ):
+ super().__init__()
+ self.use_simplified_condition_embedding = use_simplified_condition_embedding
+
+ # If `num_attention_heads` is not defined (which is the case for most models)
+ # it will default to `attention_head_dim`. This looks weird upon first reading it and it is.
+ # The reason for this behavior is to correct for incorrectly named variables that were introduced
+ # when this library was created. The incorrect naming was only discovered much later in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131
+ # Changing `attention_head_dim` to `num_attention_heads` for 40,000+ configurations is too backwards breaking
+ # which is why we correct for the naming here.
+ num_attention_heads = num_attention_heads or attention_head_dim
+
+ # Check inputs
+ if len(block_out_channels) != len(down_block_types):
+ raise ValueError(
+ f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}."
+ )
+
+ if not isinstance(only_cross_attention, bool) and len(only_cross_attention) != len(down_block_types):
+ raise ValueError(
+ f"Must provide the same number of `only_cross_attention` as `down_block_types`. `only_cross_attention`: {only_cross_attention}. `down_block_types`: {down_block_types}."
+ )
+
+ if not isinstance(num_attention_heads, int) and len(num_attention_heads) != len(down_block_types):
+ raise ValueError(
+ f"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}."
+ )
+
+ if isinstance(transformer_layers_per_block, int):
+ transformer_layers_per_block = [transformer_layers_per_block] * len(down_block_types)
+ if isinstance(temporal_transformer_layers_per_block, int):
+ temporal_transformer_layers_per_block = [temporal_transformer_layers_per_block] * len(down_block_types)
+
+ # input
+ conv_in_kernel = 3
+ conv_in_padding = (conv_in_kernel - 1) // 2
+ self.conv_in = nn.Conv2d(
+ in_channels, block_out_channels[0], kernel_size=conv_in_kernel, padding=conv_in_padding
+ )
+
+ if concat_conditioning_mask:
+ conditioning_channels = conditioning_channels + 1
+
+ self.concat_conditioning_mask = concat_conditioning_mask
+
+ # control net conditioning embedding
+ if use_simplified_condition_embedding:
+ self.controlnet_cond_embedding = zero_module(
+ nn.Conv2d(conditioning_channels, block_out_channels[0], kernel_size=3, padding=1)
+ )
+ else:
+ self.controlnet_cond_embedding = SparseControlNetConditioningEmbedding(
+ conditioning_embedding_channels=block_out_channels[0],
+ block_out_channels=conditioning_embedding_out_channels,
+ conditioning_channels=conditioning_channels,
+ )
+
+ # time
+ time_embed_dim = block_out_channels[0] * 4
+ self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift)
+ timestep_input_dim = block_out_channels[0]
+
+ self.time_embedding = TimestepEmbedding(
+ timestep_input_dim,
+ time_embed_dim,
+ act_fn=act_fn,
+ )
+
+ self.down_blocks = nn.ModuleList([])
+ self.controlnet_down_blocks = nn.ModuleList([])
+
+ if isinstance(cross_attention_dim, int):
+ cross_attention_dim = (cross_attention_dim,) * len(down_block_types)
+
+ if isinstance(only_cross_attention, bool):
+ only_cross_attention = [only_cross_attention] * len(down_block_types)
+
+ if isinstance(attention_head_dim, int):
+ attention_head_dim = (attention_head_dim,) * len(down_block_types)
+
+ if isinstance(num_attention_heads, int):
+ num_attention_heads = (num_attention_heads,) * len(down_block_types)
+
+ if isinstance(motion_num_attention_heads, int):
+ motion_num_attention_heads = (motion_num_attention_heads,) * len(down_block_types)
+
+ # down
+ output_channel = block_out_channels[0]
+
+ controlnet_block = nn.Conv2d(output_channel, output_channel, kernel_size=1)
+ controlnet_block = zero_module(controlnet_block)
+ self.controlnet_down_blocks.append(controlnet_block)
+
+ for i, down_block_type in enumerate(down_block_types):
+ input_channel = output_channel
+ output_channel = block_out_channels[i]
+ is_final_block = i == len(block_out_channels) - 1
+
+ if down_block_type == "CrossAttnDownBlockMotion":
+ down_block = CrossAttnDownBlockMotion(
+ in_channels=input_channel,
+ out_channels=output_channel,
+ temb_channels=time_embed_dim,
+ dropout=0,
+ num_layers=layers_per_block,
+ transformer_layers_per_block=transformer_layers_per_block[i],
+ resnet_eps=norm_eps,
+ resnet_time_scale_shift=resnet_time_scale_shift,
+ resnet_act_fn=act_fn,
+ resnet_groups=norm_num_groups,
+ resnet_pre_norm=True,
+ num_attention_heads=num_attention_heads[i],
+ cross_attention_dim=cross_attention_dim[i],
+ add_downsample=not is_final_block,
+ dual_cross_attention=False,
+ use_linear_projection=use_linear_projection,
+ only_cross_attention=only_cross_attention[i],
+ upcast_attention=upcast_attention,
+ temporal_num_attention_heads=motion_num_attention_heads[i],
+ temporal_max_seq_length=motion_max_seq_length,
+ temporal_transformer_layers_per_block=temporal_transformer_layers_per_block[i],
+ temporal_double_self_attention=False,
+ )
+ elif down_block_type == "DownBlockMotion":
+ down_block = DownBlockMotion(
+ in_channels=input_channel,
+ out_channels=output_channel,
+ temb_channels=time_embed_dim,
+ dropout=0,
+ num_layers=layers_per_block,
+ resnet_eps=norm_eps,
+ resnet_time_scale_shift=resnet_time_scale_shift,
+ resnet_act_fn=act_fn,
+ resnet_groups=norm_num_groups,
+ resnet_pre_norm=True,
+ add_downsample=not is_final_block,
+ temporal_num_attention_heads=motion_num_attention_heads[i],
+ temporal_max_seq_length=motion_max_seq_length,
+ temporal_transformer_layers_per_block=temporal_transformer_layers_per_block[i],
+ temporal_double_self_attention=False,
+ )
+ else:
+ raise ValueError(
+ "Invalid `block_type` encountered. Must be one of `CrossAttnDownBlockMotion` or `DownBlockMotion`"
+ )
+
+ self.down_blocks.append(down_block)
+
+ for _ in range(layers_per_block):
+ controlnet_block = nn.Conv2d(output_channel, output_channel, kernel_size=1)
+ controlnet_block = zero_module(controlnet_block)
+ self.controlnet_down_blocks.append(controlnet_block)
+
+ if not is_final_block:
+ controlnet_block = nn.Conv2d(output_channel, output_channel, kernel_size=1)
+ controlnet_block = zero_module(controlnet_block)
+ self.controlnet_down_blocks.append(controlnet_block)
+
+ # mid
+ mid_block_channels = block_out_channels[-1]
+
+ controlnet_block = nn.Conv2d(mid_block_channels, mid_block_channels, kernel_size=1)
+ controlnet_block = zero_module(controlnet_block)
+ self.controlnet_mid_block = controlnet_block
+
+ if transformer_layers_per_mid_block is None:
+ transformer_layers_per_mid_block = (
+ transformer_layers_per_block[-1] if isinstance(transformer_layers_per_block[-1], int) else 1
+ )
+
+ self.mid_block = UNetMidBlock2DCrossAttn(
+ in_channels=mid_block_channels,
+ temb_channels=time_embed_dim,
+ dropout=0,
+ num_layers=1,
+ transformer_layers_per_block=transformer_layers_per_mid_block,
+ resnet_eps=norm_eps,
+ resnet_time_scale_shift=resnet_time_scale_shift,
+ resnet_act_fn=act_fn,
+ resnet_groups=norm_num_groups,
+ resnet_pre_norm=True,
+ num_attention_heads=num_attention_heads[-1],
+ output_scale_factor=mid_block_scale_factor,
+ cross_attention_dim=cross_attention_dim[-1],
+ dual_cross_attention=False,
+ use_linear_projection=use_linear_projection,
+ upcast_attention=upcast_attention,
+ attention_type="default",
+ )
+
+ @classmethod
+ def from_unet(
+ cls,
+ unet: UNet2DConditionModel,
+ controlnet_conditioning_channel_order: str = "rgb",
+ conditioning_embedding_out_channels: Optional[Tuple[int, ...]] = (16, 32, 96, 256),
+ load_weights_from_unet: bool = True,
+ conditioning_channels: int = 3,
+ ) -> "SparseControlNetModel":
+ r"""
+ Instantiate a [`SparseControlNetModel`] from [`UNet2DConditionModel`].
+
+ Parameters:
+ unet (`UNet2DConditionModel`):
+ The UNet model weights to copy to the [`SparseControlNetModel`]. All configuration options are also
+ copied where applicable.
+ """
+ transformer_layers_per_block = (
+ unet.config.transformer_layers_per_block if "transformer_layers_per_block" in unet.config else 1
+ )
+ down_block_types = unet.config.down_block_types
+
+ for i in range(len(down_block_types)):
+ if "CrossAttn" in down_block_types[i]:
+ down_block_types[i] = "CrossAttnDownBlockMotion"
+ elif "Down" in down_block_types[i]:
+ down_block_types[i] = "DownBlockMotion"
+ else:
+ raise ValueError("Invalid `block_type` encountered. Must be a cross-attention or down block")
+
+ controlnet = cls(
+ in_channels=unet.config.in_channels,
+ conditioning_channels=conditioning_channels,
+ flip_sin_to_cos=unet.config.flip_sin_to_cos,
+ freq_shift=unet.config.freq_shift,
+ down_block_types=unet.config.down_block_types,
+ only_cross_attention=unet.config.only_cross_attention,
+ block_out_channels=unet.config.block_out_channels,
+ layers_per_block=unet.config.layers_per_block,
+ downsample_padding=unet.config.downsample_padding,
+ mid_block_scale_factor=unet.config.mid_block_scale_factor,
+ act_fn=unet.config.act_fn,
+ norm_num_groups=unet.config.norm_num_groups,
+ norm_eps=unet.config.norm_eps,
+ cross_attention_dim=unet.config.cross_attention_dim,
+ transformer_layers_per_block=transformer_layers_per_block,
+ attention_head_dim=unet.config.attention_head_dim,
+ num_attention_heads=unet.config.num_attention_heads,
+ use_linear_projection=unet.config.use_linear_projection,
+ upcast_attention=unet.config.upcast_attention,
+ resnet_time_scale_shift=unet.config.resnet_time_scale_shift,
+ conditioning_embedding_out_channels=conditioning_embedding_out_channels,
+ controlnet_conditioning_channel_order=controlnet_conditioning_channel_order,
+ )
+
+ if load_weights_from_unet:
+ controlnet.conv_in.load_state_dict(unet.conv_in.state_dict(), strict=False)
+ controlnet.time_proj.load_state_dict(unet.time_proj.state_dict(), strict=False)
+ controlnet.time_embedding.load_state_dict(unet.time_embedding.state_dict(), strict=False)
+ controlnet.down_blocks.load_state_dict(unet.down_blocks.state_dict(), strict=False)
+ controlnet.mid_block.load_state_dict(unet.mid_block.state_dict(), strict=False)
+
+ return controlnet
+
+ @property
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
+ def attn_processors(self) -> Dict[str, AttentionProcessor]:
+ r"""
+ Returns:
+ `dict` of attention processors: A dictionary containing all attention processors used in the model with
+ indexed by its weight name.
+ """
+ # set recursively
+ processors = {}
+
+ def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
+ if hasattr(module, "get_processor"):
+ processors[f"{name}.processor"] = module.get_processor()
+
+ for sub_name, child in module.named_children():
+ fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
+
+ return processors
+
+ for name, module in self.named_children():
+ fn_recursive_add_processors(name, module, processors)
+
+ return processors
+
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
+ def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
+ r"""
+ Sets the attention processor to use to compute attention.
+
+ Parameters:
+ processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
+ The instantiated processor class or a dictionary of processor classes that will be set as the processor
+ for **all** `Attention` layers.
+
+ If `processor` is a dict, the key needs to define the path to the corresponding cross attention
+ processor. This is strongly recommended when setting trainable attention processors.
+
+ """
+ count = len(self.attn_processors.keys())
+
+ if isinstance(processor, dict) and len(processor) != count:
+ raise ValueError(
+ f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
+ f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
+ )
+
+ def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
+ if hasattr(module, "set_processor"):
+ if not isinstance(processor, dict):
+ module.set_processor(processor)
+ else:
+ module.set_processor(processor.pop(f"{name}.processor"))
+
+ for sub_name, child in module.named_children():
+ fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
+
+ for name, module in self.named_children():
+ fn_recursive_attn_processor(name, module, processor)
+
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor
+ def set_default_attn_processor(self):
+ """
+ Disables custom attention processors and sets the default attention implementation.
+ """
+ if all(proc.__class__ in ADDED_KV_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
+ processor = AttnAddedKVProcessor()
+ elif all(proc.__class__ in CROSS_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
+ processor = AttnProcessor()
+ else:
+ raise ValueError(
+ f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}"
+ )
+
+ self.set_attn_processor(processor)
+
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attention_slice
+ def set_attention_slice(self, slice_size: Union[str, int, List[int]]) -> None:
+ r"""
+ Enable sliced attention computation.
+
+ When this option is enabled, the attention module splits the input tensor in slices to compute attention in
+ several steps. This is useful for saving some memory in exchange for a small decrease in speed.
+
+ Args:
+ slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`):
+ When `"auto"`, input to the attention heads is halved, so attention is computed in two steps. If
+ `"max"`, maximum amount of memory is saved by running only one slice at a time. If a number is
+ provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim`
+ must be a multiple of `slice_size`.
+ """
+ sliceable_head_dims = []
+
+ def fn_recursive_retrieve_sliceable_dims(module: torch.nn.Module):
+ if hasattr(module, "set_attention_slice"):
+ sliceable_head_dims.append(module.sliceable_head_dim)
+
+ for child in module.children():
+ fn_recursive_retrieve_sliceable_dims(child)
+
+ # retrieve number of attention layers
+ for module in self.children():
+ fn_recursive_retrieve_sliceable_dims(module)
+
+ num_sliceable_layers = len(sliceable_head_dims)
+
+ if slice_size == "auto":
+ # half the attention head size is usually a good trade-off between
+ # speed and memory
+ slice_size = [dim // 2 for dim in sliceable_head_dims]
+ elif slice_size == "max":
+ # make smallest slice possible
+ slice_size = num_sliceable_layers * [1]
+
+ slice_size = num_sliceable_layers * [slice_size] if not isinstance(slice_size, list) else slice_size
+
+ if len(slice_size) != len(sliceable_head_dims):
+ raise ValueError(
+ f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different"
+ f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}."
+ )
+
+ for i in range(len(slice_size)):
+ size = slice_size[i]
+ dim = sliceable_head_dims[i]
+ if size is not None and size > dim:
+ raise ValueError(f"size {size} has to be smaller or equal to {dim}.")
+
+ # Recursively walk through all the children.
+ # Any children which exposes the set_attention_slice method
+ # gets the message
+ def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[int]):
+ if hasattr(module, "set_attention_slice"):
+ module.set_attention_slice(slice_size.pop())
+
+ for child in module.children():
+ fn_recursive_set_attention_slice(child, slice_size)
+
+ reversed_slice_size = list(reversed(slice_size))
+ for module in self.children():
+ fn_recursive_set_attention_slice(module, reversed_slice_size)
+
+ def _set_gradient_checkpointing(self, module, value: bool = False) -> None:
+ if isinstance(module, (CrossAttnDownBlockMotion, DownBlockMotion, UNetMidBlock2DCrossAttn)):
+ module.gradient_checkpointing = value
+
+ def forward(
+ self,
+ sample: torch.Tensor,
+ timestep: Union[torch.Tensor, float, int],
+ encoder_hidden_states: torch.Tensor,
+ controlnet_cond: torch.Tensor,
+ conditioning_scale: float = 1.0,
+ timestep_cond: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
+ conditioning_mask: Optional[torch.Tensor] = None,
+ guess_mode: bool = False,
+ return_dict: bool = True,
+ ) -> Union[SparseControlNetOutput, Tuple[Tuple[torch.Tensor, ...], torch.Tensor]]:
+ """
+ The [`SparseControlNetModel`] forward method.
+
+ Args:
+ sample (`torch.Tensor`):
+ The noisy input tensor.
+ timestep (`Union[torch.Tensor, float, int]`):
+ The number of timesteps to denoise an input.
+ encoder_hidden_states (`torch.Tensor`):
+ The encoder hidden states.
+ controlnet_cond (`torch.Tensor`):
+ The conditional input tensor of shape `(batch_size, sequence_length, hidden_size)`.
+ conditioning_scale (`float`, defaults to `1.0`):
+ The scale factor for ControlNet outputs.
+ class_labels (`torch.Tensor`, *optional*, defaults to `None`):
+ Optional class labels for conditioning. Their embeddings will be summed with the timestep embeddings.
+ timestep_cond (`torch.Tensor`, *optional*, defaults to `None`):
+ Additional conditional embeddings for timestep. If provided, the embeddings will be summed with the
+ timestep_embedding passed through the `self.time_embedding` layer to obtain the final timestep
+ embeddings.
+ attention_mask (`torch.Tensor`, *optional*, defaults to `None`):
+ An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. If `1` the mask
+ is kept, otherwise if `0` it is discarded. Mask will be converted into a bias, which adds large
+ negative values to the attention scores corresponding to "discard" tokens.
+ added_cond_kwargs (`dict`):
+ Additional conditions for the Stable Diffusion XL UNet.
+ cross_attention_kwargs (`dict[str]`, *optional*, defaults to `None`):
+ A kwargs dictionary that if specified is passed along to the `AttnProcessor`.
+ guess_mode (`bool`, defaults to `False`):
+ In this mode, the ControlNet encoder tries its best to recognize the input content of the input even if
+ you remove all prompts. A `guidance_scale` between 3.0 and 5.0 is recommended.
+ return_dict (`bool`, defaults to `True`):
+ Whether or not to return a [`~models.controlnet.ControlNetOutput`] instead of a plain tuple.
+ Returns:
+ [`~models.controlnet.ControlNetOutput`] **or** `tuple`:
+ If `return_dict` is `True`, a [`~models.controlnet.ControlNetOutput`] is returned, otherwise a tuple is
+ returned where the first element is the sample tensor.
+ """
+ sample_batch_size, sample_channels, sample_num_frames, sample_height, sample_width = sample.shape
+ sample = torch.zeros_like(sample)
+
+ # check channel order
+ channel_order = self.config.controlnet_conditioning_channel_order
+
+ if channel_order == "rgb":
+ # in rgb order by default
+ ...
+ elif channel_order == "bgr":
+ controlnet_cond = torch.flip(controlnet_cond, dims=[1])
+ else:
+ raise ValueError(f"unknown `controlnet_conditioning_channel_order`: {channel_order}")
+
+ # prepare attention_mask
+ if attention_mask is not None:
+ attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0
+ attention_mask = attention_mask.unsqueeze(1)
+
+ # 1. time
+ timesteps = timestep
+ if not torch.is_tensor(timesteps):
+ # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
+ # This would be a good case for the `match` statement (Python 3.10+)
+ is_mps = sample.device.type == "mps"
+ if isinstance(timestep, float):
+ dtype = torch.float32 if is_mps else torch.float64
+ else:
+ dtype = torch.int32 if is_mps else torch.int64
+ timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
+ elif len(timesteps.shape) == 0:
+ timesteps = timesteps[None].to(sample.device)
+
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
+ timesteps = timesteps.expand(sample.shape[0])
+
+ t_emb = self.time_proj(timesteps)
+
+ # timesteps does not contain any weights and will always return f32 tensors
+ # but time_embedding might actually be running in fp16. so we need to cast here.
+ # there might be better ways to encapsulate this.
+ t_emb = t_emb.to(dtype=sample.dtype)
+
+ emb = self.time_embedding(t_emb, timestep_cond)
+ emb = emb.repeat_interleave(sample_num_frames, dim=0)
+
+ # 2. pre-process
+ batch_size, channels, num_frames, height, width = sample.shape
+
+ sample = sample.permute(0, 2, 1, 3, 4).reshape(batch_size * num_frames, channels, height, width)
+ sample = self.conv_in(sample)
+
+ batch_frames, channels, height, width = sample.shape
+ sample = sample[:, None].reshape(sample_batch_size, sample_num_frames, channels, height, width)
+
+ if self.concat_conditioning_mask:
+ controlnet_cond = torch.cat([controlnet_cond, conditioning_mask], dim=1)
+
+ batch_size, channels, num_frames, height, width = controlnet_cond.shape
+ controlnet_cond = controlnet_cond.permute(0, 2, 1, 3, 4).reshape(
+ batch_size * num_frames, channels, height, width
+ )
+ controlnet_cond = self.controlnet_cond_embedding(controlnet_cond)
+ batch_frames, channels, height, width = controlnet_cond.shape
+ controlnet_cond = controlnet_cond[:, None].reshape(batch_size, num_frames, channels, height, width)
+
+ sample = sample + controlnet_cond
+
+ batch_size, num_frames, channels, height, width = sample.shape
+ sample = sample.reshape(sample_batch_size * sample_num_frames, channels, height, width)
+
+ # 3. down
+ down_block_res_samples = (sample,)
+ for downsample_block in self.down_blocks:
+ if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
+ sample, res_samples = downsample_block(
+ hidden_states=sample,
+ temb=emb,
+ encoder_hidden_states=encoder_hidden_states,
+ attention_mask=attention_mask,
+ num_frames=num_frames,
+ cross_attention_kwargs=cross_attention_kwargs,
+ )
+ else:
+ sample, res_samples = downsample_block(hidden_states=sample, temb=emb, num_frames=num_frames)
+
+ down_block_res_samples += res_samples
+
+ # 4. mid
+ if self.mid_block is not None:
+ if hasattr(self.mid_block, "has_cross_attention") and self.mid_block.has_cross_attention:
+ sample = self.mid_block(
+ sample,
+ emb,
+ encoder_hidden_states=encoder_hidden_states,
+ attention_mask=attention_mask,
+ cross_attention_kwargs=cross_attention_kwargs,
+ )
+ else:
+ sample = self.mid_block(sample, emb)
+
+ # 5. Control net blocks
+ controlnet_down_block_res_samples = ()
+
+ for down_block_res_sample, controlnet_block in zip(down_block_res_samples, self.controlnet_down_blocks):
+ down_block_res_sample = controlnet_block(down_block_res_sample)
+ controlnet_down_block_res_samples = controlnet_down_block_res_samples + (down_block_res_sample,)
+
+ down_block_res_samples = controlnet_down_block_res_samples
+ mid_block_res_sample = self.controlnet_mid_block(sample)
+
+ # 6. scaling
+ if guess_mode and not self.config.global_pool_conditions:
+ scales = torch.logspace(-1, 0, len(down_block_res_samples) + 1, device=sample.device) # 0.1 to 1.0
+ scales = scales * conditioning_scale
+ down_block_res_samples = [sample * scale for sample, scale in zip(down_block_res_samples, scales)]
+ mid_block_res_sample = mid_block_res_sample * scales[-1] # last one
+ else:
+ down_block_res_samples = [sample * conditioning_scale for sample in down_block_res_samples]
+ mid_block_res_sample = mid_block_res_sample * conditioning_scale
+
+ if self.config.global_pool_conditions:
+ down_block_res_samples = [
+ torch.mean(sample, dim=(2, 3), keepdim=True) for sample in down_block_res_samples
+ ]
+ mid_block_res_sample = torch.mean(mid_block_res_sample, dim=(2, 3), keepdim=True)
+
+ if not return_dict:
+ return (down_block_res_samples, mid_block_res_sample)
+
+ return SparseControlNetOutput(
+ down_block_res_samples=down_block_res_samples, mid_block_res_sample=mid_block_res_sample
+ )
+
+
+# Copied from diffusers.models.controlnets.controlnet.zero_module
+def zero_module(module: nn.Module) -> nn.Module:
+ for p in module.parameters():
+ nn.init.zeros_(p)
+ return module
diff --git a/src/diffusers/models/controlnet_xs.py b/src/diffusers/models/controlnets/controlnet_xs.py
similarity index 99%
rename from src/diffusers/models/controlnet_xs.py
rename to src/diffusers/models/controlnets/controlnet_xs.py
index f676a70f060a..06e0eda3c3b0 100644
--- a/src/diffusers/models/controlnet_xs.py
+++ b/src/diffusers/models/controlnets/controlnet_xs.py
@@ -19,10 +19,10 @@
import torch.utils.checkpoint
from torch import Tensor, nn
-from ..configuration_utils import ConfigMixin, register_to_config
-from ..utils import BaseOutput, is_torch_version, logging
-from ..utils.torch_utils import apply_freeu
-from .attention_processor import (
+from ...configuration_utils import ConfigMixin, register_to_config
+from ...utils import BaseOutput, is_torch_version, logging
+from ...utils.torch_utils import apply_freeu
+from ..attention_processor import (
ADDED_KV_ATTENTION_PROCESSORS,
CROSS_ATTENTION_PROCESSORS,
Attention,
@@ -31,10 +31,9 @@
AttnProcessor,
FusedAttnProcessor2_0,
)
-from .controlnet import ControlNetConditioningEmbedding
-from .embeddings import TimestepEmbedding, Timesteps
-from .modeling_utils import ModelMixin
-from .unets.unet_2d_blocks import (
+from ..embeddings import TimestepEmbedding, Timesteps
+from ..modeling_utils import ModelMixin
+from ..unets.unet_2d_blocks import (
CrossAttnDownBlock2D,
CrossAttnUpBlock2D,
Downsample2D,
@@ -43,7 +42,8 @@
UNetMidBlock2DCrossAttn,
Upsample2D,
)
-from .unets.unet_2d_condition import UNet2DConditionModel
+from ..unets.unet_2d_condition import UNet2DConditionModel
+from .controlnet import ControlNetConditioningEmbedding
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
@@ -1062,7 +1062,8 @@ def forward(
added_cond_kwargs (`dict`):
Additional conditions for the Stable Diffusion XL UNet.
return_dict (`bool`, defaults to `True`):
- Whether or not to return a [`~models.controlnet.ControlNetOutput`] instead of a plain tuple.
+ Whether or not to return a [`~models.controlnets.controlnet.ControlNetOutput`] instead of a plain
+ tuple.
apply_control (`bool`, defaults to `True`):
If `False`, the input is run only through the base model.
diff --git a/src/diffusers/models/controlnets/multicontrolnet.py b/src/diffusers/models/controlnets/multicontrolnet.py
new file mode 100644
index 000000000000..46c3d1681cc1
--- /dev/null
+++ b/src/diffusers/models/controlnets/multicontrolnet.py
@@ -0,0 +1,183 @@
+import os
+from typing import Any, Callable, Dict, List, Optional, Tuple, Union
+
+import torch
+from torch import nn
+
+from ...models.controlnets.controlnet import ControlNetModel, ControlNetOutput
+from ...models.modeling_utils import ModelMixin
+from ...utils import logging
+
+
+logger = logging.get_logger(__name__)
+
+
+class MultiControlNetModel(ModelMixin):
+ r"""
+ Multiple `ControlNetModel` wrapper class for Multi-ControlNet
+
+ This module is a wrapper for multiple instances of the `ControlNetModel`. The `forward()` API is designed to be
+ compatible with `ControlNetModel`.
+
+ Args:
+ controlnets (`List[ControlNetModel]`):
+ Provides additional conditioning to the unet during the denoising process. You must set multiple
+ `ControlNetModel` as a list.
+ """
+
+ def __init__(self, controlnets: Union[List[ControlNetModel], Tuple[ControlNetModel]]):
+ super().__init__()
+ self.nets = nn.ModuleList(controlnets)
+
+ def forward(
+ self,
+ sample: torch.Tensor,
+ timestep: Union[torch.Tensor, float, int],
+ encoder_hidden_states: torch.Tensor,
+ controlnet_cond: List[torch.tensor],
+ conditioning_scale: List[float],
+ class_labels: Optional[torch.Tensor] = None,
+ timestep_cond: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
+ guess_mode: bool = False,
+ return_dict: bool = True,
+ ) -> Union[ControlNetOutput, Tuple]:
+ for i, (image, scale, controlnet) in enumerate(zip(controlnet_cond, conditioning_scale, self.nets)):
+ down_samples, mid_sample = controlnet(
+ sample=sample,
+ timestep=timestep,
+ encoder_hidden_states=encoder_hidden_states,
+ controlnet_cond=image,
+ conditioning_scale=scale,
+ class_labels=class_labels,
+ timestep_cond=timestep_cond,
+ attention_mask=attention_mask,
+ added_cond_kwargs=added_cond_kwargs,
+ cross_attention_kwargs=cross_attention_kwargs,
+ guess_mode=guess_mode,
+ return_dict=return_dict,
+ )
+
+ # merge samples
+ if i == 0:
+ down_block_res_samples, mid_block_res_sample = down_samples, mid_sample
+ else:
+ down_block_res_samples = [
+ samples_prev + samples_curr
+ for samples_prev, samples_curr in zip(down_block_res_samples, down_samples)
+ ]
+ mid_block_res_sample += mid_sample
+
+ return down_block_res_samples, mid_block_res_sample
+
+ def save_pretrained(
+ self,
+ save_directory: Union[str, os.PathLike],
+ is_main_process: bool = True,
+ save_function: Callable = None,
+ safe_serialization: bool = True,
+ variant: Optional[str] = None,
+ ):
+ """
+ Save a model and its configuration file to a directory, so that it can be re-loaded using the
+ `[`~pipelines.controlnet.MultiControlNetModel.from_pretrained`]` class method.
+
+ Arguments:
+ save_directory (`str` or `os.PathLike`):
+ Directory to which to save. Will be created if it doesn't exist.
+ is_main_process (`bool`, *optional*, defaults to `True`):
+ Whether the process calling this is the main process or not. Useful when in distributed training like
+ TPUs and need to call this function on all processes. In this case, set `is_main_process=True` only on
+ the main process to avoid race conditions.
+ save_function (`Callable`):
+ The function to use to save the state dictionary. Useful on distributed training like TPUs when one
+ need to replace `torch.save` by another method. Can be configured with the environment variable
+ `DIFFUSERS_SAVE_MODE`.
+ safe_serialization (`bool`, *optional*, defaults to `True`):
+ Whether to save the model using `safetensors` or the traditional PyTorch way (that uses `pickle`).
+ variant (`str`, *optional*):
+ If specified, weights are saved in the format pytorch_model..bin.
+ """
+ for idx, controlnet in enumerate(self.nets):
+ suffix = "" if idx == 0 else f"_{idx}"
+ controlnet.save_pretrained(
+ save_directory + suffix,
+ is_main_process=is_main_process,
+ save_function=save_function,
+ safe_serialization=safe_serialization,
+ variant=variant,
+ )
+
+ @classmethod
+ def from_pretrained(cls, pretrained_model_path: Optional[Union[str, os.PathLike]], **kwargs):
+ r"""
+ Instantiate a pretrained MultiControlNet model from multiple pre-trained controlnet models.
+
+ The model is set in evaluation mode by default using `model.eval()` (Dropout modules are deactivated). To train
+ the model, you should first set it back in training mode with `model.train()`.
+
+ The warning *Weights from XXX not initialized from pretrained model* means that the weights of XXX do not come
+ pretrained with the rest of the model. It is up to you to train those weights with a downstream fine-tuning
+ task.
+
+ The warning *Weights from XXX not used in YYY* means that the layer XXX is not used by YYY, therefore those
+ weights are discarded.
+
+ Parameters:
+ pretrained_model_path (`os.PathLike`):
+ A path to a *directory* containing model weights saved using
+ [`~diffusers.pipelines.controlnet.MultiControlNetModel.save_pretrained`], e.g.,
+ `./my_model_directory/controlnet`.
+ torch_dtype (`str` or `torch.dtype`, *optional*):
+ Override the default `torch.dtype` and load the model under this dtype. If `"auto"` is passed the dtype
+ will be automatically derived from the model's weights.
+ output_loading_info(`bool`, *optional*, defaults to `False`):
+ Whether or not to also return a dictionary containing missing keys, unexpected keys and error messages.
+ device_map (`str` or `Dict[str, Union[int, str, torch.device]]`, *optional*):
+ A map that specifies where each submodule should go. It doesn't need to be refined to each
+ parameter/buffer name, once a given module name is inside, every submodule of it will be sent to the
+ same device.
+
+ To have Accelerate compute the most optimized `device_map` automatically, set `device_map="auto"`. For
+ more information about each option see [designing a device
+ map](https://hf.co/docs/accelerate/main/en/usage_guides/big_modeling#designing-a-device-map).
+ max_memory (`Dict`, *optional*):
+ A dictionary device identifier to maximum memory. Will default to the maximum memory available for each
+ GPU and the available CPU RAM if unset.
+ low_cpu_mem_usage (`bool`, *optional*, defaults to `True` if torch version >= 1.9.0 else `False`):
+ Speed up model loading by not initializing the weights and only loading the pre-trained weights. This
+ also tries to not use more than 1x model size in CPU memory (including peak memory) while loading the
+ model. This is only supported when torch version >= 1.9.0. If you are using an older version of torch,
+ setting this argument to `True` will raise an error.
+ variant (`str`, *optional*):
+ If specified load weights from `variant` filename, *e.g.* pytorch_model..bin. `variant` is
+ ignored when using `from_flax`.
+ use_safetensors (`bool`, *optional*, defaults to `None`):
+ If set to `None`, the `safetensors` weights will be downloaded if they're available **and** if the
+ `safetensors` library is installed. If set to `True`, the model will be forcibly loaded from
+ `safetensors` weights. If set to `False`, loading will *not* use `safetensors`.
+ """
+ idx = 0
+ controlnets = []
+
+ # load controlnet and append to list until no controlnet directory exists anymore
+ # first controlnet has to be saved under `./mydirectory/controlnet` to be compliant with `DiffusionPipeline.from_prertained`
+ # second, third, ... controlnets have to be saved under `./mydirectory/controlnet_1`, `./mydirectory/controlnet_2`, ...
+ model_path_to_load = pretrained_model_path
+ while os.path.isdir(model_path_to_load):
+ controlnet = ControlNetModel.from_pretrained(model_path_to_load, **kwargs)
+ controlnets.append(controlnet)
+
+ idx += 1
+ model_path_to_load = pretrained_model_path + f"_{idx}"
+
+ logger.info(f"{len(controlnets)} controlnets loaded from {pretrained_model_path}.")
+
+ if len(controlnets) == 0:
+ raise ValueError(
+ f"No ControlNets found under {os.path.dirname(pretrained_model_path)}. Expected at least {pretrained_model_path + '_0'}."
+ )
+
+ return cls(controlnets)
diff --git a/src/diffusers/pipelines/animatediff/pipeline_animatediff_sparsectrl.py b/src/diffusers/pipelines/animatediff/pipeline_animatediff_sparsectrl.py
index 8b037cdc34fb..6dde7d6686ee 100644
--- a/src/diffusers/pipelines/animatediff/pipeline_animatediff_sparsectrl.py
+++ b/src/diffusers/pipelines/animatediff/pipeline_animatediff_sparsectrl.py
@@ -24,7 +24,7 @@
from ...image_processor import PipelineImageInput, VaeImageProcessor
from ...loaders import IPAdapterMixin, StableDiffusionLoraLoaderMixin, TextualInversionLoaderMixin
from ...models import AutoencoderKL, ImageProjection, UNet2DConditionModel, UNetMotionModel
-from ...models.controlnet_sparsectrl import SparseControlNetModel
+from ...models.controlnets.controlnet_sparsectrl import SparseControlNetModel
from ...models.lora import adjust_lora_scale_text_encoder
from ...models.unets.unet_motion_model import MotionAdapter
from ...schedulers import KarrasDiffusionSchedulers
diff --git a/src/diffusers/pipelines/controlnet/multicontrolnet.py b/src/diffusers/pipelines/controlnet/multicontrolnet.py
index e3c5ec6eed03..33790c10e064 100644
--- a/src/diffusers/pipelines/controlnet/multicontrolnet.py
+++ b/src/diffusers/pipelines/controlnet/multicontrolnet.py
@@ -1,183 +1,12 @@
-import os
-from typing import Any, Callable, Dict, List, Optional, Tuple, Union
-
-import torch
-from torch import nn
-
-from ...models.controlnet import ControlNetModel, ControlNetOutput
-from ...models.modeling_utils import ModelMixin
-from ...utils import logging
+from ...models.controlnets.multicontrolnet import MultiControlNetModel
+from ...utils import deprecate, logging
logger = logging.get_logger(__name__)
-class MultiControlNetModel(ModelMixin):
- r"""
- Multiple `ControlNetModel` wrapper class for Multi-ControlNet
-
- This module is a wrapper for multiple instances of the `ControlNetModel`. The `forward()` API is designed to be
- compatible with `ControlNetModel`.
-
- Args:
- controlnets (`List[ControlNetModel]`):
- Provides additional conditioning to the unet during the denoising process. You must set multiple
- `ControlNetModel` as a list.
- """
-
- def __init__(self, controlnets: Union[List[ControlNetModel], Tuple[ControlNetModel]]):
- super().__init__()
- self.nets = nn.ModuleList(controlnets)
-
- def forward(
- self,
- sample: torch.Tensor,
- timestep: Union[torch.Tensor, float, int],
- encoder_hidden_states: torch.Tensor,
- controlnet_cond: List[torch.tensor],
- conditioning_scale: List[float],
- class_labels: Optional[torch.Tensor] = None,
- timestep_cond: Optional[torch.Tensor] = None,
- attention_mask: Optional[torch.Tensor] = None,
- added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,
- cross_attention_kwargs: Optional[Dict[str, Any]] = None,
- guess_mode: bool = False,
- return_dict: bool = True,
- ) -> Union[ControlNetOutput, Tuple]:
- for i, (image, scale, controlnet) in enumerate(zip(controlnet_cond, conditioning_scale, self.nets)):
- down_samples, mid_sample = controlnet(
- sample=sample,
- timestep=timestep,
- encoder_hidden_states=encoder_hidden_states,
- controlnet_cond=image,
- conditioning_scale=scale,
- class_labels=class_labels,
- timestep_cond=timestep_cond,
- attention_mask=attention_mask,
- added_cond_kwargs=added_cond_kwargs,
- cross_attention_kwargs=cross_attention_kwargs,
- guess_mode=guess_mode,
- return_dict=return_dict,
- )
-
- # merge samples
- if i == 0:
- down_block_res_samples, mid_block_res_sample = down_samples, mid_sample
- else:
- down_block_res_samples = [
- samples_prev + samples_curr
- for samples_prev, samples_curr in zip(down_block_res_samples, down_samples)
- ]
- mid_block_res_sample += mid_sample
-
- return down_block_res_samples, mid_block_res_sample
-
- def save_pretrained(
- self,
- save_directory: Union[str, os.PathLike],
- is_main_process: bool = True,
- save_function: Callable = None,
- safe_serialization: bool = True,
- variant: Optional[str] = None,
- ):
- """
- Save a model and its configuration file to a directory, so that it can be re-loaded using the
- `[`~pipelines.controlnet.MultiControlNetModel.from_pretrained`]` class method.
-
- Arguments:
- save_directory (`str` or `os.PathLike`):
- Directory to which to save. Will be created if it doesn't exist.
- is_main_process (`bool`, *optional*, defaults to `True`):
- Whether the process calling this is the main process or not. Useful when in distributed training like
- TPUs and need to call this function on all processes. In this case, set `is_main_process=True` only on
- the main process to avoid race conditions.
- save_function (`Callable`):
- The function to use to save the state dictionary. Useful on distributed training like TPUs when one
- need to replace `torch.save` by another method. Can be configured with the environment variable
- `DIFFUSERS_SAVE_MODE`.
- safe_serialization (`bool`, *optional*, defaults to `True`):
- Whether to save the model using `safetensors` or the traditional PyTorch way (that uses `pickle`).
- variant (`str`, *optional*):
- If specified, weights are saved in the format pytorch_model..bin.
- """
- for idx, controlnet in enumerate(self.nets):
- suffix = "" if idx == 0 else f"_{idx}"
- controlnet.save_pretrained(
- save_directory + suffix,
- is_main_process=is_main_process,
- save_function=save_function,
- safe_serialization=safe_serialization,
- variant=variant,
- )
-
- @classmethod
- def from_pretrained(cls, pretrained_model_path: Optional[Union[str, os.PathLike]], **kwargs):
- r"""
- Instantiate a pretrained MultiControlNet model from multiple pre-trained controlnet models.
-
- The model is set in evaluation mode by default using `model.eval()` (Dropout modules are deactivated). To train
- the model, you should first set it back in training mode with `model.train()`.
-
- The warning *Weights from XXX not initialized from pretrained model* means that the weights of XXX do not come
- pretrained with the rest of the model. It is up to you to train those weights with a downstream fine-tuning
- task.
-
- The warning *Weights from XXX not used in YYY* means that the layer XXX is not used by YYY, therefore those
- weights are discarded.
-
- Parameters:
- pretrained_model_path (`os.PathLike`):
- A path to a *directory* containing model weights saved using
- [`~diffusers.pipelines.controlnet.MultiControlNetModel.save_pretrained`], e.g.,
- `./my_model_directory/controlnet`.
- torch_dtype (`str` or `torch.dtype`, *optional*):
- Override the default `torch.dtype` and load the model under this dtype. If `"auto"` is passed the dtype
- will be automatically derived from the model's weights.
- output_loading_info(`bool`, *optional*, defaults to `False`):
- Whether or not to also return a dictionary containing missing keys, unexpected keys and error messages.
- device_map (`str` or `Dict[str, Union[int, str, torch.device]]`, *optional*):
- A map that specifies where each submodule should go. It doesn't need to be refined to each
- parameter/buffer name, once a given module name is inside, every submodule of it will be sent to the
- same device.
-
- To have Accelerate compute the most optimized `device_map` automatically, set `device_map="auto"`. For
- more information about each option see [designing a device
- map](https://hf.co/docs/accelerate/main/en/usage_guides/big_modeling#designing-a-device-map).
- max_memory (`Dict`, *optional*):
- A dictionary device identifier to maximum memory. Will default to the maximum memory available for each
- GPU and the available CPU RAM if unset.
- low_cpu_mem_usage (`bool`, *optional*, defaults to `True` if torch version >= 1.9.0 else `False`):
- Speed up model loading by not initializing the weights and only loading the pre-trained weights. This
- also tries to not use more than 1x model size in CPU memory (including peak memory) while loading the
- model. This is only supported when torch version >= 1.9.0. If you are using an older version of torch,
- setting this argument to `True` will raise an error.
- variant (`str`, *optional*):
- If specified load weights from `variant` filename, *e.g.* pytorch_model..bin. `variant` is
- ignored when using `from_flax`.
- use_safetensors (`bool`, *optional*, defaults to `None`):
- If set to `None`, the `safetensors` weights will be downloaded if they're available **and** if the
- `safetensors` library is installed. If set to `True`, the model will be forcibly loaded from
- `safetensors` weights. If set to `False`, loading will *not* use `safetensors`.
- """
- idx = 0
- controlnets = []
-
- # load controlnet and append to list until no controlnet directory exists anymore
- # first controlnet has to be saved under `./mydirectory/controlnet` to be compliant with `DiffusionPipeline.from_prertained`
- # second, third, ... controlnets have to be saved under `./mydirectory/controlnet_1`, `./mydirectory/controlnet_2`, ...
- model_path_to_load = pretrained_model_path
- while os.path.isdir(model_path_to_load):
- controlnet = ControlNetModel.from_pretrained(model_path_to_load, **kwargs)
- controlnets.append(controlnet)
-
- idx += 1
- model_path_to_load = pretrained_model_path + f"_{idx}"
-
- logger.info(f"{len(controlnets)} controlnets loaded from {pretrained_model_path}.")
-
- if len(controlnets) == 0:
- raise ValueError(
- f"No ControlNets found under {os.path.dirname(pretrained_model_path)}. Expected at least {pretrained_model_path + '_0'}."
- )
-
- return cls(controlnets)
+class MultiControlNetModel(MultiControlNetModel):
+ def __init__(self, *args, **kwargs):
+ deprecation_message = "Importing `MultiControlNetModel` from `diffusers.pipelines.controlnet.multicontrolnet` is deprecated and this will be removed in a future version. Please use `from diffusers.models.controlnets.multicontrolnet import MultiControlNetModel`, instead."
+ deprecate("MultiControlNetModel", "0.34", deprecation_message)
+ super().__init__(*args, **kwargs)
diff --git a/src/diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py b/src/diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py
index 9f674d2d7897..a589821c1f98 100644
--- a/src/diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py
+++ b/src/diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py
@@ -26,7 +26,7 @@
from ...image_processor import PipelineImageInput, VaeImageProcessor
from ...loaders import FromSingleFileMixin, SD3LoraLoaderMixin
from ...models.autoencoders import AutoencoderKL
-from ...models.controlnet_sd3 import SD3ControlNetModel, SD3MultiControlNetModel
+from ...models.controlnets.controlnet_sd3 import SD3ControlNetModel, SD3MultiControlNetModel
from ...models.transformers import SD3Transformer2DModel
from ...schedulers import FlowMatchEulerDiscreteScheduler
from ...utils import (
diff --git a/src/diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet_inpainting.py b/src/diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet_inpainting.py
index f362c8f3d0c1..437bb9f2f182 100644
--- a/src/diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet_inpainting.py
+++ b/src/diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet_inpainting.py
@@ -26,7 +26,7 @@
from ...image_processor import PipelineImageInput, VaeImageProcessor
from ...loaders import FromSingleFileMixin, SD3LoraLoaderMixin
from ...models.autoencoders import AutoencoderKL
-from ...models.controlnet_sd3 import SD3ControlNetModel, SD3MultiControlNetModel
+from ...models.controlnets.controlnet_sd3 import SD3ControlNetModel, SD3MultiControlNetModel
from ...models.transformers import SD3Transformer2DModel
from ...schedulers import FlowMatchEulerDiscreteScheduler
from ...utils import (
diff --git a/src/diffusers/pipelines/flux/pipeline_flux_controlnet.py b/src/diffusers/pipelines/flux/pipeline_flux_controlnet.py
index 9f33e26013d5..771150b085d5 100644
--- a/src/diffusers/pipelines/flux/pipeline_flux_controlnet.py
+++ b/src/diffusers/pipelines/flux/pipeline_flux_controlnet.py
@@ -27,7 +27,7 @@
from ...image_processor import PipelineImageInput, VaeImageProcessor
from ...loaders import FluxLoraLoaderMixin, FromSingleFileMixin, TextualInversionLoaderMixin
from ...models.autoencoders import AutoencoderKL
-from ...models.controlnet_flux import FluxControlNetModel, FluxMultiControlNetModel
+from ...models.controlnets.controlnet_flux import FluxControlNetModel, FluxMultiControlNetModel
from ...models.transformers import FluxTransformer2DModel
from ...schedulers import FlowMatchEulerDiscreteScheduler
from ...utils import (
diff --git a/src/diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py b/src/diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py
index 810c970ab715..04582b71d780 100644
--- a/src/diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py
+++ b/src/diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py
@@ -13,7 +13,7 @@
from ...image_processor import PipelineImageInput, VaeImageProcessor
from ...loaders import FluxLoraLoaderMixin, FromSingleFileMixin, TextualInversionLoaderMixin
from ...models.autoencoders import AutoencoderKL
-from ...models.controlnet_flux import FluxControlNetModel, FluxMultiControlNetModel
+from ...models.controlnets.controlnet_flux import FluxControlNetModel, FluxMultiControlNetModel
from ...models.transformers import FluxTransformer2DModel
from ...schedulers import FlowMatchEulerDiscreteScheduler
from ...utils import (
diff --git a/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py b/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py
index 1f5f83561f1c..947e97e272f8 100644
--- a/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py
+++ b/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py
@@ -14,7 +14,7 @@
from ...image_processor import PipelineImageInput, VaeImageProcessor
from ...loaders import FluxLoraLoaderMixin, FromSingleFileMixin, TextualInversionLoaderMixin
from ...models.autoencoders import AutoencoderKL
-from ...models.controlnet_flux import FluxControlNetModel, FluxMultiControlNetModel
+from ...models.controlnets.controlnet_flux import FluxControlNetModel, FluxMultiControlNetModel
from ...models.transformers import FluxTransformer2DModel
from ...schedulers import FlowMatchEulerDiscreteScheduler
from ...utils import (
diff --git a/tests/pipelines/test_pipelines_common.py b/tests/pipelines/test_pipelines_common.py
index 295a94c1d2e4..12f31aec678b 100644
--- a/tests/pipelines/test_pipelines_common.py
+++ b/tests/pipelines/test_pipelines_common.py
@@ -31,7 +31,7 @@
from diffusers.image_processor import VaeImageProcessor
from diffusers.loaders import IPAdapterMixin
from diffusers.models.attention_processor import AttnProcessor
-from diffusers.models.controlnet_xs import UNetControlNetXSModel
+from diffusers.models.controlnets.controlnet_xs import UNetControlNetXSModel
from diffusers.models.unets.unet_3d_condition import UNet3DConditionModel
from diffusers.models.unets.unet_i2vgen_xl import I2VGenXLUNet
from diffusers.models.unets.unet_motion_model import UNetMotionModel