Skip to content

Commit c339be0

Browse files
committed
[cogview4] Add attention mask support to transformer model
1 parent 2f74c4e commit c339be0

File tree

1 file changed

+28
-9
lines changed

1 file changed

+28
-9
lines changed

src/diffusers/models/transformers/transformer_cogview4.py

Lines changed: 28 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -17,16 +17,17 @@
1717
import torch
1818
import torch.nn as nn
1919
import torch.nn.functional as F
20-
from ...loaders import PeftAdapterMixin
20+
2121
from ...configuration_utils import ConfigMixin, register_to_config
2222
from ...models.attention import FeedForward
2323
from ...models.attention_processor import Attention
2424
from ...models.modeling_utils import ModelMixin
2525
from ...models.normalization import AdaLayerNormContinuous
2626
from ...utils import logging
27-
from ..cache_utils import CacheMixin
2827
from ..embeddings import CogView3CombinedTimestepSizeEmbeddings
2928
from ..modeling_outputs import Transformer2DModelOutput
29+
from ...loaders import PeftAdapterMixin
30+
from ..cache_utils import CacheMixin
3031

3132

3233
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
@@ -123,10 +124,11 @@ def __call__(
123124
attn: Attention,
124125
hidden_states: torch.Tensor,
125126
encoder_hidden_states: torch.Tensor,
126-
attention_mask: Optional[torch.Tensor] = None,
127+
attention_mask: Optional[torch.LongTensor] = None,
127128
image_rotary_emb: Optional[torch.Tensor] = None,
128129
) -> torch.Tensor:
129-
text_seq_length = encoder_hidden_states.size(1)
130+
batch_size, text_seq_length, embed_dim = encoder_hidden_states.shape
131+
batch_size, image_seq_length, embed_dim = hidden_states.shape
130132
hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
131133

132134
# 1. QKV projections
@@ -156,8 +158,18 @@ def __call__(
156158
)
157159

158160
# 4. Attention
161+
if attention_mask is not None:
162+
# construct attention_mask for concated sequence
163+
text_attention_mask = attention_mask.float().to(query.device)
164+
attention_mask = torch.ones((batch_size, text_seq_length + image_seq_length), device=query.device)
165+
attention_mask[:, :text_seq_length] = text_attention_mask
166+
attention_mask = attention_mask.unsqueeze(2)
167+
attention_mask_matrix = attention_mask @ attention_mask.mT
168+
attention_mask_matrix = attention_mask_matrix == 1
169+
attention_mask_matrix = attention_mask_matrix.unsqueeze(1)
170+
159171
hidden_states = F.scaled_dot_product_attention(
160-
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
172+
query, key, value, attn_mask=attention_mask_matrix, dropout_p=0.0, is_causal=False
161173
)
162174
hidden_states = hidden_states.transpose(1, 2).flatten(2, 3)
163175
hidden_states = hidden_states.type_as(query)
@@ -203,6 +215,8 @@ def forward(
203215
encoder_hidden_states: torch.Tensor,
204216
temb: Optional[torch.Tensor] = None,
205217
image_rotary_emb: Optional[torch.Tensor] = None,
218+
attention_mask: Optional[torch.Tensor] = None,
219+
**kwargs,
206220
) -> torch.Tensor:
207221
# 1. Timestep conditioning
208222
(
@@ -223,6 +237,8 @@ def forward(
223237
hidden_states=norm_hidden_states,
224238
encoder_hidden_states=norm_encoder_hidden_states,
225239
image_rotary_emb=image_rotary_emb,
240+
attention_mask=attention_mask,
241+
**kwargs,
226242
)
227243
hidden_states = hidden_states + attn_hidden_states * gate_msa.unsqueeze(1)
228244
encoder_hidden_states = encoder_hidden_states + attn_encoder_hidden_states * c_gate_msa.unsqueeze(1)
@@ -233,8 +249,8 @@ def forward(
233249
1 + c_scale_mlp.unsqueeze(1)
234250
) + c_shift_mlp.unsqueeze(1)
235251

236-
ff_output = self.ff(norm_hidden_states)
237-
ff_output_context = self.ff(norm_encoder_hidden_states)
252+
ff_output = self.ff(norm_hidden_states, **kwargs)
253+
ff_output_context = self.ff(norm_encoder_hidden_states, **kwargs)
238254
hidden_states = hidden_states + ff_output * gate_mlp.unsqueeze(1)
239255
encoder_hidden_states = encoder_hidden_states + ff_output_context * c_gate_mlp.unsqueeze(1)
240256

@@ -381,6 +397,8 @@ def forward(
381397
target_size: torch.Tensor,
382398
crop_coords: torch.Tensor,
383399
return_dict: bool = True,
400+
attention_mask: Optional[torch.Tensor] = None,
401+
**kwargs,
384402
) -> Union[torch.Tensor, Transformer2DModelOutput]:
385403
batch_size, num_channels, height, width = hidden_states.shape
386404

@@ -391,6 +409,7 @@ def forward(
391409
p = self.config.patch_size
392410
post_patch_height = height // p
393411
post_patch_width = width // p
412+
394413
hidden_states, encoder_hidden_states = self.patch_embed(hidden_states, encoder_hidden_states)
395414

396415
temb = self.time_condition_embed(timestep, original_size, target_size, crop_coords, hidden_states.dtype)
@@ -400,11 +419,11 @@ def forward(
400419
for block in self.transformer_blocks:
401420
if torch.is_grad_enabled() and self.gradient_checkpointing:
402421
hidden_states, encoder_hidden_states = self._gradient_checkpointing_func(
403-
block, hidden_states, encoder_hidden_states, temb, image_rotary_emb
422+
block, hidden_states, encoder_hidden_states, temb, image_rotary_emb, attention_mask, **kwargs
404423
)
405424
else:
406425
hidden_states, encoder_hidden_states = block(
407-
hidden_states, encoder_hidden_states, temb, image_rotary_emb
426+
hidden_states, encoder_hidden_states, temb, image_rotary_emb, attention_mask, **kwargs
408427
)
409428

410429
# 4. Output norm & projection

0 commit comments

Comments
 (0)