Skip to content

Commit 2f1c326

Browse files
authored
Merge branch 'main' into metadata-lora
2 parents e98fb84 + b6156aa commit 2f1c326

File tree

2 files changed

+118
-13
lines changed

2 files changed

+118
-13
lines changed

src/diffusers/models/transformers/auraflow_transformer_2d.py

Lines changed: 20 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -74,15 +74,23 @@ def pe_selection_index_based_on_dim(self, h, w):
7474
# PE will be viewed as 2d-grid, and H/p x W/p of the PE will be selected
7575
# because original input are in flattened format, we have to flatten this 2d grid as well.
7676
h_p, w_p = h // self.patch_size, w // self.patch_size
77-
original_pe_indexes = torch.arange(self.pos_embed.shape[1])
7877
h_max, w_max = int(self.pos_embed_max_size**0.5), int(self.pos_embed_max_size**0.5)
79-
original_pe_indexes = original_pe_indexes.view(h_max, w_max)
78+
79+
# Calculate the top-left corner indices for the centered patch grid
8080
starth = h_max // 2 - h_p // 2
81-
endh = starth + h_p
8281
startw = w_max // 2 - w_p // 2
83-
endw = startw + w_p
84-
original_pe_indexes = original_pe_indexes[starth:endh, startw:endw]
85-
return original_pe_indexes.flatten()
82+
83+
# Generate the row and column indices for the desired patch grid
84+
rows = torch.arange(starth, starth + h_p, device=self.pos_embed.device)
85+
cols = torch.arange(startw, startw + w_p, device=self.pos_embed.device)
86+
87+
# Create a 2D grid of indices
88+
row_indices, col_indices = torch.meshgrid(rows, cols, indexing="ij")
89+
90+
# Convert the 2D grid indices to flattened 1D indices
91+
selected_indices = (row_indices * w_max + col_indices).flatten()
92+
93+
return selected_indices
8694

