From 85ea323fe95f8afe4b63a922482f011cd06c8581 Mon Sep 17 00:00:00 2001 From: AstraliteHeart <81396681+AstraliteHeart@users.noreply.github.com> Date: Wed, 9 Apr 2025 14:34:22 -0700 Subject: [PATCH 1/3] Update pe_selection_index_based_on_dim --- .../models/transformers/auraflow_transformer_2d.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/src/diffusers/models/transformers/auraflow_transformer_2d.py b/src/diffusers/models/transformers/auraflow_transformer_2d.py index 4938ed23c506..c903358dc9a9 100644 --- a/src/diffusers/models/transformers/auraflow_transformer_2d.py +++ b/src/diffusers/models/transformers/auraflow_transformer_2d.py @@ -78,11 +78,11 @@ def pe_selection_index_based_on_dim(self, h, w): h_max, w_max = int(self.pos_embed_max_size**0.5), int(self.pos_embed_max_size**0.5) original_pe_indexes = original_pe_indexes.view(h_max, w_max) starth = h_max // 2 - h_p // 2 - endh = starth + h_p - startw = w_max // 2 - w_p // 2 - endw = startw + w_p - original_pe_indexes = original_pe_indexes[starth:endh, startw:endw] - return original_pe_indexes.flatten() + startw = w_max // 2 - w_p // 2 + narrowed = torch.narrow(original_pe_indexes, 0, starth, h_p) + narrowed = torch.narrow(narrowed, 1, startw, w_p) + + return narrowed.flatten() def forward(self, latent): batch_size, num_channels, height, width = latent.size() From 5c6c6791ea69cc7a0bf6a0e65c890ba4d6f5a556 Mon Sep 17 00:00:00 2001 From: AstraliteHeart <81396681+AstraliteHeart@users.noreply.github.com> Date: Fri, 11 Apr 2025 20:11:23 -0700 Subject: [PATCH 2/3] Make pe_selection_index_based_on_dim work with torh.compile --- .../transformers/auraflow_transformer_2d.py | 22 +++++++++++++------ 1 file changed, 15 insertions(+), 7 deletions(-) diff --git a/src/diffusers/models/transformers/auraflow_transformer_2d.py b/src/diffusers/models/transformers/auraflow_transformer_2d.py index c903358dc9a9..4e291dd57b4f 100644 --- a/src/diffusers/models/transformers/auraflow_transformer_2d.py +++ b/src/diffusers/models/transformers/auraflow_transformer_2d.py @@ -74,15 +74,23 @@ def pe_selection_index_based_on_dim(self, h, w): # PE will be viewed as 2d-grid, and H/p x W/p of the PE will be selected # because original input are in flattened format, we have to flatten this 2d grid as well. h_p, w_p = h // self.patch_size, w // self.patch_size - original_pe_indexes = torch.arange(self.pos_embed.shape[1]) h_max, w_max = int(self.pos_embed_max_size**0.5), int(self.pos_embed_max_size**0.5) - original_pe_indexes = original_pe_indexes.view(h_max, w_max) + + # Calculate the top-left corner indices for the centered patch grid starth = h_max // 2 - h_p // 2 - startw = w_max // 2 - w_p // 2 - narrowed = torch.narrow(original_pe_indexes, 0, starth, h_p) - narrowed = torch.narrow(narrowed, 1, startw, w_p) - - return narrowed.flatten() + startw = w_max // 2 - w_p // 2 + + # Generate the row and column indices for the desired patch grid + rows = torch.arange(starth, starth + h_p, device=self.pos_embed.device) + cols = torch.arange(startw, startw + w_p, device=self.pos_embed.device) + + # Create a 2D grid of indices + row_indices, col_indices = torch.meshgrid(rows, cols, indexing="ij") + + # Convert the 2D grid indices to flattened 1D indices + selected_indices = (row_indices * w_max + col_indices).flatten() + + return selected_indices def forward(self, latent): batch_size, num_channels, height, width = latent.size() From ff5674c5d79e53c0a07b30d11076f6ccaaba92a8 Mon Sep 17 00:00:00 2001 From: AstraliteHeart <81396681+AstraliteHeart@users.noreply.github.com> Date: Tue, 15 Apr 2025 03:49:16 -0700 Subject: [PATCH 3/3] Fix AuraFlowTransformer2DModel's dpcstring default values --- .../models/transformers/auraflow_transformer_2d.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/src/diffusers/models/transformers/auraflow_transformer_2d.py b/src/diffusers/models/transformers/auraflow_transformer_2d.py index 4e291dd57b4f..607558b0debf 100644 --- a/src/diffusers/models/transformers/auraflow_transformer_2d.py +++ b/src/diffusers/models/transformers/auraflow_transformer_2d.py @@ -270,17 +270,17 @@ class AuraFlowTransformer2DModel(ModelMixin, ConfigMixin, FromOriginalModelMixin sample_size (`int`): The width of the latent images. This is fixed during training since it is used to learn a number of position embeddings. patch_size (`int`): Patch size to turn the input data into small patches. - in_channels (`int`, *optional*, defaults to 16): The number of channels in the input. + in_channels (`int`, *optional*, defaults to 4): The number of channels in the input. num_mmdit_layers (`int`, *optional*, defaults to 4): The number of layers of MMDiT Transformer blocks to use. - num_single_dit_layers (`int`, *optional*, defaults to 4): + num_single_dit_layers (`int`, *optional*, defaults to 32): The number of layers of Transformer blocks to use. These blocks use concatenated image and text representations. - attention_head_dim (`int`, *optional*, defaults to 64): The number of channels in each head. - num_attention_heads (`int`, *optional*, defaults to 18): The number of heads to use for multi-head attention. + attention_head_dim (`int`, *optional*, defaults to 256): The number of channels in each head. + num_attention_heads (`int`, *optional*, defaults to 12): The number of heads to use for multi-head attention. joint_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use. caption_projection_dim (`int`): Number of dimensions to use when projecting the `encoder_hidden_states`. - out_channels (`int`, defaults to 16): Number of output channels. - pos_embed_max_size (`int`, defaults to 4096): Maximum positions to embed from the image latents. + out_channels (`int`, defaults to 4): Number of output channels. + pos_embed_max_size (`int`, defaults to 1024): Maximum positions to embed from the image latents. """ _no_split_modules = ["AuraFlowJointTransformerBlock", "AuraFlowSingleTransformerBlock", "AuraFlowPatchEmbed"]