Skip to content

Rewrite AuraFlowPatchEmbed.pe_selection_index_based_on_dim to be torch.compile compatible #11297

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
32 changes: 20 additions & 12 deletions src/diffusers/models/transformers/auraflow_transformer_2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,15 +74,23 @@ def pe_selection_index_based_on_dim(self, h, w):
# PE will be viewed as 2d-grid, and H/p x W/p of the PE will be selected
# because original input are in flattened format, we have to flatten this 2d grid as well.
h_p, w_p = h // self.patch_size, w // self.patch_size
original_pe_indexes = torch.arange(self.pos_embed.shape[1])
h_max, w_max = int(self.pos_embed_max_size**0.5), int(self.pos_embed_max_size**0.5)
original_pe_indexes = original_pe_indexes.view(h_max, w_max)

# Calculate the top-left corner indices for the centered patch grid
starth = h_max // 2 - h_p // 2
endh = starth + h_p
startw = w_max // 2 - w_p // 2
endw = startw + w_p
original_pe_indexes = original_pe_indexes[starth:endh, startw:endw]
return original_pe_indexes.flatten()

# Generate the row and column indices for the desired patch grid
rows = torch.arange(starth, starth + h_p, device=self.pos_embed.device)
cols = torch.arange(startw, startw + w_p, device=self.pos_embed.device)

# Create a 2D grid of indices
row_indices, col_indices = torch.meshgrid(rows, cols, indexing="ij")

# Convert the 2D grid indices to flattened 1D indices
selected_indices = (row_indices * w_max + col_indices).flatten()

return selected_indices

def forward(self, latent):
batch_size, num_channels, height, width = latent.size()
Expand Down Expand Up @@ -275,17 +283,17 @@ class AuraFlowTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, From
sample_size (`int`): The width of the latent images. This is fixed during training since
it is used to learn a number of position embeddings.
patch_size (`int`): Patch size to turn the input data into small patches.
in_channels (`int`, *optional*, defaults to 16): The number of channels in the input.
in_channels (`int`, *optional*, defaults to 4): The number of channels in the input.
num_mmdit_layers (`int`, *optional*, defaults to 4): The number of layers of MMDiT Transformer blocks to use.
num_single_dit_layers (`int`, *optional*, defaults to 4):
num_single_dit_layers (`int`, *optional*, defaults to 32):
The number of layers of Transformer blocks to use. These blocks use concatenated image and text
representations.
attention_head_dim (`int`, *optional*, defaults to 64): The number of channels in each head.
num_attention_heads (`int`, *optional*, defaults to 18): The number of heads to use for multi-head attention.
attention_head_dim (`int`, *optional*, defaults to 256): The number of channels in each head.
num_attention_heads (`int`, *optional*, defaults to 12): The number of heads to use for multi-head attention.
joint_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use.
caption_projection_dim (`int`): Number of dimensions to use when projecting the `encoder_hidden_states`.
out_channels (`int`, defaults to 16): Number of output channels.
pos_embed_max_size (`int`, defaults to 4096): Maximum positions to embed from the image latents.
out_channels (`int`, defaults to 4): Number of output channels.
pos_embed_max_size (`int`, defaults to 1024): Maximum positions to embed from the image latents.
"""

_no_split_modules = ["AuraFlowJointTransformerBlock", "AuraFlowSingleTransformerBlock", "AuraFlowPatchEmbed"]
Expand Down