15
15
16
16
import torch
17
17
import torch .nn as nn
18
- import torch .nn .functional as F
19
18
20
19
from ...configuration_utils import ConfigMixin , register_to_config
21
20
from ...loaders import FromOriginalModelMixin , PeftAdapterMixin , SD3Transformer2DLoadersMixin
39
38
40
39
@maybe_allow_in_graph
41
40
class SD3SingleTransformerBlock (nn .Module ):
42
- r"""
43
- A Single Transformer block as part of the MMDiT architecture, used in Stable Diffusion 3 ControlNet.
44
-
45
- Reference: https://arxiv.org/abs/2403.03206
46
-
47
- Parameters:
48
- dim (`int`): The number of channels in the input and output.
49
- num_attention_heads (`int`): The number of heads to use for multi-head attention.
50
- attention_head_dim (`int`): The number of channels in each head.
51
- """
52
-
53
41
def __init__ (
54
42
self ,
55
43
dim : int ,
@@ -59,45 +47,31 @@ def __init__(
59
47
super ().__init__ ()
60
48
61
49
self .norm1 = AdaLayerNormZero (dim )
62
-
63
- if hasattr (F , "scaled_dot_product_attention" ):
64
- processor = JointAttnProcessor2_0 ()
65
- else :
66
- raise ValueError (
67
- "The current PyTorch version does not support the `scaled_dot_product_attention` function."
68
- )
69
-
70
50
self .attn = Attention (
71
51
query_dim = dim ,
72
52
dim_head = attention_head_dim ,
73
53
heads = num_attention_heads ,
74
54
out_dim = dim ,
75
55
bias = True ,
76
- processor = processor ,
56
+ processor = JointAttnProcessor2_0 () ,
77
57
eps = 1e-6 ,
78
58
)
79
59
80
60
self .norm2 = nn .LayerNorm (dim , elementwise_affine = False , eps = 1e-6 )
81
61
self .ff = FeedForward (dim = dim , dim_out = dim , activation_fn = "gelu-approximate" )
82
62
83
63
def forward (self , hidden_states : torch .Tensor , temb : torch .Tensor ):
64
+ # 1. Attention
84
65
norm_hidden_states , gate_msa , shift_mlp , scale_mlp , gate_mlp = self .norm1 (hidden_states , emb = temb )
85
- # Attention.
86
- attn_output = self .attn (
87
- hidden_states = norm_hidden_states ,
88
- encoder_hidden_states = None ,
89
- )
90
-
91
- # Process attention outputs for the `hidden_states`.
66
+ attn_output = self .attn (hidden_states = norm_hidden_states , encoder_hidden_states = None )
92
67
attn_output = gate_msa .unsqueeze (1 ) * attn_output
93
68
hidden_states = hidden_states + attn_output
94
69
70
+ # 2. Feed Forward
95
71
norm_hidden_states = self .norm2 (hidden_states )
96
- norm_hidden_states = norm_hidden_states * (1 + scale_mlp [:, None ]) + shift_mlp [:, None ]
97
-
72
+ norm_hidden_states = norm_hidden_states * (1 + scale_mlp .unsqueeze (1 )) + shift_mlp .unsqueeze (1 )
98
73
ff_output = self .ff (norm_hidden_states )
99
74
ff_output = gate_mlp .unsqueeze (1 ) * ff_output
100
-
101
75
hidden_states = hidden_states + ff_output
102
76
103
77
return hidden_states
@@ -107,26 +81,40 @@ class SD3Transformer2DModel(
107
81
ModelMixin , ConfigMixin , PeftAdapterMixin , FromOriginalModelMixin , SD3Transformer2DLoadersMixin
108
82
):
109
83
"""
110
- The Transformer model introduced in Stable Diffusion 3.
111
-
112
- Reference: https://arxiv.org/abs/2403.03206
84
+ The Transformer model introduced in [Stable Diffusion 3](https://huggingface.co/papers/2403.03206).
113
85
114
86
Parameters:
115
- sample_size (`int`): The width of the latent images. This is fixed during training since
116
- it is used to learn a number of position embeddings.
117
- patch_size (`int`): Patch size to turn the input data into small patches.
118
- in_channels (`int`, *optional*, defaults to 16): The number of channels in the input.
119
- num_layers (`int`, *optional*, defaults to 18): The number of layers of Transformer blocks to use.
120
- attention_head_dim (`int`, *optional*, defaults to 64): The number of channels in each head.
121
- num_attention_heads (`int`, *optional*, defaults to 18): The number of heads to use for multi-head attention.
122
- cross_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use.
123
- caption_projection_dim (`int`): Number of dimensions to use when projecting the `encoder_hidden_states`.
124
- pooled_projection_dim (`int`): Number of dimensions to use when projecting the `pooled_projections`.
125
- out_channels (`int`, defaults to 16): Number of output channels.
126
-
87
+ sample_size (`int`, defaults to `128`):
88
+ The width/height of the latents. This is fixed during training since it is used to learn a number of
89
+ position embeddings.
90
+ patch_size (`int`, defaults to `2`):
91
+ Patch size to turn the input data into small patches.
92
+ in_channels (`int`, defaults to `16`):
93
+ The number of latent channels in the input.
94
+ num_layers (`int`, defaults to `18`):
95
+ The number of layers of transformer blocks to use.
96
+ attention_head_dim (`int`, defaults to `64`):
97
+ The number of channels in each head.
98
+ num_attention_heads (`int`, defaults to `18`):
99
+ The number of heads to use for multi-head attention.
100
+ joint_attention_dim (`int`, defaults to `4096`):
101
+ The embedding dimension to use for joint text-image attention.
102
+ caption_projection_dim (`int`, defaults to `1152`):
103
+ The embedding dimension of caption embeddings.
104
+ pooled_projection_dim (`int`, defaults to `2048`):
105
+ The embedding dimension of pooled text projections.
106
+ out_channels (`int`, defaults to `16`):
107
+ The number of latent channels in the output.
108
+ pos_embed_max_size (`int`, defaults to `96`):
109
+ The maximum latent height/width of positional embeddings.
110
+ dual_attention_layers (`Tuple[int, ...]`, defaults to `()`):
111
+ The number of dual-stream transformer blocks to use.
112
+ qk_norm (`str`, *optional*, defaults to `None`):
113
+ The normalization to use for query and key in the attention layer. If `None`, no normalization is used.
127
114
"""
128
115
129
116
_supports_gradient_checkpointing = True
117
+ _no_split_modules = ["JointTransformerBlock" ]
130
118
_skip_layerwise_casting_patterns = ["pos_embed" , "norm" ]
131
119
132
120
@register_to_config
@@ -149,36 +137,33 @@ def __init__(
149
137
qk_norm : Optional [str ] = None ,
150
138
):
151
139
super ().__init__ ()
152
- default_out_channels = in_channels
153
- self .out_channels = out_channels if out_channels is not None else default_out_channels
154
- self .inner_dim = self .config .num_attention_heads * self .config .attention_head_dim
140
+ self .out_channels = out_channels if out_channels is not None else in_channels
141
+ self .inner_dim = num_attention_heads * attention_head_dim
155
142
156
143
self .pos_embed = PatchEmbed (
157
- height = self . config . sample_size ,
158
- width = self . config . sample_size ,
159
- patch_size = self . config . patch_size ,
160
- in_channels = self . config . in_channels ,
144
+ height = sample_size ,
145
+ width = sample_size ,
146
+ patch_size = patch_size ,
147
+ in_channels = in_channels ,
161
148
embed_dim = self .inner_dim ,
162
149
pos_embed_max_size = pos_embed_max_size , # hard-code for now.
163
150
)
164
151
self .time_text_embed = CombinedTimestepTextProjEmbeddings (
165
- embedding_dim = self .inner_dim , pooled_projection_dim = self . config . pooled_projection_dim
152
+ embedding_dim = self .inner_dim , pooled_projection_dim = pooled_projection_dim
166
153
)
167
- self .context_embedder = nn .Linear (self . config . joint_attention_dim , self . config . caption_projection_dim )
154
+ self .context_embedder = nn .Linear (joint_attention_dim , caption_projection_dim )
168
155
169
- # `attention_head_dim` is doubled to account for the mixing.
170
- # It needs to crafted when we get the actual checkpoints.
171
156
self .transformer_blocks = nn .ModuleList (
172
157
[
173
158
JointTransformerBlock (
174
159
dim = self .inner_dim ,
175
- num_attention_heads = self . config . num_attention_heads ,
176
- attention_head_dim = self . config . attention_head_dim ,
160
+ num_attention_heads = num_attention_heads ,
161
+ attention_head_dim = attention_head_dim ,
177
162
context_pre_only = i == num_layers - 1 ,
178
163
qk_norm = qk_norm ,
179
164
use_dual_attention = True if i in dual_attention_layers else False ,
180
165
)
181
- for i in range (self . config . num_layers )
166
+ for i in range (num_layers )
182
167
]
183
168
)
184
169
@@ -331,24 +316,24 @@ def unfuse_qkv_projections(self):
331
316
332
317
def forward (
333
318
self ,
334
- hidden_states : torch .FloatTensor ,
335
- encoder_hidden_states : torch .FloatTensor = None ,
336
- pooled_projections : torch .FloatTensor = None ,
319
+ hidden_states : torch .Tensor ,
320
+ encoder_hidden_states : torch .Tensor = None ,
321
+ pooled_projections : torch .Tensor = None ,
337
322
timestep : torch .LongTensor = None ,
338
323
block_controlnet_hidden_states : List = None ,
339
324
joint_attention_kwargs : Optional [Dict [str , Any ]] = None ,
340
325
return_dict : bool = True ,
341
326
skip_layers : Optional [List [int ]] = None ,
342
- ) -> Union [torch .FloatTensor , Transformer2DModelOutput ]:
327
+ ) -> Union [torch .Tensor , Transformer2DModelOutput ]:
343
328
"""
344
329
The [`SD3Transformer2DModel`] forward method.
345
330
346
331
Args:
347
- hidden_states (`torch.FloatTensor ` of shape `(batch size, channel, height, width)`):
332
+ hidden_states (`torch.Tensor ` of shape `(batch size, channel, height, width)`):
348
333
Input `hidden_states`.
349
- encoder_hidden_states (`torch.FloatTensor ` of shape `(batch size, sequence_len, embed_dims)`):
334
+ encoder_hidden_states (`torch.Tensor ` of shape `(batch size, sequence_len, embed_dims)`):
350
335
Conditional embeddings (embeddings computed from the input conditions such as prompts) to use.
351
- pooled_projections (`torch.FloatTensor ` of shape `(batch_size, projection_dim)`):
336
+ pooled_projections (`torch.Tensor ` of shape `(batch_size, projection_dim)`):
352
337
Embeddings projected from the embeddings of input conditions.
353
338
timestep (`torch.LongTensor`):
354
339
Used to indicate denoising step.
0 commit comments