14
14
15
15
from typing import Any , Dict , Optional , Tuple , Union
16
16
17
+ import math
17
18
import torch
18
19
import torch .nn .functional as F
19
20
from torch import nn
@@ -184,6 +185,91 @@ def __call__(
184
185
185
186
return hidden_states
186
187
188
+
189
+ class SanaAttnProcessor3_0 :
190
+ r"""
191
+ Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
192
+ """
193
+
194
+ def __init__ (self ):
195
+ if not hasattr (F , "scaled_dot_product_attention" ):
196
+ raise ImportError ("SanaAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0." )
197
+
198
+ @staticmethod
199
+ def scaled_dot_product_attention (query , key , value , attn_mask = None , dropout_p = 0.0 , is_causal = False , scale = None
200
+ ) -> torch .Tensor :
201
+ B , H , L , S = * query .size ()[:- 1 ], key .size (- 2 )
202
+ scale_factor = 1 / math .sqrt (query .size (- 1 )) if scale is None else scale
203
+ attn_bias = torch .zeros (B , H , L , S , dtype = query .dtype , device = query .device )
204
+
205
+ if attn_mask is not None :
206
+ if attn_mask .dtype == torch .bool :
207
+ attn_bias .masked_fill_ (attn_mask .logical_not (), float ("-inf" ))
208
+ else :
209
+ attn_bias += attn_mask
210
+ attn_weight = query @ key .transpose (- 2 , - 1 ) * scale_factor
211
+ attn_weight += attn_bias
212
+ attn_weight = torch .softmax (attn_weight , dim = - 1 )
213
+ attn_weight = torch .dropout (attn_weight , dropout_p , train = True )
214
+ return attn_weight @ value
215
+
216
+ # return x
217
+ def __call__ (
218
+ self ,
219
+ attn : Attention ,
220
+ hidden_states : torch .Tensor ,
221
+ encoder_hidden_states : Optional [torch .Tensor ] = None ,
222
+ attention_mask : Optional [torch .Tensor ] = None ,
223
+ ) -> torch .Tensor :
224
+ batch_size , sequence_length , _ = (
225
+ hidden_states .shape if encoder_hidden_states is None else encoder_hidden_states .shape
226
+ )
227
+
228
+ if attention_mask is not None :
229
+ attention_mask = attn .prepare_attention_mask (attention_mask , sequence_length , batch_size )
230
+ # scaled_dot_product_attention expects attention_mask shape to be
231
+ # (batch, heads, source_length, target_length)
232
+ attention_mask = attention_mask .view (batch_size , attn .heads , - 1 , attention_mask .shape [- 1 ])
233
+
234
+ query = attn .to_q (hidden_states )
235
+
236
+ if encoder_hidden_states is None :
237
+ encoder_hidden_states = hidden_states
238
+
239
+ key = attn .to_k (encoder_hidden_states )
240
+ value = attn .to_v (encoder_hidden_states )
241
+
242
+ if attn .norm_q is not None :
243
+ query = attn .norm_q (query )
244
+ if attn .norm_k is not None :
245
+ key = attn .norm_k (key )
246
+
247
+ inner_dim = key .shape [- 1 ]
248
+ head_dim = inner_dim // attn .heads
249
+
250
+ query = query .view (batch_size , - 1 , attn .heads , head_dim ).transpose (1 , 2 )
251
+
252
+ key = key .view (batch_size , - 1 , attn .heads , head_dim ).transpose (1 , 2 )
253
+ value = value .view (batch_size , - 1 , attn .heads , head_dim ).transpose (1 , 2 )
254
+
255
+ # the output of sdp = (batch, num_heads, seq_len, head_dim)
256
+ # TODO: add support for attn.scale when we move to Torch 2.1
257
+ hidden_states = self .scaled_dot_product_attention (
258
+ query , key , value , attn_mask = attention_mask , dropout_p = 0.0 , is_causal = False
259
+ )
260
+
261
+ hidden_states = hidden_states .transpose (1 , 2 ).reshape (batch_size , - 1 , attn .heads * head_dim )
262
+ hidden_states = hidden_states .to (query .dtype )
263
+
264
+ # linear proj
265
+ hidden_states = attn .to_out [0 ](hidden_states )
266
+ # dropout
267
+ hidden_states = attn .to_out [1 ](hidden_states )
268
+
269
+ hidden_states = hidden_states / attn .rescale_output_factor
270
+
271
+ return hidden_states
272
+
187
273
188
274
class SanaTransformerBlock (nn .Module ):
189
275
r"""
@@ -205,6 +291,7 @@ def __init__(
205
291
attention_out_bias : bool = True ,
206
292
mlp_ratio : float = 2.5 ,
207
293
qk_norm : Optional [str ] = None ,
294
+ cross_attention_type : str = "flash" ,
208
295
) -> None :
209
296
super ().__init__ ()
210
297
@@ -223,6 +310,12 @@ def __init__(
223
310
)
224
311
225
312
# 2. Cross Attention
313
+ if cross_attention_type == "flash" :
314
+ cross_attention_processor = SanaAttnProcessor2_0 ()
315
+ elif cross_attention_type == "vanilla" :
316
+ cross_attention_processor = SanaAttnProcessor3_0 ()
317
+ else :
318
+ raise ValueError (f"Cross attention type { cross_attention_type } is not defined." )
226
319
if cross_attention_dim is not None :
227
320
self .norm2 = nn .LayerNorm (dim , elementwise_affine = norm_elementwise_affine , eps = norm_eps )
228
321
self .attn2 = Attention (
@@ -235,7 +328,7 @@ def __init__(
235
328
dropout = dropout ,
236
329
bias = True ,
237
330
out_bias = attention_out_bias ,
238
- processor = SanaAttnProcessor2_0 () ,
331
+ processor = cross_attention_processor ,
239
332
)
240
333
241
334
# 3. Feed-forward
@@ -360,6 +453,7 @@ def __init__(
360
453
guidance_embeds_scale : float = 0.1 ,
361
454
qk_norm : Optional [str ] = None ,
362
455
timestep_scale : float = 1.0 ,
456
+ cross_attention_type : str = "flash" ,
363
457
) -> None :
364
458
super ().__init__ ()
365
459
@@ -402,6 +496,7 @@ def __init__(
402
496
norm_eps = norm_eps ,
403
497
mlp_ratio = mlp_ratio ,
404
498
qk_norm = qk_norm ,
499
+ cross_attention_type = cross_attention_type ,
405
500
)
406
501
for _ in range (num_layers )
407
502
]
0 commit comments