8795
def forward(self, latent):
8896
batch_size, num_channels, height, width = latent.size()
@@ -275,17 +283,17 @@ class AuraFlowTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, From
275283
sample_size (`int`): The width of the latent images. This is fixed during training since
276284
it is used to learn a number of position embeddings.
277285
patch_size (`int`): Patch size to turn the input data into small patches.
278-
in_channels (`int`, *optional*, defaults to 16): The number of channels in the input.
286+
in_channels (`int`, *optional*, defaults to 4): The number of channels in the input.
279287
num_mmdit_layers (`int`, *optional*, defaults to 4): The number of layers of MMDiT Transformer blocks to use.
280-
num_single_dit_layers (`int`, *optional*, defaults to 4):
288+
num_single_dit_layers (`int`, *optional*, defaults to 32):
281289
The number of layers of Transformer blocks to use. These blocks use concatenated image and text
282290
representations.
283-
attention_head_dim (`int`, *optional*, defaults to 64): The number of channels in each head.
284-
num_attention_heads (`int`, *optional*, defaults to 18): The number of heads to use for multi-head attention.
291+
attention_head_dim (`int`, *optional*, defaults to 256): The number of channels in each head.
292+
num_attention_heads (`int`, *optional*, defaults to 12): The number of heads to use for multi-head attention.
285293
joint_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use.
286294
caption_projection_dim (`int`): Number of dimensions to use when projecting the `encoder_hidden_states`.
287-
out_channels (`int`, defaults to 16): Number of output channels.
288-
pos_embed_max_size (`int`, defaults to 4096): Maximum positions to embed from the image latents.
295+
out_channels (`int`, defaults to 4): Number of output channels.
296+
pos_embed_max_size (`int`, defaults to 1024): Maximum positions to embed from the image latents.
289297
"""
290298

291299
_no_split_modules = ["AuraFlowJointTransformerBlock", "AuraFlowSingleTransformerBlock", "AuraFlowPatchEmbed"]

src/diffusers/pipelines/hidream_image/pipeline_hidream_image.py

Lines changed: 98 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
from ...image_processor import VaeImageProcessor
1616
from ...models import AutoencoderKL, HiDreamImageTransformer2DModel
1717
from ...schedulers import FlowMatchEulerDiscreteScheduler, UniPCMultistepScheduler
18-
from ...utils import is_torch_xla_available, logging
18+
from ...utils import is_torch_xla_available, logging, replace_example_docstring
1919
from ...utils.torch_utils import randn_tensor
2020
from ..pipeline_utils import DiffusionPipeline
2121
from .pipeline_output import HiDreamImagePipelineOutput
@@ -523,6 +523,7 @@ def interrupt(self):
523523
return self._interrupt
524524

525525
@torch.no_grad()
526+
@replace_example_docstring(EXAMPLE_DOC_STRING)
526527
def __call__(
527528
self,
528529
prompt: Union[str, List[str]] = None,
@@ -552,6 +553,102 @@ def __call__(
552553
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
553554
max_sequence_length: int = 128,
554555
):
556+
r"""
557+
Function invoked when calling the pipeline for generation.
558+
559+
Args:
560+
prompt (`str` or `List[str]`, *optional*):
561+
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
562+
instead.
563+
prompt_2 (`str` or `List[str]`, *optional*):
564+
The prompt or prompts to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
565+
will be used instead.
566+
prompt_3 (`str` or `List[str]`, *optional*):
567+
The prompt or prompts to be sent to `tokenizer_3` and `text_encoder_3`. If not defined, `prompt` is
568+
will be used instead.
569+
prompt_4 (`str` or `List[str]`, *optional*):
570+
The prompt or prompts to be sent to `tokenizer_4` and `text_encoder_4`. If not defined, `prompt` is
571+
will be used instead.
572+
height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
573+
The height in pixels of the generated image. This is set to 1024 by default for the best results.
574+
width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
575+
The width in pixels of the generated image. This is set to 1024 by default for the best results.
576+
num_inference_steps (`int`, *optional*, defaults to 50):
577+
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
578+
expense of slower inference.
579+
sigmas (`List[float]`, *optional*):
580+
Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in
581+
their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
582+
will be used.
583+
guidance_scale (`float`, *optional*, defaults to 3.5):
584+
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
585+
`guidance_scale` is defined as `w` of equation 2. of [Imagen
586+
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
587+
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
588+
usually at the expense of lower image quality.
589+
negative_prompt (`str` or `List[str]`, *optional*):
590+
The prompt or prompts not to guide the image generation. If not defined, one has to pass
591+
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `true_cfg_scale` is
592+
not greater than `1`).
593+
negative_prompt_2 (`str` or `List[str]`, *optional*):
594+
The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and
595+
`text_encoder_2`. If not defined, `negative_prompt` is used in all the text-encoders.
596+
negative_prompt_3 (`str` or `List[str]`, *optional*):
597+
The prompt or prompts not to guide the image generation to be sent to `tokenizer_3` and
598+
`text_encoder_3`. If not defined, `negative_prompt` is used in all the text-encoders.
599+
negative_prompt_4 (`str` or `List[str]`, *optional*):
600+
The prompt or prompts not to guide the image generation to be sent to `tokenizer_4` and
601+
`text_encoder_4`. If not defined, `negative_prompt` is used in all the text-encoders.
602+
num_images_per_prompt (`int`, *optional*, defaults to 1):
603+
The number of images to generate per prompt.
604+
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
605+
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
606+
to make generation deterministic.
607+
latents (`torch.FloatTensor`, *optional*):
608+
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
609+
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
610+
tensor will ge generated by sampling using the supplied random `generator`.
611+
prompt_embeds (`torch.FloatTensor`, *optional*):
612+
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
613+
provided, text embeddings will be generated from `prompt` input argument.
614+
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
615+
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
616+
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
617+
argument.
618+
pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
619+
Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
620+
If not provided, pooled text embeddings will be generated from `prompt` input argument.
621+
negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
622+
Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
623+
weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt`
624+
input argument.
625+
output_type (`str`, *optional*, defaults to `"pil"`):
626+
The output format of the generate image. Choose between
627+
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
628+
return_dict (`bool`, *optional*, defaults to `True`):
629+
Whether or not to return a [`~pipelines.flux.FluxPipelineOutput`] instead of a plain tuple.
630+
attention_kwargs (`dict`, *optional*):
631+
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
632+
`self.processor` in
633+
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
634+
callback_on_step_end (`Callable`, *optional*):
635+
A function that calls at the end of each denoising steps during the inference. The function is called
636+
with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
637+
callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
638+
`callback_on_step_end_tensor_inputs`.
639+
callback_on_step_end_tensor_inputs (`List`, *optional*):
640+
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
641+
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
642+
`._callback_tensor_inputs` attribute of your pipeline class.
643+
max_sequence_length (`int` defaults to 128): Maximum sequence length to use with the `prompt`.
644+
645+
Examples:
646+
647+
Returns:
648+
[`~pipelines.hidream_image.HiDreamImagePipelineOutput`] or `tuple`:
649+
[`~pipelines.hidream_image.HiDreamImagePipelineOutput`] if `return_dict` is True, otherwise a `tuple`. When
650+
returning a tuple, the first element is a list with the generated. images.
651+
"""
555652
height = height or self.default_sample_size * self.vae_scale_factor
556653
width = width or self.default_sample_size * self.vae_scale_factor
557654

0 commit comments

Comments
 (0)