17
17
import torch
18
18
import torch .nn as nn
19
19
import torch .nn .functional as F
20
- from ... loaders import PeftAdapterMixin
20
+
21
21
from ...configuration_utils import ConfigMixin , register_to_config
22
22
from ...models .attention import FeedForward
23
23
from ...models .attention_processor import Attention
24
24
from ...models .modeling_utils import ModelMixin
25
25
from ...models .normalization import AdaLayerNormContinuous
26
26
from ...utils import logging
27
- from ..cache_utils import CacheMixin
28
27
from ..embeddings import CogView3CombinedTimestepSizeEmbeddings
29
28
from ..modeling_outputs import Transformer2DModelOutput
29
+ from ...loaders import PeftAdapterMixin
30
+ from ..cache_utils import CacheMixin
30
31
31
32
32
33
logger = logging .get_logger (__name__ ) # pylint: disable=invalid-name
@@ -123,10 +124,11 @@ def __call__(
123
124
attn : Attention ,
124
125
hidden_states : torch .Tensor ,
125
126
encoder_hidden_states : torch .Tensor ,
126
- attention_mask : Optional [torch .Tensor ] = None ,
127
+ attention_mask : Optional [torch .LongTensor ] = None ,
127
128
image_rotary_emb : Optional [torch .Tensor ] = None ,
128
129
) -> torch .Tensor :
129
- text_seq_length = encoder_hidden_states .size (1 )
130
+ batch_size , text_seq_length , embed_dim = encoder_hidden_states .shape
131
+ batch_size , image_seq_length , embed_dim = hidden_states .shape
130
132
hidden_states = torch .cat ([encoder_hidden_states , hidden_states ], dim = 1 )
131
133
132
134
# 1. QKV projections
@@ -156,6 +158,17 @@ def __call__(
156
158
)
157
159
158
160
# 4. Attention
161
+ if attention_mask is not None :
162
+ # construct attention_mask for concated sequence
163
+ text_attention_mask = attention_mask .float ().to (query .device )
164
+ attention_mask = torch .ones ((batch_size , text_seq_length + image_seq_length ), device = query .device )
165
+ attention_mask [:, :text_seq_length ] = text_attention_mask
166
+ attention_mask = attention_mask .unsqueeze (2 )
167
+ attention_mask_matrix = attention_mask @ attention_mask .mT
168
+ attention_mask_matrix = attention_mask_matrix == 1
169
+ attention_mask_matrix = attention_mask_matrix .unsqueeze (1 )
170
+ attention_mask = attention_mask_matrix
171
+
159
172
hidden_states = F .scaled_dot_product_attention (
160
173
query , key , value , attn_mask = attention_mask , dropout_p = 0.0 , is_causal = False
161
174
)
@@ -203,6 +216,8 @@ def forward(
203
216
encoder_hidden_states : torch .Tensor ,
204
217
temb : Optional [torch .Tensor ] = None ,
205
218
image_rotary_emb : Optional [torch .Tensor ] = None ,
219
+ attention_mask : Optional [torch .Tensor ] = None ,
220
+ ** kwargs ,
206
221
) -> torch .Tensor :
207
222
# 1. Timestep conditioning
208
223
(
@@ -223,6 +238,8 @@ def forward(
223
238
hidden_states = norm_hidden_states ,
224
239
encoder_hidden_states = norm_encoder_hidden_states ,
225
240
image_rotary_emb = image_rotary_emb ,
241
+ attention_mask = attention_mask ,
242
+ ** kwargs ,
226
243
)
227
244
hidden_states = hidden_states + attn_hidden_states * gate_msa .unsqueeze (1 )
228
245
encoder_hidden_states = encoder_hidden_states + attn_encoder_hidden_states * c_gate_msa .unsqueeze (1 )
@@ -233,8 +250,8 @@ def forward(
233
250
1 + c_scale_mlp .unsqueeze (1 )
234
251
) + c_shift_mlp .unsqueeze (1 )
235
252
236
- ff_output = self .ff (norm_hidden_states )
237
- ff_output_context = self .ff (norm_encoder_hidden_states )
253
+ ff_output = self .ff (norm_hidden_states , ** kwargs )
254
+ ff_output_context = self .ff (norm_encoder_hidden_states , ** kwargs )
238
255
hidden_states = hidden_states + ff_output * gate_mlp .unsqueeze (1 )
239
256
encoder_hidden_states = encoder_hidden_states + ff_output_context * c_gate_mlp .unsqueeze (1 )
240
257
@@ -381,6 +398,8 @@ def forward(
381
398
target_size : torch .Tensor ,
382
399
crop_coords : torch .Tensor ,
383
400
return_dict : bool = True ,
401
+ attention_mask : Optional [torch .Tensor ] = None ,
402
+ ** kwargs ,
384
403
) -> Union [torch .Tensor , Transformer2DModelOutput ]:
385
404
batch_size , num_channels , height , width = hidden_states .shape
386
405
@@ -391,6 +410,7 @@ def forward(
391
410
p = self .config .patch_size
392
411
post_patch_height = height // p
393
412
post_patch_width = width // p
413
+
394
414
hidden_states , encoder_hidden_states = self .patch_embed (hidden_states , encoder_hidden_states )
395
415
396
416
temb = self .time_condition_embed (timestep , original_size , target_size , crop_coords , hidden_states .dtype )
@@ -400,11 +420,11 @@ def forward(
400
420
for block in self .transformer_blocks :
401
421
if torch .is_grad_enabled () and self .gradient_checkpointing :
402
422
hidden_states , encoder_hidden_states = self ._gradient_checkpointing_func (
403
- block , hidden_states , encoder_hidden_states , temb , image_rotary_emb
423
+ block , hidden_states , encoder_hidden_states , temb , image_rotary_emb , attention_mask , ** kwargs
404
424
)
405
425
else :
406
426
hidden_states , encoder_hidden_states = block (
407
- hidden_states , encoder_hidden_states , temb , image_rotary_emb
427
+ hidden_states , encoder_hidden_states , temb , image_rotary_emb , attention_mask , ** kwargs
408
428
)
409
429
410
430
# 4. Output norm & projection
0 commit comments