diff --git a/src/diffusers/models/transformers/transformer_hunyuan_video.py b/src/diffusers/models/transformers/transformer_hunyuan_video.py index 36f914f0b5c1..d9100b2f54d0 100644 --- a/src/diffusers/models/transformers/transformer_hunyuan_video.py +++ b/src/diffusers/models/transformers/transformer_hunyuan_video.py @@ -1068,17 +1068,15 @@ def forward( latent_sequence_length = hidden_states.shape[1] condition_sequence_length = encoder_hidden_states.shape[1] sequence_length = latent_sequence_length + condition_sequence_length - attention_mask = torch.zeros( + attention_mask = torch.ones( batch_size, sequence_length, device=hidden_states.device, dtype=torch.bool ) # [B, N] - effective_condition_sequence_length = encoder_attention_mask.sum(dim=1, dtype=torch.int) # [B,] effective_sequence_length = latent_sequence_length + effective_condition_sequence_length - - for i in range(batch_size): - attention_mask[i, : effective_sequence_length[i]] = True - # [B, 1, 1, N], for broadcasting across attention heads - attention_mask = attention_mask.unsqueeze(1).unsqueeze(1) + indices = torch.arange(sequence_length, device=hidden_states.device).unsqueeze(0) # [1, N] + mask_indices = indices >= effective_sequence_length.unsqueeze(1) # [B, N] + attention_mask = attention_mask.masked_fill(mask_indices, False) + attention_mask = attention_mask.unsqueeze(1).unsqueeze(1) # [B, 1, 1, N] # 4. Transformer blocks if torch.is_grad_enabled() and self.gradient_checkpointing: