Skip to content

Commit 0404703

Browse files
a-r-r-o-wDN6
andauthored
[refactor] Remove additional Flux code (#10881)
* update * apply review suggestions --------- Co-authored-by: Dhruv Nair <dhruv.nair@gmail.com>
1 parent 13f20c7 commit 0404703

File tree

1 file changed

+9
-46
lines changed

1 file changed

+9
-46
lines changed

src/diffusers/models/transformers/transformer_flux.py

Lines changed: 9 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@
1818
import numpy as np
1919
import torch
2020
import torch.nn as nn
21-
import torch.nn.functional as F
2221

2322
from ...configuration_utils import ConfigMixin, register_to_config
2423
from ...loaders import FluxTransformer2DLoadersMixin, FromOriginalModelMixin, PeftAdapterMixin
@@ -32,7 +31,7 @@
3231
)
3332
from ...models.modeling_utils import ModelMixin
3433
from ...models.normalization import AdaLayerNormContinuous, AdaLayerNormZero, AdaLayerNormZeroSingle
35-
from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
34+
from ...utils import USE_PEFT_BACKEND, deprecate, logging, scale_lora_layers, unscale_lora_layers
3635
from ...utils.import_utils import is_torch_npu_available
3736
from ...utils.torch_utils import maybe_allow_in_graph
3837
from ..cache_utils import CacheMixin
@@ -45,20 +44,7 @@
4544

4645
@maybe_allow_in_graph
4746
class FluxSingleTransformerBlock(nn.Module):
48-
r"""
49-
A Transformer block following the MMDiT architecture, introduced in Stable Diffusion 3.
50-
51-
Reference: https://arxiv.org/abs/2403.03206
52-
53-
Parameters:
54-
dim (`int`): The number of channels in the input and output.
55-
num_attention_heads (`int`): The number of heads to use for multi-head attention.
56-
attention_head_dim (`int`): The number of channels in each head.
57-
context_pre_only (`bool`): Boolean to determine if we should add some blocks associated with the
58-
processing of `context` conditions.
59-
"""
60-
61-
def __init__(self, dim, num_attention_heads, attention_head_dim, mlp_ratio=4.0):
47+
def __init__(self, dim: int, num_attention_heads: int, attention_head_dim: int, mlp_ratio: float = 4.0):
6248
super().__init__()
6349
self.mlp_hidden_dim = int(dim * mlp_ratio)
6450

@@ -68,9 +54,15 @@ def __init__(self, dim, num_attention_heads, attention_head_dim, mlp_ratio=4.0):
6854
self.proj_out = nn.Linear(dim + self.mlp_hidden_dim, dim)
6955

7056
if is_torch_npu_available():
57+
deprecation_message = (
58+
"Defaulting to FluxAttnProcessor2_0_NPU for NPU devices will be removed. Attention processors "
59+
"should be set explicitly using the `set_attn_processor` method."
60+
)
61+
deprecate("npu_processor", "0.34.0", deprecation_message)
7162
processor = FluxAttnProcessor2_0_NPU()
7263
else:
7364
processor = FluxAttnProcessor2_0()
65+
7466
self.attn = Attention(
7567
query_dim=dim,
7668
cross_attention_dim=None,
@@ -113,39 +105,14 @@ def forward(
113105

114106
@maybe_allow_in_graph
115107
class FluxTransformerBlock(nn.Module):
116-
r"""
117-
A Transformer block following the MMDiT architecture, introduced in Stable Diffusion 3.
118-
119-
Reference: https://arxiv.org/abs/2403.03206
120-
121-
Args:
122-
dim (`int`):
123-
The embedding dimension of the block.
124-
num_attention_heads (`int`):
125-
The number of attention heads to use.
126-
attention_head_dim (`int`):
127-
The number of dimensions to use for each attention head.
128-
qk_norm (`str`, defaults to `"rms_norm"`):
129-
The normalization to use for the query and key tensors.
130-
eps (`float`, defaults to `1e-6`):
131-
The epsilon value to use for the normalization.
132-
"""
133-
134108
def __init__(
135109
self, dim: int, num_attention_heads: int, attention_head_dim: int, qk_norm: str = "rms_norm", eps: float = 1e-6
136110
):
137111
super().__init__()
138112

139113
self.norm1 = AdaLayerNormZero(dim)
140-
141114
self.norm1_context = AdaLayerNormZero(dim)
142115

143-
if hasattr(F, "scaled_dot_product_attention"):
144-
processor = FluxAttnProcessor2_0()
145-
else:
146-
raise ValueError(
147-
"The current PyTorch version does not support the `scaled_dot_product_attention` function."
148-
)
149116
self.attn = Attention(
150117
query_dim=dim,
151118
cross_attention_dim=None,
@@ -155,7 +122,7 @@ def __init__(
155122
out_dim=dim,
156123
context_pre_only=False,
157124
bias=True,
158-
processor=processor,
125+
processor=FluxAttnProcessor2_0(),
159126
qk_norm=qk_norm,
160127
eps=eps,
161128
)
@@ -166,10 +133,6 @@ def __init__(
166133
self.norm2_context = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
167134
self.ff_context = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate")
168135

169-
# let chunk size default to None
170-
self._chunk_size = None
171-
self._chunk_dim = 0
172-
173136
def forward(
174137
self,
175138
hidden_states: torch.Tensor,

0 commit comments

Comments
 (0)