From 700d1a76cc61a3d917accb50b5095dd6c2cd8581 Mon Sep 17 00:00:00 2001 From: Pham Hong Vinh Date: Thu, 28 Nov 2024 18:14:07 +0700 Subject: [PATCH 1/5] support attention mask --- src/diffusers/models/attention_processor.py | 38 +++++++++++++++++++-- 1 file changed, 35 insertions(+), 3 deletions(-) diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index ffbf4a0056c6..ba9667ad14e8 100644 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -1089,7 +1089,7 @@ def __call__( ) -> torch.FloatTensor: residual = hidden_states - batch_size = hidden_states.shape[0] + batch_size, sequence_length, _ = hidden_states.shape # `sample` projections. query = attn.to_q(hidden_states) @@ -1129,11 +1129,27 @@ def __call__( if attn.norm_added_k is not None: encoder_hidden_states_key_proj = attn.norm_added_k(encoder_hidden_states_key_proj) + query_pad_size = query.size(2) + key_pad_size = key.size(2) + query = torch.cat([query, encoder_hidden_states_query_proj], dim=2) key = torch.cat([key, encoder_hidden_states_key_proj], dim=2) value = torch.cat([value, encoder_hidden_states_value_proj], dim=2) - hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False) + if attention_mask is not None: + padding_shape = (attention_mask.shape[0], query_pad_size) + padding = torch.zeros(padding_shape, dtype=attention_mask.dtype, device=attention_mask.device) + query_attention_mask = torch.cat([padding, attention_mask], dim=2).unsqueeze(2) # N, Iq + Tq, 1 + + padding_shape = (attention_mask.shape[0], key_pad_size) + padding = torch.zeros(padding_shape, dtype=attention_mask.dtype, device=attention_mask.device) + key_attention_mask = torch.cat([padding, attention_mask], dim=2).unsqueeze(1) # N, 1, Ik + Tk + + attention_mask = torch.bmm(query_attention_mask, key_attention_mask) + + hidden_states = F.scaled_dot_product_attention( + query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False + ) hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) hidden_states = hidden_states.to(query.dtype) @@ -1896,18 +1912,34 @@ def __call__( if attn.norm_added_k is not None: encoder_hidden_states_key_proj = attn.norm_added_k(encoder_hidden_states_key_proj) + query_pad_size = query.size(2) + key_pad_size = key.size(2) + # attention query = torch.cat([encoder_hidden_states_query_proj, query], dim=2) key = torch.cat([encoder_hidden_states_key_proj, key], dim=2) value = torch.cat([encoder_hidden_states_value_proj, value], dim=2) + if attention_mask is not None: + padding_shape = (attention_mask.shape[0], query_pad_size) + padding = torch.zeros(padding_shape, dtype=attention_mask.dtype, device=attention_mask.device) + query_attention_mask = torch.cat([padding, attention_mask], dim=2).unsqueeze(2) # N, Iq + Tq, 1 + + padding_shape = (attention_mask.shape[0], key_pad_size) + padding = torch.zeros(padding_shape, dtype=attention_mask.dtype, device=attention_mask.device) + key_attention_mask = torch.cat([padding, attention_mask], dim=2).unsqueeze(1) # N, 1, Ik + Tk + + attention_mask = torch.bmm(query_attention_mask, key_attention_mask) + if image_rotary_emb is not None: from .embeddings import apply_rotary_emb query = apply_rotary_emb(query, image_rotary_emb) key = apply_rotary_emb(key, image_rotary_emb) - hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False) + hidden_states = F.scaled_dot_product_attention( + query, key, value, dropout_p=0.0, is_causal=False, attention_mask=attention_mask + ) hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) hidden_states = hidden_states.to(query.dtype) From 799ee96636f790bfc675fb09c816ce334312c9ea Mon Sep 17 00:00:00 2001 From: Pham Hong Vinh Date: Mon, 23 Dec 2024 10:15:28 +0700 Subject: [PATCH 2/5] resolve conflict --- src/diffusers/models/attention_processor.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index a4f3d7a6338b..c8716e7643c8 100644 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -2349,11 +2349,7 @@ def __call__( key = apply_rotary_emb(key, image_rotary_emb) hidden_states = F.scaled_dot_product_attention( -<<<<<<< HEAD - query, key, value, dropout_p=0.0, is_causal=False, attention_mask=attention_mask -======= query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False ->>>>>>> main ) hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) hidden_states = hidden_states.to(query.dtype) From 01b9179e3eac96e56228e3ff79cba01ad3d5f388 Mon Sep 17 00:00:00 2001 From: Pham Hong Vinh Date: Mon, 23 Dec 2024 13:57:58 +0700 Subject: [PATCH 3/5] refactor mask making --- src/diffusers/models/attention_processor.py | 56 ++++++++++++--------- 1 file changed, 32 insertions(+), 24 deletions(-) diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index c8716e7643c8..aa36b6603478 100644 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -731,6 +731,36 @@ def prepare_attention_mask( return attention_mask + def prepare_joint_attention_mask( + self, attention_mask: torch.Tensor, target_length: int + ) -> torch.Tensor: + if attention_mask is None: + return attention_mask + + current_length: int = attention_mask.shape[-1] + remaining_length: int = target_length - current_length + if current_length != target_length: + if attention_mask.device.type == "mps": + # HACK: MPS: Does not support padding by greater than dimension of input tensor. + # Instead, we can manually construct the padding tensor. + padding_shape = (attention_mask.shape[0], remaining_length) + padding = torch.ones(padding_shape, dtype=attention_mask.dtype, device=attention_mask.device) + attention_mask = torch.cat([padding, attention_mask], dim=2) + else: + attention_mask = F.pad(attention_mask, (remaining_length, 0), value=1.0) + + if attention_mask.dim() == 3: + # If provided attention mask has shape [batch_size, target_seq_length, src_seq_length], + # we only need to broadcast it to all the heads + attention_mask = attention_mask[:, None, :, :] + elif attention_mask.dim() == 2: + # If provided attention mask has shape [batch_size, seq_length], + # we boardcast both the heads and the target sequences, + # there is no need to mask all the lines for target padding token as it would not affect other non-padding tokens + attention_mask = attention_mask[:, None, None, :] + + return attention_mask + def norm_encoder_hidden_states(self, encoder_hidden_states: torch.Tensor) -> torch.Tensor: r""" Normalize the encoder hidden states. Requires `self.norm_cross` to be specified when constructing the @@ -1454,23 +1484,12 @@ def __call__( if attn.norm_added_k is not None: encoder_hidden_states_key_proj = attn.norm_added_k(encoder_hidden_states_key_proj) - query_pad_size = query.size(2) - key_pad_size = key.size(2) - query = torch.cat([query, encoder_hidden_states_query_proj], dim=2) key = torch.cat([key, encoder_hidden_states_key_proj], dim=2) value = torch.cat([value, encoder_hidden_states_value_proj], dim=2) if attention_mask is not None: - padding_shape = (attention_mask.shape[0], query_pad_size) - padding = torch.zeros(padding_shape, dtype=attention_mask.dtype, device=attention_mask.device) - query_attention_mask = torch.cat([padding, attention_mask], dim=2).unsqueeze(2) # N, Iq + Tq, 1 - - padding_shape = (attention_mask.shape[0], key_pad_size) - padding = torch.zeros(padding_shape, dtype=attention_mask.dtype, device=attention_mask.device) - key_attention_mask = torch.cat([padding, attention_mask], dim=2).unsqueeze(1) # N, 1, Ik + Tk - - attention_mask = torch.bmm(query_attention_mask, key_attention_mask) + attention_mask = attn.prepare_joint_attention_mask(attention_mask, query.shape[2]) hidden_states = F.scaled_dot_product_attention( query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False @@ -2323,24 +2342,13 @@ def __call__( if attn.norm_added_k is not None: encoder_hidden_states_key_proj = attn.norm_added_k(encoder_hidden_states_key_proj) - query_pad_size = query.size(2) - key_pad_size = key.size(2) - # attention query = torch.cat([encoder_hidden_states_query_proj, query], dim=2) key = torch.cat([encoder_hidden_states_key_proj, key], dim=2) value = torch.cat([encoder_hidden_states_value_proj, value], dim=2) if attention_mask is not None: - padding_shape = (attention_mask.shape[0], query_pad_size) - padding = torch.zeros(padding_shape, dtype=attention_mask.dtype, device=attention_mask.device) - query_attention_mask = torch.cat([padding, attention_mask], dim=2).unsqueeze(2) # N, Iq + Tq, 1 - - padding_shape = (attention_mask.shape[0], key_pad_size) - padding = torch.zeros(padding_shape, dtype=attention_mask.dtype, device=attention_mask.device) - key_attention_mask = torch.cat([padding, attention_mask], dim=2).unsqueeze(1) # N, 1, Ik + Tk - - attention_mask = torch.bmm(query_attention_mask, key_attention_mask) + attention_mask = attn.prepare_joint_attention_mask(attention_mask, query.shape[2]) if image_rotary_emb is not None: from .embeddings import apply_rotary_emb From 1d4e9c412b6977468edf94a34dd915fbc2e451f1 Mon Sep 17 00:00:00 2001 From: Pham Hong Vinh Date: Mon, 23 Dec 2024 13:59:17 +0700 Subject: [PATCH 4/5] undo some changes --- src/diffusers/models/attention_processor.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index aa36b6603478..073653e3af57 100644 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -1444,7 +1444,7 @@ def __call__( ) -> torch.FloatTensor: residual = hidden_states - batch_size, sequence_length, _ = hidden_states.shape + batch_size = hidden_states.shape[0] # `sample` projections. query = attn.to_q(hidden_states) From cf27270e0b74802c0aeb6243f63d0d0e834f84a6 Mon Sep 17 00:00:00 2001 From: Pham Hong Vinh Date: Mon, 23 Dec 2024 23:18:18 +0700 Subject: [PATCH 5/5] convert mask to float --- src/diffusers/models/attention_processor.py | 17 ++++++++++++++--- 1 file changed, 14 insertions(+), 3 deletions(-) diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index 073653e3af57..9d5f4a75ee1c 100644 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -732,7 +732,7 @@ def prepare_attention_mask( return attention_mask def prepare_joint_attention_mask( - self, attention_mask: torch.Tensor, target_length: int + self, attention_mask: torch.Tensor, target_length: int, dtype: torch.dtype ) -> torch.Tensor: if attention_mask is None: return attention_mask @@ -759,6 +759,13 @@ def prepare_joint_attention_mask( # there is no need to mask all the lines for target padding token as it would not affect other non-padding tokens attention_mask = attention_mask[:, None, None, :] + # Since attention_mask is 1.0 for positions we want to attend and 0.0 for + # masked positions, this operation will create a tensor which is 0.0 for + # positions we want to attend and the dtype's smallest value for masked positions. + # Since we are adding it to the raw scores before the softmax, this is + # effectively the same as removing these entirely. + attention_mask = (1.0 - attention_mask) * torch.finfo(dtype).min + return attention_mask def norm_encoder_hidden_states(self, encoder_hidden_states: torch.Tensor) -> torch.Tensor: @@ -1489,7 +1496,9 @@ def __call__( value = torch.cat([value, encoder_hidden_states_value_proj], dim=2) if attention_mask is not None: - attention_mask = attn.prepare_joint_attention_mask(attention_mask, query.shape[2]) + attention_mask = attn.prepare_joint_attention_mask(attention_mask, key.shape[2], key.dtype) + else: + attention_mask = None hidden_states = F.scaled_dot_product_attention( query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False @@ -2348,7 +2357,9 @@ def __call__( value = torch.cat([encoder_hidden_states_value_proj, value], dim=2) if attention_mask is not None: - attention_mask = attn.prepare_joint_attention_mask(attention_mask, query.shape[2]) + attention_mask = attn.prepare_joint_attention_mask(attention_mask, key.shape[2], key.dtype) + else: + attention_mask = None if image_rotary_emb is not None: from .embeddings import apply_rotary_emb