Skip to content

Commit 1123ee7

Browse files
committed
Add cross attention type for Sana-Sprint.
1 parent 5d9a5da commit 1123ee7

File tree

1 file changed

+96
-1
lines changed

1 file changed

+96
-1
lines changed

src/diffusers/models/transformers/sana_transformer.py

Lines changed: 96 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
from typing import Any, Dict, Optional, Tuple, Union
1616

17+
import math
1718
import torch
1819
import torch.nn.functional as F
1920
from torch import nn
@@ -184,6 +185,91 @@ def __call__(
184185

185186
return hidden_states
186187

188+
189+
class SanaAttnProcessor3_0:
190+
r"""
191+
Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
192+
"""
193+
194+
def __init__(self):
195+
if not hasattr(F, "scaled_dot_product_attention"):
196+
raise ImportError("SanaAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
197+
198+
@staticmethod
199+
def scaled_dot_product_attention(query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False, scale=None
200+
) -> torch.Tensor:
201+
B, H, L, S = *query.size()[:-1], key.size(-2)
202+
scale_factor = 1 / math.sqrt(query.size(-1)) if scale is None else scale
203+
attn_bias = torch.zeros(B, H, L, S, dtype=query.dtype, device=query.device)
204+
205+
if attn_mask is not None:
206+
if attn_mask.dtype == torch.bool:
207+
attn_bias.masked_fill_(attn_mask.logical_not(), float("-inf"))
208+
else:
209+
attn_bias += attn_mask
210+
attn_weight = query @ key.transpose(-2, -1) * scale_factor
211+
attn_weight += attn_bias
212+
attn_weight = torch.softmax(attn_weight, dim=-1)
213+
attn_weight = torch.dropout(attn_weight, dropout_p, train=True)
214+
return attn_weight @ value
215+
216+
# return x
217+
def __call__(
218+
self,
219+
attn: Attention,
220+
hidden_states: torch.Tensor,
221+
encoder_hidden_states: Optional[torch.Tensor] = None,
222+
attention_mask: Optional[torch.Tensor] = None,
223+
) -> torch.Tensor:
224+
batch_size, sequence_length, _ = (
225+
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
226+
)
227+
228+
if attention_mask is not None:
229+
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
230+
# scaled_dot_product_attention expects attention_mask shape to be
231+
# (batch, heads, source_length, target_length)
232+
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
233+
234+
query = attn.to_q(hidden_states)
235+
236+
if encoder_hidden_states is None:
237+
encoder_hidden_states = hidden_states
238+
239+
key = attn.to_k(encoder_hidden_states)
240+
value = attn.to_v(encoder_hidden_states)
241+
242+
if attn.norm_q is not None:
243+
query = attn.norm_q(query)
244+
if attn.norm_k is not None:
245+
key = attn.norm_k(key)
246+
247+
inner_dim = key.shape[-1]
248+
head_dim = inner_dim // attn.heads
249+
250+
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
251+
252+
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
253+
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
254+
255+
# the output of sdp = (batch, num_heads, seq_len, head_dim)
256+
# TODO: add support for attn.scale when we move to Torch 2.1
257+
hidden_states = self.scaled_dot_product_attention(
258+
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
259+
)
260+
261+
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
262+
hidden_states = hidden_states.to(query.dtype)
263+
264+
# linear proj
265+
hidden_states = attn.to_out[0](hidden_states)
266+
# dropout
267+
hidden_states = attn.to_out[1](hidden_states)
268+
269+
hidden_states = hidden_states / attn.rescale_output_factor
270+
271+
return hidden_states
272+
187273

188274
class SanaTransformerBlock(nn.Module):
189275
r"""
@@ -205,6 +291,7 @@ def __init__(
205291
attention_out_bias: bool = True,
206292
mlp_ratio: float = 2.5,
207293
qk_norm: Optional[str] = None,
294+
cross_attention_type: str = "flash",
208295
) -> None:
209296
super().__init__()
210297

@@ -223,6 +310,12 @@ def __init__(
223310
)
224311

225312
# 2. Cross Attention
313+
if cross_attention_type == "flash":
314+
cross_attention_processor = SanaAttnProcessor2_0()
315+
elif cross_attention_type == "vanilla":
316+
cross_attention_processor = SanaAttnProcessor3_0()
317+
else:
318+
raise ValueError(f"Cross attention type {cross_attention_type} is not defined.")
226319
if cross_attention_dim is not None:
227320
self.norm2 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps)
228321
self.attn2 = Attention(
@@ -235,7 +328,7 @@ def __init__(
235328
dropout=dropout,
236329
bias=True,
237330
out_bias=attention_out_bias,
238-
processor=SanaAttnProcessor2_0(),
331+
processor=cross_attention_processor,
239332
)
240333

241334
# 3. Feed-forward
@@ -360,6 +453,7 @@ def __init__(
360453
guidance_embeds_scale: float = 0.1,
361454
qk_norm: Optional[str] = None,
362455
timestep_scale: float = 1.0,
456+
cross_attention_type: str = "flash",
363457
) -> None:
364458
super().__init__()
365459

@@ -402,6 +496,7 @@ def __init__(
402496
norm_eps=norm_eps,
403497
mlp_ratio=mlp_ratio,
404498
qk_norm=qk_norm,
499+
cross_attention_type=cross_attention_type,
405500
)
406501
for _ in range(num_layers)
407502
]

0 commit comments

Comments
 (0)