Skip to content

Commit 264060e

Browse files
committed
[cogview4] Add attention mask support to transformer model
1 parent 2f74c4e commit 264060e

File tree

1 file changed

+28
-8
lines changed

1 file changed

+28
-8
lines changed

src/diffusers/models/transformers/transformer_cogview4.py

Lines changed: 28 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -17,16 +17,17 @@
1717
import torch
1818
import torch.nn as nn
1919
import torch.nn.functional as F
20-
from ...loaders import PeftAdapterMixin
20+
2121
from ...configuration_utils import ConfigMixin, register_to_config
2222
from ...models.attention import FeedForward
2323
from ...models.attention_processor import Attention
2424
from ...models.modeling_utils import ModelMixin
2525
from ...models.normalization import AdaLayerNormContinuous
2626
from ...utils import logging
27-
from ..cache_utils import CacheMixin
2827
from ..embeddings import CogView3CombinedTimestepSizeEmbeddings
2928
from ..modeling_outputs import Transformer2DModelOutput
29+
from ...loaders import PeftAdapterMixin
30+
from ..cache_utils import CacheMixin
3031

3132

3233
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
@@ -123,10 +124,11 @@ def __call__(
123124
attn: Attention,
124125
hidden_states: torch.Tensor,
125126
encoder_hidden_states: torch.Tensor,
126-
attention_mask: Optional[torch.Tensor] = None,
127+
attention_mask: Optional[torch.LongTensor] = None,
127128
image_rotary_emb: Optional[torch.Tensor] = None,
128129
) -> torch.Tensor:
129-
text_seq_length = encoder_hidden_states.size(1)
130+
batch_size, text_seq_length, embed_dim = encoder_hidden_states.shape
131+
batch_size, image_seq_length, embed_dim = hidden_states.shape
130132
hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
131133

132134
# 1. QKV projections
@@ -156,6 +158,17 @@ def __call__(
156158
)
157159

158160
# 4. Attention
161+
if attention_mask is not None:
162+
# construct attention_mask for concated sequence
163+
text_attention_mask = attention_mask.float().to(query.device)
164+
attention_mask = torch.ones((batch_size, text_seq_length + image_seq_length), device=query.device)
165+
attention_mask[:, :text_seq_length] = text_attention_mask
166+
attention_mask = attention_mask.unsqueeze(2)
167+
attention_mask_matrix = attention_mask @ attention_mask.mT
168+
attention_mask_matrix = attention_mask_matrix == 1
169+
attention_mask_matrix = attention_mask_matrix.unsqueeze(1)
170+
attention_mask = attention_mask_matrix
171+
159172
hidden_states = F.scaled_dot_product_attention(
160173
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
161174
)
@@ -203,6 +216,8 @@ def forward(
203216
encoder_hidden_states: torch.Tensor,
204217
temb: Optional[torch.Tensor] = None,
205218
image_rotary_emb: Optional[torch.Tensor] = None,
219+
attention_mask: Optional[torch.Tensor] = None,
220+
**kwargs,
206221
) -> torch.Tensor:
207222
# 1. Timestep conditioning
208223
(
@@ -223,6 +238,8 @@ def forward(
223238
hidden_states=norm_hidden_states,
224239
encoder_hidden_states=norm_encoder_hidden_states,
225240
image_rotary_emb=image_rotary_emb,
241+
attention_mask=attention_mask,
242+
**kwargs,
226243
)
227244
hidden_states = hidden_states + attn_hidden_states * gate_msa.unsqueeze(1)
228245
encoder_hidden_states = encoder_hidden_states + attn_encoder_hidden_states * c_gate_msa.unsqueeze(1)
@@ -233,8 +250,8 @@ def forward(
233250
1 + c_scale_mlp.unsqueeze(1)
234251
) + c_shift_mlp.unsqueeze(1)
235252

236-
ff_output = self.ff(norm_hidden_states)
237-
ff_output_context = self.ff(norm_encoder_hidden_states)
253+
ff_output = self.ff(norm_hidden_states, **kwargs)
254+
ff_output_context = self.ff(norm_encoder_hidden_states, **kwargs)
238255
hidden_states = hidden_states + ff_output * gate_mlp.unsqueeze(1)
239256
encoder_hidden_states = encoder_hidden_states + ff_output_context * c_gate_mlp.unsqueeze(1)
240257

@@ -381,6 +398,8 @@ def forward(
381398
target_size: torch.Tensor,
382399
crop_coords: torch.Tensor,
383400
return_dict: bool = True,
401+
attention_mask: Optional[torch.Tensor] = None,
402+
**kwargs,
384403
) -> Union[torch.Tensor, Transformer2DModelOutput]:
385404
batch_size, num_channels, height, width = hidden_states.shape
386405

@@ -391,6 +410,7 @@ def forward(
391410
p = self.config.patch_size
392411
post_patch_height = height // p
393412
post_patch_width = width // p
413+
394414
hidden_states, encoder_hidden_states = self.patch_embed(hidden_states, encoder_hidden_states)
395415

396416
temb = self.time_condition_embed(timestep, original_size, target_size, crop_coords, hidden_states.dtype)
@@ -400,11 +420,11 @@ def forward(
400420
for block in self.transformer_blocks:
401421
if torch.is_grad_enabled() and self.gradient_checkpointing:
402422
hidden_states, encoder_hidden_states = self._gradient_checkpointing_func(
403-
block, hidden_states, encoder_hidden_states, temb, image_rotary_emb
423+
block, hidden_states, encoder_hidden_states, temb, image_rotary_emb, attention_mask, **kwargs
404424
)
405425
else:
406426
hidden_states, encoder_hidden_states = block(
407-
hidden_states, encoder_hidden_states, temb, image_rotary_emb
427+
hidden_states, encoder_hidden_states, temb, image_rotary_emb, attention_mask, **kwargs
408428
)
409429

410430
# 4. Output norm & projection

0 commit comments

Comments
 (0)