Skip to content

Commit 13f20c7

Browse files
authored
[refactor] SD3 docs & remove additional code (#10882)
* update * update * update
1 parent 8759969 commit 13f20c7

File tree

3 files changed

+107
-80
lines changed

3 files changed

+107
-80
lines changed

src/diffusers/models/attention_processor.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1410,7 +1410,7 @@ class JointAttnProcessor2_0:
14101410

14111411
def __init__(self):
14121412
if not hasattr(F, "scaled_dot_product_attention"):
1413-
raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
1413+
raise ImportError("JointAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
14141414

14151415
def __call__(
14161416
self,

src/diffusers/models/controlnets/controlnet_sd3.py

Lines changed: 54 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,48 @@ class SD3ControlNetOutput(BaseOutput):
4040

4141

4242
class SD3ControlNetModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin):
43+
r"""
44+
ControlNet model for [Stable Diffusion 3](https://huggingface.co/papers/2403.03206).
45+
46+
Parameters:
47+
sample_size (`int`, defaults to `128`):
48+
The width/height of the latents. This is fixed during training since it is used to learn a number of
49+
position embeddings.
50+
patch_size (`int`, defaults to `2`):
51+
Patch size to turn the input data into small patches.
52+
in_channels (`int`, defaults to `16`):
53+
The number of latent channels in the input.
54+
num_layers (`int`, defaults to `18`):
55+
The number of layers of transformer blocks to use.
56+
attention_head_dim (`int`, defaults to `64`):
57+
The number of channels in each head.
58+
num_attention_heads (`int`, defaults to `18`):
59+
The number of heads to use for multi-head attention.
60+
joint_attention_dim (`int`, defaults to `4096`):
61+
The embedding dimension to use for joint text-image attention.
62+
caption_projection_dim (`int`, defaults to `1152`):
63+
The embedding dimension of caption embeddings.
64+
pooled_projection_dim (`int`, defaults to `2048`):
65+
The embedding dimension of pooled text projections.
66+
out_channels (`int`, defaults to `16`):
67+
The number of latent channels in the output.
68+
pos_embed_max_size (`int`, defaults to `96`):
69+
The maximum latent height/width of positional embeddings.
70+
extra_conditioning_channels (`int`, defaults to `0`):
71+
The number of extra channels to use for conditioning for patch embedding.
72+
dual_attention_layers (`Tuple[int, ...]`, defaults to `()`):
73+
The number of dual-stream transformer blocks to use.
74+
qk_norm (`str`, *optional*, defaults to `None`):
75+
The normalization to use for query and key in the attention layer. If `None`, no normalization is used.
76+
pos_embed_type (`str`, defaults to `"sincos"`):
77+
The type of positional embedding to use. Choose between `"sincos"` and `None`.
78+
use_pos_embed (`bool`, defaults to `True`):
79+
Whether to use positional embeddings.
80+
force_zeros_for_pooled_projection (`bool`, defaults to `True`):
81+
Whether to force zeros for pooled projection embeddings. This is handled in the pipelines by reading the
82+
config value of the ControlNet model.
83+
"""
84+
4385
_supports_gradient_checkpointing = True
4486

4587
@register_to_config
@@ -93,7 +135,7 @@ def __init__(
93135
JointTransformerBlock(
94136
dim=self.inner_dim,
95137
num_attention_heads=num_attention_heads,
96-
attention_head_dim=self.config.attention_head_dim,
138+
attention_head_dim=attention_head_dim,
97139
context_pre_only=False,
98140
qk_norm=qk_norm,
99141
use_dual_attention=True if i in dual_attention_layers else False,
@@ -108,7 +150,7 @@ def __init__(
108150
SD3SingleTransformerBlock(
109151
dim=self.inner_dim,
110152
num_attention_heads=num_attention_heads,
111-
attention_head_dim=self.config.attention_head_dim,
153+
attention_head_dim=attention_head_dim,
112154
)
113155
for _ in range(num_layers)
114156
]
@@ -297,28 +339,28 @@ def from_transformer(
297339

298340
def forward(
299341
self,
300-
hidden_states: torch.FloatTensor,
342+
hidden_states: torch.Tensor,
301343
controlnet_cond: torch.Tensor,
302344
conditioning_scale: float = 1.0,
303-
encoder_hidden_states: torch.FloatTensor = None,
304-
pooled_projections: torch.FloatTensor = None,
345+
encoder_hidden_states: torch.Tensor = None,
346+
pooled_projections: torch.Tensor = None,
305347
timestep: torch.LongTensor = None,
306348
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
307349
return_dict: bool = True,
308-
) -> Union[torch.FloatTensor, Transformer2DModelOutput]:
350+
) -> Union[torch.Tensor, Transformer2DModelOutput]:
309351
"""
310352
The [`SD3Transformer2DModel`] forward method.
311353
312354
Args:
313-
hidden_states (`torch.FloatTensor` of shape `(batch size, channel, height, width)`):
355+
hidden_states (`torch.Tensor` of shape `(batch size, channel, height, width)`):
314356
Input `hidden_states`.
315357
controlnet_cond (`torch.Tensor`):
316358
The conditional input tensor of shape `(batch_size, sequence_length, hidden_size)`.
317359
conditioning_scale (`float`, defaults to `1.0`):
318360
The scale factor for ControlNet outputs.
319-
encoder_hidden_states (`torch.FloatTensor` of shape `(batch size, sequence_len, embed_dims)`):
361+
encoder_hidden_states (`torch.Tensor` of shape `(batch size, sequence_len, embed_dims)`):
320362
Conditional embeddings (embeddings computed from the input conditions such as prompts) to use.
321-
pooled_projections (`torch.FloatTensor` of shape `(batch_size, projection_dim)`): Embeddings projected
363+
pooled_projections (`torch.Tensor` of shape `(batch_size, projection_dim)`): Embeddings projected
322364
from the embeddings of input conditions.
323365
timestep ( `torch.LongTensor`):
324366
Used to indicate denoising step.
@@ -437,11 +479,11 @@ def __init__(self, controlnets):
437479

438480
def forward(
439481
self,
440-
hidden_states: torch.FloatTensor,
482+
hidden_states: torch.Tensor,
441483
controlnet_cond: List[torch.tensor],
442484
conditioning_scale: List[float],
443-
pooled_projections: torch.FloatTensor,
444-
encoder_hidden_states: torch.FloatTensor = None,
485+
pooled_projections: torch.Tensor,
486+
encoder_hidden_states: torch.Tensor = None,
445487
timestep: torch.LongTensor = None,
446488
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
447489
return_dict: bool = True,

src/diffusers/models/transformers/transformer_sd3.py

Lines changed: 52 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515

1616
import torch
1717
import torch.nn as nn
18-
import torch.nn.functional as F
1918

2019
from ...configuration_utils import ConfigMixin, register_to_config
2120
from ...loaders import FromOriginalModelMixin, PeftAdapterMixin, SD3Transformer2DLoadersMixin
@@ -39,17 +38,6 @@
3938

4039
@maybe_allow_in_graph
4140
class SD3SingleTransformerBlock(nn.Module):
42-
r"""
43-
A Single Transformer block as part of the MMDiT architecture, used in Stable Diffusion 3 ControlNet.
44-
45-
Reference: https://arxiv.org/abs/2403.03206
46-
47-
Parameters:
48-
dim (`int`): The number of channels in the input and output.
49-
num_attention_heads (`int`): The number of heads to use for multi-head attention.
50-
attention_head_dim (`int`): The number of channels in each head.
51-
"""
52-
5341
def __init__(
5442
self,
5543
dim: int,
@@ -59,45 +47,31 @@ def __init__(
5947
super().__init__()
6048

6149
self.norm1 = AdaLayerNormZero(dim)
62-
63-
if hasattr(F, "scaled_dot_product_attention"):
64-
processor = JointAttnProcessor2_0()
65-
else:
66-
raise ValueError(
67-
"The current PyTorch version does not support the `scaled_dot_product_attention` function."
68-
)
69-
7050
self.attn = Attention(
7151
query_dim=dim,
7252
dim_head=attention_head_dim,
7353
heads=num_attention_heads,
7454
out_dim=dim,
7555
bias=True,
76-
processor=processor,
56+
processor=JointAttnProcessor2_0(),
7757
eps=1e-6,
7858
)
7959

8060
self.norm2 = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
8161
self.ff = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate")
8262

8363
def forward(self, hidden_states: torch.Tensor, temb: torch.Tensor):
64+
# 1. Attention
8465
norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(hidden_states, emb=temb)
85-
# Attention.
86-
attn_output = self.attn(
87-
hidden_states=norm_hidden_states,
88-
encoder_hidden_states=None,
89-
)
90-
91-
# Process attention outputs for the `hidden_states`.
66+
attn_output = self.attn(hidden_states=norm_hidden_states, encoder_hidden_states=None)
9267
attn_output = gate_msa.unsqueeze(1) * attn_output
9368
hidden_states = hidden_states + attn_output
9469

70+
# 2. Feed Forward
9571
norm_hidden_states = self.norm2(hidden_states)
96-
norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
97-
72+
norm_hidden_states = norm_hidden_states * (1 + scale_mlp.unsqueeze(1)) + shift_mlp.unsqueeze(1)
9873
ff_output = self.ff(norm_hidden_states)
9974
ff_output = gate_mlp.unsqueeze(1) * ff_output
100-
10175
hidden_states = hidden_states + ff_output
10276

10377
return hidden_states
@@ -107,26 +81,40 @@ class SD3Transformer2DModel(
10781
ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin, SD3Transformer2DLoadersMixin
10882
):
10983
"""
110-
The Transformer model introduced in Stable Diffusion 3.
111-
112-
Reference: https://arxiv.org/abs/2403.03206
84+
The Transformer model introduced in [Stable Diffusion 3](https://huggingface.co/papers/2403.03206).
11385
11486
Parameters:
115-
sample_size (`int`): The width of the latent images. This is fixed during training since
116-
it is used to learn a number of position embeddings.
117-
patch_size (`int`): Patch size to turn the input data into small patches.
118-
in_channels (`int`, *optional*, defaults to 16): The number of channels in the input.
119-
num_layers (`int`, *optional*, defaults to 18): The number of layers of Transformer blocks to use.
120-
attention_head_dim (`int`, *optional*, defaults to 64): The number of channels in each head.
121-
num_attention_heads (`int`, *optional*, defaults to 18): The number of heads to use for multi-head attention.
122-
cross_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use.
123-
caption_projection_dim (`int`): Number of dimensions to use when projecting the `encoder_hidden_states`.
124-
pooled_projection_dim (`int`): Number of dimensions to use when projecting the `pooled_projections`.
125-
out_channels (`int`, defaults to 16): Number of output channels.
126-
87+
sample_size (`int`, defaults to `128`):
88+
The width/height of the latents. This is fixed during training since it is used to learn a number of
89+
position embeddings.
90+
patch_size (`int`, defaults to `2`):
91+
Patch size to turn the input data into small patches.
92+
in_channels (`int`, defaults to `16`):
93+
The number of latent channels in the input.
94+
num_layers (`int`, defaults to `18`):
95+
The number of layers of transformer blocks to use.
96+
attention_head_dim (`int`, defaults to `64`):
97+
The number of channels in each head.
98+
num_attention_heads (`int`, defaults to `18`):
99+
The number of heads to use for multi-head attention.
100+
joint_attention_dim (`int`, defaults to `4096`):
101+
The embedding dimension to use for joint text-image attention.
102+
caption_projection_dim (`int`, defaults to `1152`):
103+
The embedding dimension of caption embeddings.
104+
pooled_projection_dim (`int`, defaults to `2048`):
105+
The embedding dimension of pooled text projections.
106+
out_channels (`int`, defaults to `16`):
107+
The number of latent channels in the output.
108+
pos_embed_max_size (`int`, defaults to `96`):
109+
The maximum latent height/width of positional embeddings.
110+
dual_attention_layers (`Tuple[int, ...]`, defaults to `()`):
111+
The number of dual-stream transformer blocks to use.
112+
qk_norm (`str`, *optional*, defaults to `None`):
113+
The normalization to use for query and key in the attention layer. If `None`, no normalization is used.
127114
"""
128115

129116
_supports_gradient_checkpointing = True
117+
_no_split_modules = ["JointTransformerBlock"]
130118
_skip_layerwise_casting_patterns = ["pos_embed", "norm"]
131119

132120
@register_to_config
@@ -149,36 +137,33 @@ def __init__(
149137
qk_norm: Optional[str] = None,
150138
):
151139
super().__init__()
152-
default_out_channels = in_channels
153-
self.out_channels = out_channels if out_channels is not None else default_out_channels
154-
self.inner_dim = self.config.num_attention_heads * self.config.attention_head_dim
140+
self.out_channels = out_channels if out_channels is not None else in_channels
141+
self.inner_dim = num_attention_heads * attention_head_dim
155142

156143
self.pos_embed = PatchEmbed(
157-
height=self.config.sample_size,
158-
width=self.config.sample_size,
159-
patch_size=self.config.patch_size,
160-
in_channels=self.config.in_channels,
144+
height=sample_size,
145+
width=sample_size,
146+
patch_size=patch_size,
147+
in_channels=in_channels,
161148
embed_dim=self.inner_dim,
162149
pos_embed_max_size=pos_embed_max_size, # hard-code for now.
163150
)
164151
self.time_text_embed = CombinedTimestepTextProjEmbeddings(
165-
embedding_dim=self.inner_dim, pooled_projection_dim=self.config.pooled_projection_dim
152+
embedding_dim=self.inner_dim, pooled_projection_dim=pooled_projection_dim
166153
)
167-
self.context_embedder = nn.Linear(self.config.joint_attention_dim, self.config.caption_projection_dim)
154+
self.context_embedder = nn.Linear(joint_attention_dim, caption_projection_dim)
168155

169-
# `attention_head_dim` is doubled to account for the mixing.
170-
# It needs to crafted when we get the actual checkpoints.
171156
self.transformer_blocks = nn.ModuleList(
172157
[
173158
JointTransformerBlock(
174159
dim=self.inner_dim,
175-
num_attention_heads=self.config.num_attention_heads,
176-
attention_head_dim=self.config.attention_head_dim,
160+
num_attention_heads=num_attention_heads,
161+
attention_head_dim=attention_head_dim,
177162
context_pre_only=i == num_layers - 1,
178163
qk_norm=qk_norm,
179164
use_dual_attention=True if i in dual_attention_layers else False,
180165
)
181-
for i in range(self.config.num_layers)
166+
for i in range(num_layers)
182167
]
183168
)
184169

@@ -331,24 +316,24 @@ def unfuse_qkv_projections(self):
331316

332317
def forward(
333318
self,
334-
hidden_states: torch.FloatTensor,
335-
encoder_hidden_states: torch.FloatTensor = None,
336-
pooled_projections: torch.FloatTensor = None,
319+
hidden_states: torch.Tensor,
320+
encoder_hidden_states: torch.Tensor = None,
321+
pooled_projections: torch.Tensor = None,
337322
timestep: torch.LongTensor = None,
338323
block_controlnet_hidden_states: List = None,
339324
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
340325
return_dict: bool = True,
341326
skip_layers: Optional[List[int]] = None,
342-
) -> Union[torch.FloatTensor, Transformer2DModelOutput]:
327+
) -> Union[torch.Tensor, Transformer2DModelOutput]:
343328
"""
344329
The [`SD3Transformer2DModel`] forward method.
345330
346331
Args:
347-
hidden_states (`torch.FloatTensor` of shape `(batch size, channel, height, width)`):
332+
hidden_states (`torch.Tensor` of shape `(batch size, channel, height, width)`):
348333
Input `hidden_states`.
349-
encoder_hidden_states (`torch.FloatTensor` of shape `(batch size, sequence_len, embed_dims)`):
334+
encoder_hidden_states (`torch.Tensor` of shape `(batch size, sequence_len, embed_dims)`):
350335
Conditional embeddings (embeddings computed from the input conditions such as prompts) to use.
351-
pooled_projections (`torch.FloatTensor` of shape `(batch_size, projection_dim)`):
336+
pooled_projections (`torch.Tensor` of shape `(batch_size, projection_dim)`):
352337
Embeddings projected from the embeddings of input conditions.
353338
timestep (`torch.LongTensor`):
354339
Used to indicate denoising step.

0 commit comments

Comments
 (0)