18
18
import numpy as np
19
19
import torch
20
20
import torch .nn as nn
21
- import torch .nn .functional as F
22
21
23
22
from ...configuration_utils import ConfigMixin , register_to_config
24
23
from ...loaders import FluxTransformer2DLoadersMixin , FromOriginalModelMixin , PeftAdapterMixin
32
31
)
33
32
from ...models .modeling_utils import ModelMixin
34
33
from ...models .normalization import AdaLayerNormContinuous , AdaLayerNormZero , AdaLayerNormZeroSingle
35
- from ...utils import USE_PEFT_BACKEND , logging , scale_lora_layers , unscale_lora_layers
34
+ from ...utils import USE_PEFT_BACKEND , deprecate , logging , scale_lora_layers , unscale_lora_layers
36
35
from ...utils .import_utils import is_torch_npu_available
37
36
from ...utils .torch_utils import maybe_allow_in_graph
38
37
from ..cache_utils import CacheMixin
45
44
46
45
@maybe_allow_in_graph
47
46
class FluxSingleTransformerBlock (nn .Module ):
48
- r"""
49
- A Transformer block following the MMDiT architecture, introduced in Stable Diffusion 3.
50
-
51
- Reference: https://arxiv.org/abs/2403.03206
52
-
53
- Parameters:
54
- dim (`int`): The number of channels in the input and output.
55
- num_attention_heads (`int`): The number of heads to use for multi-head attention.
56
- attention_head_dim (`int`): The number of channels in each head.
57
- context_pre_only (`bool`): Boolean to determine if we should add some blocks associated with the
58
- processing of `context` conditions.
59
- """
60
-
61
- def __init__ (self , dim , num_attention_heads , attention_head_dim , mlp_ratio = 4.0 ):
47
+ def __init__ (self , dim : int , num_attention_heads : int , attention_head_dim : int , mlp_ratio : float = 4.0 ):
62
48
super ().__init__ ()
63
49
self .mlp_hidden_dim = int (dim * mlp_ratio )
64
50
@@ -68,9 +54,15 @@ def __init__(self, dim, num_attention_heads, attention_head_dim, mlp_ratio=4.0):
68
54
self .proj_out = nn .Linear (dim + self .mlp_hidden_dim , dim )
69
55
70
56
if is_torch_npu_available ():
57
+ deprecation_message = (
58
+ "Defaulting to FluxAttnProcessor2_0_NPU for NPU devices will be removed. Attention processors "
59
+ "should be set explicitly using the `set_attn_processor` method."
60
+ )
61
+ deprecate ("npu_processor" , "0.34.0" , deprecation_message )
71
62
processor = FluxAttnProcessor2_0_NPU ()
72
63
else :
73
64
processor = FluxAttnProcessor2_0 ()
65
+
74
66
self .attn = Attention (
75
67
query_dim = dim ,
76
68
cross_attention_dim = None ,
@@ -113,39 +105,14 @@ def forward(
113
105
114
106
@maybe_allow_in_graph
115
107
class FluxTransformerBlock (nn .Module ):
116
- r"""
117
- A Transformer block following the MMDiT architecture, introduced in Stable Diffusion 3.
118
-
119
- Reference: https://arxiv.org/abs/2403.03206
120
-
121
- Args:
122
- dim (`int`):
123
- The embedding dimension of the block.
124
- num_attention_heads (`int`):
125
- The number of attention heads to use.
126
- attention_head_dim (`int`):
127
- The number of dimensions to use for each attention head.
128
- qk_norm (`str`, defaults to `"rms_norm"`):
129
- The normalization to use for the query and key tensors.
130
- eps (`float`, defaults to `1e-6`):
131
- The epsilon value to use for the normalization.
132
- """
133
-
134
108
def __init__ (
135
109
self , dim : int , num_attention_heads : int , attention_head_dim : int , qk_norm : str = "rms_norm" , eps : float = 1e-6
136
110
):
137
111
super ().__init__ ()
138
112
139
113
self .norm1 = AdaLayerNormZero (dim )
140
-
141
114
self .norm1_context = AdaLayerNormZero (dim )
142
115
143
- if hasattr (F , "scaled_dot_product_attention" ):
144
- processor = FluxAttnProcessor2_0 ()
145
- else :
146
- raise ValueError (
147
- "The current PyTorch version does not support the `scaled_dot_product_attention` function."
148
- )
149
116
self .attn = Attention (
150
117
query_dim = dim ,
151
118
cross_attention_dim = None ,
@@ -155,7 +122,7 @@ def __init__(
155
122
out_dim = dim ,
156
123
context_pre_only = False ,
157
124
bias = True ,
158
- processor = processor ,
125
+ processor = FluxAttnProcessor2_0 () ,
159
126
qk_norm = qk_norm ,
160
127
eps = eps ,
161
128
)
@@ -166,10 +133,6 @@ def __init__(
166
133
self .norm2_context = nn .LayerNorm (dim , elementwise_affine = False , eps = 1e-6 )
167
134
self .ff_context = FeedForward (dim = dim , dim_out = dim , activation_fn = "gelu-approximate" )
168
135
169
- # let chunk size default to None
170
- self ._chunk_size = None
171
- self ._chunk_dim = 0
172
-
173
136
def forward (
174
137
self ,
175
138
hidden_states : torch .Tensor ,
0 commit comments