From fb33e6ce9417fda5ba882008de00995d2eee50a7 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Tue, 2 Jul 2024 15:55:17 +0530 Subject: [PATCH 01/21] move vae flax module. --- src/diffusers/__init__.py | 4 ++-- src/diffusers/models/__init__.py | 6 +++--- src/diffusers/models/autoencoders/__init__.py | 5 +++++ src/diffusers/models/{ => autoencoders}/vae_flax.py | 6 +++--- 4 files changed, 13 insertions(+), 8 deletions(-) rename src/diffusers/models/{ => autoencoders}/vae_flax.py (99%) diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index 6f80cab0f357..93dd519d577a 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -422,10 +422,10 @@ else: + _import_structure["models.autoencoders.vae_flax"] = ["FlaxAutoencoderKL"] _import_structure["models.controlnet_flax"] = ["FlaxControlNetModel"] _import_structure["models.modeling_flax_utils"] = ["FlaxModelMixin"] _import_structure["models.unets.unet_2d_condition_flax"] = ["FlaxUNet2DConditionModel"] - _import_structure["models.vae_flax"] = ["FlaxAutoencoderKL"] _import_structure["pipelines"].extend(["FlaxDiffusionPipeline"]) _import_structure["schedulers"].extend( [ @@ -791,10 +791,10 @@ except OptionalDependencyNotAvailable: from .utils.dummy_flax_objects import * # noqa F403 else: + from .models.autoencoders.vae_flax import FlaxAutoencoderKL from .models.controlnet_flax import FlaxControlNetModel from .models.modeling_flax_utils import FlaxModelMixin from .models.unets.unet_2d_condition_flax import FlaxUNet2DConditionModel - from .models.vae_flax import FlaxAutoencoderKL from .pipelines import FlaxDiffusionPipeline from .schedulers import ( FlaxDDIMScheduler, diff --git a/src/diffusers/models/__init__.py b/src/diffusers/models/__init__.py index f3fda596aa71..9e2a998b2c04 100644 --- a/src/diffusers/models/__init__.py +++ b/src/diffusers/models/__init__.py @@ -59,10 +59,10 @@ _import_structure["unets.uvit_2d"] = ["UVit2DModel"] if is_flax_available(): + _import_structure["autoencoders.vae_flax"] = ["FlaxAutoencoderKL"] + _import_structure["controlnet_flax"] = ["FlaxControlNetModel"] _import_structure["unets.unet_2d_condition_flax"] = ["FlaxUNet2DConditionModel"] - _import_structure["vae_flax"] = ["FlaxAutoencoderKL"] - if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: if is_torch_available(): @@ -107,9 +107,9 @@ ) if is_flax_available(): + from .autoencoders import FlaxAutoencoderKL from .controlnet_flax import FlaxControlNetModel from .unets import FlaxUNet2DConditionModel - from .vae_flax import FlaxAutoencoderKL else: import sys diff --git a/src/diffusers/models/autoencoders/__init__.py b/src/diffusers/models/autoencoders/__init__.py index 5c47748d62e0..7d6eac31bc28 100644 --- a/src/diffusers/models/autoencoders/__init__.py +++ b/src/diffusers/models/autoencoders/__init__.py @@ -1,6 +1,11 @@ +from ...utils import is_flax_available from .autoencoder_asym_kl import AsymmetricAutoencoderKL from .autoencoder_kl import AutoencoderKL from .autoencoder_kl_temporal_decoder import AutoencoderKLTemporalDecoder from .autoencoder_tiny import AutoencoderTiny from .consistency_decoder_vae import ConsistencyDecoderVAE from .vq_model import VQModel + + +if is_flax_available(): + from .vae_flax import FlaxAutoencoderKL diff --git a/src/diffusers/models/vae_flax.py b/src/diffusers/models/autoencoders/vae_flax.py similarity index 99% rename from src/diffusers/models/vae_flax.py rename to src/diffusers/models/autoencoders/vae_flax.py index 5027f4230e3b..abee7e55ef36 100644 --- a/src/diffusers/models/vae_flax.py +++ b/src/diffusers/models/autoencoders/vae_flax.py @@ -24,9 +24,9 @@ import jax.numpy as jnp from flax.core.frozen_dict import FrozenDict -from ..configuration_utils import ConfigMixin, flax_register_to_config -from ..utils import BaseOutput -from .modeling_flax_utils import FlaxModelMixin +from ...configuration_utils import ConfigMixin, flax_register_to_config +from ...utils import BaseOutput +from ..modeling_flax_utils import FlaxModelMixin @flax.struct.dataclass From 8cbe71b5393111577f39aee674a08e075330b5f8 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Tue, 2 Jul 2024 15:59:33 +0530 Subject: [PATCH 02/21] controlnet module. --- .../promptdiffusioncontrolnet.py | 2 +- src/diffusers/__init__.py | 2 +- src/diffusers/models/__init__.py | 10 +++++----- .../models/{ => controlnets}/controlnet.py | 16 ++++++++-------- .../{ => controlnets}/controlnet_flax.py | 10 +++++----- .../{ => controlnets}/controlnet_hunyuan.py | 14 +++++++------- .../models/{ => controlnets}/controlnet_sd3.py | 16 ++++++++-------- .../models/{ => controlnets}/controlnet_xs.py | 18 +++++++++--------- .../pipelines/controlnet/multicontrolnet.py | 2 +- .../pipeline_stable_diffusion_3_controlnet.py | 2 +- tests/pipelines/test_pipelines_common.py | 2 +- 11 files changed, 47 insertions(+), 47 deletions(-) rename src/diffusers/models/{ => controlnets}/controlnet.py (98%) rename src/diffusers/models/{ => controlnets}/controlnet_flax.py (98%) rename src/diffusers/models/{ => controlnets}/controlnet_hunyuan.py (98%) rename src/diffusers/models/{ => controlnets}/controlnet_sd3.py (96%) rename src/diffusers/models/{ => controlnets}/controlnet_xs.py (99%) diff --git a/examples/research_projects/promptdiffusion/promptdiffusioncontrolnet.py b/examples/research_projects/promptdiffusion/promptdiffusioncontrolnet.py index 46cabd863dfa..0e1eb9dcde49 100644 --- a/examples/research_projects/promptdiffusion/promptdiffusioncontrolnet.py +++ b/examples/research_projects/promptdiffusion/promptdiffusioncontrolnet.py @@ -16,7 +16,7 @@ import torch from diffusers.configuration_utils import register_to_config -from diffusers.models.controlnet import ( +from diffusers.models.controlnets.controlnet import ( ControlNetConditioningEmbedding, ControlNetModel, ControlNetOutput, diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index 93dd519d577a..063f1f09f7ef 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -792,7 +792,7 @@ from .utils.dummy_flax_objects import * # noqa F403 else: from .models.autoencoders.vae_flax import FlaxAutoencoderKL - from .models.controlnet_flax import FlaxControlNetModel + from .models.controlnets.controlnet_flax import FlaxControlNetModel from .models.modeling_flax_utils import FlaxModelMixin from .models.unets.unet_2d_condition_flax import FlaxUNet2DConditionModel from .pipelines import FlaxDiffusionPipeline diff --git a/src/diffusers/models/__init__.py b/src/diffusers/models/__init__.py index 9e2a998b2c04..b45496edadff 100644 --- a/src/diffusers/models/__init__.py +++ b/src/diffusers/models/__init__.py @@ -75,10 +75,10 @@ ConsistencyDecoderVAE, VQModel, ) - from .controlnet import ControlNetModel - from .controlnet_hunyuan import HunyuanDiT2DControlNetModel, HunyuanDiT2DMultiControlNetModel - from .controlnet_sd3 import SD3ControlNetModel, SD3MultiControlNetModel - from .controlnet_xs import ControlNetXSAdapter, UNetControlNetXSModel + from .controlnets.controlnet import ControlNetModel + from .controlnets.controlnet_hunyuan import HunyuanDiT2DControlNetModel, HunyuanDiT2DMultiControlNetModel + from .controlnets.controlnet_sd3 import SD3ControlNetModel, SD3MultiControlNetModel + from .controlnets.controlnet_xs import ControlNetXSAdapter, UNetControlNetXSModel from .embeddings import ImageProjection from .modeling_utils import ModelMixin from .transformers import ( @@ -108,7 +108,7 @@ if is_flax_available(): from .autoencoders import FlaxAutoencoderKL - from .controlnet_flax import FlaxControlNetModel + from .controlnets.controlnet_flax import FlaxControlNetModel from .unets import FlaxUNet2DConditionModel else: diff --git a/src/diffusers/models/controlnet.py b/src/diffusers/models/controlnets/controlnet.py similarity index 98% rename from src/diffusers/models/controlnet.py rename to src/diffusers/models/controlnets/controlnet.py index 8fb49d7b547f..9f390affc430 100644 --- a/src/diffusers/models/controlnet.py +++ b/src/diffusers/models/controlnets/controlnet.py @@ -18,26 +18,26 @@ from torch import nn from torch.nn import functional as F -from ..configuration_utils import ConfigMixin, register_to_config -from ..loaders.single_file_model import FromOriginalModelMixin -from ..utils import BaseOutput, logging -from .attention_processor import ( +from ...configuration_utils import ConfigMixin, register_to_config +from ...loaders.single_file_model import FromOriginalModelMixin +from ...utils import BaseOutput, logging +from ..attention_processor import ( ADDED_KV_ATTENTION_PROCESSORS, CROSS_ATTENTION_PROCESSORS, AttentionProcessor, AttnAddedKVProcessor, AttnProcessor, ) -from .embeddings import TextImageProjection, TextImageTimeEmbedding, TextTimeEmbedding, TimestepEmbedding, Timesteps -from .modeling_utils import ModelMixin -from .unets.unet_2d_blocks import ( +from ..embeddings import TextImageProjection, TextImageTimeEmbedding, TextTimeEmbedding, TimestepEmbedding, Timesteps +from ..modeling_utils import ModelMixin +from ..unets.unet_2d_blocks import ( CrossAttnDownBlock2D, DownBlock2D, UNetMidBlock2D, UNetMidBlock2DCrossAttn, get_down_block, ) -from .unets.unet_2d_condition import UNet2DConditionModel +from ..unets.unet_2d_condition import UNet2DConditionModel logger = logging.get_logger(__name__) # pylint: disable=invalid-name diff --git a/src/diffusers/models/controlnet_flax.py b/src/diffusers/models/controlnets/controlnet_flax.py similarity index 98% rename from src/diffusers/models/controlnet_flax.py rename to src/diffusers/models/controlnets/controlnet_flax.py index 0540850a9e61..ab8d9b5f8cbb 100644 --- a/src/diffusers/models/controlnet_flax.py +++ b/src/diffusers/models/controlnets/controlnet_flax.py @@ -19,11 +19,11 @@ import jax.numpy as jnp from flax.core.frozen_dict import FrozenDict -from ..configuration_utils import ConfigMixin, flax_register_to_config -from ..utils import BaseOutput -from .embeddings_flax import FlaxTimestepEmbedding, FlaxTimesteps -from .modeling_flax_utils import FlaxModelMixin -from .unets.unet_2d_blocks_flax import ( +from ...configuration_utils import ConfigMixin, flax_register_to_config +from ...utils import BaseOutput +from ..embeddings_flax import FlaxTimestepEmbedding, FlaxTimesteps +from ..modeling_flax_utils import FlaxModelMixin +from ..unets.unet_2d_blocks_flax import ( FlaxCrossAttnDownBlock2D, FlaxDownBlock2D, FlaxUNetMidBlock2DCrossAttn, diff --git a/src/diffusers/models/controlnet_hunyuan.py b/src/diffusers/models/controlnets/controlnet_hunyuan.py similarity index 98% rename from src/diffusers/models/controlnet_hunyuan.py rename to src/diffusers/models/controlnets/controlnet_hunyuan.py index c97cdaf913b3..228b3d476c08 100644 --- a/src/diffusers/models/controlnet_hunyuan.py +++ b/src/diffusers/models/controlnets/controlnet_hunyuan.py @@ -17,17 +17,17 @@ import torch from torch import nn -from ..configuration_utils import ConfigMixin, register_to_config -from ..utils import logging -from .attention_processor import AttentionProcessor -from .controlnet import BaseOutput, Tuple, zero_module -from .embeddings import ( +from ...configuration_utils import ConfigMixin, register_to_config +from ...utils import logging +from ..attention_processor import AttentionProcessor +from ..embeddings import ( HunyuanCombinedTimestepTextSizeStyleEmbedding, PatchEmbed, PixArtAlphaTextProjection, ) -from .modeling_utils import ModelMixin -from .transformers.hunyuan_transformer_2d import HunyuanDiTBlock +from ..modeling_utils import ModelMixin +from ..transformers.hunyuan_transformer_2d import HunyuanDiTBlock +from .controlnet import BaseOutput, Tuple, zero_module logger = logging.get_logger(__name__) # pylint: disable=invalid-name diff --git a/src/diffusers/models/controlnet_sd3.py b/src/diffusers/models/controlnets/controlnet_sd3.py similarity index 96% rename from src/diffusers/models/controlnet_sd3.py rename to src/diffusers/models/controlnets/controlnet_sd3.py index 86a6a9178514..78cddbdfbe5f 100644 --- a/src/diffusers/models/controlnet_sd3.py +++ b/src/diffusers/models/controlnets/controlnet_sd3.py @@ -19,15 +19,15 @@ import torch import torch.nn as nn -from ..configuration_utils import ConfigMixin, register_to_config -from ..loaders import FromOriginalModelMixin, PeftAdapterMixin -from ..models.attention import JointTransformerBlock -from ..models.attention_processor import Attention, AttentionProcessor -from ..models.modeling_outputs import Transformer2DModelOutput -from ..models.modeling_utils import ModelMixin -from ..utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers +from ...configuration_utils import ConfigMixin, register_to_config +from ...loaders import FromOriginalModelMixin, PeftAdapterMixin +from ...utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers +from ..attention import JointTransformerBlock +from ..attention_processor import Attention, AttentionProcessor +from ..embeddings import CombinedTimestepTextProjEmbeddings, PatchEmbed +from ..modeling_outputs import Transformer2DModelOutput +from ..modeling_utils import ModelMixin from .controlnet import BaseOutput, zero_module -from .embeddings import CombinedTimestepTextProjEmbeddings, PatchEmbed logger = logging.get_logger(__name__) # pylint: disable=invalid-name diff --git a/src/diffusers/models/controlnet_xs.py b/src/diffusers/models/controlnets/controlnet_xs.py similarity index 99% rename from src/diffusers/models/controlnet_xs.py rename to src/diffusers/models/controlnets/controlnet_xs.py index 354acfebe0a2..2d53d4c24e7a 100644 --- a/src/diffusers/models/controlnet_xs.py +++ b/src/diffusers/models/controlnets/controlnet_xs.py @@ -19,10 +19,10 @@ import torch.utils.checkpoint from torch import Tensor, nn -from ..configuration_utils import ConfigMixin, register_to_config -from ..utils import BaseOutput, is_torch_version, logging -from ..utils.torch_utils import apply_freeu -from .attention_processor import ( +from ...configuration_utils import ConfigMixin, register_to_config +from ...utils import BaseOutput, is_torch_version, logging +from ...utils.torch_utils import apply_freeu +from ..attention_processor import ( ADDED_KV_ATTENTION_PROCESSORS, CROSS_ATTENTION_PROCESSORS, Attention, @@ -30,10 +30,9 @@ AttnAddedKVProcessor, AttnProcessor, ) -from .controlnet import ControlNetConditioningEmbedding -from .embeddings import TimestepEmbedding, Timesteps -from .modeling_utils import ModelMixin -from .unets.unet_2d_blocks import ( +from ..embeddings import TimestepEmbedding, Timesteps +from ..modeling_utils import ModelMixin +from ..unets.unet_2d_blocks import ( CrossAttnDownBlock2D, CrossAttnUpBlock2D, Downsample2D, @@ -42,7 +41,8 @@ UNetMidBlock2DCrossAttn, Upsample2D, ) -from .unets.unet_2d_condition import UNet2DConditionModel +from ..unets.unet_2d_condition import UNet2DConditionModel +from .controlnet import ControlNetConditioningEmbedding logger = logging.get_logger(__name__) # pylint: disable=invalid-name diff --git a/src/diffusers/pipelines/controlnet/multicontrolnet.py b/src/diffusers/pipelines/controlnet/multicontrolnet.py index e3c5ec6eed03..46c3d1681cc1 100644 --- a/src/diffusers/pipelines/controlnet/multicontrolnet.py +++ b/src/diffusers/pipelines/controlnet/multicontrolnet.py @@ -4,7 +4,7 @@ import torch from torch import nn -from ...models.controlnet import ControlNetModel, ControlNetOutput +from ...models.controlnets.controlnet import ControlNetModel, ControlNetOutput from ...models.modeling_utils import ModelMixin from ...utils import logging diff --git a/src/diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py b/src/diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py index 7542f895ccf2..71dd4d3ca24f 100644 --- a/src/diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py +++ b/src/diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py @@ -26,7 +26,7 @@ from ...image_processor import PipelineImageInput, VaeImageProcessor from ...loaders import FromSingleFileMixin, SD3LoraLoaderMixin from ...models.autoencoders import AutoencoderKL -from ...models.controlnet_sd3 import SD3ControlNetModel, SD3MultiControlNetModel +from ...models.controlnets.controlnet_sd3 import SD3ControlNetModel, SD3MultiControlNetModel from ...models.transformers import SD3Transformer2DModel from ...schedulers import FlowMatchEulerDiscreteScheduler from ...utils import ( diff --git a/tests/pipelines/test_pipelines_common.py b/tests/pipelines/test_pipelines_common.py index 8f2419db92e3..67eaed6aa5e0 100644 --- a/tests/pipelines/test_pipelines_common.py +++ b/tests/pipelines/test_pipelines_common.py @@ -32,7 +32,7 @@ from diffusers.image_processor import VaeImageProcessor from diffusers.loaders import IPAdapterMixin from diffusers.models.attention_processor import AttnProcessor -from diffusers.models.controlnet_xs import UNetControlNetXSModel +from diffusers.models.controlnets.controlnet_xs import UNetControlNetXSModel from diffusers.models.unets.unet_3d_condition import UNet3DConditionModel from diffusers.models.unets.unet_i2vgen_xl import I2VGenXLUNet from diffusers.models.unets.unet_motion_model import UNetMotionModel From 56c3c582fdec88e01f71322966252ae9e8f8ca69 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Tue, 2 Jul 2024 16:23:59 +0530 Subject: [PATCH 03/21] prepare for PR. --- src/diffusers/models/__init__.py | 26 +++++++++++++------- src/diffusers/models/controlnets/__init__.py | 15 +++++++++++ src/diffusers/utils/dummy_flax_objects.py | 8 +++--- 3 files changed, 36 insertions(+), 13 deletions(-) create mode 100644 src/diffusers/models/controlnets/__init__.py diff --git a/src/diffusers/models/__init__.py b/src/diffusers/models/__init__.py index b45496edadff..e540009324b9 100644 --- a/src/diffusers/models/__init__.py +++ b/src/diffusers/models/__init__.py @@ -32,10 +32,13 @@ _import_structure["autoencoders.autoencoder_tiny"] = ["AutoencoderTiny"] _import_structure["autoencoders.consistency_decoder_vae"] = ["ConsistencyDecoderVAE"] _import_structure["autoencoders.vq_model"] = ["VQModel"] - _import_structure["controlnet"] = ["ControlNetModel"] - _import_structure["controlnet_hunyuan"] = ["HunyuanDiT2DControlNetModel", "HunyuanDiT2DMultiControlNetModel"] - _import_structure["controlnet_sd3"] = ["SD3ControlNetModel", "SD3MultiControlNetModel"] - _import_structure["controlnet_xs"] = ["ControlNetXSAdapter", "UNetControlNetXSModel"] + _import_structure["controlnets.controlnet"] = ["ControlNetModel"] + _import_structure["controlnets.controlnet_hunyuan"] = [ + "HunyuanDiT2DControlNetModel", + "HunyuanDiT2DMultiControlNetModel", + ] + _import_structure["controlnets.controlnet_sd3"] = ["SD3ControlNetModel", "SD3MultiControlNetModel"] + _import_structure["controlnets.controlnet_xs"] = ["ControlNetXSAdapter", "UNetControlNetXSModel"] _import_structure["embeddings"] = ["ImageProjection"] _import_structure["modeling_utils"] = ["ModelMixin"] _import_structure["transformers.dit_transformer_2d"] = ["DiTTransformer2DModel"] @@ -75,10 +78,15 @@ ConsistencyDecoderVAE, VQModel, ) - from .controlnets.controlnet import ControlNetModel - from .controlnets.controlnet_hunyuan import HunyuanDiT2DControlNetModel, HunyuanDiT2DMultiControlNetModel - from .controlnets.controlnet_sd3 import SD3ControlNetModel, SD3MultiControlNetModel - from .controlnets.controlnet_xs import ControlNetXSAdapter, UNetControlNetXSModel + from .controlnets import ( + ControlNetModel, + ControlNetXSAdapter, + HunyuanDiT2DControlNetModel, + HunyuanDiT2DMultiControlNetModel, + SD3ControlNetModel, + SD3MultiControlNetModel, + UNetControlNetXSModel, + ) from .embeddings import ImageProjection from .modeling_utils import ModelMixin from .transformers import ( @@ -108,7 +116,7 @@ if is_flax_available(): from .autoencoders import FlaxAutoencoderKL - from .controlnets.controlnet_flax import FlaxControlNetModel + from .controlnets import FlaxControlNetModel from .unets import FlaxUNet2DConditionModel else: diff --git a/src/diffusers/models/controlnets/__init__.py b/src/diffusers/models/controlnets/__init__.py new file mode 100644 index 000000000000..75f6fc541c5e --- /dev/null +++ b/src/diffusers/models/controlnets/__init__.py @@ -0,0 +1,15 @@ +from ...utils import is_flax_available, is_torch_available + + +if is_torch_available(): + from .controlnet import ControlNetModel, ControlNetOutput + from .controlnet_hunyuan import ( + HunyuanControlNetOutput, + HunyuanDiT2DControlNetModel, + HunyuanDiT2DMultiControlNetModel, + ) + from .controlnet_sd3 import SD3ControlNetModel, SD3ControlNetOutput, SD3MultiControlNetModel + from .controlnet_xs import ControlNetXSAdapter, ControlNetXSOutput, UNetControlNetXSModel + +if is_flax_available(): + from .controlnet_flax import FlaxControlNetModel diff --git a/src/diffusers/utils/dummy_flax_objects.py b/src/diffusers/utils/dummy_flax_objects.py index 5fa8dbc81931..ea3559236f28 100644 --- a/src/diffusers/utils/dummy_flax_objects.py +++ b/src/diffusers/utils/dummy_flax_objects.py @@ -2,7 +2,7 @@ from ..utils import DummyObject, requires_backends -class FlaxControlNetModel(metaclass=DummyObject): +class FlaxAutoencoderKL(metaclass=DummyObject): _backends = ["flax"] def __init__(self, *args, **kwargs): @@ -17,7 +17,7 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["flax"]) -class FlaxModelMixin(metaclass=DummyObject): +class FlaxControlNetModel(metaclass=DummyObject): _backends = ["flax"] def __init__(self, *args, **kwargs): @@ -32,7 +32,7 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["flax"]) -class FlaxUNet2DConditionModel(metaclass=DummyObject): +class FlaxModelMixin(metaclass=DummyObject): _backends = ["flax"] def __init__(self, *args, **kwargs): @@ -47,7 +47,7 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["flax"]) -class FlaxAutoencoderKL(metaclass=DummyObject): +class FlaxUNet2DConditionModel(metaclass=DummyObject): _backends = ["flax"] def __init__(self, *args, **kwargs): From 4c5588c25d3cd54c3862765767599fc8eab20a4c Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Tue, 2 Jul 2024 16:36:03 +0530 Subject: [PATCH 04/21] revert a commit --- src/diffusers/__init__.py | 4 ++-- src/diffusers/models/__init__.py | 8 ++++---- src/diffusers/models/autoencoders/__init__.py | 5 ----- src/diffusers/models/{autoencoders => }/vae_flax.py | 6 +++--- 4 files changed, 9 insertions(+), 14 deletions(-) rename src/diffusers/models/{autoencoders => }/vae_flax.py (99%) diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index 063f1f09f7ef..5e4cfb55e299 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -422,10 +422,10 @@ else: - _import_structure["models.autoencoders.vae_flax"] = ["FlaxAutoencoderKL"] _import_structure["models.controlnet_flax"] = ["FlaxControlNetModel"] _import_structure["models.modeling_flax_utils"] = ["FlaxModelMixin"] _import_structure["models.unets.unet_2d_condition_flax"] = ["FlaxUNet2DConditionModel"] + _import_structure["models.vae_flax"] = ["FlaxAutoencoderKL"] _import_structure["pipelines"].extend(["FlaxDiffusionPipeline"]) _import_structure["schedulers"].extend( [ @@ -791,10 +791,10 @@ except OptionalDependencyNotAvailable: from .utils.dummy_flax_objects import * # noqa F403 else: - from .models.autoencoders.vae_flax import FlaxAutoencoderKL from .models.controlnets.controlnet_flax import FlaxControlNetModel from .models.modeling_flax_utils import FlaxModelMixin from .models.unets.unet_2d_condition_flax import FlaxUNet2DConditionModel + from .models.vae_flax import FlaxAutoencoderKL from .pipelines import FlaxDiffusionPipeline from .schedulers import ( FlaxDDIMScheduler, diff --git a/src/diffusers/models/__init__.py b/src/diffusers/models/__init__.py index e540009324b9..1e46ef66988f 100644 --- a/src/diffusers/models/__init__.py +++ b/src/diffusers/models/__init__.py @@ -62,10 +62,10 @@ _import_structure["unets.uvit_2d"] = ["UVit2DModel"] if is_flax_available(): - _import_structure["autoencoders.vae_flax"] = ["FlaxAutoencoderKL"] - - _import_structure["controlnet_flax"] = ["FlaxControlNetModel"] + _import_structure["controlents.controlnet_flax"] = ["FlaxControlNetModel"] _import_structure["unets.unet_2d_condition_flax"] = ["FlaxUNet2DConditionModel"] + _import_structure["vae_flax"] = ["FlaxAutoencoderKL"] + if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: if is_torch_available(): @@ -115,9 +115,9 @@ ) if is_flax_available(): - from .autoencoders import FlaxAutoencoderKL from .controlnets import FlaxControlNetModel from .unets import FlaxUNet2DConditionModel + from .vae_flax import FlaxAutoencoderKL else: import sys diff --git a/src/diffusers/models/autoencoders/__init__.py b/src/diffusers/models/autoencoders/__init__.py index 7d6eac31bc28..5c47748d62e0 100644 --- a/src/diffusers/models/autoencoders/__init__.py +++ b/src/diffusers/models/autoencoders/__init__.py @@ -1,11 +1,6 @@ -from ...utils import is_flax_available from .autoencoder_asym_kl import AsymmetricAutoencoderKL from .autoencoder_kl import AutoencoderKL from .autoencoder_kl_temporal_decoder import AutoencoderKLTemporalDecoder from .autoencoder_tiny import AutoencoderTiny from .consistency_decoder_vae import ConsistencyDecoderVAE from .vq_model import VQModel - - -if is_flax_available(): - from .vae_flax import FlaxAutoencoderKL diff --git a/src/diffusers/models/autoencoders/vae_flax.py b/src/diffusers/models/vae_flax.py similarity index 99% rename from src/diffusers/models/autoencoders/vae_flax.py rename to src/diffusers/models/vae_flax.py index abee7e55ef36..5027f4230e3b 100644 --- a/src/diffusers/models/autoencoders/vae_flax.py +++ b/src/diffusers/models/vae_flax.py @@ -24,9 +24,9 @@ import jax.numpy as jnp from flax.core.frozen_dict import FrozenDict -from ...configuration_utils import ConfigMixin, flax_register_to_config -from ...utils import BaseOutput -from ..modeling_flax_utils import FlaxModelMixin +from ..configuration_utils import ConfigMixin, flax_register_to_config +from ..utils import BaseOutput +from .modeling_flax_utils import FlaxModelMixin @flax.struct.dataclass From cce2aaaad65a54172752b47b5057845f9ad357c3 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Tue, 2 Jul 2024 16:40:25 +0530 Subject: [PATCH 05/21] gracefully deprecate controlnet deps. --- .../promptdiffusioncontrolnet.py | 2 +- src/diffusers/models/controlnet.py | 36 +++++++++++++++++++ 2 files changed, 37 insertions(+), 1 deletion(-) create mode 100644 src/diffusers/models/controlnet.py diff --git a/examples/research_projects/promptdiffusion/promptdiffusioncontrolnet.py b/examples/research_projects/promptdiffusion/promptdiffusioncontrolnet.py index 0e1eb9dcde49..46cabd863dfa 100644 --- a/examples/research_projects/promptdiffusion/promptdiffusioncontrolnet.py +++ b/examples/research_projects/promptdiffusion/promptdiffusioncontrolnet.py @@ -16,7 +16,7 @@ import torch from diffusers.configuration_utils import register_to_config -from diffusers.models.controlnets.controlnet import ( +from diffusers.models.controlnet import ( ControlNetConditioningEmbedding, ControlNetModel, ControlNetOutput, diff --git a/src/diffusers/models/controlnet.py b/src/diffusers/models/controlnet.py new file mode 100644 index 000000000000..7ef97cef0eb5 --- /dev/null +++ b/src/diffusers/models/controlnet.py @@ -0,0 +1,36 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from ..utils import deprecate +from .controlnets.controlnet import ControlNetConditioningEmbedding, ControlNetModel, ControlNetOutput + + +class ControlNetOutput(ControlNetOutput): + def __init__(self, *args, **kwargs): + deprecation_message = "Importing `ControlNetOutput` from `diffusers.models.controlnet` is deprecated and this will be removed in a future version. Please use `from diffusers.models.controlents.controlnet import ControlNetOutput`, instead." + deprecate("ControlNetOutput", "0.34", deprecation_message) + super().__init__(*args, **kwargs) + + +class ControlNetModel(ControlNetModel): + def __init__(self, *args, **kwargs): + deprecation_message = "Importing `ControlNetModel` from `diffusers.models.controlnet` is deprecated and this will be removed in a future version. Please use `from diffusers.models.controlnets.controlnet import ControlNetModel`, instead." + deprecate("ControlNetModel", "0.34", deprecation_message) + super().__init__(*args, **kwargs) + + +class ControlNetConditioningEmbedding(ControlNetConditioningEmbedding): + def __init__(self, *args, **kwargs): + deprecation_message = "Importing `ControlNetConditioningEmbedding` from `diffusers.models.controlnet` is deprecated and this will be removed in a future version. Please use `from diffusers.models.controlnets.controlnet import ControlNetConditioningEmbedding`, instead." + deprecate("ControlNetConditioningEmbedding", "0.34", deprecation_message) + super().__init__(*args, **kwargs) From 5a03eb5c8777f17980534ccc923c2d72f9b6296b Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Tue, 2 Jul 2024 16:42:28 +0530 Subject: [PATCH 06/21] fix --- src/diffusers/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index 5e4cfb55e299..061b7ff9efcc 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -422,7 +422,7 @@ else: - _import_structure["models.controlnet_flax"] = ["FlaxControlNetModel"] + _import_structure["models.controlnets.controlnet_flax"] = ["FlaxControlNetModel"] _import_structure["models.modeling_flax_utils"] = ["FlaxModelMixin"] _import_structure["models.unets.unet_2d_condition_flax"] = ["FlaxUNet2DConditionModel"] _import_structure["models.vae_flax"] = ["FlaxAutoencoderKL"] From f00d02fa68d7e751d9a1f8f51945497bfb0973b2 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Tue, 2 Jul 2024 16:46:23 +0530 Subject: [PATCH 07/21] fix doc path --- docs/source/en/api/models/controlnet_sd3.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/en/api/models/controlnet_sd3.md b/docs/source/en/api/models/controlnet_sd3.md index 59db64546fa2..78564d238eea 100644 --- a/docs/source/en/api/models/controlnet_sd3.md +++ b/docs/source/en/api/models/controlnet_sd3.md @@ -38,5 +38,5 @@ pipe = StableDiffusion3ControlNetPipeline.from_pretrained("stabilityai/stable-di ## SD3ControlNetOutput -[[autodoc]] models.controlnet_sd3.SD3ControlNetOutput +[[autodoc]] models.controlnets.controlnet_sd3.SD3ControlNetOutput From bfa2207e54f86c4659b20353bbef6aa8d438d488 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Tue, 2 Jul 2024 16:48:08 +0530 Subject: [PATCH 08/21] fix-copies --- src/diffusers/utils/dummy_flax_objects.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/diffusers/utils/dummy_flax_objects.py b/src/diffusers/utils/dummy_flax_objects.py index ea3559236f28..5fa8dbc81931 100644 --- a/src/diffusers/utils/dummy_flax_objects.py +++ b/src/diffusers/utils/dummy_flax_objects.py @@ -2,7 +2,7 @@ from ..utils import DummyObject, requires_backends -class FlaxAutoencoderKL(metaclass=DummyObject): +class FlaxControlNetModel(metaclass=DummyObject): _backends = ["flax"] def __init__(self, *args, **kwargs): @@ -17,7 +17,7 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["flax"]) -class FlaxControlNetModel(metaclass=DummyObject): +class FlaxModelMixin(metaclass=DummyObject): _backends = ["flax"] def __init__(self, *args, **kwargs): @@ -32,7 +32,7 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["flax"]) -class FlaxModelMixin(metaclass=DummyObject): +class FlaxUNet2DConditionModel(metaclass=DummyObject): _backends = ["flax"] def __init__(self, *args, **kwargs): @@ -47,7 +47,7 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["flax"]) -class FlaxUNet2DConditionModel(metaclass=DummyObject): +class FlaxAutoencoderKL(metaclass=DummyObject): _backends = ["flax"] def __init__(self, *args, **kwargs): From a4515bbca78d17c83f8aac49ab0affbd1a4412c9 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Tue, 2 Jul 2024 16:52:21 +0530 Subject: [PATCH 09/21] fix path --- docs/source/en/api/models/controlnet.md | 4 ++-- .../promptdiffusion/promptdiffusioncontrolnet.py | 6 +++--- src/diffusers/models/controlnets/controlnet.py | 6 +++--- src/diffusers/models/controlnets/controlnet_xs.py | 2 +- 4 files changed, 9 insertions(+), 9 deletions(-) diff --git a/docs/source/en/api/models/controlnet.md b/docs/source/en/api/models/controlnet.md index c2fdf1c6f975..35f1a2e5a887 100644 --- a/docs/source/en/api/models/controlnet.md +++ b/docs/source/en/api/models/controlnet.md @@ -39,7 +39,7 @@ pipe = StableDiffusionControlNetPipeline.from_single_file(url, controlnet=contro ## ControlNetOutput -[[autodoc]] models.controlnet.ControlNetOutput +[[autodoc]] models.controlnets.controlnet.ControlNetOutput ## FlaxControlNetModel @@ -47,4 +47,4 @@ pipe = StableDiffusionControlNetPipeline.from_single_file(url, controlnet=contro ## FlaxControlNetOutput -[[autodoc]] models.controlnet_flax.FlaxControlNetOutput +[[autodoc]] models.controlnets.controlnet_flax.FlaxControlNetOutput diff --git a/examples/research_projects/promptdiffusion/promptdiffusioncontrolnet.py b/examples/research_projects/promptdiffusion/promptdiffusioncontrolnet.py index 46cabd863dfa..6b1826a1c92d 100644 --- a/examples/research_projects/promptdiffusion/promptdiffusioncontrolnet.py +++ b/examples/research_projects/promptdiffusion/promptdiffusioncontrolnet.py @@ -229,11 +229,11 @@ def forward( In this mode, the ControlNet encoder tries its best to recognize the input content of the input even if you remove all prompts. A `guidance_scale` between 3.0 and 5.0 is recommended. return_dict (`bool`, defaults to `True`): - Whether or not to return a [`~models.controlnet.ControlNetOutput`] instead of a plain tuple. + Whether or not to return a [`~models.controlnets.controlnet.ControlNetOutput`] instead of a plain tuple. Returns: - [`~models.controlnet.ControlNetOutput`] **or** `tuple`: - If `return_dict` is `True`, a [`~models.controlnet.ControlNetOutput`] is returned, otherwise a tuple is + [`~models.controlnets.controlnet.ControlNetOutput`] **or** `tuple`: + If `return_dict` is `True`, a [`~models.controlnets.controlnet.ControlNetOutput`] is returned, otherwise a tuple is returned where the first element is the sample tensor. """ # check channel order diff --git a/src/diffusers/models/controlnets/controlnet.py b/src/diffusers/models/controlnets/controlnet.py index 9f390affc430..73a63c439791 100644 --- a/src/diffusers/models/controlnets/controlnet.py +++ b/src/diffusers/models/controlnets/controlnet.py @@ -710,11 +710,11 @@ def forward( In this mode, the ControlNet encoder tries its best to recognize the input content of the input even if you remove all prompts. A `guidance_scale` between 3.0 and 5.0 is recommended. return_dict (`bool`, defaults to `True`): - Whether or not to return a [`~models.controlnet.ControlNetOutput`] instead of a plain tuple. + Whether or not to return a [`~models.controlnets.controlnet.ControlNetOutput`] instead of a plain tuple. Returns: - [`~models.controlnet.ControlNetOutput`] **or** `tuple`: - If `return_dict` is `True`, a [`~models.controlnet.ControlNetOutput`] is returned, otherwise a tuple is + [`~models.controlnets.controlnet.ControlNetOutput`] **or** `tuple`: + If `return_dict` is `True`, a [`~models.controlnets.controlnet.ControlNetOutput`] is returned, otherwise a tuple is returned where the first element is the sample tensor. """ # check channel order diff --git a/src/diffusers/models/controlnets/controlnet_xs.py b/src/diffusers/models/controlnets/controlnet_xs.py index 2d53d4c24e7a..a1df035d5e3f 100644 --- a/src/diffusers/models/controlnets/controlnet_xs.py +++ b/src/diffusers/models/controlnets/controlnet_xs.py @@ -1059,7 +1059,7 @@ def forward( added_cond_kwargs (`dict`): Additional conditions for the Stable Diffusion XL UNet. return_dict (`bool`, defaults to `True`): - Whether or not to return a [`~models.controlnet.ControlNetOutput`] instead of a plain tuple. + Whether or not to return a [`~models.controlnets.controlnet.ControlNetOutput`] instead of a plain tuple. apply_control (`bool`, defaults to `True`): If `False`, the input is run only through the base model. From 940afd8722aff7e83c8b02dc155f626ecfb7b663 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Tue, 2 Jul 2024 16:52:33 +0530 Subject: [PATCH 10/21] style --- src/diffusers/models/controlnets/controlnet.py | 7 ++++--- src/diffusers/models/controlnets/controlnet_xs.py | 3 ++- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/src/diffusers/models/controlnets/controlnet.py b/src/diffusers/models/controlnets/controlnet.py index 73a63c439791..bd00f6dd1906 100644 --- a/src/diffusers/models/controlnets/controlnet.py +++ b/src/diffusers/models/controlnets/controlnet.py @@ -710,12 +710,13 @@ def forward( In this mode, the ControlNet encoder tries its best to recognize the input content of the input even if you remove all prompts. A `guidance_scale` between 3.0 and 5.0 is recommended. return_dict (`bool`, defaults to `True`): - Whether or not to return a [`~models.controlnets.controlnet.ControlNetOutput`] instead of a plain tuple. + Whether or not to return a [`~models.controlnets.controlnet.ControlNetOutput`] instead of a plain + tuple. Returns: [`~models.controlnets.controlnet.ControlNetOutput`] **or** `tuple`: - If `return_dict` is `True`, a [`~models.controlnets.controlnet.ControlNetOutput`] is returned, otherwise a tuple is - returned where the first element is the sample tensor. + If `return_dict` is `True`, a [`~models.controlnets.controlnet.ControlNetOutput`] is returned, + otherwise a tuple is returned where the first element is the sample tensor. """ # check channel order channel_order = self.config.controlnet_conditioning_channel_order diff --git a/src/diffusers/models/controlnets/controlnet_xs.py b/src/diffusers/models/controlnets/controlnet_xs.py index a1df035d5e3f..3f3cb11b6999 100644 --- a/src/diffusers/models/controlnets/controlnet_xs.py +++ b/src/diffusers/models/controlnets/controlnet_xs.py @@ -1059,7 +1059,8 @@ def forward( added_cond_kwargs (`dict`): Additional conditions for the Stable Diffusion XL UNet. return_dict (`bool`, defaults to `True`): - Whether or not to return a [`~models.controlnets.controlnet.ControlNetOutput`] instead of a plain tuple. + Whether or not to return a [`~models.controlnets.controlnet.ControlNetOutput`] instead of a plain + tuple. apply_control (`bool`, defaults to `True`): If `False`, the input is run only through the base model. From f31cfc2505a5ed061ce7a8202a3e64474433c6a4 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Tue, 2 Jul 2024 17:02:11 +0530 Subject: [PATCH 11/21] style --- src/diffusers/models/__init__.py | 2 +- src/diffusers/models/controlnet.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/diffusers/models/__init__.py b/src/diffusers/models/__init__.py index 1e46ef66988f..ca443f5ec5b1 100644 --- a/src/diffusers/models/__init__.py +++ b/src/diffusers/models/__init__.py @@ -62,7 +62,7 @@ _import_structure["unets.uvit_2d"] = ["UVit2DModel"] if is_flax_available(): - _import_structure["controlents.controlnet_flax"] = ["FlaxControlNetModel"] + _import_structure["controlnets.controlnet_flax"] = ["FlaxControlNetModel"] _import_structure["unets.unet_2d_condition_flax"] = ["FlaxUNet2DConditionModel"] _import_structure["vae_flax"] = ["FlaxAutoencoderKL"] diff --git a/src/diffusers/models/controlnet.py b/src/diffusers/models/controlnet.py index 7ef97cef0eb5..ec1abfc77b06 100644 --- a/src/diffusers/models/controlnet.py +++ b/src/diffusers/models/controlnet.py @@ -17,7 +17,7 @@ class ControlNetOutput(ControlNetOutput): def __init__(self, *args, **kwargs): - deprecation_message = "Importing `ControlNetOutput` from `diffusers.models.controlnet` is deprecated and this will be removed in a future version. Please use `from diffusers.models.controlents.controlnet import ControlNetOutput`, instead." + deprecation_message = "Importing `ControlNetOutput` from `diffusers.models.controlnet` is deprecated and this will be removed in a future version. Please use `from diffusers.models.controlnets.controlnet import ControlNetOutput`, instead." deprecate("ControlNetOutput", "0.34", deprecation_message) super().__init__(*args, **kwargs) From be0afc49c14d3face25f9098d0a2954a36ee2eb0 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Wed, 23 Oct 2024 17:41:19 +0530 Subject: [PATCH 12/21] conflicts --- src/diffusers/models/controlnet.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/src/diffusers/models/controlnet.py b/src/diffusers/models/controlnet.py index ec1abfc77b06..c2e6225e123c 100644 --- a/src/diffusers/models/controlnet.py +++ b/src/diffusers/models/controlnet.py @@ -12,7 +12,11 @@ # See the License for the specific language governing permissions and # limitations under the License. from ..utils import deprecate -from .controlnets.controlnet import ControlNetConditioningEmbedding, ControlNetModel, ControlNetOutput +from .controlnets.controlnet import ( + ControlNetConditioningEmbedding, + ControlNetModel, + ControlNetOutput, +) class ControlNetOutput(ControlNetOutput): From e7e7adf9c13d13cff517e5fd2dafa9bb07fef0ac Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Wed, 23 Oct 2024 17:45:45 +0530 Subject: [PATCH 13/21] fix --- src/diffusers/models/controlnet.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/diffusers/models/controlnet.py b/src/diffusers/models/controlnet.py index c2e6225e123c..174f2b9ada96 100644 --- a/src/diffusers/models/controlnet.py +++ b/src/diffusers/models/controlnet.py @@ -12,10 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. from ..utils import deprecate -from .controlnets.controlnet import ( +from .controlnets.controlnet import ( # noqa + BaseOutput, ControlNetConditioningEmbedding, ControlNetModel, ControlNetOutput, + zero_module, ) From c6946da1383ceabf2168f08c124c17b302766b66 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Wed, 23 Oct 2024 17:47:45 +0530 Subject: [PATCH 14/21] fix-copies --- src/diffusers/models/controlnet_sparsectrl.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/models/controlnet_sparsectrl.py b/src/diffusers/models/controlnet_sparsectrl.py index fa37e1f9e393..946be368145d 100644 --- a/src/diffusers/models/controlnet_sparsectrl.py +++ b/src/diffusers/models/controlnet_sparsectrl.py @@ -781,7 +781,7 @@ def forward( ) -# Copied from diffusers.models.controlnet.zero_module +# Copied from diffusers.models.controlnets.controlnet.zero_module def zero_module(module: nn.Module) -> nn.Module: for p in module.parameters(): nn.init.zeros_(p) From 80045823b969ef9c1fc9808818c7e29d0765d3a7 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Wed, 23 Oct 2024 17:59:35 +0530 Subject: [PATCH 15/21] sparsectrl. --- src/diffusers/models/__init__.py | 3 +- src/diffusers/models/controlnet_sparsectrl.py | 784 +---------------- src/diffusers/models/controlnets/__init__.py | 5 + .../controlnets/controlnet_sparsectrl.py | 788 ++++++++++++++++++ 4 files changed, 816 insertions(+), 764 deletions(-) create mode 100644 src/diffusers/models/controlnets/controlnet_sparsectrl.py diff --git a/src/diffusers/models/__init__.py b/src/diffusers/models/__init__.py index d170aea842f4..dca13491ec6e 100644 --- a/src/diffusers/models/__init__.py +++ b/src/diffusers/models/__init__.py @@ -35,13 +35,13 @@ _import_structure["autoencoders.consistency_decoder_vae"] = ["ConsistencyDecoderVAE"] _import_structure["autoencoders.vq_model"] = ["VQModel"] _import_structure["controlnet_flux"] = ["FluxControlNetModel", "FluxMultiControlNetModel"] - _import_structure["controlnet_sparsectrl"] = ["SparseControlNetModel"] _import_structure["controlnets.controlnet"] = ["ControlNetModel"] _import_structure["controlnets.controlnet_hunyuan"] = [ "HunyuanDiT2DControlNetModel", "HunyuanDiT2DMultiControlNetModel", ] _import_structure["controlnets.controlnet_sd3"] = ["SD3ControlNetModel", "SD3MultiControlNetModel"] + _import_structure["controlnets.controlnet_sparsectrl"] = ["SparseControlNetModel"] _import_structure["controlnets.controlnet_xs"] = ["ControlNetXSAdapter", "UNetControlNetXSModel"] _import_structure["embeddings"] = ["ImageProjection"] _import_structure["modeling_utils"] = ["ModelMixin"] @@ -101,6 +101,7 @@ HunyuanDiT2DMultiControlNetModel, SD3ControlNetModel, SD3MultiControlNetModel, + SparseControlNetModel, UNetControlNetXSModel, ) from .embeddings import ImageProjection diff --git a/src/diffusers/models/controlnet_sparsectrl.py b/src/diffusers/models/controlnet_sparsectrl.py index 946be368145d..1ccbd385b9a6 100644 --- a/src/diffusers/models/controlnet_sparsectrl.py +++ b/src/diffusers/models/controlnet_sparsectrl.py @@ -12,777 +12,35 @@ # See the License for the specific language governing permissions and # limitations under the License. -from dataclasses import dataclass -from typing import Any, Dict, List, Optional, Tuple, Union -import torch -from torch import nn -from torch.nn import functional as F - -from ..configuration_utils import ConfigMixin, register_to_config -from ..loaders import FromOriginalModelMixin -from ..utils import BaseOutput, logging -from .attention_processor import ( - ADDED_KV_ATTENTION_PROCESSORS, - CROSS_ATTENTION_PROCESSORS, - AttentionProcessor, - AttnAddedKVProcessor, - AttnProcessor, +from ..utils import deprecate, logging +from .controlnets.controlnet_sparsectrl import ( # noqa + SparseControlNetConditioningEmbedding, + SparseControlNetModel, + SparseControlNetOutput, + zero_module, ) -from .embeddings import TimestepEmbedding, Timesteps -from .modeling_utils import ModelMixin -from .unets.unet_2d_blocks import UNetMidBlock2DCrossAttn -from .unets.unet_2d_condition import UNet2DConditionModel -from .unets.unet_motion_model import CrossAttnDownBlockMotion, DownBlockMotion logger = logging.get_logger(__name__) # pylint: disable=invalid-name -@dataclass -class SparseControlNetOutput(BaseOutput): - """ - The output of [`SparseControlNetModel`]. - - Args: - down_block_res_samples (`tuple[torch.Tensor]`): - A tuple of downsample activations at different resolutions for each downsampling block. Each tensor should - be of shape `(batch_size, channel * resolution, height //resolution, width // resolution)`. Output can be - used to condition the original UNet's downsampling activations. - mid_down_block_re_sample (`torch.Tensor`): - The activation of the middle block (the lowest sample resolution). Each tensor should be of shape - `(batch_size, channel * lowest_resolution, height // lowest_resolution, width // lowest_resolution)`. - Output can be used to condition the original UNet's middle block activation. - """ - - down_block_res_samples: Tuple[torch.Tensor] - mid_block_res_sample: torch.Tensor - - -class SparseControlNetConditioningEmbedding(nn.Module): - def __init__( - self, - conditioning_embedding_channels: int, - conditioning_channels: int = 3, - block_out_channels: Tuple[int, ...] = (16, 32, 96, 256), - ): - super().__init__() - - self.conv_in = nn.Conv2d(conditioning_channels, block_out_channels[0], kernel_size=3, padding=1) - self.blocks = nn.ModuleList([]) - - for i in range(len(block_out_channels) - 1): - channel_in = block_out_channels[i] - channel_out = block_out_channels[i + 1] - self.blocks.append(nn.Conv2d(channel_in, channel_in, kernel_size=3, padding=1)) - self.blocks.append(nn.Conv2d(channel_in, channel_out, kernel_size=3, padding=1, stride=2)) - - self.conv_out = zero_module( - nn.Conv2d(block_out_channels[-1], conditioning_embedding_channels, kernel_size=3, padding=1) - ) - - def forward(self, conditioning: torch.Tensor) -> torch.Tensor: - embedding = self.conv_in(conditioning) - embedding = F.silu(embedding) - - for block in self.blocks: - embedding = block(embedding) - embedding = F.silu(embedding) - - embedding = self.conv_out(embedding) - return embedding - - -class SparseControlNetModel(ModelMixin, ConfigMixin, FromOriginalModelMixin): - """ - A SparseControlNet model as described in [SparseCtrl: Adding Sparse Controls to Text-to-Video Diffusion - Models](https://arxiv.org/abs/2311.16933). - - Args: - in_channels (`int`, defaults to 4): - The number of channels in the input sample. - conditioning_channels (`int`, defaults to 4): - The number of input channels in the controlnet conditional embedding module. If - `concat_condition_embedding` is True, the value provided here is incremented by 1. - flip_sin_to_cos (`bool`, defaults to `True`): - Whether to flip the sin to cos in the time embedding. - freq_shift (`int`, defaults to 0): - The frequency shift to apply to the time embedding. - down_block_types (`tuple[str]`, defaults to `("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")`): - The tuple of downsample blocks to use. - only_cross_attention (`Union[bool, Tuple[bool]]`, defaults to `False`): - block_out_channels (`tuple[int]`, defaults to `(320, 640, 1280, 1280)`): - The tuple of output channels for each block. - layers_per_block (`int`, defaults to 2): - The number of layers per block. - downsample_padding (`int`, defaults to 1): - The padding to use for the downsampling convolution. - mid_block_scale_factor (`float`, defaults to 1): - The scale factor to use for the mid block. - act_fn (`str`, defaults to "silu"): - The activation function to use. - norm_num_groups (`int`, *optional*, defaults to 32): - The number of groups to use for the normalization. If None, normalization and activation layers is skipped - in post-processing. - norm_eps (`float`, defaults to 1e-5): - The epsilon to use for the normalization. - cross_attention_dim (`int`, defaults to 1280): - The dimension of the cross attention features. - transformer_layers_per_block (`int` or `Tuple[int]`, *optional*, defaults to 1): - The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`]. Only relevant for - [`~models.unet_2d_blocks.CrossAttnDownBlock2D`], [`~models.unet_2d_blocks.CrossAttnUpBlock2D`], - [`~models.unet_2d_blocks.UNetMidBlock2DCrossAttn`]. - transformer_layers_per_mid_block (`int` or `Tuple[int]`, *optional*, defaults to 1): - The number of transformer layers to use in each layer in the middle block. - attention_head_dim (`int` or `Tuple[int]`, defaults to 8): - The dimension of the attention heads. - num_attention_heads (`int` or `Tuple[int]`, *optional*): - The number of heads to use for multi-head attention. - use_linear_projection (`bool`, defaults to `False`): - upcast_attention (`bool`, defaults to `False`): - resnet_time_scale_shift (`str`, defaults to `"default"`): - Time scale shift config for ResNet blocks (see `ResnetBlock2D`). Choose from `default` or `scale_shift`. - conditioning_embedding_out_channels (`Tuple[int]`, defaults to `(16, 32, 96, 256)`): - The tuple of output channel for each block in the `conditioning_embedding` layer. - global_pool_conditions (`bool`, defaults to `False`): - TODO(Patrick) - unused parameter - controlnet_conditioning_channel_order (`str`, defaults to `rgb`): - motion_max_seq_length (`int`, defaults to `32`): - The maximum sequence length to use in the motion module. - motion_num_attention_heads (`int` or `Tuple[int]`, defaults to `8`): - The number of heads to use in each attention layer of the motion module. - concat_conditioning_mask (`bool`, defaults to `True`): - use_simplified_condition_embedding (`bool`, defaults to `True`): - """ - - _supports_gradient_checkpointing = True - - @register_to_config - def __init__( - self, - in_channels: int = 4, - conditioning_channels: int = 4, - flip_sin_to_cos: bool = True, - freq_shift: int = 0, - down_block_types: Tuple[str, ...] = ( - "CrossAttnDownBlockMotion", - "CrossAttnDownBlockMotion", - "CrossAttnDownBlockMotion", - "DownBlockMotion", - ), - only_cross_attention: Union[bool, Tuple[bool]] = False, - block_out_channels: Tuple[int, ...] = (320, 640, 1280, 1280), - layers_per_block: int = 2, - downsample_padding: int = 1, - mid_block_scale_factor: float = 1, - act_fn: str = "silu", - norm_num_groups: Optional[int] = 32, - norm_eps: float = 1e-5, - cross_attention_dim: int = 768, - transformer_layers_per_block: Union[int, Tuple[int, ...]] = 1, - transformer_layers_per_mid_block: Optional[Union[int, Tuple[int]]] = None, - temporal_transformer_layers_per_block: Union[int, Tuple[int, ...]] = 1, - attention_head_dim: Union[int, Tuple[int, ...]] = 8, - num_attention_heads: Optional[Union[int, Tuple[int, ...]]] = None, - use_linear_projection: bool = False, - upcast_attention: bool = False, - resnet_time_scale_shift: str = "default", - conditioning_embedding_out_channels: Optional[Tuple[int, ...]] = (16, 32, 96, 256), - global_pool_conditions: bool = False, - controlnet_conditioning_channel_order: str = "rgb", - motion_max_seq_length: int = 32, - motion_num_attention_heads: int = 8, - concat_conditioning_mask: bool = True, - use_simplified_condition_embedding: bool = True, - ): - super().__init__() - self.use_simplified_condition_embedding = use_simplified_condition_embedding - - # If `num_attention_heads` is not defined (which is the case for most models) - # it will default to `attention_head_dim`. This looks weird upon first reading it and it is. - # The reason for this behavior is to correct for incorrectly named variables that were introduced - # when this library was created. The incorrect naming was only discovered much later in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131 - # Changing `attention_head_dim` to `num_attention_heads` for 40,000+ configurations is too backwards breaking - # which is why we correct for the naming here. - num_attention_heads = num_attention_heads or attention_head_dim - - # Check inputs - if len(block_out_channels) != len(down_block_types): - raise ValueError( - f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}." - ) - - if not isinstance(only_cross_attention, bool) and len(only_cross_attention) != len(down_block_types): - raise ValueError( - f"Must provide the same number of `only_cross_attention` as `down_block_types`. `only_cross_attention`: {only_cross_attention}. `down_block_types`: {down_block_types}." - ) - - if not isinstance(num_attention_heads, int) and len(num_attention_heads) != len(down_block_types): - raise ValueError( - f"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}." - ) - - if isinstance(transformer_layers_per_block, int): - transformer_layers_per_block = [transformer_layers_per_block] * len(down_block_types) - if isinstance(temporal_transformer_layers_per_block, int): - temporal_transformer_layers_per_block = [temporal_transformer_layers_per_block] * len(down_block_types) - - # input - conv_in_kernel = 3 - conv_in_padding = (conv_in_kernel - 1) // 2 - self.conv_in = nn.Conv2d( - in_channels, block_out_channels[0], kernel_size=conv_in_kernel, padding=conv_in_padding - ) - - if concat_conditioning_mask: - conditioning_channels = conditioning_channels + 1 - - self.concat_conditioning_mask = concat_conditioning_mask - - # control net conditioning embedding - if use_simplified_condition_embedding: - self.controlnet_cond_embedding = zero_module( - nn.Conv2d(conditioning_channels, block_out_channels[0], kernel_size=3, padding=1) - ) - else: - self.controlnet_cond_embedding = SparseControlNetConditioningEmbedding( - conditioning_embedding_channels=block_out_channels[0], - block_out_channels=conditioning_embedding_out_channels, - conditioning_channels=conditioning_channels, - ) - - # time - time_embed_dim = block_out_channels[0] * 4 - self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift) - timestep_input_dim = block_out_channels[0] - - self.time_embedding = TimestepEmbedding( - timestep_input_dim, - time_embed_dim, - act_fn=act_fn, - ) - - self.down_blocks = nn.ModuleList([]) - self.controlnet_down_blocks = nn.ModuleList([]) - - if isinstance(cross_attention_dim, int): - cross_attention_dim = (cross_attention_dim,) * len(down_block_types) - - if isinstance(only_cross_attention, bool): - only_cross_attention = [only_cross_attention] * len(down_block_types) - - if isinstance(attention_head_dim, int): - attention_head_dim = (attention_head_dim,) * len(down_block_types) - - if isinstance(num_attention_heads, int): - num_attention_heads = (num_attention_heads,) * len(down_block_types) - - if isinstance(motion_num_attention_heads, int): - motion_num_attention_heads = (motion_num_attention_heads,) * len(down_block_types) - - # down - output_channel = block_out_channels[0] - - controlnet_block = nn.Conv2d(output_channel, output_channel, kernel_size=1) - controlnet_block = zero_module(controlnet_block) - self.controlnet_down_blocks.append(controlnet_block) - - for i, down_block_type in enumerate(down_block_types): - input_channel = output_channel - output_channel = block_out_channels[i] - is_final_block = i == len(block_out_channels) - 1 - - if down_block_type == "CrossAttnDownBlockMotion": - down_block = CrossAttnDownBlockMotion( - in_channels=input_channel, - out_channels=output_channel, - temb_channels=time_embed_dim, - dropout=0, - num_layers=layers_per_block, - transformer_layers_per_block=transformer_layers_per_block[i], - resnet_eps=norm_eps, - resnet_time_scale_shift=resnet_time_scale_shift, - resnet_act_fn=act_fn, - resnet_groups=norm_num_groups, - resnet_pre_norm=True, - num_attention_heads=num_attention_heads[i], - cross_attention_dim=cross_attention_dim[i], - add_downsample=not is_final_block, - dual_cross_attention=False, - use_linear_projection=use_linear_projection, - only_cross_attention=only_cross_attention[i], - upcast_attention=upcast_attention, - temporal_num_attention_heads=motion_num_attention_heads[i], - temporal_max_seq_length=motion_max_seq_length, - temporal_transformer_layers_per_block=temporal_transformer_layers_per_block[i], - temporal_double_self_attention=False, - ) - elif down_block_type == "DownBlockMotion": - down_block = DownBlockMotion( - in_channels=input_channel, - out_channels=output_channel, - temb_channels=time_embed_dim, - dropout=0, - num_layers=layers_per_block, - resnet_eps=norm_eps, - resnet_time_scale_shift=resnet_time_scale_shift, - resnet_act_fn=act_fn, - resnet_groups=norm_num_groups, - resnet_pre_norm=True, - add_downsample=not is_final_block, - temporal_num_attention_heads=motion_num_attention_heads[i], - temporal_max_seq_length=motion_max_seq_length, - temporal_transformer_layers_per_block=temporal_transformer_layers_per_block[i], - temporal_double_self_attention=False, - ) - else: - raise ValueError( - "Invalid `block_type` encountered. Must be one of `CrossAttnDownBlockMotion` or `DownBlockMotion`" - ) - - self.down_blocks.append(down_block) - - for _ in range(layers_per_block): - controlnet_block = nn.Conv2d(output_channel, output_channel, kernel_size=1) - controlnet_block = zero_module(controlnet_block) - self.controlnet_down_blocks.append(controlnet_block) - - if not is_final_block: - controlnet_block = nn.Conv2d(output_channel, output_channel, kernel_size=1) - controlnet_block = zero_module(controlnet_block) - self.controlnet_down_blocks.append(controlnet_block) - - # mid - mid_block_channels = block_out_channels[-1] - - controlnet_block = nn.Conv2d(mid_block_channels, mid_block_channels, kernel_size=1) - controlnet_block = zero_module(controlnet_block) - self.controlnet_mid_block = controlnet_block - - if transformer_layers_per_mid_block is None: - transformer_layers_per_mid_block = ( - transformer_layers_per_block[-1] if isinstance(transformer_layers_per_block[-1], int) else 1 - ) - - self.mid_block = UNetMidBlock2DCrossAttn( - in_channels=mid_block_channels, - temb_channels=time_embed_dim, - dropout=0, - num_layers=1, - transformer_layers_per_block=transformer_layers_per_mid_block, - resnet_eps=norm_eps, - resnet_time_scale_shift=resnet_time_scale_shift, - resnet_act_fn=act_fn, - resnet_groups=norm_num_groups, - resnet_pre_norm=True, - num_attention_heads=num_attention_heads[-1], - output_scale_factor=mid_block_scale_factor, - cross_attention_dim=cross_attention_dim[-1], - dual_cross_attention=False, - use_linear_projection=use_linear_projection, - upcast_attention=upcast_attention, - attention_type="default", - ) - - @classmethod - def from_unet( - cls, - unet: UNet2DConditionModel, - controlnet_conditioning_channel_order: str = "rgb", - conditioning_embedding_out_channels: Optional[Tuple[int, ...]] = (16, 32, 96, 256), - load_weights_from_unet: bool = True, - conditioning_channels: int = 3, - ) -> "SparseControlNetModel": - r""" - Instantiate a [`SparseControlNetModel`] from [`UNet2DConditionModel`]. - - Parameters: - unet (`UNet2DConditionModel`): - The UNet model weights to copy to the [`SparseControlNetModel`]. All configuration options are also - copied where applicable. - """ - transformer_layers_per_block = ( - unet.config.transformer_layers_per_block if "transformer_layers_per_block" in unet.config else 1 - ) - down_block_types = unet.config.down_block_types - - for i in range(len(down_block_types)): - if "CrossAttn" in down_block_types[i]: - down_block_types[i] = "CrossAttnDownBlockMotion" - elif "Down" in down_block_types[i]: - down_block_types[i] = "DownBlockMotion" - else: - raise ValueError("Invalid `block_type` encountered. Must be a cross-attention or down block") - - controlnet = cls( - in_channels=unet.config.in_channels, - conditioning_channels=conditioning_channels, - flip_sin_to_cos=unet.config.flip_sin_to_cos, - freq_shift=unet.config.freq_shift, - down_block_types=unet.config.down_block_types, - only_cross_attention=unet.config.only_cross_attention, - block_out_channels=unet.config.block_out_channels, - layers_per_block=unet.config.layers_per_block, - downsample_padding=unet.config.downsample_padding, - mid_block_scale_factor=unet.config.mid_block_scale_factor, - act_fn=unet.config.act_fn, - norm_num_groups=unet.config.norm_num_groups, - norm_eps=unet.config.norm_eps, - cross_attention_dim=unet.config.cross_attention_dim, - transformer_layers_per_block=transformer_layers_per_block, - attention_head_dim=unet.config.attention_head_dim, - num_attention_heads=unet.config.num_attention_heads, - use_linear_projection=unet.config.use_linear_projection, - upcast_attention=unet.config.upcast_attention, - resnet_time_scale_shift=unet.config.resnet_time_scale_shift, - conditioning_embedding_out_channels=conditioning_embedding_out_channels, - controlnet_conditioning_channel_order=controlnet_conditioning_channel_order, - ) - - if load_weights_from_unet: - controlnet.conv_in.load_state_dict(unet.conv_in.state_dict(), strict=False) - controlnet.time_proj.load_state_dict(unet.time_proj.state_dict(), strict=False) - controlnet.time_embedding.load_state_dict(unet.time_embedding.state_dict(), strict=False) - controlnet.down_blocks.load_state_dict(unet.down_blocks.state_dict(), strict=False) - controlnet.mid_block.load_state_dict(unet.mid_block.state_dict(), strict=False) - - return controlnet - - @property - # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors - def attn_processors(self) -> Dict[str, AttentionProcessor]: - r""" - Returns: - `dict` of attention processors: A dictionary containing all attention processors used in the model with - indexed by its weight name. - """ - # set recursively - processors = {} - - def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]): - if hasattr(module, "get_processor"): - processors[f"{name}.processor"] = module.get_processor() - - for sub_name, child in module.named_children(): - fn_recursive_add_processors(f"{name}.{sub_name}", child, processors) - - return processors - - for name, module in self.named_children(): - fn_recursive_add_processors(name, module, processors) - - return processors - - # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor - def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]): - r""" - Sets the attention processor to use to compute attention. - - Parameters: - processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`): - The instantiated processor class or a dictionary of processor classes that will be set as the processor - for **all** `Attention` layers. - - If `processor` is a dict, the key needs to define the path to the corresponding cross attention - processor. This is strongly recommended when setting trainable attention processors. - - """ - count = len(self.attn_processors.keys()) - - if isinstance(processor, dict) and len(processor) != count: - raise ValueError( - f"A dict of processors was passed, but the number of processors {len(processor)} does not match the" - f" number of attention layers: {count}. Please make sure to pass {count} processor classes." - ) - - def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): - if hasattr(module, "set_processor"): - if not isinstance(processor, dict): - module.set_processor(processor) - else: - module.set_processor(processor.pop(f"{name}.processor")) - - for sub_name, child in module.named_children(): - fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor) - - for name, module in self.named_children(): - fn_recursive_attn_processor(name, module, processor) - - # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor - def set_default_attn_processor(self): - """ - Disables custom attention processors and sets the default attention implementation. - """ - if all(proc.__class__ in ADDED_KV_ATTENTION_PROCESSORS for proc in self.attn_processors.values()): - processor = AttnAddedKVProcessor() - elif all(proc.__class__ in CROSS_ATTENTION_PROCESSORS for proc in self.attn_processors.values()): - processor = AttnProcessor() - else: - raise ValueError( - f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}" - ) - - self.set_attn_processor(processor) - - # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attention_slice - def set_attention_slice(self, slice_size: Union[str, int, List[int]]) -> None: - r""" - Enable sliced attention computation. - - When this option is enabled, the attention module splits the input tensor in slices to compute attention in - several steps. This is useful for saving some memory in exchange for a small decrease in speed. - - Args: - slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`): - When `"auto"`, input to the attention heads is halved, so attention is computed in two steps. If - `"max"`, maximum amount of memory is saved by running only one slice at a time. If a number is - provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim` - must be a multiple of `slice_size`. - """ - sliceable_head_dims = [] - - def fn_recursive_retrieve_sliceable_dims(module: torch.nn.Module): - if hasattr(module, "set_attention_slice"): - sliceable_head_dims.append(module.sliceable_head_dim) - - for child in module.children(): - fn_recursive_retrieve_sliceable_dims(child) - - # retrieve number of attention layers - for module in self.children(): - fn_recursive_retrieve_sliceable_dims(module) - - num_sliceable_layers = len(sliceable_head_dims) - - if slice_size == "auto": - # half the attention head size is usually a good trade-off between - # speed and memory - slice_size = [dim // 2 for dim in sliceable_head_dims] - elif slice_size == "max": - # make smallest slice possible - slice_size = num_sliceable_layers * [1] - - slice_size = num_sliceable_layers * [slice_size] if not isinstance(slice_size, list) else slice_size - - if len(slice_size) != len(sliceable_head_dims): - raise ValueError( - f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different" - f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}." - ) - - for i in range(len(slice_size)): - size = slice_size[i] - dim = sliceable_head_dims[i] - if size is not None and size > dim: - raise ValueError(f"size {size} has to be smaller or equal to {dim}.") - - # Recursively walk through all the children. - # Any children which exposes the set_attention_slice method - # gets the message - def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[int]): - if hasattr(module, "set_attention_slice"): - module.set_attention_slice(slice_size.pop()) - - for child in module.children(): - fn_recursive_set_attention_slice(child, slice_size) - - reversed_slice_size = list(reversed(slice_size)) - for module in self.children(): - fn_recursive_set_attention_slice(module, reversed_slice_size) - - def _set_gradient_checkpointing(self, module, value: bool = False) -> None: - if isinstance(module, (CrossAttnDownBlockMotion, DownBlockMotion, UNetMidBlock2DCrossAttn)): - module.gradient_checkpointing = value - - def forward( - self, - sample: torch.Tensor, - timestep: Union[torch.Tensor, float, int], - encoder_hidden_states: torch.Tensor, - controlnet_cond: torch.Tensor, - conditioning_scale: float = 1.0, - timestep_cond: Optional[torch.Tensor] = None, - attention_mask: Optional[torch.Tensor] = None, - cross_attention_kwargs: Optional[Dict[str, Any]] = None, - conditioning_mask: Optional[torch.Tensor] = None, - guess_mode: bool = False, - return_dict: bool = True, - ) -> Union[SparseControlNetOutput, Tuple[Tuple[torch.Tensor, ...], torch.Tensor]]: - """ - The [`SparseControlNetModel`] forward method. - - Args: - sample (`torch.Tensor`): - The noisy input tensor. - timestep (`Union[torch.Tensor, float, int]`): - The number of timesteps to denoise an input. - encoder_hidden_states (`torch.Tensor`): - The encoder hidden states. - controlnet_cond (`torch.Tensor`): - The conditional input tensor of shape `(batch_size, sequence_length, hidden_size)`. - conditioning_scale (`float`, defaults to `1.0`): - The scale factor for ControlNet outputs. - class_labels (`torch.Tensor`, *optional*, defaults to `None`): - Optional class labels for conditioning. Their embeddings will be summed with the timestep embeddings. - timestep_cond (`torch.Tensor`, *optional*, defaults to `None`): - Additional conditional embeddings for timestep. If provided, the embeddings will be summed with the - timestep_embedding passed through the `self.time_embedding` layer to obtain the final timestep - embeddings. - attention_mask (`torch.Tensor`, *optional*, defaults to `None`): - An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. If `1` the mask - is kept, otherwise if `0` it is discarded. Mask will be converted into a bias, which adds large - negative values to the attention scores corresponding to "discard" tokens. - added_cond_kwargs (`dict`): - Additional conditions for the Stable Diffusion XL UNet. - cross_attention_kwargs (`dict[str]`, *optional*, defaults to `None`): - A kwargs dictionary that if specified is passed along to the `AttnProcessor`. - guess_mode (`bool`, defaults to `False`): - In this mode, the ControlNet encoder tries its best to recognize the input content of the input even if - you remove all prompts. A `guidance_scale` between 3.0 and 5.0 is recommended. - return_dict (`bool`, defaults to `True`): - Whether or not to return a [`~models.controlnet.ControlNetOutput`] instead of a plain tuple. - Returns: - [`~models.controlnet.ControlNetOutput`] **or** `tuple`: - If `return_dict` is `True`, a [`~models.controlnet.ControlNetOutput`] is returned, otherwise a tuple is - returned where the first element is the sample tensor. - """ - sample_batch_size, sample_channels, sample_num_frames, sample_height, sample_width = sample.shape - sample = torch.zeros_like(sample) - - # check channel order - channel_order = self.config.controlnet_conditioning_channel_order - - if channel_order == "rgb": - # in rgb order by default - ... - elif channel_order == "bgr": - controlnet_cond = torch.flip(controlnet_cond, dims=[1]) - else: - raise ValueError(f"unknown `controlnet_conditioning_channel_order`: {channel_order}") - - # prepare attention_mask - if attention_mask is not None: - attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0 - attention_mask = attention_mask.unsqueeze(1) - - # 1. time - timesteps = timestep - if not torch.is_tensor(timesteps): - # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can - # This would be a good case for the `match` statement (Python 3.10+) - is_mps = sample.device.type == "mps" - if isinstance(timestep, float): - dtype = torch.float32 if is_mps else torch.float64 - else: - dtype = torch.int32 if is_mps else torch.int64 - timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device) - elif len(timesteps.shape) == 0: - timesteps = timesteps[None].to(sample.device) - - # broadcast to batch dimension in a way that's compatible with ONNX/Core ML - timesteps = timesteps.expand(sample.shape[0]) - - t_emb = self.time_proj(timesteps) - - # timesteps does not contain any weights and will always return f32 tensors - # but time_embedding might actually be running in fp16. so we need to cast here. - # there might be better ways to encapsulate this. - t_emb = t_emb.to(dtype=sample.dtype) - - emb = self.time_embedding(t_emb, timestep_cond) - emb = emb.repeat_interleave(sample_num_frames, dim=0) - - # 2. pre-process - batch_size, channels, num_frames, height, width = sample.shape - - sample = sample.permute(0, 2, 1, 3, 4).reshape(batch_size * num_frames, channels, height, width) - sample = self.conv_in(sample) - - batch_frames, channels, height, width = sample.shape - sample = sample[:, None].reshape(sample_batch_size, sample_num_frames, channels, height, width) - - if self.concat_conditioning_mask: - controlnet_cond = torch.cat([controlnet_cond, conditioning_mask], dim=1) - - batch_size, channels, num_frames, height, width = controlnet_cond.shape - controlnet_cond = controlnet_cond.permute(0, 2, 1, 3, 4).reshape( - batch_size * num_frames, channels, height, width - ) - controlnet_cond = self.controlnet_cond_embedding(controlnet_cond) - batch_frames, channels, height, width = controlnet_cond.shape - controlnet_cond = controlnet_cond[:, None].reshape(batch_size, num_frames, channels, height, width) - - sample = sample + controlnet_cond - - batch_size, num_frames, channels, height, width = sample.shape - sample = sample.reshape(sample_batch_size * sample_num_frames, channels, height, width) - - # 3. down - down_block_res_samples = (sample,) - for downsample_block in self.down_blocks: - if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention: - sample, res_samples = downsample_block( - hidden_states=sample, - temb=emb, - encoder_hidden_states=encoder_hidden_states, - attention_mask=attention_mask, - num_frames=num_frames, - cross_attention_kwargs=cross_attention_kwargs, - ) - else: - sample, res_samples = downsample_block(hidden_states=sample, temb=emb, num_frames=num_frames) - - down_block_res_samples += res_samples - - # 4. mid - if self.mid_block is not None: - if hasattr(self.mid_block, "has_cross_attention") and self.mid_block.has_cross_attention: - sample = self.mid_block( - sample, - emb, - encoder_hidden_states=encoder_hidden_states, - attention_mask=attention_mask, - cross_attention_kwargs=cross_attention_kwargs, - ) - else: - sample = self.mid_block(sample, emb) - - # 5. Control net blocks - controlnet_down_block_res_samples = () - - for down_block_res_sample, controlnet_block in zip(down_block_res_samples, self.controlnet_down_blocks): - down_block_res_sample = controlnet_block(down_block_res_sample) - controlnet_down_block_res_samples = controlnet_down_block_res_samples + (down_block_res_sample,) - - down_block_res_samples = controlnet_down_block_res_samples - mid_block_res_sample = self.controlnet_mid_block(sample) - - # 6. scaling - if guess_mode and not self.config.global_pool_conditions: - scales = torch.logspace(-1, 0, len(down_block_res_samples) + 1, device=sample.device) # 0.1 to 1.0 - scales = scales * conditioning_scale - down_block_res_samples = [sample * scale for sample, scale in zip(down_block_res_samples, scales)] - mid_block_res_sample = mid_block_res_sample * scales[-1] # last one - else: - down_block_res_samples = [sample * conditioning_scale for sample in down_block_res_samples] - mid_block_res_sample = mid_block_res_sample * conditioning_scale - - if self.config.global_pool_conditions: - down_block_res_samples = [ - torch.mean(sample, dim=(2, 3), keepdim=True) for sample in down_block_res_samples - ] - mid_block_res_sample = torch.mean(mid_block_res_sample, dim=(2, 3), keepdim=True) +class SparseControlNetOutput(SparseControlNetOutput): + def __init__(self, *args, **kwargs): + deprecation_message = "Importing `SparseControlNetOutput` from `diffusers.models.controlnet_sparsectrl` is deprecated and this will be removed in a future version. Please use `from diffusers.models.controlnets.controlnet_sparsectrl import SparseControlNetOutput`, instead." + deprecate("SparseControlNetOutput", "0.34", deprecation_message) + super().__init__(*args, **kwargs) - if not return_dict: - return (down_block_res_samples, mid_block_res_sample) - return SparseControlNetOutput( - down_block_res_samples=down_block_res_samples, mid_block_res_sample=mid_block_res_sample - ) +class SparseControlNetConditioningEmbedding(SparseControlNetConditioningEmbedding): + def __init__(self, *args, **kwargs): + deprecation_message = "Importing `SparseControlNetConditioningEmbedding` from `diffusers.models.controlnet_sparsectrl` is deprecated and this will be removed in a future version. Please use `from diffusers.models.controlnets.controlnet_sparsectrl import SparseControlNetConditioningEmbedding`, instead." + deprecate("SparseControlNetConditioningEmbedding", "0.34", deprecation_message) + super().__init__(*args, **kwargs) -# Copied from diffusers.models.controlnets.controlnet.zero_module -def zero_module(module: nn.Module) -> nn.Module: - for p in module.parameters(): - nn.init.zeros_(p) - return module +class SparseControlNetModel(SparseControlNetModel): + def __init__(self, *args, **kwargs): + deprecation_message = "Importing `SparseControlNetModel` from `diffusers.models.controlnet_sparsectrl` is deprecated and this will be removed in a future version. Please use `from diffusers.models.controlnets.controlnet_sparsectrl import SparseControlNetModel`, instead." + deprecate("SparseControlNetModel", "0.34", deprecation_message) + super().__init__(*args, **kwargs) diff --git a/src/diffusers/models/controlnets/__init__.py b/src/diffusers/models/controlnets/__init__.py index 75f6fc541c5e..fb713009a456 100644 --- a/src/diffusers/models/controlnets/__init__.py +++ b/src/diffusers/models/controlnets/__init__.py @@ -9,6 +9,11 @@ HunyuanDiT2DMultiControlNetModel, ) from .controlnet_sd3 import SD3ControlNetModel, SD3ControlNetOutput, SD3MultiControlNetModel + from .controlnet_sparsectrl import ( + SparseControlNetConditioningEmbedding, + SparseControlNetModel, + SparseControlNetOutput, + ) from .controlnet_xs import ControlNetXSAdapter, ControlNetXSOutput, UNetControlNetXSModel if is_flax_available(): diff --git a/src/diffusers/models/controlnets/controlnet_sparsectrl.py b/src/diffusers/models/controlnets/controlnet_sparsectrl.py new file mode 100644 index 000000000000..946be368145d --- /dev/null +++ b/src/diffusers/models/controlnets/controlnet_sparsectrl.py @@ -0,0 +1,788 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass +from typing import Any, Dict, List, Optional, Tuple, Union + +import torch +from torch import nn +from torch.nn import functional as F + +from ..configuration_utils import ConfigMixin, register_to_config +from ..loaders import FromOriginalModelMixin +from ..utils import BaseOutput, logging +from .attention_processor import ( + ADDED_KV_ATTENTION_PROCESSORS, + CROSS_ATTENTION_PROCESSORS, + AttentionProcessor, + AttnAddedKVProcessor, + AttnProcessor, +) +from .embeddings import TimestepEmbedding, Timesteps +from .modeling_utils import ModelMixin +from .unets.unet_2d_blocks import UNetMidBlock2DCrossAttn +from .unets.unet_2d_condition import UNet2DConditionModel +from .unets.unet_motion_model import CrossAttnDownBlockMotion, DownBlockMotion + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +@dataclass +class SparseControlNetOutput(BaseOutput): + """ + The output of [`SparseControlNetModel`]. + + Args: + down_block_res_samples (`tuple[torch.Tensor]`): + A tuple of downsample activations at different resolutions for each downsampling block. Each tensor should + be of shape `(batch_size, channel * resolution, height //resolution, width // resolution)`. Output can be + used to condition the original UNet's downsampling activations. + mid_down_block_re_sample (`torch.Tensor`): + The activation of the middle block (the lowest sample resolution). Each tensor should be of shape + `(batch_size, channel * lowest_resolution, height // lowest_resolution, width // lowest_resolution)`. + Output can be used to condition the original UNet's middle block activation. + """ + + down_block_res_samples: Tuple[torch.Tensor] + mid_block_res_sample: torch.Tensor + + +class SparseControlNetConditioningEmbedding(nn.Module): + def __init__( + self, + conditioning_embedding_channels: int, + conditioning_channels: int = 3, + block_out_channels: Tuple[int, ...] = (16, 32, 96, 256), + ): + super().__init__() + + self.conv_in = nn.Conv2d(conditioning_channels, block_out_channels[0], kernel_size=3, padding=1) + self.blocks = nn.ModuleList([]) + + for i in range(len(block_out_channels) - 1): + channel_in = block_out_channels[i] + channel_out = block_out_channels[i + 1] + self.blocks.append(nn.Conv2d(channel_in, channel_in, kernel_size=3, padding=1)) + self.blocks.append(nn.Conv2d(channel_in, channel_out, kernel_size=3, padding=1, stride=2)) + + self.conv_out = zero_module( + nn.Conv2d(block_out_channels[-1], conditioning_embedding_channels, kernel_size=3, padding=1) + ) + + def forward(self, conditioning: torch.Tensor) -> torch.Tensor: + embedding = self.conv_in(conditioning) + embedding = F.silu(embedding) + + for block in self.blocks: + embedding = block(embedding) + embedding = F.silu(embedding) + + embedding = self.conv_out(embedding) + return embedding + + +class SparseControlNetModel(ModelMixin, ConfigMixin, FromOriginalModelMixin): + """ + A SparseControlNet model as described in [SparseCtrl: Adding Sparse Controls to Text-to-Video Diffusion + Models](https://arxiv.org/abs/2311.16933). + + Args: + in_channels (`int`, defaults to 4): + The number of channels in the input sample. + conditioning_channels (`int`, defaults to 4): + The number of input channels in the controlnet conditional embedding module. If + `concat_condition_embedding` is True, the value provided here is incremented by 1. + flip_sin_to_cos (`bool`, defaults to `True`): + Whether to flip the sin to cos in the time embedding. + freq_shift (`int`, defaults to 0): + The frequency shift to apply to the time embedding. + down_block_types (`tuple[str]`, defaults to `("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")`): + The tuple of downsample blocks to use. + only_cross_attention (`Union[bool, Tuple[bool]]`, defaults to `False`): + block_out_channels (`tuple[int]`, defaults to `(320, 640, 1280, 1280)`): + The tuple of output channels for each block. + layers_per_block (`int`, defaults to 2): + The number of layers per block. + downsample_padding (`int`, defaults to 1): + The padding to use for the downsampling convolution. + mid_block_scale_factor (`float`, defaults to 1): + The scale factor to use for the mid block. + act_fn (`str`, defaults to "silu"): + The activation function to use. + norm_num_groups (`int`, *optional*, defaults to 32): + The number of groups to use for the normalization. If None, normalization and activation layers is skipped + in post-processing. + norm_eps (`float`, defaults to 1e-5): + The epsilon to use for the normalization. + cross_attention_dim (`int`, defaults to 1280): + The dimension of the cross attention features. + transformer_layers_per_block (`int` or `Tuple[int]`, *optional*, defaults to 1): + The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`]. Only relevant for + [`~models.unet_2d_blocks.CrossAttnDownBlock2D`], [`~models.unet_2d_blocks.CrossAttnUpBlock2D`], + [`~models.unet_2d_blocks.UNetMidBlock2DCrossAttn`]. + transformer_layers_per_mid_block (`int` or `Tuple[int]`, *optional*, defaults to 1): + The number of transformer layers to use in each layer in the middle block. + attention_head_dim (`int` or `Tuple[int]`, defaults to 8): + The dimension of the attention heads. + num_attention_heads (`int` or `Tuple[int]`, *optional*): + The number of heads to use for multi-head attention. + use_linear_projection (`bool`, defaults to `False`): + upcast_attention (`bool`, defaults to `False`): + resnet_time_scale_shift (`str`, defaults to `"default"`): + Time scale shift config for ResNet blocks (see `ResnetBlock2D`). Choose from `default` or `scale_shift`. + conditioning_embedding_out_channels (`Tuple[int]`, defaults to `(16, 32, 96, 256)`): + The tuple of output channel for each block in the `conditioning_embedding` layer. + global_pool_conditions (`bool`, defaults to `False`): + TODO(Patrick) - unused parameter + controlnet_conditioning_channel_order (`str`, defaults to `rgb`): + motion_max_seq_length (`int`, defaults to `32`): + The maximum sequence length to use in the motion module. + motion_num_attention_heads (`int` or `Tuple[int]`, defaults to `8`): + The number of heads to use in each attention layer of the motion module. + concat_conditioning_mask (`bool`, defaults to `True`): + use_simplified_condition_embedding (`bool`, defaults to `True`): + """ + + _supports_gradient_checkpointing = True + + @register_to_config + def __init__( + self, + in_channels: int = 4, + conditioning_channels: int = 4, + flip_sin_to_cos: bool = True, + freq_shift: int = 0, + down_block_types: Tuple[str, ...] = ( + "CrossAttnDownBlockMotion", + "CrossAttnDownBlockMotion", + "CrossAttnDownBlockMotion", + "DownBlockMotion", + ), + only_cross_attention: Union[bool, Tuple[bool]] = False, + block_out_channels: Tuple[int, ...] = (320, 640, 1280, 1280), + layers_per_block: int = 2, + downsample_padding: int = 1, + mid_block_scale_factor: float = 1, + act_fn: str = "silu", + norm_num_groups: Optional[int] = 32, + norm_eps: float = 1e-5, + cross_attention_dim: int = 768, + transformer_layers_per_block: Union[int, Tuple[int, ...]] = 1, + transformer_layers_per_mid_block: Optional[Union[int, Tuple[int]]] = None, + temporal_transformer_layers_per_block: Union[int, Tuple[int, ...]] = 1, + attention_head_dim: Union[int, Tuple[int, ...]] = 8, + num_attention_heads: Optional[Union[int, Tuple[int, ...]]] = None, + use_linear_projection: bool = False, + upcast_attention: bool = False, + resnet_time_scale_shift: str = "default", + conditioning_embedding_out_channels: Optional[Tuple[int, ...]] = (16, 32, 96, 256), + global_pool_conditions: bool = False, + controlnet_conditioning_channel_order: str = "rgb", + motion_max_seq_length: int = 32, + motion_num_attention_heads: int = 8, + concat_conditioning_mask: bool = True, + use_simplified_condition_embedding: bool = True, + ): + super().__init__() + self.use_simplified_condition_embedding = use_simplified_condition_embedding + + # If `num_attention_heads` is not defined (which is the case for most models) + # it will default to `attention_head_dim`. This looks weird upon first reading it and it is. + # The reason for this behavior is to correct for incorrectly named variables that were introduced + # when this library was created. The incorrect naming was only discovered much later in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131 + # Changing `attention_head_dim` to `num_attention_heads` for 40,000+ configurations is too backwards breaking + # which is why we correct for the naming here. + num_attention_heads = num_attention_heads or attention_head_dim + + # Check inputs + if len(block_out_channels) != len(down_block_types): + raise ValueError( + f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}." + ) + + if not isinstance(only_cross_attention, bool) and len(only_cross_attention) != len(down_block_types): + raise ValueError( + f"Must provide the same number of `only_cross_attention` as `down_block_types`. `only_cross_attention`: {only_cross_attention}. `down_block_types`: {down_block_types}." + ) + + if not isinstance(num_attention_heads, int) and len(num_attention_heads) != len(down_block_types): + raise ValueError( + f"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}." + ) + + if isinstance(transformer_layers_per_block, int): + transformer_layers_per_block = [transformer_layers_per_block] * len(down_block_types) + if isinstance(temporal_transformer_layers_per_block, int): + temporal_transformer_layers_per_block = [temporal_transformer_layers_per_block] * len(down_block_types) + + # input + conv_in_kernel = 3 + conv_in_padding = (conv_in_kernel - 1) // 2 + self.conv_in = nn.Conv2d( + in_channels, block_out_channels[0], kernel_size=conv_in_kernel, padding=conv_in_padding + ) + + if concat_conditioning_mask: + conditioning_channels = conditioning_channels + 1 + + self.concat_conditioning_mask = concat_conditioning_mask + + # control net conditioning embedding + if use_simplified_condition_embedding: + self.controlnet_cond_embedding = zero_module( + nn.Conv2d(conditioning_channels, block_out_channels[0], kernel_size=3, padding=1) + ) + else: + self.controlnet_cond_embedding = SparseControlNetConditioningEmbedding( + conditioning_embedding_channels=block_out_channels[0], + block_out_channels=conditioning_embedding_out_channels, + conditioning_channels=conditioning_channels, + ) + + # time + time_embed_dim = block_out_channels[0] * 4 + self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift) + timestep_input_dim = block_out_channels[0] + + self.time_embedding = TimestepEmbedding( + timestep_input_dim, + time_embed_dim, + act_fn=act_fn, + ) + + self.down_blocks = nn.ModuleList([]) + self.controlnet_down_blocks = nn.ModuleList([]) + + if isinstance(cross_attention_dim, int): + cross_attention_dim = (cross_attention_dim,) * len(down_block_types) + + if isinstance(only_cross_attention, bool): + only_cross_attention = [only_cross_attention] * len(down_block_types) + + if isinstance(attention_head_dim, int): + attention_head_dim = (attention_head_dim,) * len(down_block_types) + + if isinstance(num_attention_heads, int): + num_attention_heads = (num_attention_heads,) * len(down_block_types) + + if isinstance(motion_num_attention_heads, int): + motion_num_attention_heads = (motion_num_attention_heads,) * len(down_block_types) + + # down + output_channel = block_out_channels[0] + + controlnet_block = nn.Conv2d(output_channel, output_channel, kernel_size=1) + controlnet_block = zero_module(controlnet_block) + self.controlnet_down_blocks.append(controlnet_block) + + for i, down_block_type in enumerate(down_block_types): + input_channel = output_channel + output_channel = block_out_channels[i] + is_final_block = i == len(block_out_channels) - 1 + + if down_block_type == "CrossAttnDownBlockMotion": + down_block = CrossAttnDownBlockMotion( + in_channels=input_channel, + out_channels=output_channel, + temb_channels=time_embed_dim, + dropout=0, + num_layers=layers_per_block, + transformer_layers_per_block=transformer_layers_per_block[i], + resnet_eps=norm_eps, + resnet_time_scale_shift=resnet_time_scale_shift, + resnet_act_fn=act_fn, + resnet_groups=norm_num_groups, + resnet_pre_norm=True, + num_attention_heads=num_attention_heads[i], + cross_attention_dim=cross_attention_dim[i], + add_downsample=not is_final_block, + dual_cross_attention=False, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention[i], + upcast_attention=upcast_attention, + temporal_num_attention_heads=motion_num_attention_heads[i], + temporal_max_seq_length=motion_max_seq_length, + temporal_transformer_layers_per_block=temporal_transformer_layers_per_block[i], + temporal_double_self_attention=False, + ) + elif down_block_type == "DownBlockMotion": + down_block = DownBlockMotion( + in_channels=input_channel, + out_channels=output_channel, + temb_channels=time_embed_dim, + dropout=0, + num_layers=layers_per_block, + resnet_eps=norm_eps, + resnet_time_scale_shift=resnet_time_scale_shift, + resnet_act_fn=act_fn, + resnet_groups=norm_num_groups, + resnet_pre_norm=True, + add_downsample=not is_final_block, + temporal_num_attention_heads=motion_num_attention_heads[i], + temporal_max_seq_length=motion_max_seq_length, + temporal_transformer_layers_per_block=temporal_transformer_layers_per_block[i], + temporal_double_self_attention=False, + ) + else: + raise ValueError( + "Invalid `block_type` encountered. Must be one of `CrossAttnDownBlockMotion` or `DownBlockMotion`" + ) + + self.down_blocks.append(down_block) + + for _ in range(layers_per_block): + controlnet_block = nn.Conv2d(output_channel, output_channel, kernel_size=1) + controlnet_block = zero_module(controlnet_block) + self.controlnet_down_blocks.append(controlnet_block) + + if not is_final_block: + controlnet_block = nn.Conv2d(output_channel, output_channel, kernel_size=1) + controlnet_block = zero_module(controlnet_block) + self.controlnet_down_blocks.append(controlnet_block) + + # mid + mid_block_channels = block_out_channels[-1] + + controlnet_block = nn.Conv2d(mid_block_channels, mid_block_channels, kernel_size=1) + controlnet_block = zero_module(controlnet_block) + self.controlnet_mid_block = controlnet_block + + if transformer_layers_per_mid_block is None: + transformer_layers_per_mid_block = ( + transformer_layers_per_block[-1] if isinstance(transformer_layers_per_block[-1], int) else 1 + ) + + self.mid_block = UNetMidBlock2DCrossAttn( + in_channels=mid_block_channels, + temb_channels=time_embed_dim, + dropout=0, + num_layers=1, + transformer_layers_per_block=transformer_layers_per_mid_block, + resnet_eps=norm_eps, + resnet_time_scale_shift=resnet_time_scale_shift, + resnet_act_fn=act_fn, + resnet_groups=norm_num_groups, + resnet_pre_norm=True, + num_attention_heads=num_attention_heads[-1], + output_scale_factor=mid_block_scale_factor, + cross_attention_dim=cross_attention_dim[-1], + dual_cross_attention=False, + use_linear_projection=use_linear_projection, + upcast_attention=upcast_attention, + attention_type="default", + ) + + @classmethod + def from_unet( + cls, + unet: UNet2DConditionModel, + controlnet_conditioning_channel_order: str = "rgb", + conditioning_embedding_out_channels: Optional[Tuple[int, ...]] = (16, 32, 96, 256), + load_weights_from_unet: bool = True, + conditioning_channels: int = 3, + ) -> "SparseControlNetModel": + r""" + Instantiate a [`SparseControlNetModel`] from [`UNet2DConditionModel`]. + + Parameters: + unet (`UNet2DConditionModel`): + The UNet model weights to copy to the [`SparseControlNetModel`]. All configuration options are also + copied where applicable. + """ + transformer_layers_per_block = ( + unet.config.transformer_layers_per_block if "transformer_layers_per_block" in unet.config else 1 + ) + down_block_types = unet.config.down_block_types + + for i in range(len(down_block_types)): + if "CrossAttn" in down_block_types[i]: + down_block_types[i] = "CrossAttnDownBlockMotion" + elif "Down" in down_block_types[i]: + down_block_types[i] = "DownBlockMotion" + else: + raise ValueError("Invalid `block_type` encountered. Must be a cross-attention or down block") + + controlnet = cls( + in_channels=unet.config.in_channels, + conditioning_channels=conditioning_channels, + flip_sin_to_cos=unet.config.flip_sin_to_cos, + freq_shift=unet.config.freq_shift, + down_block_types=unet.config.down_block_types, + only_cross_attention=unet.config.only_cross_attention, + block_out_channels=unet.config.block_out_channels, + layers_per_block=unet.config.layers_per_block, + downsample_padding=unet.config.downsample_padding, + mid_block_scale_factor=unet.config.mid_block_scale_factor, + act_fn=unet.config.act_fn, + norm_num_groups=unet.config.norm_num_groups, + norm_eps=unet.config.norm_eps, + cross_attention_dim=unet.config.cross_attention_dim, + transformer_layers_per_block=transformer_layers_per_block, + attention_head_dim=unet.config.attention_head_dim, + num_attention_heads=unet.config.num_attention_heads, + use_linear_projection=unet.config.use_linear_projection, + upcast_attention=unet.config.upcast_attention, + resnet_time_scale_shift=unet.config.resnet_time_scale_shift, + conditioning_embedding_out_channels=conditioning_embedding_out_channels, + controlnet_conditioning_channel_order=controlnet_conditioning_channel_order, + ) + + if load_weights_from_unet: + controlnet.conv_in.load_state_dict(unet.conv_in.state_dict(), strict=False) + controlnet.time_proj.load_state_dict(unet.time_proj.state_dict(), strict=False) + controlnet.time_embedding.load_state_dict(unet.time_embedding.state_dict(), strict=False) + controlnet.down_blocks.load_state_dict(unet.down_blocks.state_dict(), strict=False) + controlnet.mid_block.load_state_dict(unet.mid_block.state_dict(), strict=False) + + return controlnet + + @property + # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors + def attn_processors(self) -> Dict[str, AttentionProcessor]: + r""" + Returns: + `dict` of attention processors: A dictionary containing all attention processors used in the model with + indexed by its weight name. + """ + # set recursively + processors = {} + + def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]): + if hasattr(module, "get_processor"): + processors[f"{name}.processor"] = module.get_processor() + + for sub_name, child in module.named_children(): + fn_recursive_add_processors(f"{name}.{sub_name}", child, processors) + + return processors + + for name, module in self.named_children(): + fn_recursive_add_processors(name, module, processors) + + return processors + + # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor + def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]): + r""" + Sets the attention processor to use to compute attention. + + Parameters: + processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`): + The instantiated processor class or a dictionary of processor classes that will be set as the processor + for **all** `Attention` layers. + + If `processor` is a dict, the key needs to define the path to the corresponding cross attention + processor. This is strongly recommended when setting trainable attention processors. + + """ + count = len(self.attn_processors.keys()) + + if isinstance(processor, dict) and len(processor) != count: + raise ValueError( + f"A dict of processors was passed, but the number of processors {len(processor)} does not match the" + f" number of attention layers: {count}. Please make sure to pass {count} processor classes." + ) + + def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): + if hasattr(module, "set_processor"): + if not isinstance(processor, dict): + module.set_processor(processor) + else: + module.set_processor(processor.pop(f"{name}.processor")) + + for sub_name, child in module.named_children(): + fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor) + + for name, module in self.named_children(): + fn_recursive_attn_processor(name, module, processor) + + # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor + def set_default_attn_processor(self): + """ + Disables custom attention processors and sets the default attention implementation. + """ + if all(proc.__class__ in ADDED_KV_ATTENTION_PROCESSORS for proc in self.attn_processors.values()): + processor = AttnAddedKVProcessor() + elif all(proc.__class__ in CROSS_ATTENTION_PROCESSORS for proc in self.attn_processors.values()): + processor = AttnProcessor() + else: + raise ValueError( + f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}" + ) + + self.set_attn_processor(processor) + + # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attention_slice + def set_attention_slice(self, slice_size: Union[str, int, List[int]]) -> None: + r""" + Enable sliced attention computation. + + When this option is enabled, the attention module splits the input tensor in slices to compute attention in + several steps. This is useful for saving some memory in exchange for a small decrease in speed. + + Args: + slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`): + When `"auto"`, input to the attention heads is halved, so attention is computed in two steps. If + `"max"`, maximum amount of memory is saved by running only one slice at a time. If a number is + provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim` + must be a multiple of `slice_size`. + """ + sliceable_head_dims = [] + + def fn_recursive_retrieve_sliceable_dims(module: torch.nn.Module): + if hasattr(module, "set_attention_slice"): + sliceable_head_dims.append(module.sliceable_head_dim) + + for child in module.children(): + fn_recursive_retrieve_sliceable_dims(child) + + # retrieve number of attention layers + for module in self.children(): + fn_recursive_retrieve_sliceable_dims(module) + + num_sliceable_layers = len(sliceable_head_dims) + + if slice_size == "auto": + # half the attention head size is usually a good trade-off between + # speed and memory + slice_size = [dim // 2 for dim in sliceable_head_dims] + elif slice_size == "max": + # make smallest slice possible + slice_size = num_sliceable_layers * [1] + + slice_size = num_sliceable_layers * [slice_size] if not isinstance(slice_size, list) else slice_size + + if len(slice_size) != len(sliceable_head_dims): + raise ValueError( + f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different" + f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}." + ) + + for i in range(len(slice_size)): + size = slice_size[i] + dim = sliceable_head_dims[i] + if size is not None and size > dim: + raise ValueError(f"size {size} has to be smaller or equal to {dim}.") + + # Recursively walk through all the children. + # Any children which exposes the set_attention_slice method + # gets the message + def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[int]): + if hasattr(module, "set_attention_slice"): + module.set_attention_slice(slice_size.pop()) + + for child in module.children(): + fn_recursive_set_attention_slice(child, slice_size) + + reversed_slice_size = list(reversed(slice_size)) + for module in self.children(): + fn_recursive_set_attention_slice(module, reversed_slice_size) + + def _set_gradient_checkpointing(self, module, value: bool = False) -> None: + if isinstance(module, (CrossAttnDownBlockMotion, DownBlockMotion, UNetMidBlock2DCrossAttn)): + module.gradient_checkpointing = value + + def forward( + self, + sample: torch.Tensor, + timestep: Union[torch.Tensor, float, int], + encoder_hidden_states: torch.Tensor, + controlnet_cond: torch.Tensor, + conditioning_scale: float = 1.0, + timestep_cond: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + conditioning_mask: Optional[torch.Tensor] = None, + guess_mode: bool = False, + return_dict: bool = True, + ) -> Union[SparseControlNetOutput, Tuple[Tuple[torch.Tensor, ...], torch.Tensor]]: + """ + The [`SparseControlNetModel`] forward method. + + Args: + sample (`torch.Tensor`): + The noisy input tensor. + timestep (`Union[torch.Tensor, float, int]`): + The number of timesteps to denoise an input. + encoder_hidden_states (`torch.Tensor`): + The encoder hidden states. + controlnet_cond (`torch.Tensor`): + The conditional input tensor of shape `(batch_size, sequence_length, hidden_size)`. + conditioning_scale (`float`, defaults to `1.0`): + The scale factor for ControlNet outputs. + class_labels (`torch.Tensor`, *optional*, defaults to `None`): + Optional class labels for conditioning. Their embeddings will be summed with the timestep embeddings. + timestep_cond (`torch.Tensor`, *optional*, defaults to `None`): + Additional conditional embeddings for timestep. If provided, the embeddings will be summed with the + timestep_embedding passed through the `self.time_embedding` layer to obtain the final timestep + embeddings. + attention_mask (`torch.Tensor`, *optional*, defaults to `None`): + An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. If `1` the mask + is kept, otherwise if `0` it is discarded. Mask will be converted into a bias, which adds large + negative values to the attention scores corresponding to "discard" tokens. + added_cond_kwargs (`dict`): + Additional conditions for the Stable Diffusion XL UNet. + cross_attention_kwargs (`dict[str]`, *optional*, defaults to `None`): + A kwargs dictionary that if specified is passed along to the `AttnProcessor`. + guess_mode (`bool`, defaults to `False`): + In this mode, the ControlNet encoder tries its best to recognize the input content of the input even if + you remove all prompts. A `guidance_scale` between 3.0 and 5.0 is recommended. + return_dict (`bool`, defaults to `True`): + Whether or not to return a [`~models.controlnet.ControlNetOutput`] instead of a plain tuple. + Returns: + [`~models.controlnet.ControlNetOutput`] **or** `tuple`: + If `return_dict` is `True`, a [`~models.controlnet.ControlNetOutput`] is returned, otherwise a tuple is + returned where the first element is the sample tensor. + """ + sample_batch_size, sample_channels, sample_num_frames, sample_height, sample_width = sample.shape + sample = torch.zeros_like(sample) + + # check channel order + channel_order = self.config.controlnet_conditioning_channel_order + + if channel_order == "rgb": + # in rgb order by default + ... + elif channel_order == "bgr": + controlnet_cond = torch.flip(controlnet_cond, dims=[1]) + else: + raise ValueError(f"unknown `controlnet_conditioning_channel_order`: {channel_order}") + + # prepare attention_mask + if attention_mask is not None: + attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0 + attention_mask = attention_mask.unsqueeze(1) + + # 1. time + timesteps = timestep + if not torch.is_tensor(timesteps): + # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can + # This would be a good case for the `match` statement (Python 3.10+) + is_mps = sample.device.type == "mps" + if isinstance(timestep, float): + dtype = torch.float32 if is_mps else torch.float64 + else: + dtype = torch.int32 if is_mps else torch.int64 + timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device) + elif len(timesteps.shape) == 0: + timesteps = timesteps[None].to(sample.device) + + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timesteps = timesteps.expand(sample.shape[0]) + + t_emb = self.time_proj(timesteps) + + # timesteps does not contain any weights and will always return f32 tensors + # but time_embedding might actually be running in fp16. so we need to cast here. + # there might be better ways to encapsulate this. + t_emb = t_emb.to(dtype=sample.dtype) + + emb = self.time_embedding(t_emb, timestep_cond) + emb = emb.repeat_interleave(sample_num_frames, dim=0) + + # 2. pre-process + batch_size, channels, num_frames, height, width = sample.shape + + sample = sample.permute(0, 2, 1, 3, 4).reshape(batch_size * num_frames, channels, height, width) + sample = self.conv_in(sample) + + batch_frames, channels, height, width = sample.shape + sample = sample[:, None].reshape(sample_batch_size, sample_num_frames, channels, height, width) + + if self.concat_conditioning_mask: + controlnet_cond = torch.cat([controlnet_cond, conditioning_mask], dim=1) + + batch_size, channels, num_frames, height, width = controlnet_cond.shape + controlnet_cond = controlnet_cond.permute(0, 2, 1, 3, 4).reshape( + batch_size * num_frames, channels, height, width + ) + controlnet_cond = self.controlnet_cond_embedding(controlnet_cond) + batch_frames, channels, height, width = controlnet_cond.shape + controlnet_cond = controlnet_cond[:, None].reshape(batch_size, num_frames, channels, height, width) + + sample = sample + controlnet_cond + + batch_size, num_frames, channels, height, width = sample.shape + sample = sample.reshape(sample_batch_size * sample_num_frames, channels, height, width) + + # 3. down + down_block_res_samples = (sample,) + for downsample_block in self.down_blocks: + if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention: + sample, res_samples = downsample_block( + hidden_states=sample, + temb=emb, + encoder_hidden_states=encoder_hidden_states, + attention_mask=attention_mask, + num_frames=num_frames, + cross_attention_kwargs=cross_attention_kwargs, + ) + else: + sample, res_samples = downsample_block(hidden_states=sample, temb=emb, num_frames=num_frames) + + down_block_res_samples += res_samples + + # 4. mid + if self.mid_block is not None: + if hasattr(self.mid_block, "has_cross_attention") and self.mid_block.has_cross_attention: + sample = self.mid_block( + sample, + emb, + encoder_hidden_states=encoder_hidden_states, + attention_mask=attention_mask, + cross_attention_kwargs=cross_attention_kwargs, + ) + else: + sample = self.mid_block(sample, emb) + + # 5. Control net blocks + controlnet_down_block_res_samples = () + + for down_block_res_sample, controlnet_block in zip(down_block_res_samples, self.controlnet_down_blocks): + down_block_res_sample = controlnet_block(down_block_res_sample) + controlnet_down_block_res_samples = controlnet_down_block_res_samples + (down_block_res_sample,) + + down_block_res_samples = controlnet_down_block_res_samples + mid_block_res_sample = self.controlnet_mid_block(sample) + + # 6. scaling + if guess_mode and not self.config.global_pool_conditions: + scales = torch.logspace(-1, 0, len(down_block_res_samples) + 1, device=sample.device) # 0.1 to 1.0 + scales = scales * conditioning_scale + down_block_res_samples = [sample * scale for sample, scale in zip(down_block_res_samples, scales)] + mid_block_res_sample = mid_block_res_sample * scales[-1] # last one + else: + down_block_res_samples = [sample * conditioning_scale for sample in down_block_res_samples] + mid_block_res_sample = mid_block_res_sample * conditioning_scale + + if self.config.global_pool_conditions: + down_block_res_samples = [ + torch.mean(sample, dim=(2, 3), keepdim=True) for sample in down_block_res_samples + ] + mid_block_res_sample = torch.mean(mid_block_res_sample, dim=(2, 3), keepdim=True) + + if not return_dict: + return (down_block_res_samples, mid_block_res_sample) + + return SparseControlNetOutput( + down_block_res_samples=down_block_res_samples, mid_block_res_sample=mid_block_res_sample + ) + + +# Copied from diffusers.models.controlnets.controlnet.zero_module +def zero_module(module: nn.Module) -> nn.Module: + for p in module.parameters(): + nn.init.zeros_(p) + return module From 82e311589340a67e6edf12197d8f5916903fafb2 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Wed, 23 Oct 2024 18:06:58 +0530 Subject: [PATCH 16/21] updates --- src/diffusers/models/__init__.py | 7 +- src/diffusers/models/controlnet_flux.py | 529 +---------------- src/diffusers/models/controlnets/__init__.py | 1 + .../models/controlnets/controlnet_flux.py | 536 ++++++++++++++++++ 4 files changed, 557 insertions(+), 516 deletions(-) create mode 100644 src/diffusers/models/controlnets/controlnet_flux.py diff --git a/src/diffusers/models/__init__.py b/src/diffusers/models/__init__.py index dca13491ec6e..947099a737cc 100644 --- a/src/diffusers/models/__init__.py +++ b/src/diffusers/models/__init__.py @@ -34,8 +34,8 @@ _import_structure["autoencoders.autoencoder_tiny"] = ["AutoencoderTiny"] _import_structure["autoencoders.consistency_decoder_vae"] = ["ConsistencyDecoderVAE"] _import_structure["autoencoders.vq_model"] = ["VQModel"] - _import_structure["controlnet_flux"] = ["FluxControlNetModel", "FluxMultiControlNetModel"] _import_structure["controlnets.controlnet"] = ["ControlNetModel"] + _import_structure["controlnets.controlnet_flux"] = ["FluxControlNetModel", "FluxMultiControlNetModel"] _import_structure["controlnets.controlnet_hunyuan"] = [ "HunyuanDiT2DControlNetModel", "HunyuanDiT2DMultiControlNetModel", @@ -91,12 +91,11 @@ ConsistencyDecoderVAE, VQModel, ) - from .controlnet import ControlNetModel - from .controlnet_flux import FluxControlNetModel, FluxMultiControlNetModel - from .controlnet_sparsectrl import SparseControlNetModel from .controlnets import ( ControlNetModel, ControlNetXSAdapter, + FluxControlNetModel, + FluxMultiControlNetModel, HunyuanDiT2DControlNetModel, HunyuanDiT2DMultiControlNetModel, SD3ControlNetModel, diff --git a/src/diffusers/models/controlnet_flux.py b/src/diffusers/models/controlnet_flux.py index 961e30155a3d..9b256239d712 100644 --- a/src/diffusers/models/controlnet_flux.py +++ b/src/diffusers/models/controlnet_flux.py @@ -12,525 +12,30 @@ # See the License for the specific language governing permissions and # limitations under the License. -from dataclasses import dataclass -from typing import Any, Dict, List, Optional, Tuple, Union -import torch -import torch.nn as nn - -from ..configuration_utils import ConfigMixin, register_to_config -from ..loaders import PeftAdapterMixin -from ..models.attention_processor import AttentionProcessor -from ..models.modeling_utils import ModelMixin -from ..utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers -from .controlnet import BaseOutput, ControlNetConditioningEmbedding, zero_module -from .embeddings import CombinedTimestepGuidanceTextProjEmbeddings, CombinedTimestepTextProjEmbeddings, FluxPosEmbed -from .modeling_outputs import Transformer2DModelOutput -from .transformers.transformer_flux import FluxSingleTransformerBlock, FluxTransformerBlock +from ..utils import deprecate, logging +from .controlnets.controlnet_flux import FluxControlNetModel, FluxControlNetOutput, FluxMultiControlNetModel logger = logging.get_logger(__name__) # pylint: disable=invalid-name -@dataclass -class FluxControlNetOutput(BaseOutput): - controlnet_block_samples: Tuple[torch.Tensor] - controlnet_single_block_samples: Tuple[torch.Tensor] - - -class FluxControlNetModel(ModelMixin, ConfigMixin, PeftAdapterMixin): - _supports_gradient_checkpointing = True - - @register_to_config - def __init__( - self, - patch_size: int = 1, - in_channels: int = 64, - num_layers: int = 19, - num_single_layers: int = 38, - attention_head_dim: int = 128, - num_attention_heads: int = 24, - joint_attention_dim: int = 4096, - pooled_projection_dim: int = 768, - guidance_embeds: bool = False, - axes_dims_rope: List[int] = [16, 56, 56], - num_mode: int = None, - conditioning_embedding_channels: int = None, - ): - super().__init__() - self.out_channels = in_channels - self.inner_dim = num_attention_heads * attention_head_dim - - self.pos_embed = FluxPosEmbed(theta=10000, axes_dim=axes_dims_rope) - text_time_guidance_cls = ( - CombinedTimestepGuidanceTextProjEmbeddings if guidance_embeds else CombinedTimestepTextProjEmbeddings - ) - self.time_text_embed = text_time_guidance_cls( - embedding_dim=self.inner_dim, pooled_projection_dim=pooled_projection_dim - ) - - self.context_embedder = nn.Linear(joint_attention_dim, self.inner_dim) - self.x_embedder = torch.nn.Linear(in_channels, self.inner_dim) - - self.transformer_blocks = nn.ModuleList( - [ - FluxTransformerBlock( - dim=self.inner_dim, - num_attention_heads=num_attention_heads, - attention_head_dim=attention_head_dim, - ) - for i in range(num_layers) - ] - ) - - self.single_transformer_blocks = nn.ModuleList( - [ - FluxSingleTransformerBlock( - dim=self.inner_dim, - num_attention_heads=num_attention_heads, - attention_head_dim=attention_head_dim, - ) - for i in range(num_single_layers) - ] - ) - - # controlnet_blocks - self.controlnet_blocks = nn.ModuleList([]) - for _ in range(len(self.transformer_blocks)): - self.controlnet_blocks.append(zero_module(nn.Linear(self.inner_dim, self.inner_dim))) - - self.controlnet_single_blocks = nn.ModuleList([]) - for _ in range(len(self.single_transformer_blocks)): - self.controlnet_single_blocks.append(zero_module(nn.Linear(self.inner_dim, self.inner_dim))) - - self.union = num_mode is not None - if self.union: - self.controlnet_mode_embedder = nn.Embedding(num_mode, self.inner_dim) - - if conditioning_embedding_channels is not None: - self.input_hint_block = ControlNetConditioningEmbedding( - conditioning_embedding_channels=conditioning_embedding_channels, block_out_channels=(16, 16, 16, 16) - ) - self.controlnet_x_embedder = torch.nn.Linear(in_channels, self.inner_dim) - else: - self.input_hint_block = None - self.controlnet_x_embedder = zero_module(torch.nn.Linear(in_channels, self.inner_dim)) - - self.gradient_checkpointing = False - - @property - # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors - def attn_processors(self): - r""" - Returns: - `dict` of attention processors: A dictionary containing all attention processors used in the model with - indexed by its weight name. - """ - # set recursively - processors = {} - - def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]): - if hasattr(module, "get_processor"): - processors[f"{name}.processor"] = module.get_processor() - - for sub_name, child in module.named_children(): - fn_recursive_add_processors(f"{name}.{sub_name}", child, processors) - - return processors - - for name, module in self.named_children(): - fn_recursive_add_processors(name, module, processors) - - return processors - - # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor - def set_attn_processor(self, processor): - r""" - Sets the attention processor to use to compute attention. - - Parameters: - processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`): - The instantiated processor class or a dictionary of processor classes that will be set as the processor - for **all** `Attention` layers. - - If `processor` is a dict, the key needs to define the path to the corresponding cross attention - processor. This is strongly recommended when setting trainable attention processors. - - """ - count = len(self.attn_processors.keys()) - - if isinstance(processor, dict) and len(processor) != count: - raise ValueError( - f"A dict of processors was passed, but the number of processors {len(processor)} does not match the" - f" number of attention layers: {count}. Please make sure to pass {count} processor classes." - ) - - def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): - if hasattr(module, "set_processor"): - if not isinstance(processor, dict): - module.set_processor(processor) - else: - module.set_processor(processor.pop(f"{name}.processor")) - - for sub_name, child in module.named_children(): - fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor) - - for name, module in self.named_children(): - fn_recursive_attn_processor(name, module, processor) - - def _set_gradient_checkpointing(self, module, value=False): - if hasattr(module, "gradient_checkpointing"): - module.gradient_checkpointing = value - - @classmethod - def from_transformer( - cls, - transformer, - num_layers: int = 4, - num_single_layers: int = 10, - attention_head_dim: int = 128, - num_attention_heads: int = 24, - load_weights_from_transformer=True, - ): - config = transformer.config - config["num_layers"] = num_layers - config["num_single_layers"] = num_single_layers - config["attention_head_dim"] = attention_head_dim - config["num_attention_heads"] = num_attention_heads - - controlnet = cls(**config) - - if load_weights_from_transformer: - controlnet.pos_embed.load_state_dict(transformer.pos_embed.state_dict()) - controlnet.time_text_embed.load_state_dict(transformer.time_text_embed.state_dict()) - controlnet.context_embedder.load_state_dict(transformer.context_embedder.state_dict()) - controlnet.x_embedder.load_state_dict(transformer.x_embedder.state_dict()) - controlnet.transformer_blocks.load_state_dict(transformer.transformer_blocks.state_dict(), strict=False) - controlnet.single_transformer_blocks.load_state_dict( - transformer.single_transformer_blocks.state_dict(), strict=False - ) - - controlnet.controlnet_x_embedder = zero_module(controlnet.controlnet_x_embedder) - - return controlnet - - def forward( - self, - hidden_states: torch.Tensor, - controlnet_cond: torch.Tensor, - controlnet_mode: torch.Tensor = None, - conditioning_scale: float = 1.0, - encoder_hidden_states: torch.Tensor = None, - pooled_projections: torch.Tensor = None, - timestep: torch.LongTensor = None, - img_ids: torch.Tensor = None, - txt_ids: torch.Tensor = None, - guidance: torch.Tensor = None, - joint_attention_kwargs: Optional[Dict[str, Any]] = None, - return_dict: bool = True, - ) -> Union[torch.FloatTensor, Transformer2DModelOutput]: - """ - The [`FluxTransformer2DModel`] forward method. - - Args: - hidden_states (`torch.FloatTensor` of shape `(batch size, channel, height, width)`): - Input `hidden_states`. - controlnet_cond (`torch.Tensor`): - The conditional input tensor of shape `(batch_size, sequence_length, hidden_size)`. - controlnet_mode (`torch.Tensor`): - The mode tensor of shape `(batch_size, 1)`. - conditioning_scale (`float`, defaults to `1.0`): - The scale factor for ControlNet outputs. - encoder_hidden_states (`torch.FloatTensor` of shape `(batch size, sequence_len, embed_dims)`): - Conditional embeddings (embeddings computed from the input conditions such as prompts) to use. - pooled_projections (`torch.FloatTensor` of shape `(batch_size, projection_dim)`): Embeddings projected - from the embeddings of input conditions. - timestep ( `torch.LongTensor`): - Used to indicate denoising step. - block_controlnet_hidden_states: (`list` of `torch.Tensor`): - A list of tensors that if specified are added to the residuals of transformer blocks. - joint_attention_kwargs (`dict`, *optional*): - A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under - `self.processor` in - [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). - return_dict (`bool`, *optional*, defaults to `True`): - Whether or not to return a [`~models.transformer_2d.Transformer2DModelOutput`] instead of a plain - tuple. - - Returns: - If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a - `tuple` where the first element is the sample tensor. - """ - if joint_attention_kwargs is not None: - joint_attention_kwargs = joint_attention_kwargs.copy() - lora_scale = joint_attention_kwargs.pop("scale", 1.0) - else: - lora_scale = 1.0 - - if USE_PEFT_BACKEND: - # weight the lora layers by setting `lora_scale` for each PEFT layer - scale_lora_layers(self, lora_scale) - else: - if joint_attention_kwargs is not None and joint_attention_kwargs.get("scale", None) is not None: - logger.warning( - "Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective." - ) - hidden_states = self.x_embedder(hidden_states) - - if self.input_hint_block is not None: - controlnet_cond = self.input_hint_block(controlnet_cond) - batch_size, channels, height_pw, width_pw = controlnet_cond.shape - height = height_pw // self.config.patch_size - width = width_pw // self.config.patch_size - controlnet_cond = controlnet_cond.reshape( - batch_size, channels, height, self.config.patch_size, width, self.config.patch_size - ) - controlnet_cond = controlnet_cond.permute(0, 2, 4, 1, 3, 5) - controlnet_cond = controlnet_cond.reshape(batch_size, height * width, -1) - # add - hidden_states = hidden_states + self.controlnet_x_embedder(controlnet_cond) - - timestep = timestep.to(hidden_states.dtype) * 1000 - if guidance is not None: - guidance = guidance.to(hidden_states.dtype) * 1000 - else: - guidance = None - temb = ( - self.time_text_embed(timestep, pooled_projections) - if guidance is None - else self.time_text_embed(timestep, guidance, pooled_projections) - ) - encoder_hidden_states = self.context_embedder(encoder_hidden_states) - - if self.union: - # union mode - if controlnet_mode is None: - raise ValueError("`controlnet_mode` cannot be `None` when applying ControlNet-Union") - # union mode emb - controlnet_mode_emb = self.controlnet_mode_embedder(controlnet_mode) - encoder_hidden_states = torch.cat([controlnet_mode_emb, encoder_hidden_states], dim=1) - txt_ids = torch.cat([txt_ids[:1], txt_ids], dim=0) - - if txt_ids.ndim == 3: - logger.warning( - "Passing `txt_ids` 3d torch.Tensor is deprecated." - "Please remove the batch dimension and pass it as a 2d torch Tensor" - ) - txt_ids = txt_ids[0] - if img_ids.ndim == 3: - logger.warning( - "Passing `img_ids` 3d torch.Tensor is deprecated." - "Please remove the batch dimension and pass it as a 2d torch Tensor" - ) - img_ids = img_ids[0] - - ids = torch.cat((txt_ids, img_ids), dim=0) - image_rotary_emb = self.pos_embed(ids) - - block_samples = () - for index_block, block in enumerate(self.transformer_blocks): - if self.training and self.gradient_checkpointing: - - def create_custom_forward(module, return_dict=None): - def custom_forward(*inputs): - if return_dict is not None: - return module(*inputs, return_dict=return_dict) - else: - return module(*inputs) - - return custom_forward - - ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} - encoder_hidden_states, hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(block), - hidden_states, - encoder_hidden_states, - temb, - image_rotary_emb, - **ckpt_kwargs, - ) - - else: - encoder_hidden_states, hidden_states = block( - hidden_states=hidden_states, - encoder_hidden_states=encoder_hidden_states, - temb=temb, - image_rotary_emb=image_rotary_emb, - ) - block_samples = block_samples + (hidden_states,) - - hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1) - - single_block_samples = () - for index_block, block in enumerate(self.single_transformer_blocks): - if self.training and self.gradient_checkpointing: - - def create_custom_forward(module, return_dict=None): - def custom_forward(*inputs): - if return_dict is not None: - return module(*inputs, return_dict=return_dict) - else: - return module(*inputs) - - return custom_forward - - ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} - hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(block), - hidden_states, - temb, - image_rotary_emb, - **ckpt_kwargs, - ) - - else: - hidden_states = block( - hidden_states=hidden_states, - temb=temb, - image_rotary_emb=image_rotary_emb, - ) - single_block_samples = single_block_samples + (hidden_states[:, encoder_hidden_states.shape[1] :],) - - # controlnet block - controlnet_block_samples = () - for block_sample, controlnet_block in zip(block_samples, self.controlnet_blocks): - block_sample = controlnet_block(block_sample) - controlnet_block_samples = controlnet_block_samples + (block_sample,) - - controlnet_single_block_samples = () - for single_block_sample, controlnet_block in zip(single_block_samples, self.controlnet_single_blocks): - single_block_sample = controlnet_block(single_block_sample) - controlnet_single_block_samples = controlnet_single_block_samples + (single_block_sample,) - - # scaling - controlnet_block_samples = [sample * conditioning_scale for sample in controlnet_block_samples] - controlnet_single_block_samples = [sample * conditioning_scale for sample in controlnet_single_block_samples] - - controlnet_block_samples = None if len(controlnet_block_samples) == 0 else controlnet_block_samples - controlnet_single_block_samples = ( - None if len(controlnet_single_block_samples) == 0 else controlnet_single_block_samples - ) - - if USE_PEFT_BACKEND: - # remove `lora_scale` from each PEFT layer - unscale_lora_layers(self, lora_scale) - - if not return_dict: - return (controlnet_block_samples, controlnet_single_block_samples) - - return FluxControlNetOutput( - controlnet_block_samples=controlnet_block_samples, - controlnet_single_block_samples=controlnet_single_block_samples, - ) - - -class FluxMultiControlNetModel(ModelMixin): - r""" - `FluxMultiControlNetModel` wrapper class for Multi-FluxControlNetModel - - This module is a wrapper for multiple instances of the `FluxControlNetModel`. The `forward()` API is designed to be - compatible with `FluxControlNetModel`. - - Args: - controlnets (`List[FluxControlNetModel]`): - Provides additional conditioning to the unet during the denoising process. You must set multiple - `FluxControlNetModel` as a list. - """ - - def __init__(self, controlnets): - super().__init__() - self.nets = nn.ModuleList(controlnets) - - def forward( - self, - hidden_states: torch.FloatTensor, - controlnet_cond: List[torch.tensor], - controlnet_mode: List[torch.tensor], - conditioning_scale: List[float], - encoder_hidden_states: torch.Tensor = None, - pooled_projections: torch.Tensor = None, - timestep: torch.LongTensor = None, - img_ids: torch.Tensor = None, - txt_ids: torch.Tensor = None, - guidance: torch.Tensor = None, - joint_attention_kwargs: Optional[Dict[str, Any]] = None, - return_dict: bool = True, - ) -> Union[FluxControlNetOutput, Tuple]: - # ControlNet-Union with multiple conditions - # only load one ControlNet for saving memories - if len(self.nets) == 1 and self.nets[0].union: - controlnet = self.nets[0] - - for i, (image, mode, scale) in enumerate(zip(controlnet_cond, controlnet_mode, conditioning_scale)): - block_samples, single_block_samples = controlnet( - hidden_states=hidden_states, - controlnet_cond=image, - controlnet_mode=mode[:, None], - conditioning_scale=scale, - timestep=timestep, - guidance=guidance, - pooled_projections=pooled_projections, - encoder_hidden_states=encoder_hidden_states, - txt_ids=txt_ids, - img_ids=img_ids, - joint_attention_kwargs=joint_attention_kwargs, - return_dict=return_dict, - ) - - # merge samples - if i == 0: - control_block_samples = block_samples - control_single_block_samples = single_block_samples - else: - control_block_samples = [ - control_block_sample + block_sample - for control_block_sample, block_sample in zip(control_block_samples, block_samples) - ] +class FluxControlNetOutput(FluxControlNetOutput): + def __init__(self, *args, **kwargs): + deprecation_message = "Importing `FluxControlNetOutput` from `diffusers.models.controlnet_flux` is deprecated and this will be removed in a future version. Please use `from diffusers.models.controlnets.controlnet_flux import FluxControlNetOutput`, instead." + deprecate("FluxControlNetOutput", "0.34", deprecation_message) + super().__init__(*args, **kwargs) - control_single_block_samples = [ - control_single_block_sample + block_sample - for control_single_block_sample, block_sample in zip( - control_single_block_samples, single_block_samples - ) - ] - # Regular Multi-ControlNets - # load all ControlNets into memories - else: - for i, (image, mode, scale, controlnet) in enumerate( - zip(controlnet_cond, controlnet_mode, conditioning_scale, self.nets) - ): - block_samples, single_block_samples = controlnet( - hidden_states=hidden_states, - controlnet_cond=image, - controlnet_mode=mode[:, None], - conditioning_scale=scale, - timestep=timestep, - guidance=guidance, - pooled_projections=pooled_projections, - encoder_hidden_states=encoder_hidden_states, - txt_ids=txt_ids, - img_ids=img_ids, - joint_attention_kwargs=joint_attention_kwargs, - return_dict=return_dict, - ) +class FluxControlNetModel(FluxControlNetModel): + def __init__(self, *args, **kwargs): + deprecation_message = "Importing `FluxControlNetModel` from `diffusers.models.controlnet_flux` is deprecated and this will be removed in a future version. Please use `from diffusers.models.controlnets.controlnet_flux import FluxControlNetModel`, instead." + deprecate("FluxControlNetModel", "0.34", deprecation_message) + super().__init__(*args, **kwargs) - # merge samples - if i == 0: - control_block_samples = block_samples - control_single_block_samples = single_block_samples - else: - if block_samples is not None and control_block_samples is not None: - control_block_samples = [ - control_block_sample + block_sample - for control_block_sample, block_sample in zip(control_block_samples, block_samples) - ] - if single_block_samples is not None and control_single_block_samples is not None: - control_single_block_samples = [ - control_single_block_sample + block_sample - for control_single_block_sample, block_sample in zip( - control_single_block_samples, single_block_samples - ) - ] - return control_block_samples, control_single_block_samples +class FluxMultiControlNetModel(FluxMultiControlNetModel): + def __init__(self, *args, **kwargs): + deprecation_message = "Importing `FluxMultiControlNetModel` from `diffusers.models.controlnet_flux` is deprecated and this will be removed in a future version. Please use `from diffusers.models.controlnets.controlnet_flux import FluxMultiControlNetModel`, instead." + deprecate("FluxMultiControlNetModel", "0.34", deprecation_message) + super().__init__(*args, **kwargs) diff --git a/src/diffusers/models/controlnets/__init__.py b/src/diffusers/models/controlnets/__init__.py index fb713009a456..d9690182983d 100644 --- a/src/diffusers/models/controlnets/__init__.py +++ b/src/diffusers/models/controlnets/__init__.py @@ -3,6 +3,7 @@ if is_torch_available(): from .controlnet import ControlNetModel, ControlNetOutput + from .controlnet_flux import FluxControlNetModel, FluxControlNetOutput, FluxMultiControlNetModel from .controlnet_hunyuan import ( HunyuanControlNetOutput, HunyuanDiT2DControlNetModel, diff --git a/src/diffusers/models/controlnets/controlnet_flux.py b/src/diffusers/models/controlnets/controlnet_flux.py new file mode 100644 index 000000000000..961e30155a3d --- /dev/null +++ b/src/diffusers/models/controlnets/controlnet_flux.py @@ -0,0 +1,536 @@ +# Copyright 2024 Black Forest Labs, The HuggingFace Team and The InstantX Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass +from typing import Any, Dict, List, Optional, Tuple, Union + +import torch +import torch.nn as nn + +from ..configuration_utils import ConfigMixin, register_to_config +from ..loaders import PeftAdapterMixin +from ..models.attention_processor import AttentionProcessor +from ..models.modeling_utils import ModelMixin +from ..utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers +from .controlnet import BaseOutput, ControlNetConditioningEmbedding, zero_module +from .embeddings import CombinedTimestepGuidanceTextProjEmbeddings, CombinedTimestepTextProjEmbeddings, FluxPosEmbed +from .modeling_outputs import Transformer2DModelOutput +from .transformers.transformer_flux import FluxSingleTransformerBlock, FluxTransformerBlock + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +@dataclass +class FluxControlNetOutput(BaseOutput): + controlnet_block_samples: Tuple[torch.Tensor] + controlnet_single_block_samples: Tuple[torch.Tensor] + + +class FluxControlNetModel(ModelMixin, ConfigMixin, PeftAdapterMixin): + _supports_gradient_checkpointing = True + + @register_to_config + def __init__( + self, + patch_size: int = 1, + in_channels: int = 64, + num_layers: int = 19, + num_single_layers: int = 38, + attention_head_dim: int = 128, + num_attention_heads: int = 24, + joint_attention_dim: int = 4096, + pooled_projection_dim: int = 768, + guidance_embeds: bool = False, + axes_dims_rope: List[int] = [16, 56, 56], + num_mode: int = None, + conditioning_embedding_channels: int = None, + ): + super().__init__() + self.out_channels = in_channels + self.inner_dim = num_attention_heads * attention_head_dim + + self.pos_embed = FluxPosEmbed(theta=10000, axes_dim=axes_dims_rope) + text_time_guidance_cls = ( + CombinedTimestepGuidanceTextProjEmbeddings if guidance_embeds else CombinedTimestepTextProjEmbeddings + ) + self.time_text_embed = text_time_guidance_cls( + embedding_dim=self.inner_dim, pooled_projection_dim=pooled_projection_dim + ) + + self.context_embedder = nn.Linear(joint_attention_dim, self.inner_dim) + self.x_embedder = torch.nn.Linear(in_channels, self.inner_dim) + + self.transformer_blocks = nn.ModuleList( + [ + FluxTransformerBlock( + dim=self.inner_dim, + num_attention_heads=num_attention_heads, + attention_head_dim=attention_head_dim, + ) + for i in range(num_layers) + ] + ) + + self.single_transformer_blocks = nn.ModuleList( + [ + FluxSingleTransformerBlock( + dim=self.inner_dim, + num_attention_heads=num_attention_heads, + attention_head_dim=attention_head_dim, + ) + for i in range(num_single_layers) + ] + ) + + # controlnet_blocks + self.controlnet_blocks = nn.ModuleList([]) + for _ in range(len(self.transformer_blocks)): + self.controlnet_blocks.append(zero_module(nn.Linear(self.inner_dim, self.inner_dim))) + + self.controlnet_single_blocks = nn.ModuleList([]) + for _ in range(len(self.single_transformer_blocks)): + self.controlnet_single_blocks.append(zero_module(nn.Linear(self.inner_dim, self.inner_dim))) + + self.union = num_mode is not None + if self.union: + self.controlnet_mode_embedder = nn.Embedding(num_mode, self.inner_dim) + + if conditioning_embedding_channels is not None: + self.input_hint_block = ControlNetConditioningEmbedding( + conditioning_embedding_channels=conditioning_embedding_channels, block_out_channels=(16, 16, 16, 16) + ) + self.controlnet_x_embedder = torch.nn.Linear(in_channels, self.inner_dim) + else: + self.input_hint_block = None + self.controlnet_x_embedder = zero_module(torch.nn.Linear(in_channels, self.inner_dim)) + + self.gradient_checkpointing = False + + @property + # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors + def attn_processors(self): + r""" + Returns: + `dict` of attention processors: A dictionary containing all attention processors used in the model with + indexed by its weight name. + """ + # set recursively + processors = {} + + def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]): + if hasattr(module, "get_processor"): + processors[f"{name}.processor"] = module.get_processor() + + for sub_name, child in module.named_children(): + fn_recursive_add_processors(f"{name}.{sub_name}", child, processors) + + return processors + + for name, module in self.named_children(): + fn_recursive_add_processors(name, module, processors) + + return processors + + # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor + def set_attn_processor(self, processor): + r""" + Sets the attention processor to use to compute attention. + + Parameters: + processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`): + The instantiated processor class or a dictionary of processor classes that will be set as the processor + for **all** `Attention` layers. + + If `processor` is a dict, the key needs to define the path to the corresponding cross attention + processor. This is strongly recommended when setting trainable attention processors. + + """ + count = len(self.attn_processors.keys()) + + if isinstance(processor, dict) and len(processor) != count: + raise ValueError( + f"A dict of processors was passed, but the number of processors {len(processor)} does not match the" + f" number of attention layers: {count}. Please make sure to pass {count} processor classes." + ) + + def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): + if hasattr(module, "set_processor"): + if not isinstance(processor, dict): + module.set_processor(processor) + else: + module.set_processor(processor.pop(f"{name}.processor")) + + for sub_name, child in module.named_children(): + fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor) + + for name, module in self.named_children(): + fn_recursive_attn_processor(name, module, processor) + + def _set_gradient_checkpointing(self, module, value=False): + if hasattr(module, "gradient_checkpointing"): + module.gradient_checkpointing = value + + @classmethod + def from_transformer( + cls, + transformer, + num_layers: int = 4, + num_single_layers: int = 10, + attention_head_dim: int = 128, + num_attention_heads: int = 24, + load_weights_from_transformer=True, + ): + config = transformer.config + config["num_layers"] = num_layers + config["num_single_layers"] = num_single_layers + config["attention_head_dim"] = attention_head_dim + config["num_attention_heads"] = num_attention_heads + + controlnet = cls(**config) + + if load_weights_from_transformer: + controlnet.pos_embed.load_state_dict(transformer.pos_embed.state_dict()) + controlnet.time_text_embed.load_state_dict(transformer.time_text_embed.state_dict()) + controlnet.context_embedder.load_state_dict(transformer.context_embedder.state_dict()) + controlnet.x_embedder.load_state_dict(transformer.x_embedder.state_dict()) + controlnet.transformer_blocks.load_state_dict(transformer.transformer_blocks.state_dict(), strict=False) + controlnet.single_transformer_blocks.load_state_dict( + transformer.single_transformer_blocks.state_dict(), strict=False + ) + + controlnet.controlnet_x_embedder = zero_module(controlnet.controlnet_x_embedder) + + return controlnet + + def forward( + self, + hidden_states: torch.Tensor, + controlnet_cond: torch.Tensor, + controlnet_mode: torch.Tensor = None, + conditioning_scale: float = 1.0, + encoder_hidden_states: torch.Tensor = None, + pooled_projections: torch.Tensor = None, + timestep: torch.LongTensor = None, + img_ids: torch.Tensor = None, + txt_ids: torch.Tensor = None, + guidance: torch.Tensor = None, + joint_attention_kwargs: Optional[Dict[str, Any]] = None, + return_dict: bool = True, + ) -> Union[torch.FloatTensor, Transformer2DModelOutput]: + """ + The [`FluxTransformer2DModel`] forward method. + + Args: + hidden_states (`torch.FloatTensor` of shape `(batch size, channel, height, width)`): + Input `hidden_states`. + controlnet_cond (`torch.Tensor`): + The conditional input tensor of shape `(batch_size, sequence_length, hidden_size)`. + controlnet_mode (`torch.Tensor`): + The mode tensor of shape `(batch_size, 1)`. + conditioning_scale (`float`, defaults to `1.0`): + The scale factor for ControlNet outputs. + encoder_hidden_states (`torch.FloatTensor` of shape `(batch size, sequence_len, embed_dims)`): + Conditional embeddings (embeddings computed from the input conditions such as prompts) to use. + pooled_projections (`torch.FloatTensor` of shape `(batch_size, projection_dim)`): Embeddings projected + from the embeddings of input conditions. + timestep ( `torch.LongTensor`): + Used to indicate denoising step. + block_controlnet_hidden_states: (`list` of `torch.Tensor`): + A list of tensors that if specified are added to the residuals of transformer blocks. + joint_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~models.transformer_2d.Transformer2DModelOutput`] instead of a plain + tuple. + + Returns: + If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a + `tuple` where the first element is the sample tensor. + """ + if joint_attention_kwargs is not None: + joint_attention_kwargs = joint_attention_kwargs.copy() + lora_scale = joint_attention_kwargs.pop("scale", 1.0) + else: + lora_scale = 1.0 + + if USE_PEFT_BACKEND: + # weight the lora layers by setting `lora_scale` for each PEFT layer + scale_lora_layers(self, lora_scale) + else: + if joint_attention_kwargs is not None and joint_attention_kwargs.get("scale", None) is not None: + logger.warning( + "Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective." + ) + hidden_states = self.x_embedder(hidden_states) + + if self.input_hint_block is not None: + controlnet_cond = self.input_hint_block(controlnet_cond) + batch_size, channels, height_pw, width_pw = controlnet_cond.shape + height = height_pw // self.config.patch_size + width = width_pw // self.config.patch_size + controlnet_cond = controlnet_cond.reshape( + batch_size, channels, height, self.config.patch_size, width, self.config.patch_size + ) + controlnet_cond = controlnet_cond.permute(0, 2, 4, 1, 3, 5) + controlnet_cond = controlnet_cond.reshape(batch_size, height * width, -1) + # add + hidden_states = hidden_states + self.controlnet_x_embedder(controlnet_cond) + + timestep = timestep.to(hidden_states.dtype) * 1000 + if guidance is not None: + guidance = guidance.to(hidden_states.dtype) * 1000 + else: + guidance = None + temb = ( + self.time_text_embed(timestep, pooled_projections) + if guidance is None + else self.time_text_embed(timestep, guidance, pooled_projections) + ) + encoder_hidden_states = self.context_embedder(encoder_hidden_states) + + if self.union: + # union mode + if controlnet_mode is None: + raise ValueError("`controlnet_mode` cannot be `None` when applying ControlNet-Union") + # union mode emb + controlnet_mode_emb = self.controlnet_mode_embedder(controlnet_mode) + encoder_hidden_states = torch.cat([controlnet_mode_emb, encoder_hidden_states], dim=1) + txt_ids = torch.cat([txt_ids[:1], txt_ids], dim=0) + + if txt_ids.ndim == 3: + logger.warning( + "Passing `txt_ids` 3d torch.Tensor is deprecated." + "Please remove the batch dimension and pass it as a 2d torch Tensor" + ) + txt_ids = txt_ids[0] + if img_ids.ndim == 3: + logger.warning( + "Passing `img_ids` 3d torch.Tensor is deprecated." + "Please remove the batch dimension and pass it as a 2d torch Tensor" + ) + img_ids = img_ids[0] + + ids = torch.cat((txt_ids, img_ids), dim=0) + image_rotary_emb = self.pos_embed(ids) + + block_samples = () + for index_block, block in enumerate(self.transformer_blocks): + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module, return_dict=None): + def custom_forward(*inputs): + if return_dict is not None: + return module(*inputs, return_dict=return_dict) + else: + return module(*inputs) + + return custom_forward + + ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} + encoder_hidden_states, hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(block), + hidden_states, + encoder_hidden_states, + temb, + image_rotary_emb, + **ckpt_kwargs, + ) + + else: + encoder_hidden_states, hidden_states = block( + hidden_states=hidden_states, + encoder_hidden_states=encoder_hidden_states, + temb=temb, + image_rotary_emb=image_rotary_emb, + ) + block_samples = block_samples + (hidden_states,) + + hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1) + + single_block_samples = () + for index_block, block in enumerate(self.single_transformer_blocks): + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module, return_dict=None): + def custom_forward(*inputs): + if return_dict is not None: + return module(*inputs, return_dict=return_dict) + else: + return module(*inputs) + + return custom_forward + + ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(block), + hidden_states, + temb, + image_rotary_emb, + **ckpt_kwargs, + ) + + else: + hidden_states = block( + hidden_states=hidden_states, + temb=temb, + image_rotary_emb=image_rotary_emb, + ) + single_block_samples = single_block_samples + (hidden_states[:, encoder_hidden_states.shape[1] :],) + + # controlnet block + controlnet_block_samples = () + for block_sample, controlnet_block in zip(block_samples, self.controlnet_blocks): + block_sample = controlnet_block(block_sample) + controlnet_block_samples = controlnet_block_samples + (block_sample,) + + controlnet_single_block_samples = () + for single_block_sample, controlnet_block in zip(single_block_samples, self.controlnet_single_blocks): + single_block_sample = controlnet_block(single_block_sample) + controlnet_single_block_samples = controlnet_single_block_samples + (single_block_sample,) + + # scaling + controlnet_block_samples = [sample * conditioning_scale for sample in controlnet_block_samples] + controlnet_single_block_samples = [sample * conditioning_scale for sample in controlnet_single_block_samples] + + controlnet_block_samples = None if len(controlnet_block_samples) == 0 else controlnet_block_samples + controlnet_single_block_samples = ( + None if len(controlnet_single_block_samples) == 0 else controlnet_single_block_samples + ) + + if USE_PEFT_BACKEND: + # remove `lora_scale` from each PEFT layer + unscale_lora_layers(self, lora_scale) + + if not return_dict: + return (controlnet_block_samples, controlnet_single_block_samples) + + return FluxControlNetOutput( + controlnet_block_samples=controlnet_block_samples, + controlnet_single_block_samples=controlnet_single_block_samples, + ) + + +class FluxMultiControlNetModel(ModelMixin): + r""" + `FluxMultiControlNetModel` wrapper class for Multi-FluxControlNetModel + + This module is a wrapper for multiple instances of the `FluxControlNetModel`. The `forward()` API is designed to be + compatible with `FluxControlNetModel`. + + Args: + controlnets (`List[FluxControlNetModel]`): + Provides additional conditioning to the unet during the denoising process. You must set multiple + `FluxControlNetModel` as a list. + """ + + def __init__(self, controlnets): + super().__init__() + self.nets = nn.ModuleList(controlnets) + + def forward( + self, + hidden_states: torch.FloatTensor, + controlnet_cond: List[torch.tensor], + controlnet_mode: List[torch.tensor], + conditioning_scale: List[float], + encoder_hidden_states: torch.Tensor = None, + pooled_projections: torch.Tensor = None, + timestep: torch.LongTensor = None, + img_ids: torch.Tensor = None, + txt_ids: torch.Tensor = None, + guidance: torch.Tensor = None, + joint_attention_kwargs: Optional[Dict[str, Any]] = None, + return_dict: bool = True, + ) -> Union[FluxControlNetOutput, Tuple]: + # ControlNet-Union with multiple conditions + # only load one ControlNet for saving memories + if len(self.nets) == 1 and self.nets[0].union: + controlnet = self.nets[0] + + for i, (image, mode, scale) in enumerate(zip(controlnet_cond, controlnet_mode, conditioning_scale)): + block_samples, single_block_samples = controlnet( + hidden_states=hidden_states, + controlnet_cond=image, + controlnet_mode=mode[:, None], + conditioning_scale=scale, + timestep=timestep, + guidance=guidance, + pooled_projections=pooled_projections, + encoder_hidden_states=encoder_hidden_states, + txt_ids=txt_ids, + img_ids=img_ids, + joint_attention_kwargs=joint_attention_kwargs, + return_dict=return_dict, + ) + + # merge samples + if i == 0: + control_block_samples = block_samples + control_single_block_samples = single_block_samples + else: + control_block_samples = [ + control_block_sample + block_sample + for control_block_sample, block_sample in zip(control_block_samples, block_samples) + ] + + control_single_block_samples = [ + control_single_block_sample + block_sample + for control_single_block_sample, block_sample in zip( + control_single_block_samples, single_block_samples + ) + ] + + # Regular Multi-ControlNets + # load all ControlNets into memories + else: + for i, (image, mode, scale, controlnet) in enumerate( + zip(controlnet_cond, controlnet_mode, conditioning_scale, self.nets) + ): + block_samples, single_block_samples = controlnet( + hidden_states=hidden_states, + controlnet_cond=image, + controlnet_mode=mode[:, None], + conditioning_scale=scale, + timestep=timestep, + guidance=guidance, + pooled_projections=pooled_projections, + encoder_hidden_states=encoder_hidden_states, + txt_ids=txt_ids, + img_ids=img_ids, + joint_attention_kwargs=joint_attention_kwargs, + return_dict=return_dict, + ) + + # merge samples + if i == 0: + control_block_samples = block_samples + control_single_block_samples = single_block_samples + else: + if block_samples is not None and control_block_samples is not None: + control_block_samples = [ + control_block_sample + block_sample + for control_block_sample, block_sample in zip(control_block_samples, block_samples) + ] + if single_block_samples is not None and control_single_block_samples is not None: + control_single_block_samples = [ + control_single_block_sample + block_sample + for control_single_block_sample, block_sample in zip( + control_single_block_samples, single_block_samples + ) + ] + + return control_block_samples, control_single_block_samples From 60bfa57326b2c64ce9771d60ef54c44201d1be20 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Wed, 23 Oct 2024 18:09:21 +0530 Subject: [PATCH 17/21] fix --- .../models/controlnets/controlnet_flux.py | 18 +++++++++--------- .../controlnets/controlnet_sparsectrl.py | 18 +++++++++--------- 2 files changed, 18 insertions(+), 18 deletions(-) diff --git a/src/diffusers/models/controlnets/controlnet_flux.py b/src/diffusers/models/controlnets/controlnet_flux.py index 961e30155a3d..e6a3eceed9b4 100644 --- a/src/diffusers/models/controlnets/controlnet_flux.py +++ b/src/diffusers/models/controlnets/controlnet_flux.py @@ -18,15 +18,15 @@ import torch import torch.nn as nn -from ..configuration_utils import ConfigMixin, register_to_config -from ..loaders import PeftAdapterMixin -from ..models.attention_processor import AttentionProcessor -from ..models.modeling_utils import ModelMixin -from ..utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers -from .controlnet import BaseOutput, ControlNetConditioningEmbedding, zero_module -from .embeddings import CombinedTimestepGuidanceTextProjEmbeddings, CombinedTimestepTextProjEmbeddings, FluxPosEmbed -from .modeling_outputs import Transformer2DModelOutput -from .transformers.transformer_flux import FluxSingleTransformerBlock, FluxTransformerBlock +from ...configuration_utils import ConfigMixin, register_to_config +from ...loaders import PeftAdapterMixin +from ...models.attention_processor import AttentionProcessor +from ...models.modeling_utils import ModelMixin +from ...utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers +from ..controlnet import BaseOutput, ControlNetConditioningEmbedding, zero_module +from ..embeddings import CombinedTimestepGuidanceTextProjEmbeddings, CombinedTimestepTextProjEmbeddings, FluxPosEmbed +from ..modeling_outputs import Transformer2DModelOutput +from ..transformers.transformer_flux import FluxSingleTransformerBlock, FluxTransformerBlock logger = logging.get_logger(__name__) # pylint: disable=invalid-name diff --git a/src/diffusers/models/controlnets/controlnet_sparsectrl.py b/src/diffusers/models/controlnets/controlnet_sparsectrl.py index 946be368145d..fd599c10b2d7 100644 --- a/src/diffusers/models/controlnets/controlnet_sparsectrl.py +++ b/src/diffusers/models/controlnets/controlnet_sparsectrl.py @@ -19,21 +19,21 @@ from torch import nn from torch.nn import functional as F -from ..configuration_utils import ConfigMixin, register_to_config -from ..loaders import FromOriginalModelMixin -from ..utils import BaseOutput, logging -from .attention_processor import ( +from ...configuration_utils import ConfigMixin, register_to_config +from ...loaders import FromOriginalModelMixin +from ...utils import BaseOutput, logging +from ..attention_processor import ( ADDED_KV_ATTENTION_PROCESSORS, CROSS_ATTENTION_PROCESSORS, AttentionProcessor, AttnAddedKVProcessor, AttnProcessor, ) -from .embeddings import TimestepEmbedding, Timesteps -from .modeling_utils import ModelMixin -from .unets.unet_2d_blocks import UNetMidBlock2DCrossAttn -from .unets.unet_2d_condition import UNet2DConditionModel -from .unets.unet_motion_model import CrossAttnDownBlockMotion, DownBlockMotion +from ..embeddings import TimestepEmbedding, Timesteps +from ..modeling_utils import ModelMixin +from ..unets.unet_2d_blocks import UNetMidBlock2DCrossAttn +from ..unets.unet_2d_condition import UNet2DConditionModel +from ..unets.unet_motion_model import CrossAttnDownBlockMotion, DownBlockMotion logger = logging.get_logger(__name__) # pylint: disable=invalid-name From 019a5ac5813dc398153a066707e2f6e31838f14e Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Wed, 23 Oct 2024 18:20:44 +0530 Subject: [PATCH 18/21] updates --- src/diffusers/models/controlnet_sd3.py | 41 ++++++++++++++++++++++++++ 1 file changed, 41 insertions(+) create mode 100644 src/diffusers/models/controlnet_sd3.py diff --git a/src/diffusers/models/controlnet_sd3.py b/src/diffusers/models/controlnet_sd3.py new file mode 100644 index 000000000000..5e70559e9ac4 --- /dev/null +++ b/src/diffusers/models/controlnet_sd3.py @@ -0,0 +1,41 @@ +# Copyright 2024 Stability AI, The HuggingFace Team and The InstantX Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from ..utils import deprecate, logging +from .controlnets.controlnet_sd3 import SD3ControlNetModel, SD3ControlNetOutput, SD3MultiControlNetModel + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +class SD3ControlNetOutput(SD3ControlNetOutput): + def __init__(self, *args, **kwargs): + deprecation_message = "Importing `SD3ControlNetOutput` from `diffusers.models.controlnet_sd3` is deprecated and this will be removed in a future version. Please use `from diffusers.models.controlnets.controlnet_sd3 import SD3ControlNetOutput`, instead." + deprecate("SD3ControlNetOutput", "0.34", deprecation_message) + super().__init__(*args, **kwargs) + + +class SD3ControlNetModel(SD3ControlNetModel): + def __init__(self, *args, **kwargs): + deprecation_message = "Importing `SD3ControlNetModel` from `diffusers.models.controlnet_sd3` is deprecated and this will be removed in a future version. Please use `from diffusers.models.controlnets.controlnet_sd3 import SD3ControlNetModel`, instead." + deprecate("SD3ControlNetModel", "0.34", deprecation_message) + super().__init__(*args, **kwargs) + + +class SD3MultiControlNetModel(SD3MultiControlNetModel): + def __init__(self, *args, **kwargs): + deprecation_message = "Importing `SD3MultiControlNetModel` from `diffusers.models.controlnet_sd3` is deprecated and this will be removed in a future version. Please use `from diffusers.models.controlnets.controlnet_sd3 import SD3MultiControlNetModel`, instead." + deprecate("SD3MultiControlNetModel", "0.34", deprecation_message) + super().__init__(*args, **kwargs) From 7623d6c904a52b8fe23b34839af8697fad7dda23 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Wed, 23 Oct 2024 19:00:22 +0530 Subject: [PATCH 19/21] updates --- .../pipelines/animatediff/pipeline_animatediff_sparsectrl.py | 2 +- .../pipeline_stable_diffusion_3_controlnet_inpainting.py | 2 +- src/diffusers/pipelines/flux/pipeline_flux_controlnet.py | 2 +- .../pipelines/flux/pipeline_flux_controlnet_inpainting.py | 2 +- 4 files changed, 4 insertions(+), 4 deletions(-) diff --git a/src/diffusers/pipelines/animatediff/pipeline_animatediff_sparsectrl.py b/src/diffusers/pipelines/animatediff/pipeline_animatediff_sparsectrl.py index 8b037cdc34fb..6dde7d6686ee 100644 --- a/src/diffusers/pipelines/animatediff/pipeline_animatediff_sparsectrl.py +++ b/src/diffusers/pipelines/animatediff/pipeline_animatediff_sparsectrl.py @@ -24,7 +24,7 @@ from ...image_processor import PipelineImageInput, VaeImageProcessor from ...loaders import IPAdapterMixin, StableDiffusionLoraLoaderMixin, TextualInversionLoaderMixin from ...models import AutoencoderKL, ImageProjection, UNet2DConditionModel, UNetMotionModel -from ...models.controlnet_sparsectrl import SparseControlNetModel +from ...models.controlnets.controlnet_sparsectrl import SparseControlNetModel from ...models.lora import adjust_lora_scale_text_encoder from ...models.unets.unet_motion_model import MotionAdapter from ...schedulers import KarrasDiffusionSchedulers diff --git a/src/diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet_inpainting.py b/src/diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet_inpainting.py index f362c8f3d0c1..437bb9f2f182 100644 --- a/src/diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet_inpainting.py +++ b/src/diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet_inpainting.py @@ -26,7 +26,7 @@ from ...image_processor import PipelineImageInput, VaeImageProcessor from ...loaders import FromSingleFileMixin, SD3LoraLoaderMixin from ...models.autoencoders import AutoencoderKL -from ...models.controlnet_sd3 import SD3ControlNetModel, SD3MultiControlNetModel +from ...models.controlnets.controlnet_sd3 import SD3ControlNetModel, SD3MultiControlNetModel from ...models.transformers import SD3Transformer2DModel from ...schedulers import FlowMatchEulerDiscreteScheduler from ...utils import ( diff --git a/src/diffusers/pipelines/flux/pipeline_flux_controlnet.py b/src/diffusers/pipelines/flux/pipeline_flux_controlnet.py index 5136c4200147..a6bf57def7b4 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_controlnet.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_controlnet.py @@ -27,7 +27,7 @@ from ...image_processor import PipelineImageInput, VaeImageProcessor from ...loaders import FluxLoraLoaderMixin, FromSingleFileMixin, TextualInversionLoaderMixin from ...models.autoencoders import AutoencoderKL -from ...models.controlnet_flux import FluxControlNetModel, FluxMultiControlNetModel +from ...models.controlnets.controlnet_flux import FluxControlNetModel, FluxMultiControlNetModel from ...models.transformers import FluxTransformer2DModel from ...schedulers import FlowMatchEulerDiscreteScheduler from ...utils import ( diff --git a/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py b/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py index 46784f2d46d1..35822869ee85 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py @@ -14,7 +14,7 @@ from ...image_processor import PipelineImageInput, VaeImageProcessor from ...loaders import FluxLoraLoaderMixin, FromSingleFileMixin, TextualInversionLoaderMixin from ...models.autoencoders import AutoencoderKL -from ...models.controlnet_flux import FluxControlNetModel, FluxMultiControlNetModel +from ...models.controlnets.controlnet_flux import FluxControlNetModel, FluxMultiControlNetModel from ...models.transformers import FluxTransformer2DModel from ...schedulers import FlowMatchEulerDiscreteScheduler from ...utils import ( From 39d747bb2d8ed039d3bd6587c4a355e406e70df8 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Wed, 23 Oct 2024 19:15:26 +0530 Subject: [PATCH 20/21] updates --- .../pipelines/flux/pipeline_flux_controlnet_image_to_image.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py b/src/diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py index 7b40ddfca79a..fa2f9f638b8d 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py @@ -13,7 +13,7 @@ from ...image_processor import PipelineImageInput, VaeImageProcessor from ...loaders import FluxLoraLoaderMixin, FromSingleFileMixin, TextualInversionLoaderMixin from ...models.autoencoders import AutoencoderKL -from ...models.controlnet_flux import FluxControlNetModel, FluxMultiControlNetModel +from ...models.controlnets.controlnet_flux import FluxControlNetModel, FluxMultiControlNetModel from ...models.transformers import FluxTransformer2DModel from ...schedulers import FlowMatchEulerDiscreteScheduler from ...utils import ( From 8edc2c0381c8ac0fe1a36531200c14c81e8bd33f Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Wed, 23 Oct 2024 19:25:58 +0530 Subject: [PATCH 21/21] fix --- src/diffusers/models/__init__.py | 2 + src/diffusers/models/controlnets/__init__.py | 1 + .../models/controlnets/multicontrolnet.py | 183 +++++++++++++++++ .../pipelines/controlnet/multicontrolnet.py | 185 +----------------- 4 files changed, 193 insertions(+), 178 deletions(-) create mode 100644 src/diffusers/models/controlnets/multicontrolnet.py diff --git a/src/diffusers/models/__init__.py b/src/diffusers/models/__init__.py index 947099a737cc..469d7429b9b5 100644 --- a/src/diffusers/models/__init__.py +++ b/src/diffusers/models/__init__.py @@ -43,6 +43,7 @@ _import_structure["controlnets.controlnet_sd3"] = ["SD3ControlNetModel", "SD3MultiControlNetModel"] _import_structure["controlnets.controlnet_sparsectrl"] = ["SparseControlNetModel"] _import_structure["controlnets.controlnet_xs"] = ["ControlNetXSAdapter", "UNetControlNetXSModel"] + _import_structure["controlnets.multicontrolnet"] = ["MultiControlNetModel"] _import_structure["embeddings"] = ["ImageProjection"] _import_structure["modeling_utils"] = ["ModelMixin"] _import_structure["transformers.auraflow_transformer_2d"] = ["AuraFlowTransformer2DModel"] @@ -98,6 +99,7 @@ FluxMultiControlNetModel, HunyuanDiT2DControlNetModel, HunyuanDiT2DMultiControlNetModel, + MultiControlNetModel, SD3ControlNetModel, SD3MultiControlNetModel, SparseControlNetModel, diff --git a/src/diffusers/models/controlnets/__init__.py b/src/diffusers/models/controlnets/__init__.py index d9690182983d..3e4b3561e839 100644 --- a/src/diffusers/models/controlnets/__init__.py +++ b/src/diffusers/models/controlnets/__init__.py @@ -16,6 +16,7 @@ SparseControlNetOutput, ) from .controlnet_xs import ControlNetXSAdapter, ControlNetXSOutput, UNetControlNetXSModel + from .multicontrolnet import MultiControlNetModel if is_flax_available(): from .controlnet_flax import FlaxControlNetModel diff --git a/src/diffusers/models/controlnets/multicontrolnet.py b/src/diffusers/models/controlnets/multicontrolnet.py new file mode 100644 index 000000000000..46c3d1681cc1 --- /dev/null +++ b/src/diffusers/models/controlnets/multicontrolnet.py @@ -0,0 +1,183 @@ +import os +from typing import Any, Callable, Dict, List, Optional, Tuple, Union + +import torch +from torch import nn + +from ...models.controlnets.controlnet import ControlNetModel, ControlNetOutput +from ...models.modeling_utils import ModelMixin +from ...utils import logging + + +logger = logging.get_logger(__name__) + + +class MultiControlNetModel(ModelMixin): + r""" + Multiple `ControlNetModel` wrapper class for Multi-ControlNet + + This module is a wrapper for multiple instances of the `ControlNetModel`. The `forward()` API is designed to be + compatible with `ControlNetModel`. + + Args: + controlnets (`List[ControlNetModel]`): + Provides additional conditioning to the unet during the denoising process. You must set multiple + `ControlNetModel` as a list. + """ + + def __init__(self, controlnets: Union[List[ControlNetModel], Tuple[ControlNetModel]]): + super().__init__() + self.nets = nn.ModuleList(controlnets) + + def forward( + self, + sample: torch.Tensor, + timestep: Union[torch.Tensor, float, int], + encoder_hidden_states: torch.Tensor, + controlnet_cond: List[torch.tensor], + conditioning_scale: List[float], + class_labels: Optional[torch.Tensor] = None, + timestep_cond: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + guess_mode: bool = False, + return_dict: bool = True, + ) -> Union[ControlNetOutput, Tuple]: + for i, (image, scale, controlnet) in enumerate(zip(controlnet_cond, conditioning_scale, self.nets)): + down_samples, mid_sample = controlnet( + sample=sample, + timestep=timestep, + encoder_hidden_states=encoder_hidden_states, + controlnet_cond=image, + conditioning_scale=scale, + class_labels=class_labels, + timestep_cond=timestep_cond, + attention_mask=attention_mask, + added_cond_kwargs=added_cond_kwargs, + cross_attention_kwargs=cross_attention_kwargs, + guess_mode=guess_mode, + return_dict=return_dict, + ) + + # merge samples + if i == 0: + down_block_res_samples, mid_block_res_sample = down_samples, mid_sample + else: + down_block_res_samples = [ + samples_prev + samples_curr + for samples_prev, samples_curr in zip(down_block_res_samples, down_samples) + ] + mid_block_res_sample += mid_sample + + return down_block_res_samples, mid_block_res_sample + + def save_pretrained( + self, + save_directory: Union[str, os.PathLike], + is_main_process: bool = True, + save_function: Callable = None, + safe_serialization: bool = True, + variant: Optional[str] = None, + ): + """ + Save a model and its configuration file to a directory, so that it can be re-loaded using the + `[`~pipelines.controlnet.MultiControlNetModel.from_pretrained`]` class method. + + Arguments: + save_directory (`str` or `os.PathLike`): + Directory to which to save. Will be created if it doesn't exist. + is_main_process (`bool`, *optional*, defaults to `True`): + Whether the process calling this is the main process or not. Useful when in distributed training like + TPUs and need to call this function on all processes. In this case, set `is_main_process=True` only on + the main process to avoid race conditions. + save_function (`Callable`): + The function to use to save the state dictionary. Useful on distributed training like TPUs when one + need to replace `torch.save` by another method. Can be configured with the environment variable + `DIFFUSERS_SAVE_MODE`. + safe_serialization (`bool`, *optional*, defaults to `True`): + Whether to save the model using `safetensors` or the traditional PyTorch way (that uses `pickle`). + variant (`str`, *optional*): + If specified, weights are saved in the format pytorch_model..bin. + """ + for idx, controlnet in enumerate(self.nets): + suffix = "" if idx == 0 else f"_{idx}" + controlnet.save_pretrained( + save_directory + suffix, + is_main_process=is_main_process, + save_function=save_function, + safe_serialization=safe_serialization, + variant=variant, + ) + + @classmethod + def from_pretrained(cls, pretrained_model_path: Optional[Union[str, os.PathLike]], **kwargs): + r""" + Instantiate a pretrained MultiControlNet model from multiple pre-trained controlnet models. + + The model is set in evaluation mode by default using `model.eval()` (Dropout modules are deactivated). To train + the model, you should first set it back in training mode with `model.train()`. + + The warning *Weights from XXX not initialized from pretrained model* means that the weights of XXX do not come + pretrained with the rest of the model. It is up to you to train those weights with a downstream fine-tuning + task. + + The warning *Weights from XXX not used in YYY* means that the layer XXX is not used by YYY, therefore those + weights are discarded. + + Parameters: + pretrained_model_path (`os.PathLike`): + A path to a *directory* containing model weights saved using + [`~diffusers.pipelines.controlnet.MultiControlNetModel.save_pretrained`], e.g., + `./my_model_directory/controlnet`. + torch_dtype (`str` or `torch.dtype`, *optional*): + Override the default `torch.dtype` and load the model under this dtype. If `"auto"` is passed the dtype + will be automatically derived from the model's weights. + output_loading_info(`bool`, *optional*, defaults to `False`): + Whether or not to also return a dictionary containing missing keys, unexpected keys and error messages. + device_map (`str` or `Dict[str, Union[int, str, torch.device]]`, *optional*): + A map that specifies where each submodule should go. It doesn't need to be refined to each + parameter/buffer name, once a given module name is inside, every submodule of it will be sent to the + same device. + + To have Accelerate compute the most optimized `device_map` automatically, set `device_map="auto"`. For + more information about each option see [designing a device + map](https://hf.co/docs/accelerate/main/en/usage_guides/big_modeling#designing-a-device-map). + max_memory (`Dict`, *optional*): + A dictionary device identifier to maximum memory. Will default to the maximum memory available for each + GPU and the available CPU RAM if unset. + low_cpu_mem_usage (`bool`, *optional*, defaults to `True` if torch version >= 1.9.0 else `False`): + Speed up model loading by not initializing the weights and only loading the pre-trained weights. This + also tries to not use more than 1x model size in CPU memory (including peak memory) while loading the + model. This is only supported when torch version >= 1.9.0. If you are using an older version of torch, + setting this argument to `True` will raise an error. + variant (`str`, *optional*): + If specified load weights from `variant` filename, *e.g.* pytorch_model..bin. `variant` is + ignored when using `from_flax`. + use_safetensors (`bool`, *optional*, defaults to `None`): + If set to `None`, the `safetensors` weights will be downloaded if they're available **and** if the + `safetensors` library is installed. If set to `True`, the model will be forcibly loaded from + `safetensors` weights. If set to `False`, loading will *not* use `safetensors`. + """ + idx = 0 + controlnets = [] + + # load controlnet and append to list until no controlnet directory exists anymore + # first controlnet has to be saved under `./mydirectory/controlnet` to be compliant with `DiffusionPipeline.from_prertained` + # second, third, ... controlnets have to be saved under `./mydirectory/controlnet_1`, `./mydirectory/controlnet_2`, ... + model_path_to_load = pretrained_model_path + while os.path.isdir(model_path_to_load): + controlnet = ControlNetModel.from_pretrained(model_path_to_load, **kwargs) + controlnets.append(controlnet) + + idx += 1 + model_path_to_load = pretrained_model_path + f"_{idx}" + + logger.info(f"{len(controlnets)} controlnets loaded from {pretrained_model_path}.") + + if len(controlnets) == 0: + raise ValueError( + f"No ControlNets found under {os.path.dirname(pretrained_model_path)}. Expected at least {pretrained_model_path + '_0'}." + ) + + return cls(controlnets) diff --git a/src/diffusers/pipelines/controlnet/multicontrolnet.py b/src/diffusers/pipelines/controlnet/multicontrolnet.py index 46c3d1681cc1..33790c10e064 100644 --- a/src/diffusers/pipelines/controlnet/multicontrolnet.py +++ b/src/diffusers/pipelines/controlnet/multicontrolnet.py @@ -1,183 +1,12 @@ -import os -from typing import Any, Callable, Dict, List, Optional, Tuple, Union - -import torch -from torch import nn - -from ...models.controlnets.controlnet import ControlNetModel, ControlNetOutput -from ...models.modeling_utils import ModelMixin -from ...utils import logging +from ...models.controlnets.multicontrolnet import MultiControlNetModel +from ...utils import deprecate, logging logger = logging.get_logger(__name__) -class MultiControlNetModel(ModelMixin): - r""" - Multiple `ControlNetModel` wrapper class for Multi-ControlNet - - This module is a wrapper for multiple instances of the `ControlNetModel`. The `forward()` API is designed to be - compatible with `ControlNetModel`. - - Args: - controlnets (`List[ControlNetModel]`): - Provides additional conditioning to the unet during the denoising process. You must set multiple - `ControlNetModel` as a list. - """ - - def __init__(self, controlnets: Union[List[ControlNetModel], Tuple[ControlNetModel]]): - super().__init__() - self.nets = nn.ModuleList(controlnets) - - def forward( - self, - sample: torch.Tensor, - timestep: Union[torch.Tensor, float, int], - encoder_hidden_states: torch.Tensor, - controlnet_cond: List[torch.tensor], - conditioning_scale: List[float], - class_labels: Optional[torch.Tensor] = None, - timestep_cond: Optional[torch.Tensor] = None, - attention_mask: Optional[torch.Tensor] = None, - added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None, - cross_attention_kwargs: Optional[Dict[str, Any]] = None, - guess_mode: bool = False, - return_dict: bool = True, - ) -> Union[ControlNetOutput, Tuple]: - for i, (image, scale, controlnet) in enumerate(zip(controlnet_cond, conditioning_scale, self.nets)): - down_samples, mid_sample = controlnet( - sample=sample, - timestep=timestep, - encoder_hidden_states=encoder_hidden_states, - controlnet_cond=image, - conditioning_scale=scale, - class_labels=class_labels, - timestep_cond=timestep_cond, - attention_mask=attention_mask, - added_cond_kwargs=added_cond_kwargs, - cross_attention_kwargs=cross_attention_kwargs, - guess_mode=guess_mode, - return_dict=return_dict, - ) - - # merge samples - if i == 0: - down_block_res_samples, mid_block_res_sample = down_samples, mid_sample - else: - down_block_res_samples = [ - samples_prev + samples_curr - for samples_prev, samples_curr in zip(down_block_res_samples, down_samples) - ] - mid_block_res_sample += mid_sample - - return down_block_res_samples, mid_block_res_sample - - def save_pretrained( - self, - save_directory: Union[str, os.PathLike], - is_main_process: bool = True, - save_function: Callable = None, - safe_serialization: bool = True, - variant: Optional[str] = None, - ): - """ - Save a model and its configuration file to a directory, so that it can be re-loaded using the - `[`~pipelines.controlnet.MultiControlNetModel.from_pretrained`]` class method. - - Arguments: - save_directory (`str` or `os.PathLike`): - Directory to which to save. Will be created if it doesn't exist. - is_main_process (`bool`, *optional*, defaults to `True`): - Whether the process calling this is the main process or not. Useful when in distributed training like - TPUs and need to call this function on all processes. In this case, set `is_main_process=True` only on - the main process to avoid race conditions. - save_function (`Callable`): - The function to use to save the state dictionary. Useful on distributed training like TPUs when one - need to replace `torch.save` by another method. Can be configured with the environment variable - `DIFFUSERS_SAVE_MODE`. - safe_serialization (`bool`, *optional*, defaults to `True`): - Whether to save the model using `safetensors` or the traditional PyTorch way (that uses `pickle`). - variant (`str`, *optional*): - If specified, weights are saved in the format pytorch_model..bin. - """ - for idx, controlnet in enumerate(self.nets): - suffix = "" if idx == 0 else f"_{idx}" - controlnet.save_pretrained( - save_directory + suffix, - is_main_process=is_main_process, - save_function=save_function, - safe_serialization=safe_serialization, - variant=variant, - ) - - @classmethod - def from_pretrained(cls, pretrained_model_path: Optional[Union[str, os.PathLike]], **kwargs): - r""" - Instantiate a pretrained MultiControlNet model from multiple pre-trained controlnet models. - - The model is set in evaluation mode by default using `model.eval()` (Dropout modules are deactivated). To train - the model, you should first set it back in training mode with `model.train()`. - - The warning *Weights from XXX not initialized from pretrained model* means that the weights of XXX do not come - pretrained with the rest of the model. It is up to you to train those weights with a downstream fine-tuning - task. - - The warning *Weights from XXX not used in YYY* means that the layer XXX is not used by YYY, therefore those - weights are discarded. - - Parameters: - pretrained_model_path (`os.PathLike`): - A path to a *directory* containing model weights saved using - [`~diffusers.pipelines.controlnet.MultiControlNetModel.save_pretrained`], e.g., - `./my_model_directory/controlnet`. - torch_dtype (`str` or `torch.dtype`, *optional*): - Override the default `torch.dtype` and load the model under this dtype. If `"auto"` is passed the dtype - will be automatically derived from the model's weights. - output_loading_info(`bool`, *optional*, defaults to `False`): - Whether or not to also return a dictionary containing missing keys, unexpected keys and error messages. - device_map (`str` or `Dict[str, Union[int, str, torch.device]]`, *optional*): - A map that specifies where each submodule should go. It doesn't need to be refined to each - parameter/buffer name, once a given module name is inside, every submodule of it will be sent to the - same device. - - To have Accelerate compute the most optimized `device_map` automatically, set `device_map="auto"`. For - more information about each option see [designing a device - map](https://hf.co/docs/accelerate/main/en/usage_guides/big_modeling#designing-a-device-map). - max_memory (`Dict`, *optional*): - A dictionary device identifier to maximum memory. Will default to the maximum memory available for each - GPU and the available CPU RAM if unset. - low_cpu_mem_usage (`bool`, *optional*, defaults to `True` if torch version >= 1.9.0 else `False`): - Speed up model loading by not initializing the weights and only loading the pre-trained weights. This - also tries to not use more than 1x model size in CPU memory (including peak memory) while loading the - model. This is only supported when torch version >= 1.9.0. If you are using an older version of torch, - setting this argument to `True` will raise an error. - variant (`str`, *optional*): - If specified load weights from `variant` filename, *e.g.* pytorch_model..bin. `variant` is - ignored when using `from_flax`. - use_safetensors (`bool`, *optional*, defaults to `None`): - If set to `None`, the `safetensors` weights will be downloaded if they're available **and** if the - `safetensors` library is installed. If set to `True`, the model will be forcibly loaded from - `safetensors` weights. If set to `False`, loading will *not* use `safetensors`. - """ - idx = 0 - controlnets = [] - - # load controlnet and append to list until no controlnet directory exists anymore - # first controlnet has to be saved under `./mydirectory/controlnet` to be compliant with `DiffusionPipeline.from_prertained` - # second, third, ... controlnets have to be saved under `./mydirectory/controlnet_1`, `./mydirectory/controlnet_2`, ... - model_path_to_load = pretrained_model_path - while os.path.isdir(model_path_to_load): - controlnet = ControlNetModel.from_pretrained(model_path_to_load, **kwargs) - controlnets.append(controlnet) - - idx += 1 - model_path_to_load = pretrained_model_path + f"_{idx}" - - logger.info(f"{len(controlnets)} controlnets loaded from {pretrained_model_path}.") - - if len(controlnets) == 0: - raise ValueError( - f"No ControlNets found under {os.path.dirname(pretrained_model_path)}. Expected at least {pretrained_model_path + '_0'}." - ) - - return cls(controlnets) +class MultiControlNetModel(MultiControlNetModel): + def __init__(self, *args, **kwargs): + deprecation_message = "Importing `MultiControlNetModel` from `diffusers.pipelines.controlnet.multicontrolnet` is deprecated and this will be removed in a future version. Please use `from diffusers.models.controlnets.multicontrolnet import MultiControlNetModel`, instead." + deprecate("MultiControlNetModel", "0.34", deprecation_message) + super().__init__(*args, **kwargs)