diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml index 8f53ae5f3fc8..c4b023ca47b3 100644 --- a/docs/source/en/_toctree.yml +++ b/docs/source/en/_toctree.yml @@ -175,7 +175,7 @@ title: gguf - local: quantization/torchao title: torchao - - local: quantization/quanto + - local: quantization/quanto title: quanto title: Quantization Methods - sections: @@ -300,6 +300,8 @@ title: EasyAnimateTransformer3DModel - local: api/models/flux_transformer title: FluxTransformer2DModel + - local: api/models/hidream_image_transformer + title: HiDreamImageTransformer2DModel - local: api/models/hunyuan_transformer2d title: HunyuanDiT2DModel - local: api/models/hunyuan_video_transformer_3d @@ -446,6 +448,8 @@ title: Flux - local: api/pipelines/control_flux_inpaint title: FluxControlInpaint + - local: api/pipelines/hidream + title: HiDream-I1 - local: api/pipelines/hunyuandit title: Hunyuan-DiT - local: api/pipelines/hunyuan_video diff --git a/docs/source/en/api/models/hidream_image_transformer.md b/docs/source/en/api/models/hidream_image_transformer.md new file mode 100644 index 000000000000..4218e7f56bec --- /dev/null +++ b/docs/source/en/api/models/hidream_image_transformer.md @@ -0,0 +1,30 @@ + + +# HiDreamImageTransformer2DModel + +A Transformer model for image-like data from [HiDream-I1](https://huggingface.co/HiDream-ai). + +The model can be loaded with the following code snippet. + +```python +from diffusers import HiDreamImageTransformer2DModel + +transformer = HiDreamImageTransformer2DModel.from_pretrained("HiDream-ai/HiDream-I1-Full", subfolder="transformer", torch_dtype=torch.bfloat16) +``` + +## HiDreamImageTransformer2DModel + +[[autodoc]] HiDreamImageTransformer2DModel + +## Transformer2DModelOutput + +[[autodoc]] models.modeling_outputs.Transformer2DModelOutput diff --git a/docs/source/en/api/pipelines/hidream.md b/docs/source/en/api/pipelines/hidream.md new file mode 100644 index 000000000000..f728d3d90f4c --- /dev/null +++ b/docs/source/en/api/pipelines/hidream.md @@ -0,0 +1,43 @@ + + +# HiDreamImage + +[HiDream-I1](https://huggingface.co/HiDream-ai) by HiDream.ai + + + +Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading#reuse-a-pipeline) section to learn how to efficiently load the same components into multiple pipelines. + + + +## Available models + +The following models are available for the [`HiDreamImagePipeline`](text-to-image) pipeline: + +| Model name | Description | +|:---|:---| +| [`HiDream-ai/HiDream-I1-Full`](https://huggingface.co/HiDream-ai/HiDream-I1-Full) | - | +| [`HiDream-ai/HiDream-I1-Dev`](https://huggingface.co/HiDream-ai/HiDream-I1-Dev) | - | +| [`HiDream-ai/HiDream-I1-Fast`](https://huggingface.co/HiDream-ai/HiDream-I1-Fast) | - | + +## HiDreamImagePipeline + +[[autodoc]] HiDreamImagePipeline + - all + - __call__ + +## HiDreamImagePipelineOutput + +[[autodoc]] pipelines.hidream_image.pipeline_output.HiDreamImagePipelineOutput diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index a0245e2fe3ee..6c3bb7d52e82 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -171,6 +171,7 @@ "FluxControlNetModel", "FluxMultiControlNetModel", "FluxTransformer2DModel", + "HiDreamImageTransformer2DModel", "HunyuanDiT2DControlNetModel", "HunyuanDiT2DModel", "HunyuanDiT2DMultiControlNetModel", @@ -368,6 +369,7 @@ "FluxInpaintPipeline", "FluxPipeline", "FluxPriorReduxPipeline", + "HiDreamImagePipeline", "HunyuanDiTControlNetPipeline", "HunyuanDiTPAGPipeline", "HunyuanDiTPipeline", @@ -761,6 +763,7 @@ FluxControlNetModel, FluxMultiControlNetModel, FluxTransformer2DModel, + HiDreamImageTransformer2DModel, HunyuanDiT2DControlNetModel, HunyuanDiT2DModel, HunyuanDiT2DMultiControlNetModel, @@ -937,6 +940,7 @@ FluxInpaintPipeline, FluxPipeline, FluxPriorReduxPipeline, + HiDreamImagePipeline, HunyuanDiTControlNetPipeline, HunyuanDiTPAGPipeline, HunyuanDiTPipeline, diff --git a/src/diffusers/models/__init__.py b/src/diffusers/models/__init__.py index 218394af2843..3213a50057bf 100755 --- a/src/diffusers/models/__init__.py +++ b/src/diffusers/models/__init__.py @@ -76,6 +76,7 @@ _import_structure["transformers.transformer_cogview4"] = ["CogView4Transformer2DModel"] _import_structure["transformers.transformer_easyanimate"] = ["EasyAnimateTransformer3DModel"] _import_structure["transformers.transformer_flux"] = ["FluxTransformer2DModel"] + _import_structure["transformers.transformer_hidream_image"] = ["HiDreamImageTransformer2DModel"] _import_structure["transformers.transformer_hunyuan_video"] = ["HunyuanVideoTransformer3DModel"] _import_structure["transformers.transformer_ltx"] = ["LTXVideoTransformer3DModel"] _import_structure["transformers.transformer_lumina2"] = ["Lumina2Transformer2DModel"] @@ -151,6 +152,7 @@ DualTransformer2DModel, EasyAnimateTransformer3DModel, FluxTransformer2DModel, + HiDreamImageTransformer2DModel, HunyuanDiT2DModel, HunyuanVideoTransformer3DModel, LatteTransformer3DModel, diff --git a/src/diffusers/models/transformers/__init__.py b/src/diffusers/models/transformers/__init__.py index 5392935da02b..191484fd9692 100755 --- a/src/diffusers/models/transformers/__init__.py +++ b/src/diffusers/models/transformers/__init__.py @@ -21,6 +21,7 @@ from .transformer_cogview4 import CogView4Transformer2DModel from .transformer_easyanimate import EasyAnimateTransformer3DModel from .transformer_flux import FluxTransformer2DModel + from .transformer_hidream_image import HiDreamImageTransformer2DModel from .transformer_hunyuan_video import HunyuanVideoTransformer3DModel from .transformer_ltx import LTXVideoTransformer3DModel from .transformer_lumina2 import Lumina2Transformer2DModel diff --git a/src/diffusers/models/transformers/transformer_hidream_image.py b/src/diffusers/models/transformers/transformer_hidream_image.py new file mode 100644 index 000000000000..2bdf7d152268 --- /dev/null +++ b/src/diffusers/models/transformers/transformer_hidream_image.py @@ -0,0 +1,896 @@ +from typing import Any, Dict, List, Optional, Tuple, Union + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from ...configuration_utils import ConfigMixin, register_to_config +from ...loaders import PeftAdapterMixin +from ...models.modeling_outputs import Transformer2DModelOutput +from ...models.modeling_utils import ModelMixin +from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers +from ...utils.torch_utils import maybe_allow_in_graph +from ..attention import Attention +from ..embeddings import TimestepEmbedding, Timesteps + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +class HiDreamImageFeedForwardSwiGLU(nn.Module): + def __init__( + self, + dim: int, + hidden_dim: int, + multiple_of: int = 256, + ffn_dim_multiplier: Optional[float] = None, + ): + super().__init__() + hidden_dim = int(2 * hidden_dim / 3) + # custom dim factor multiplier + if ffn_dim_multiplier is not None: + hidden_dim = int(ffn_dim_multiplier * hidden_dim) + hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of) + + self.w1 = nn.Linear(dim, hidden_dim, bias=False) + self.w2 = nn.Linear(hidden_dim, dim, bias=False) + self.w3 = nn.Linear(dim, hidden_dim, bias=False) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.w2(torch.nn.functional.silu(self.w1(x)) * self.w3(x)) + + +class HiDreamImagePooledEmbed(nn.Module): + def __init__(self, text_emb_dim, hidden_size): + super().__init__() + self.pooled_embedder = TimestepEmbedding(in_channels=text_emb_dim, time_embed_dim=hidden_size) + + def forward(self, pooled_embed: torch.Tensor) -> torch.Tensor: + return self.pooled_embedder(pooled_embed) + + +class HiDreamImageTimestepEmbed(nn.Module): + def __init__(self, hidden_size, frequency_embedding_size=256): + super().__init__() + self.time_proj = Timesteps(num_channels=frequency_embedding_size, flip_sin_to_cos=True, downscale_freq_shift=0) + self.timestep_embedder = TimestepEmbedding(in_channels=frequency_embedding_size, time_embed_dim=hidden_size) + + def forward(self, timesteps: torch.Tensor, wdtype: Optional[torch.dtype] = None): + t_emb = self.time_proj(timesteps).to(dtype=wdtype) + t_emb = self.timestep_embedder(t_emb) + return t_emb + + +class HiDreamImageOutEmbed(nn.Module): + def __init__(self, hidden_size, patch_size, out_channels): + super().__init__() + self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + self.linear = nn.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True) + self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(hidden_size, 2 * hidden_size, bias=True)) + + def forward(self, hidden_states: torch.Tensor, temb: torch.Tensor) -> torch.Tensor: + shift, scale = self.adaLN_modulation(temb).chunk(2, dim=1) + hidden_states = self.norm_final(hidden_states) * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1) + hidden_states = self.linear(hidden_states) + return hidden_states + + +class HiDreamImagePatchEmbed(nn.Module): + def __init__( + self, + patch_size=2, + in_channels=4, + out_channels=1024, + ): + super().__init__() + self.patch_size = patch_size + self.out_channels = out_channels + self.proj = nn.Linear(in_channels * patch_size * patch_size, out_channels, bias=True) + + def forward(self, latent): + latent = self.proj(latent) + return latent + + +def rope(pos: torch.Tensor, dim: int, theta: int) -> torch.Tensor: + assert dim % 2 == 0, "The dimension must be even." + + scale = torch.arange(0, dim, 2, dtype=torch.float64, device=pos.device) / dim + omega = 1.0 / (theta**scale) + + batch_size, seq_length = pos.shape + out = torch.einsum("...n,d->...nd", pos, omega) + cos_out = torch.cos(out) + sin_out = torch.sin(out) + + stacked_out = torch.stack([cos_out, -sin_out, sin_out, cos_out], dim=-1) + out = stacked_out.view(batch_size, -1, dim // 2, 2, 2) + return out.float() + + +class HiDreamImageEmbedND(nn.Module): + def __init__(self, theta: int, axes_dim: List[int]): + super().__init__() + self.theta = theta + self.axes_dim = axes_dim + + def forward(self, ids: torch.Tensor) -> torch.Tensor: + n_axes = ids.shape[-1] + emb = torch.cat( + [rope(ids[..., i], self.axes_dim[i], self.theta) for i in range(n_axes)], + dim=-3, + ) + return emb.unsqueeze(2) + + +def apply_rope(xq: torch.Tensor, xk: torch.Tensor, freqs_cis: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + xq_ = xq.float().reshape(*xq.shape[:-1], -1, 1, 2) + xk_ = xk.float().reshape(*xk.shape[:-1], -1, 1, 2) + xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1] + xk_out = freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1] + return xq_out.reshape(*xq.shape).type_as(xq), xk_out.reshape(*xk.shape).type_as(xk) + + +@maybe_allow_in_graph +class HiDreamAttention(Attention): + def __init__( + self, + query_dim: int, + heads: int = 8, + dim_head: int = 64, + upcast_attention: bool = False, + upcast_softmax: bool = False, + scale_qk: bool = True, + eps: float = 1e-5, + processor=None, + out_dim: int = None, + single: bool = False, + ): + super(Attention, self).__init__() + self.inner_dim = out_dim if out_dim is not None else dim_head * heads + self.query_dim = query_dim + self.upcast_attention = upcast_attention + self.upcast_softmax = upcast_softmax + self.out_dim = out_dim if out_dim is not None else query_dim + + self.scale_qk = scale_qk + self.scale = dim_head**-0.5 if self.scale_qk else 1.0 + + self.heads = out_dim // dim_head if out_dim is not None else heads + self.sliceable_head_dim = heads + self.single = single + + self.to_q = nn.Linear(query_dim, self.inner_dim) + self.to_k = nn.Linear(self.inner_dim, self.inner_dim) + self.to_v = nn.Linear(self.inner_dim, self.inner_dim) + self.to_out = nn.Linear(self.inner_dim, self.out_dim) + self.q_rms_norm = nn.RMSNorm(self.inner_dim, eps) + self.k_rms_norm = nn.RMSNorm(self.inner_dim, eps) + + if not single: + self.to_q_t = nn.Linear(query_dim, self.inner_dim) + self.to_k_t = nn.Linear(self.inner_dim, self.inner_dim) + self.to_v_t = nn.Linear(self.inner_dim, self.inner_dim) + self.to_out_t = nn.Linear(self.inner_dim, self.out_dim) + self.q_rms_norm_t = nn.RMSNorm(self.inner_dim, eps) + self.k_rms_norm_t = nn.RMSNorm(self.inner_dim, eps) + + self.set_processor(processor) + + def forward( + self, + norm_hidden_states: torch.Tensor, + hidden_states_masks: torch.Tensor = None, + norm_encoder_hidden_states: torch.Tensor = None, + image_rotary_emb: torch.Tensor = None, + ) -> torch.Tensor: + return self.processor( + self, + hidden_states=norm_hidden_states, + hidden_states_masks=hidden_states_masks, + encoder_hidden_states=norm_encoder_hidden_states, + image_rotary_emb=image_rotary_emb, + ) + + +class HiDreamAttnProcessor: + """Attention processor used typically in processing the SD3-like self-attention projections.""" + + def __call__( + self, + attn: HiDreamAttention, + hidden_states: torch.Tensor, + hidden_states_masks: Optional[torch.Tensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + image_rotary_emb: torch.Tensor = None, + *args, + **kwargs, + ) -> torch.Tensor: + dtype = hidden_states.dtype + batch_size = hidden_states.shape[0] + + query_i = attn.q_rms_norm(attn.to_q(hidden_states)).to(dtype=dtype) + key_i = attn.k_rms_norm(attn.to_k(hidden_states)).to(dtype=dtype) + value_i = attn.to_v(hidden_states) + + inner_dim = key_i.shape[-1] + head_dim = inner_dim // attn.heads + + query_i = query_i.view(batch_size, -1, attn.heads, head_dim) + key_i = key_i.view(batch_size, -1, attn.heads, head_dim) + value_i = value_i.view(batch_size, -1, attn.heads, head_dim) + if hidden_states_masks is not None: + key_i = key_i * hidden_states_masks.view(batch_size, -1, 1, 1) + + if not attn.single: + query_t = attn.q_rms_norm_t(attn.to_q_t(encoder_hidden_states)).to(dtype=dtype) + key_t = attn.k_rms_norm_t(attn.to_k_t(encoder_hidden_states)).to(dtype=dtype) + value_t = attn.to_v_t(encoder_hidden_states) + + query_t = query_t.view(batch_size, -1, attn.heads, head_dim) + key_t = key_t.view(batch_size, -1, attn.heads, head_dim) + value_t = value_t.view(batch_size, -1, attn.heads, head_dim) + + num_image_tokens = query_i.shape[1] + num_text_tokens = query_t.shape[1] + query = torch.cat([query_i, query_t], dim=1) + key = torch.cat([key_i, key_t], dim=1) + value = torch.cat([value_i, value_t], dim=1) + else: + query = query_i + key = key_i + value = value_i + + if query.shape[-1] == image_rotary_emb.shape[-3] * 2: + query, key = apply_rope(query, key, image_rotary_emb) + + else: + query_1, query_2 = query.chunk(2, dim=-1) + key_1, key_2 = key.chunk(2, dim=-1) + query_1, key_1 = apply_rope(query_1, key_1, image_rotary_emb) + query = torch.cat([query_1, query_2], dim=-1) + key = torch.cat([key_1, key_2], dim=-1) + + hidden_states = F.scaled_dot_product_attention( + query.transpose(1, 2), key.transpose(1, 2), value.transpose(1, 2), dropout_p=0.0, is_causal=False + ) + + hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) + hidden_states = hidden_states.to(query.dtype) + + if not attn.single: + hidden_states_i, hidden_states_t = torch.split(hidden_states, [num_image_tokens, num_text_tokens], dim=1) + hidden_states_i = attn.to_out(hidden_states_i) + hidden_states_t = attn.to_out_t(hidden_states_t) + return hidden_states_i, hidden_states_t + else: + hidden_states = attn.to_out(hidden_states) + return hidden_states + + +# Modified from https://github.com/deepseek-ai/DeepSeek-V3/blob/main/inference/model.py +class MoEGate(nn.Module): + def __init__(self, embed_dim, num_routed_experts=4, num_activated_experts=2, aux_loss_alpha=0.01): + super().__init__() + self.top_k = num_activated_experts + self.n_routed_experts = num_routed_experts + + self.scoring_func = "softmax" + self.alpha = aux_loss_alpha + self.seq_aux = False + + # topk selection algorithm + self.norm_topk_prob = False + self.gating_dim = embed_dim + self.weight = nn.Parameter(torch.randn(self.n_routed_experts, self.gating_dim) / embed_dim**0.5) + + def forward(self, hidden_states): + bsz, seq_len, h = hidden_states.shape + # print(bsz, seq_len, h) + ### compute gating score + hidden_states = hidden_states.view(-1, h) + logits = F.linear(hidden_states, self.weight, None) + if self.scoring_func == "softmax": + scores = logits.softmax(dim=-1) + else: + raise NotImplementedError(f"insupportable scoring function for MoE gating: {self.scoring_func}") + + ### select top-k experts + topk_weight, topk_idx = torch.topk(scores, k=self.top_k, dim=-1, sorted=False) + + ### norm gate to sum 1 + if self.top_k > 1 and self.norm_topk_prob: + denominator = topk_weight.sum(dim=-1, keepdim=True) + 1e-20 + topk_weight = topk_weight / denominator + + ### expert-level computation auxiliary loss + if self.training and self.alpha > 0.0: + scores_for_aux = scores + aux_topk = self.top_k + # always compute aux loss based on the naive greedy topk method + topk_idx_for_aux_loss = topk_idx.view(bsz, -1) + if self.seq_aux: + scores_for_seq_aux = scores_for_aux.view(bsz, seq_len, -1) + ce = torch.zeros(bsz, self.n_routed_experts, device=hidden_states.device) + ce.scatter_add_( + 1, topk_idx_for_aux_loss, torch.ones(bsz, seq_len * aux_topk, device=hidden_states.device) + ).div_(seq_len * aux_topk / self.n_routed_experts) + aux_loss = (ce * scores_for_seq_aux.mean(dim=1)).sum(dim=1).mean() * self.alpha + else: + mask_ce = F.one_hot(topk_idx_for_aux_loss.view(-1), num_classes=self.n_routed_experts) + ce = mask_ce.float().mean(0) + + Pi = scores_for_aux.mean(0) + fi = ce * self.n_routed_experts + aux_loss = (Pi * fi).sum() * self.alpha + else: + aux_loss = None + return topk_idx, topk_weight, aux_loss + + +# Modified from https://github.com/deepseek-ai/DeepSeek-V3/blob/main/inference/model.py +class MOEFeedForwardSwiGLU(nn.Module): + def __init__( + self, + dim: int, + hidden_dim: int, + num_routed_experts: int, + num_activated_experts: int, + ): + super().__init__() + self.shared_experts = HiDreamImageFeedForwardSwiGLU(dim, hidden_dim // 2) + self.experts = nn.ModuleList( + [HiDreamImageFeedForwardSwiGLU(dim, hidden_dim) for i in range(num_routed_experts)] + ) + self.gate = MoEGate( + embed_dim=dim, num_routed_experts=num_routed_experts, num_activated_experts=num_activated_experts + ) + self.num_activated_experts = num_activated_experts + + def forward(self, x): + wtype = x.dtype + identity = x + orig_shape = x.shape + topk_idx, topk_weight, aux_loss = self.gate(x) + x = x.view(-1, x.shape[-1]) + flat_topk_idx = topk_idx.view(-1) + if self.training: + x = x.repeat_interleave(self.num_activated_experts, dim=0) + y = torch.empty_like(x, dtype=wtype) + for i, expert in enumerate(self.experts): + y[flat_topk_idx == i] = expert(x[flat_topk_idx == i]).to(dtype=wtype) + y = (y.view(*topk_weight.shape, -1) * topk_weight.unsqueeze(-1)).sum(dim=1) + y = y.view(*orig_shape).to(dtype=wtype) + # y = AddAuxiliaryLoss.apply(y, aux_loss) + else: + y = self.moe_infer(x, flat_topk_idx, topk_weight.view(-1, 1)).view(*orig_shape) + y = y + self.shared_experts(identity) + return y + + @torch.no_grad() + def moe_infer(self, x, flat_expert_indices, flat_expert_weights): + expert_cache = torch.zeros_like(x) + idxs = flat_expert_indices.argsort() + tokens_per_expert = flat_expert_indices.bincount().cpu().numpy().cumsum(0) + token_idxs = idxs // self.num_activated_experts + for i, end_idx in enumerate(tokens_per_expert): + start_idx = 0 if i == 0 else tokens_per_expert[i - 1] + if start_idx == end_idx: + continue + expert = self.experts[i] + exp_token_idx = token_idxs[start_idx:end_idx] + expert_tokens = x[exp_token_idx] + expert_out = expert(expert_tokens) + expert_out.mul_(flat_expert_weights[idxs[start_idx:end_idx]]) + + # for fp16 and other dtype + expert_cache = expert_cache.to(expert_out.dtype) + expert_cache.scatter_reduce_(0, exp_token_idx.view(-1, 1).repeat(1, x.shape[-1]), expert_out, reduce="sum") + return expert_cache + + +class TextProjection(nn.Module): + def __init__(self, in_features, hidden_size): + super().__init__() + self.linear = nn.Linear(in_features=in_features, out_features=hidden_size, bias=False) + + def forward(self, caption): + hidden_states = self.linear(caption) + return hidden_states + + +@maybe_allow_in_graph +class HiDreamImageSingleTransformerBlock(nn.Module): + def __init__( + self, + dim: int, + num_attention_heads: int, + attention_head_dim: int, + num_routed_experts: int = 4, + num_activated_experts: int = 2, + ): + super().__init__() + self.num_attention_heads = num_attention_heads + self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(dim, 6 * dim, bias=True)) + + # 1. Attention + self.norm1_i = nn.LayerNorm(dim, eps=1e-06, elementwise_affine=False) + self.attn1 = HiDreamAttention( + query_dim=dim, + heads=num_attention_heads, + dim_head=attention_head_dim, + processor=HiDreamAttnProcessor(), + single=True, + ) + + # 3. Feed-forward + self.norm3_i = nn.LayerNorm(dim, eps=1e-06, elementwise_affine=False) + if num_routed_experts > 0: + self.ff_i = MOEFeedForwardSwiGLU( + dim=dim, + hidden_dim=4 * dim, + num_routed_experts=num_routed_experts, + num_activated_experts=num_activated_experts, + ) + else: + self.ff_i = HiDreamImageFeedForwardSwiGLU(dim=dim, hidden_dim=4 * dim) + + def forward( + self, + hidden_states: torch.Tensor, + hidden_states_masks: Optional[torch.Tensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + temb: Optional[torch.Tensor] = None, + image_rotary_emb: torch.Tensor = None, + ) -> torch.Tensor: + wtype = hidden_states.dtype + shift_msa_i, scale_msa_i, gate_msa_i, shift_mlp_i, scale_mlp_i, gate_mlp_i = self.adaLN_modulation(temb)[ + :, None + ].chunk(6, dim=-1) + + # 1. MM-Attention + norm_hidden_states = self.norm1_i(hidden_states).to(dtype=wtype) + norm_hidden_states = norm_hidden_states * (1 + scale_msa_i) + shift_msa_i + attn_output_i = self.attn1( + norm_hidden_states, + hidden_states_masks, + image_rotary_emb=image_rotary_emb, + ) + hidden_states = gate_msa_i * attn_output_i + hidden_states + + # 2. Feed-forward + norm_hidden_states = self.norm3_i(hidden_states).to(dtype=wtype) + norm_hidden_states = norm_hidden_states * (1 + scale_mlp_i) + shift_mlp_i + ff_output_i = gate_mlp_i * self.ff_i(norm_hidden_states.to(dtype=wtype)) + hidden_states = ff_output_i + hidden_states + return hidden_states + + +@maybe_allow_in_graph +class HiDreamImageTransformerBlock(nn.Module): + def __init__( + self, + dim: int, + num_attention_heads: int, + attention_head_dim: int, + num_routed_experts: int = 4, + num_activated_experts: int = 2, + ): + super().__init__() + self.num_attention_heads = num_attention_heads + self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(dim, 12 * dim, bias=True)) + + # 1. Attention + self.norm1_i = nn.LayerNorm(dim, eps=1e-06, elementwise_affine=False) + self.norm1_t = nn.LayerNorm(dim, eps=1e-06, elementwise_affine=False) + self.attn1 = HiDreamAttention( + query_dim=dim, + heads=num_attention_heads, + dim_head=attention_head_dim, + processor=HiDreamAttnProcessor(), + single=False, + ) + + # 3. Feed-forward + self.norm3_i = nn.LayerNorm(dim, eps=1e-06, elementwise_affine=False) + if num_routed_experts > 0: + self.ff_i = MOEFeedForwardSwiGLU( + dim=dim, + hidden_dim=4 * dim, + num_routed_experts=num_routed_experts, + num_activated_experts=num_activated_experts, + ) + else: + self.ff_i = HiDreamImageFeedForwardSwiGLU(dim=dim, hidden_dim=4 * dim) + self.norm3_t = nn.LayerNorm(dim, eps=1e-06, elementwise_affine=False) + self.ff_t = HiDreamImageFeedForwardSwiGLU(dim=dim, hidden_dim=4 * dim) + + def forward( + self, + hidden_states: torch.Tensor, + hidden_states_masks: Optional[torch.Tensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + temb: Optional[torch.Tensor] = None, + image_rotary_emb: torch.Tensor = None, + ) -> torch.Tensor: + wtype = hidden_states.dtype + ( + shift_msa_i, + scale_msa_i, + gate_msa_i, + shift_mlp_i, + scale_mlp_i, + gate_mlp_i, + shift_msa_t, + scale_msa_t, + gate_msa_t, + shift_mlp_t, + scale_mlp_t, + gate_mlp_t, + ) = self.adaLN_modulation(temb)[:, None].chunk(12, dim=-1) + + # 1. MM-Attention + norm_hidden_states = self.norm1_i(hidden_states).to(dtype=wtype) + norm_hidden_states = norm_hidden_states * (1 + scale_msa_i) + shift_msa_i + norm_encoder_hidden_states = self.norm1_t(encoder_hidden_states).to(dtype=wtype) + norm_encoder_hidden_states = norm_encoder_hidden_states * (1 + scale_msa_t) + shift_msa_t + + attn_output_i, attn_output_t = self.attn1( + norm_hidden_states, + hidden_states_masks, + norm_encoder_hidden_states, + image_rotary_emb=image_rotary_emb, + ) + + hidden_states = gate_msa_i * attn_output_i + hidden_states + encoder_hidden_states = gate_msa_t * attn_output_t + encoder_hidden_states + + # 2. Feed-forward + norm_hidden_states = self.norm3_i(hidden_states).to(dtype=wtype) + norm_hidden_states = norm_hidden_states * (1 + scale_mlp_i) + shift_mlp_i + norm_encoder_hidden_states = self.norm3_t(encoder_hidden_states).to(dtype=wtype) + norm_encoder_hidden_states = norm_encoder_hidden_states * (1 + scale_mlp_t) + shift_mlp_t + + ff_output_i = gate_mlp_i * self.ff_i(norm_hidden_states) + ff_output_t = gate_mlp_t * self.ff_t(norm_encoder_hidden_states) + hidden_states = ff_output_i + hidden_states + encoder_hidden_states = ff_output_t + encoder_hidden_states + return hidden_states, encoder_hidden_states + + +class HiDreamBlock(nn.Module): + def __init__(self, block: Union[HiDreamImageTransformerBlock, HiDreamImageSingleTransformerBlock]): + super().__init__() + self.block = block + + def forward( + self, + hidden_states: torch.Tensor, + hidden_states_masks: Optional[torch.Tensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + temb: Optional[torch.Tensor] = None, + image_rotary_emb: torch.Tensor = None, + ) -> torch.Tensor: + return self.block( + hidden_states=hidden_states, + hidden_states_masks=hidden_states_masks, + encoder_hidden_states=encoder_hidden_states, + temb=temb, + image_rotary_emb=image_rotary_emb, + ) + + +class HiDreamImageTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin): + _supports_gradient_checkpointing = True + _no_split_modules = ["HiDreamImageTransformerBlock", "HiDreamImageSingleTransformerBlock"] + + @register_to_config + def __init__( + self, + patch_size: Optional[int] = None, + in_channels: int = 64, + out_channels: Optional[int] = None, + num_layers: int = 16, + num_single_layers: int = 32, + attention_head_dim: int = 128, + num_attention_heads: int = 20, + caption_channels: List[int] = None, + text_emb_dim: int = 2048, + num_routed_experts: int = 4, + num_activated_experts: int = 2, + axes_dims_rope: Tuple[int, int] = (32, 32), + max_resolution: Tuple[int, int] = (128, 128), + llama_layers: List[int] = None, + ): + super().__init__() + self.out_channels = out_channels or in_channels + self.inner_dim = self.config.num_attention_heads * self.config.attention_head_dim + self.llama_layers = llama_layers + + self.t_embedder = HiDreamImageTimestepEmbed(self.inner_dim) + self.p_embedder = HiDreamImagePooledEmbed(text_emb_dim, self.inner_dim) + self.x_embedder = HiDreamImagePatchEmbed( + patch_size=patch_size, + in_channels=in_channels, + out_channels=self.inner_dim, + ) + self.pe_embedder = HiDreamImageEmbedND(theta=10000, axes_dim=axes_dims_rope) + + self.double_stream_blocks = nn.ModuleList( + [ + HiDreamBlock( + HiDreamImageTransformerBlock( + dim=self.inner_dim, + num_attention_heads=self.config.num_attention_heads, + attention_head_dim=self.config.attention_head_dim, + num_routed_experts=num_routed_experts, + num_activated_experts=num_activated_experts, + ) + ) + for _ in range(self.config.num_layers) + ] + ) + + self.single_stream_blocks = nn.ModuleList( + [ + HiDreamBlock( + HiDreamImageSingleTransformerBlock( + dim=self.inner_dim, + num_attention_heads=self.config.num_attention_heads, + attention_head_dim=self.config.attention_head_dim, + num_routed_experts=num_routed_experts, + num_activated_experts=num_activated_experts, + ) + ) + for _ in range(self.config.num_single_layers) + ] + ) + + self.final_layer = HiDreamImageOutEmbed(self.inner_dim, patch_size, self.out_channels) + + caption_channels = [ + caption_channels[1], + ] * (num_layers + num_single_layers) + [ + caption_channels[0], + ] + caption_projection = [] + for caption_channel in caption_channels: + caption_projection.append(TextProjection(in_features=caption_channel, hidden_size=self.inner_dim)) + self.caption_projection = nn.ModuleList(caption_projection) + self.max_seq = max_resolution[0] * max_resolution[1] // (patch_size * patch_size) + + def expand_timesteps(self, timesteps, batch_size, device): + if not torch.is_tensor(timesteps): + is_mps = device.type == "mps" + if isinstance(timesteps, float): + dtype = torch.float32 if is_mps else torch.float64 + else: + dtype = torch.int32 if is_mps else torch.int64 + timesteps = torch.tensor([timesteps], dtype=dtype, device=device) + elif len(timesteps.shape) == 0: + timesteps = timesteps[None].to(device) + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timesteps = timesteps.expand(batch_size) + return timesteps + + def unpatchify(self, x: torch.Tensor, img_sizes: List[Tuple[int, int]], is_training: bool) -> List[torch.Tensor]: + if is_training: + B, S, F = x.shape + C = F // (self.config.patch_size * self.config.patch_size) + x = ( + x.reshape(B, S, self.config.patch_size, self.config.patch_size, C) + .permute(0, 4, 1, 2, 3) + .reshape(B, C, S, self.config.patch_size * self.config.patch_size) + ) + else: + x_arr = [] + p1 = self.config.patch_size + p2 = self.config.patch_size + for i, img_size in enumerate(img_sizes): + pH, pW = img_size + t = x[i, : pH * pW].reshape(1, pH, pW, -1) + F_token = t.shape[-1] + C = F_token // (p1 * p2) + t = t.reshape(1, pH, pW, p1, p2, C) + t = t.permute(0, 5, 1, 3, 2, 4) + t = t.reshape(1, C, pH * p1, pW * p2) + x_arr.append(t) + x = torch.cat(x_arr, dim=0) + return x + + def patchify(self, x, max_seq, img_sizes=None): + pz2 = self.config.patch_size * self.config.patch_size + if isinstance(x, torch.Tensor): + B, C = x.shape[0], x.shape[1] + device = x.device + dtype = x.dtype + else: + B, C = len(x), x[0].shape[0] + device = x[0].device + dtype = x[0].dtype + x_masks = torch.zeros((B, max_seq), dtype=dtype, device=device) + + if img_sizes is not None: + for i, img_size in enumerate(img_sizes): + x_masks[i, 0 : img_size[0] * img_size[1]] = 1 + B, C, S, _ = x.shape + x = x.permute(0, 2, 3, 1).reshape(B, S, pz2 * C) + elif isinstance(x, torch.Tensor): + B, C, Hp1, Wp2 = x.shape + pH, pW = Hp1 // self.config.patch_size, Wp2 // self.config.patch_size + x = x.reshape(B, C, pH, self.config.patch_size, pW, self.config.patch_size) + x = x.permute(0, 2, 4, 3, 5, 1) + x = x.reshape(B, pH * pW, self.config.patch_size * self.config.patch_size * C) + img_sizes = [[pH, pW]] * B + x_masks = None + else: + raise NotImplementedError + return x, x_masks, img_sizes + + def forward( + self, + hidden_states: torch.Tensor, + timesteps: torch.LongTensor = None, + encoder_hidden_states: torch.Tensor = None, + pooled_embeds: torch.Tensor = None, + img_sizes: Optional[List[Tuple[int, int]]] = None, + img_ids: Optional[torch.Tensor] = None, + attention_kwargs: Optional[Dict[str, Any]] = None, + return_dict: bool = True, + ): + if attention_kwargs is not None: + attention_kwargs = attention_kwargs.copy() + lora_scale = attention_kwargs.pop("scale", 1.0) + else: + lora_scale = 1.0 + + if USE_PEFT_BACKEND: + # weight the lora layers by setting `lora_scale` for each PEFT layer + scale_lora_layers(self, lora_scale) + else: + if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None: + logger.warning( + "Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective." + ) + + # spatial forward + batch_size = hidden_states.shape[0] + hidden_states_type = hidden_states.dtype + + if hidden_states.shape[-2] != hidden_states.shape[-1]: + B, C, H, W = hidden_states.shape + patch_size = self.config.patch_size + pH, pW = H // patch_size, W // patch_size + out = torch.zeros( + (B, C, self.max_seq, patch_size * patch_size), + dtype=hidden_states.dtype, + device=hidden_states.device, + ) + hidden_states = hidden_states.reshape(B, C, pH, patch_size, pW, patch_size) + hidden_states = hidden_states.permute(0, 1, 2, 4, 3, 5) + hidden_states = hidden_states.reshape(B, C, pH * pW, patch_size * patch_size) + out[:, :, 0 : pH * pW] = hidden_states + hidden_states = out + + # 0. time + timesteps = self.expand_timesteps(timesteps, batch_size, hidden_states.device) + timesteps = self.t_embedder(timesteps, hidden_states_type) + p_embedder = self.p_embedder(pooled_embeds) + temb = timesteps + p_embedder + + hidden_states, hidden_states_masks, img_sizes = self.patchify(hidden_states, self.max_seq, img_sizes) + if hidden_states_masks is None: + pH, pW = img_sizes[0] + img_ids = torch.zeros(pH, pW, 3, device=hidden_states.device) + img_ids[..., 1] = img_ids[..., 1] + torch.arange(pH, device=hidden_states.device)[:, None] + img_ids[..., 2] = img_ids[..., 2] + torch.arange(pW, device=hidden_states.device)[None, :] + img_ids = ( + img_ids.reshape(img_ids.shape[0] * img_ids.shape[1], img_ids.shape[2]) + .unsqueeze(0) + .repeat(batch_size, 1, 1) + ) + hidden_states = self.x_embedder(hidden_states) + + T5_encoder_hidden_states = encoder_hidden_states[0] + encoder_hidden_states = encoder_hidden_states[-1] + encoder_hidden_states = [encoder_hidden_states[k] for k in self.llama_layers] + + if self.caption_projection is not None: + new_encoder_hidden_states = [] + for i, enc_hidden_state in enumerate(encoder_hidden_states): + enc_hidden_state = self.caption_projection[i](enc_hidden_state) + enc_hidden_state = enc_hidden_state.view(batch_size, -1, hidden_states.shape[-1]) + new_encoder_hidden_states.append(enc_hidden_state) + encoder_hidden_states = new_encoder_hidden_states + T5_encoder_hidden_states = self.caption_projection[-1](T5_encoder_hidden_states) + T5_encoder_hidden_states = T5_encoder_hidden_states.view(batch_size, -1, hidden_states.shape[-1]) + encoder_hidden_states.append(T5_encoder_hidden_states) + + txt_ids = torch.zeros( + batch_size, + encoder_hidden_states[-1].shape[1] + + encoder_hidden_states[-2].shape[1] + + encoder_hidden_states[0].shape[1], + 3, + device=img_ids.device, + dtype=img_ids.dtype, + ) + ids = torch.cat((img_ids, txt_ids), dim=1) + image_rotary_emb = self.pe_embedder(ids) + + # 2. Blocks + block_id = 0 + initial_encoder_hidden_states = torch.cat([encoder_hidden_states[-1], encoder_hidden_states[-2]], dim=1) + initial_encoder_hidden_states_seq_len = initial_encoder_hidden_states.shape[1] + for bid, block in enumerate(self.double_stream_blocks): + cur_llama31_encoder_hidden_states = encoder_hidden_states[block_id] + cur_encoder_hidden_states = torch.cat( + [initial_encoder_hidden_states, cur_llama31_encoder_hidden_states], dim=1 + ) + if torch.is_grad_enabled() and self.gradient_checkpointing: + hidden_states, initial_encoder_hidden_states = self._gradient_checkpointing_func( + block, + hidden_states, + hidden_states_masks, + cur_encoder_hidden_states, + temb, + image_rotary_emb, + ) + else: + hidden_states, initial_encoder_hidden_states = block( + hidden_states=hidden_states, + hidden_states_masks=hidden_states_masks, + encoder_hidden_states=cur_encoder_hidden_states, + temb=temb, + image_rotary_emb=image_rotary_emb, + ) + initial_encoder_hidden_states = initial_encoder_hidden_states[:, :initial_encoder_hidden_states_seq_len] + block_id += 1 + + image_tokens_seq_len = hidden_states.shape[1] + hidden_states = torch.cat([hidden_states, initial_encoder_hidden_states], dim=1) + hidden_states_seq_len = hidden_states.shape[1] + if hidden_states_masks is not None: + encoder_attention_mask_ones = torch.ones( + (batch_size, initial_encoder_hidden_states.shape[1] + cur_llama31_encoder_hidden_states.shape[1]), + device=hidden_states_masks.device, + dtype=hidden_states_masks.dtype, + ) + hidden_states_masks = torch.cat([hidden_states_masks, encoder_attention_mask_ones], dim=1) + + for bid, block in enumerate(self.single_stream_blocks): + cur_llama31_encoder_hidden_states = encoder_hidden_states[block_id] + hidden_states = torch.cat([hidden_states, cur_llama31_encoder_hidden_states], dim=1) + if torch.is_grad_enabled() and self.gradient_checkpointing: + hidden_states = self._gradient_checkpointing_func( + block, + hidden_states, + hidden_states_masks, + None, + temb, + image_rotary_emb, + ) + else: + hidden_states = block( + hidden_states=hidden_states, + hidden_states_masks=hidden_states_masks, + encoder_hidden_states=None, + temb=temb, + image_rotary_emb=image_rotary_emb, + ) + hidden_states = hidden_states[:, :hidden_states_seq_len] + block_id += 1 + + hidden_states = hidden_states[:, :image_tokens_seq_len, ...] + output = self.final_layer(hidden_states, temb) + output = self.unpatchify(output, img_sizes, self.training) + if hidden_states_masks is not None: + hidden_states_masks = hidden_states_masks[:, :image_tokens_seq_len] + + if USE_PEFT_BACKEND: + # remove `lora_scale` from each PEFT layer + unscale_lora_layers(self, lora_scale) + + if not return_dict: + return (output, hidden_states_masks) + return Transformer2DModelOutput(sample=output, mask=hidden_states_masks) diff --git a/src/diffusers/pipelines/__init__.py b/src/diffusers/pipelines/__init__.py index a2f618857ac1..3007a991dbd9 100644 --- a/src/diffusers/pipelines/__init__.py +++ b/src/diffusers/pipelines/__init__.py @@ -221,6 +221,7 @@ "EasyAnimateInpaintPipeline", "EasyAnimateControlPipeline", ] + _import_structure["hidream_image"] = ["HiDreamImagePipeline"] _import_structure["hunyuandit"] = ["HunyuanDiTPipeline"] _import_structure["hunyuan_video"] = [ "HunyuanVideoPipeline", @@ -585,6 +586,7 @@ FluxPriorReduxPipeline, ReduxImageEncoder, ) + from .hidream_image import HiDreamImagePipeline from .hunyuan_video import ( HunyuanSkyreelsImageToVideoPipeline, HunyuanVideoImageToVideoPipeline, diff --git a/src/diffusers/pipelines/hidream_image/__init__.py b/src/diffusers/pipelines/hidream_image/__init__.py new file mode 100644 index 000000000000..498df900e68b --- /dev/null +++ b/src/diffusers/pipelines/hidream_image/__init__.py @@ -0,0 +1,47 @@ +from typing import TYPE_CHECKING + +from ...utils import ( + DIFFUSERS_SLOW_IMPORT, + OptionalDependencyNotAvailable, + _LazyModule, + get_objects_from_module, + is_torch_available, + is_transformers_available, +) + + +_dummy_objects = {} +_additional_imports = {} +_import_structure = {"pipeline_output": ["HiDreamImagePipelineOutput"]} + +try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from ...utils import dummy_torch_and_transformers_objects # noqa F403 + + _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) +else: + _import_structure["pipeline_hidream_image"] = ["HiDreamImagePipeline"] +if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: + try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() + except OptionalDependencyNotAvailable: + from ...utils.dummy_torch_and_transformers_objects import * # noqa F403 + else: + from .pipeline_hidream_image import HiDreamImagePipeline +else: + import sys + + sys.modules[__name__] = _LazyModule( + __name__, + globals()["__file__"], + _import_structure, + module_spec=__spec__, + ) + + for name, value in _dummy_objects.items(): + setattr(sys.modules[__name__], name, value) + for name, value in _additional_imports.items(): + setattr(sys.modules[__name__], name, value) diff --git a/src/diffusers/pipelines/hidream_image/pipeline_hidream_image.py b/src/diffusers/pipelines/hidream_image/pipeline_hidream_image.py new file mode 100644 index 000000000000..e16dedb53674 --- /dev/null +++ b/src/diffusers/pipelines/hidream_image/pipeline_hidream_image.py @@ -0,0 +1,739 @@ +import inspect +import math +from typing import Any, Callable, Dict, List, Optional, Union + +import torch +from transformers import ( + CLIPTextModelWithProjection, + CLIPTokenizer, + LlamaForCausalLM, + PreTrainedTokenizerFast, + T5EncoderModel, + T5Tokenizer, +) + +from ...image_processor import VaeImageProcessor +from ...models import AutoencoderKL, HiDreamImageTransformer2DModel +from ...schedulers import FlowMatchEulerDiscreteScheduler, UniPCMultistepScheduler +from ...utils import is_torch_xla_available, logging +from ...utils.torch_utils import randn_tensor +from ..pipeline_utils import DiffusionPipeline +from .pipeline_output import HiDreamImagePipelineOutput + + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import torch + >>> from transformers import PreTrainedTokenizerFast, LlamaForCausalLM + >>> from diffusers import UniPCMultistepScheduler, HiDreamImagePipeline, HiDreamImageTransformer2DModel + + >>> scheduler = UniPCMultistepScheduler( + ... flow_shift=3.0, prediction_type="flow_prediction", use_flow_sigmas=True + ... ) + + >>> tokenizer_4 = PreTrainedTokenizerFast.from_pretrained("meta-llama/Meta-Llama-3.1-8B-Instruct") + >>> text_encoder_4 = LlamaForCausalLM.from_pretrained( + ... "meta-llama/Meta-Llama-3.1-8B-Instruct", + ... output_hidden_states=True, + ... output_attentions=True, + ... torch_dtype=torch.bfloat16, + ... ) + + >>> transformer = HiDreamImageTransformer2DModel.from_pretrained( + ... "HiDream-ai/HiDream-I1-Full", subfolder="transformer", torch_dtype=torch.bfloat16 + ... ) + + >>> pipe = HiDreamImagePipeline.from_pretrained( + ... "HiDream-ai/HiDream-I1-Full", + ... scheduler=scheduler, + ... tokenizer_4=tokenizer_4, + ... text_encoder_4=text_encoder_4, + ... transformer=transformer, + ... torch_dtype=torch.bfloat16, + ... ) + >>> pipe.enable_model_cpu_offload() + + >>> image = pipe( + ... 'A cat holding a sign that says "Hi-Dreams.ai".', + ... height=1024, + ... width=1024, + ... guidance_scale=5.0, + ... num_inference_steps=50, + ... generator=torch.Generator("cuda").manual_seed(0), + ... ).images[0] + >>> image.save("output.png") + ``` +""" + + +# Copied from diffusers.pipelines.flux.pipeline_flux.calculate_shift +def calculate_shift( + image_seq_len, + base_seq_len: int = 256, + max_seq_len: int = 4096, + base_shift: float = 0.5, + max_shift: float = 1.15, +): + m = (max_shift - base_shift) / (max_seq_len - base_seq_len) + b = base_shift - m * base_seq_len + mu = image_seq_len * m + b + return mu + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: Optional[int] = None, + device: Optional[Union[str, torch.device]] = None, + timesteps: Optional[List[int]] = None, + sigmas: Optional[List[float]] = None, + **kwargs, +): + r""" + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`List[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`List[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +class HiDreamImagePipeline(DiffusionPipeline): + model_cpu_offload_seq = "text_encoder->text_encoder_2->text_encoder_3->text_encoder_4->transformer->vae" + _callback_tensor_inputs = ["latents", "prompt_embeds"] + + def __init__( + self, + scheduler: FlowMatchEulerDiscreteScheduler, + vae: AutoencoderKL, + text_encoder: CLIPTextModelWithProjection, + tokenizer: CLIPTokenizer, + text_encoder_2: CLIPTextModelWithProjection, + tokenizer_2: CLIPTokenizer, + text_encoder_3: T5EncoderModel, + tokenizer_3: T5Tokenizer, + text_encoder_4: LlamaForCausalLM, + tokenizer_4: PreTrainedTokenizerFast, + transformer: HiDreamImageTransformer2DModel, + ): + super().__init__() + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + text_encoder_2=text_encoder_2, + text_encoder_3=text_encoder_3, + text_encoder_4=text_encoder_4, + tokenizer=tokenizer, + tokenizer_2=tokenizer_2, + tokenizer_3=tokenizer_3, + tokenizer_4=tokenizer_4, + scheduler=scheduler, + transformer=transformer, + ) + self.vae_scale_factor = ( + 2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8 + ) + # HiDreamImage latents are turned into 2x2 patches and packed. This means the latent width and height has to be divisible + # by the patch size. So the vae scale factor is multiplied by the patch size to account for this + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor * 2) + self.default_sample_size = 128 + if getattr(self, "tokenizer_4", None) is not None: + self.tokenizer_4.pad_token = self.tokenizer_4.eos_token + + def _get_t5_prompt_embeds( + self, + prompt: Union[str, List[str]] = None, + max_sequence_length: int = 128, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + ): + device = device or self._execution_device + dtype = dtype or self.text_encoder_3.dtype + + prompt = [prompt] if isinstance(prompt, str) else prompt + + text_inputs = self.tokenizer_3( + prompt, + padding="max_length", + max_length=min(max_sequence_length, self.tokenizer_3.model_max_length), + truncation=True, + add_special_tokens=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + attention_mask = text_inputs.attention_mask + untruncated_ids = self.tokenizer_3(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): + removed_text = self.tokenizer_3.batch_decode( + untruncated_ids[:, min(max_sequence_length, self.tokenizer_3.model_max_length) - 1 : -1] + ) + logger.warning( + "The following part of your input was truncated because `max_sequence_length` is set to " + f" {min(max_sequence_length, self.tokenizer_3.model_max_length)} tokens: {removed_text}" + ) + + prompt_embeds = self.text_encoder_3(text_input_ids.to(device), attention_mask=attention_mask.to(device))[0] + prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) + return prompt_embeds + + def _get_clip_prompt_embeds( + self, + tokenizer, + text_encoder, + prompt: Union[str, List[str]], + max_sequence_length: int = 128, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + ): + device = device or self._execution_device + dtype = dtype or text_encoder.dtype + + prompt = [prompt] if isinstance(prompt, str) else prompt + + text_inputs = tokenizer( + prompt, + padding="max_length", + max_length=min(max_sequence_length, 218), + truncation=True, + return_tensors="pt", + ) + + text_input_ids = text_inputs.input_ids + untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): + removed_text = tokenizer.batch_decode(untruncated_ids[:, 218 - 1 : -1]) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {218} tokens: {removed_text}" + ) + prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True) + + # Use pooled output of CLIPTextModel + prompt_embeds = prompt_embeds[0] + prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) + return prompt_embeds + + def _get_llama3_prompt_embeds( + self, + prompt: Union[str, List[str]] = None, + max_sequence_length: int = 128, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + ): + device = device or self._execution_device + dtype = dtype or self.text_encoder_4.dtype + + prompt = [prompt] if isinstance(prompt, str) else prompt + + text_inputs = self.tokenizer_4( + prompt, + padding="max_length", + max_length=min(max_sequence_length, self.tokenizer_4.model_max_length), + truncation=True, + add_special_tokens=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + attention_mask = text_inputs.attention_mask + untruncated_ids = self.tokenizer_4(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): + removed_text = self.tokenizer_4.batch_decode( + untruncated_ids[:, min(max_sequence_length, self.tokenizer_4.model_max_length) - 1 : -1] + ) + logger.warning( + "The following part of your input was truncated because `max_sequence_length` is set to " + f" {min(max_sequence_length, self.tokenizer_4.model_max_length)} tokens: {removed_text}" + ) + + outputs = self.text_encoder_4( + text_input_ids.to(device), + attention_mask=attention_mask.to(device), + output_hidden_states=True, + output_attentions=True, + ) + + prompt_embeds = outputs.hidden_states[1:] + prompt_embeds = torch.stack(prompt_embeds, dim=0) + return prompt_embeds + + def encode_prompt( + self, + prompt: Union[str, List[str]], + prompt_2: Union[str, List[str]], + prompt_3: Union[str, List[str]], + prompt_4: Union[str, List[str]], + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + num_images_per_prompt: int = 1, + do_classifier_free_guidance: bool = True, + negative_prompt: Optional[Union[str, List[str]]] = None, + negative_prompt_2: Optional[Union[str, List[str]]] = None, + negative_prompt_3: Optional[Union[str, List[str]]] = None, + negative_prompt_4: Optional[Union[str, List[str]]] = None, + prompt_embeds: Optional[List[torch.FloatTensor]] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + pooled_prompt_embeds: Optional[torch.FloatTensor] = None, + negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None, + max_sequence_length: int = 128, + lora_scale: Optional[float] = None, + ): + prompt = [prompt] if isinstance(prompt, str) else prompt + if prompt is not None: + batch_size = len(prompt) + else: + batch_size = prompt_embeds[0].shape[0] if isinstance(prompt_embeds, list) else prompt_embeds.shape[0] + + prompt_embeds, pooled_prompt_embeds = self._encode_prompt( + prompt=prompt, + prompt_2=prompt_2, + prompt_3=prompt_3, + prompt_4=prompt_4, + device=device, + dtype=dtype, + num_images_per_prompt=num_images_per_prompt, + prompt_embeds=prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds, + max_sequence_length=max_sequence_length, + ) + + if do_classifier_free_guidance and negative_prompt_embeds is None: + negative_prompt = negative_prompt or "" + negative_prompt_2 = negative_prompt_2 or negative_prompt + negative_prompt_3 = negative_prompt_3 or negative_prompt + negative_prompt_4 = negative_prompt_4 or negative_prompt + + # normalize str to list + negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt + negative_prompt_2 = ( + batch_size * [negative_prompt_2] if isinstance(negative_prompt_2, str) else negative_prompt_2 + ) + negative_prompt_3 = ( + batch_size * [negative_prompt_3] if isinstance(negative_prompt_3, str) else negative_prompt_3 + ) + negative_prompt_4 = ( + batch_size * [negative_prompt_4] if isinstance(negative_prompt_4, str) else negative_prompt_4 + ) + + if prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + + negative_prompt_embeds, negative_pooled_prompt_embeds = self._encode_prompt( + prompt=negative_prompt, + prompt_2=negative_prompt_2, + prompt_3=negative_prompt_3, + prompt_4=negative_prompt_4, + device=device, + dtype=dtype, + num_images_per_prompt=num_images_per_prompt, + prompt_embeds=negative_prompt_embeds, + pooled_prompt_embeds=negative_pooled_prompt_embeds, + max_sequence_length=max_sequence_length, + ) + return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds + + def _encode_prompt( + self, + prompt: Union[str, List[str]], + prompt_2: Union[str, List[str]], + prompt_3: Union[str, List[str]], + prompt_4: Union[str, List[str]], + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + num_images_per_prompt: int = 1, + prompt_embeds: Optional[List[torch.FloatTensor]] = None, + pooled_prompt_embeds: Optional[torch.FloatTensor] = None, + max_sequence_length: int = 128, + ): + device = device or self._execution_device + if prompt is not None: + batch_size = len(prompt) + else: + batch_size = prompt_embeds[0].shape[0] if isinstance(prompt_embeds, list) else prompt_embeds.shape[0] + + if pooled_prompt_embeds is None: + prompt_2 = prompt_2 or prompt + prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2 + + pooled_prompt_embeds_1 = self._get_clip_prompt_embeds( + self.tokenizer, self.text_encoder, prompt, max_sequence_length, device, dtype + ) + pooled_prompt_embeds_2 = self._get_clip_prompt_embeds( + self.tokenizer_2, self.text_encoder_2, prompt_2, max_sequence_length, device, dtype + ) + pooled_prompt_embeds = torch.cat([pooled_prompt_embeds_1, pooled_prompt_embeds_2], dim=-1) + + pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt) + pooled_prompt_embeds = pooled_prompt_embeds.view(batch_size * num_images_per_prompt, -1) + + if prompt_embeds is None: + prompt_3 = prompt_3 or prompt + prompt_3 = [prompt_3] if isinstance(prompt_3, str) else prompt_3 + + prompt_4 = prompt_4 or prompt + prompt_4 = [prompt_4] if isinstance(prompt_4, str) else prompt_4 + + t5_prompt_embeds = self._get_t5_prompt_embeds(prompt_3, max_sequence_length, device, dtype) + llama3_prompt_embeds = self._get_llama3_prompt_embeds(prompt_4, max_sequence_length, device, dtype) + + _, seq_len, _ = t5_prompt_embeds.shape + t5_prompt_embeds = t5_prompt_embeds.repeat(1, num_images_per_prompt, 1) + t5_prompt_embeds = t5_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + _, _, seq_len, dim = llama3_prompt_embeds.shape + llama3_prompt_embeds = llama3_prompt_embeds.repeat(1, 1, num_images_per_prompt, 1) + llama3_prompt_embeds = llama3_prompt_embeds.view(-1, batch_size * num_images_per_prompt, seq_len, dim) + + prompt_embeds = [t5_prompt_embeds, llama3_prompt_embeds] + + return prompt_embeds, pooled_prompt_embeds + + def enable_vae_slicing(self): + r""" + Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to + compute decoding in several steps. This is useful to save some memory and allow larger batch sizes. + """ + self.vae.enable_slicing() + + def disable_vae_slicing(self): + r""" + Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to + computing decoding in one step. + """ + self.vae.disable_slicing() + + def enable_vae_tiling(self): + r""" + Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to + compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow + processing larger images. + """ + self.vae.enable_tiling() + + def disable_vae_tiling(self): + r""" + Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to + computing decoding in one step. + """ + self.vae.disable_tiling() + + def prepare_latents( + self, + batch_size, + num_channels_latents, + height, + width, + dtype, + device, + generator, + latents=None, + ): + # VAE applies 8x compression on images but we must also account for packing which requires + # latent height and width to be divisible by 2. + height = 2 * (int(height) // (self.vae_scale_factor * 2)) + width = 2 * (int(width) // (self.vae_scale_factor * 2)) + + shape = (batch_size, num_channels_latents, height, width) + + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + if latents.shape != shape: + raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}") + latents = latents.to(device) + return latents + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def do_classifier_free_guidance(self): + return self._guidance_scale > 1 + + @property + def attention_kwargs(self): + return self._attention_kwargs + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def interrupt(self): + return self._interrupt + + @torch.no_grad() + def __call__( + self, + prompt: Union[str, List[str]] = None, + prompt_2: Optional[Union[str, List[str]]] = None, + prompt_3: Optional[Union[str, List[str]]] = None, + prompt_4: Optional[Union[str, List[str]]] = None, + height: Optional[int] = None, + width: Optional[int] = None, + num_inference_steps: int = 50, + sigmas: Optional[List[float]] = None, + guidance_scale: float = 5.0, + negative_prompt: Optional[Union[str, List[str]]] = None, + negative_prompt_2: Optional[Union[str, List[str]]] = None, + negative_prompt_3: Optional[Union[str, List[str]]] = None, + negative_prompt_4: Optional[Union[str, List[str]]] = None, + num_images_per_prompt: Optional[int] = 1, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.FloatTensor] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + pooled_prompt_embeds: Optional[torch.FloatTensor] = None, + negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + attention_kwargs: Optional[Dict[str, Any]] = None, + callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + max_sequence_length: int = 128, + ): + height = height or self.default_sample_size * self.vae_scale_factor + width = width or self.default_sample_size * self.vae_scale_factor + + division = self.vae_scale_factor * 2 + S_max = (self.default_sample_size * self.vae_scale_factor) ** 2 + scale = S_max / (width * height) + scale = math.sqrt(scale) + width, height = int(width * scale // division * division), int(height * scale // division * division) + + self._guidance_scale = guidance_scale + self._attention_kwargs = attention_kwargs + self._interrupt = False + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + elif prompt_embeds is not None: + batch_size = prompt_embeds[0].shape[0] if isinstance(prompt_embeds, list) else prompt_embeds.shape[0] + else: + batch_size = 1 + + device = self._execution_device + + lora_scale = self.attention_kwargs.get("scale", None) if self.attention_kwargs is not None else None + ( + prompt_embeds, + negative_prompt_embeds, + pooled_prompt_embeds, + negative_pooled_prompt_embeds, + ) = self.encode_prompt( + prompt=prompt, + prompt_2=prompt_2, + prompt_3=prompt_3, + prompt_4=prompt_4, + negative_prompt=negative_prompt, + negative_prompt_2=negative_prompt_2, + negative_prompt_3=negative_prompt_3, + negative_prompt_4=negative_prompt_4, + do_classifier_free_guidance=self.do_classifier_free_guidance, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds, + negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, + device=device, + num_images_per_prompt=num_images_per_prompt, + max_sequence_length=max_sequence_length, + lora_scale=lora_scale, + ) + + if self.do_classifier_free_guidance: + prompt_embeds_arr = [] + for n, p in zip(negative_prompt_embeds, prompt_embeds): + if len(n.shape) == 3: + prompt_embeds_arr.append(torch.cat([n, p], dim=0)) + else: + prompt_embeds_arr.append(torch.cat([n, p], dim=1)) + prompt_embeds = prompt_embeds_arr + pooled_prompt_embeds = torch.cat([negative_pooled_prompt_embeds, pooled_prompt_embeds], dim=0) + + # 4. Prepare latent variables + num_channels_latents = self.transformer.config.in_channels + latents = self.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + pooled_prompt_embeds.dtype, + device, + generator, + latents, + ) + + if latents.shape[-2] != latents.shape[-1]: + B, C, H, W = latents.shape + pH, pW = H // self.transformer.config.patch_size, W // self.transformer.config.patch_size + + img_sizes = torch.tensor([pH, pW], dtype=torch.int64).reshape(-1) + img_ids = torch.zeros(pH, pW, 3) + img_ids[..., 1] = img_ids[..., 1] + torch.arange(pH)[:, None] + img_ids[..., 2] = img_ids[..., 2] + torch.arange(pW)[None, :] + img_ids = img_ids.reshape(pH * pW, -1) + img_ids_pad = torch.zeros(self.transformer.max_seq, 3) + img_ids_pad[: pH * pW, :] = img_ids + + img_sizes = img_sizes.unsqueeze(0).to(latents.device) + img_ids = img_ids_pad.unsqueeze(0).to(latents.device) + if self.do_classifier_free_guidance: + img_sizes = img_sizes.repeat(2 * B, 1) + img_ids = img_ids.repeat(2 * B, 1, 1) + else: + img_sizes = img_ids = None + + # 5. Prepare timesteps + mu = calculate_shift(self.transformer.max_seq) + scheduler_kwargs = {"mu": mu} + if isinstance(self.scheduler, UniPCMultistepScheduler): + self.scheduler.set_timesteps(num_inference_steps, device=device) # , shift=math.exp(mu)) + timesteps = self.scheduler.timesteps + else: + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, + num_inference_steps, + device, + sigmas=sigmas, + **scheduler_kwargs, + ) + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + self._num_timesteps = len(timesteps) + + # 6. Denoising loop + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timestep = t.expand(latent_model_input.shape[0]) + + noise_pred = self.transformer( + hidden_states=latent_model_input, + timesteps=timestep, + encoder_hidden_states=prompt_embeds, + pooled_embeds=pooled_prompt_embeds, + img_sizes=img_sizes, + img_ids=img_ids, + return_dict=False, + )[0] + noise_pred = -noise_pred + + # perform guidance + if self.do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + latents_dtype = latents.dtype + latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] + + if latents.dtype != latents_dtype: + if torch.backends.mps.is_available(): + # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 + latents = latents.to(latents_dtype) + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + if XLA_AVAILABLE: + xm.mark_step() + + if output_type == "latent": + image = latents + + else: + latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor + + image = self.vae.decode(latents, return_dict=False)[0] + image = self.image_processor.postprocess(image, output_type=output_type) + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (image,) + + return HiDreamImagePipelineOutput(images=image) diff --git a/src/diffusers/pipelines/hidream_image/pipeline_output.py b/src/diffusers/pipelines/hidream_image/pipeline_output.py new file mode 100644 index 000000000000..1890a8a3f5f1 --- /dev/null +++ b/src/diffusers/pipelines/hidream_image/pipeline_output.py @@ -0,0 +1,21 @@ +from dataclasses import dataclass +from typing import List, Union + +import numpy as np +import PIL.Image + +from ...utils import BaseOutput + + +@dataclass +class HiDreamImagePipelineOutput(BaseOutput): + """ + Output class for HiDreamImage pipelines. + + Args: + images (`List[PIL.Image.Image]` or `np.ndarray`) + List of denoised PIL images of length `batch_size` or numpy array of shape `(batch_size, height, width, + num_channels)`. PIL images or numpy array present the denoised images of the diffusion pipeline. + """ + + images: Union[List[PIL.Image.Image], np.ndarray] diff --git a/src/diffusers/utils/dummy_pt_objects.py b/src/diffusers/utils/dummy_pt_objects.py index dd9117ddca18..c2dffbb1d1b8 100644 --- a/src/diffusers/utils/dummy_pt_objects.py +++ b/src/diffusers/utils/dummy_pt_objects.py @@ -505,6 +505,21 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) +class HiDreamImageTransformer2DModel(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + class HunyuanDiT2DControlNetModel(metaclass=DummyObject): _backends = ["torch"] diff --git a/src/diffusers/utils/dummy_torch_and_transformers_objects.py b/src/diffusers/utils/dummy_torch_and_transformers_objects.py index 2136131126e9..2dc6160b1e5c 100644 --- a/src/diffusers/utils/dummy_torch_and_transformers_objects.py +++ b/src/diffusers/utils/dummy_torch_and_transformers_objects.py @@ -617,6 +617,21 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) +class HiDreamImagePipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + class HunyuanDiTControlNetPipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] diff --git a/tests/pipelines/hidream/__init__.py b/tests/pipelines/hidream/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/tests/pipelines/hidream/test_pipeline_hidream.py b/tests/pipelines/hidream/test_pipeline_hidream.py new file mode 100644 index 000000000000..597a20216882 --- /dev/null +++ b/tests/pipelines/hidream/test_pipeline_hidream.py @@ -0,0 +1,156 @@ +# coding=utf-8 +# Copyright 2024 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np +import torch +from transformers import ( + AutoTokenizer, + CLIPTextConfig, + CLIPTextModelWithProjection, + CLIPTokenizer, + LlamaForCausalLM, + T5EncoderModel, +) + +from diffusers import ( + AutoencoderKL, + FlowMatchEulerDiscreteScheduler, + HiDreamImagePipeline, + HiDreamImageTransformer2DModel, +) +from diffusers.utils.testing_utils import enable_full_determinism + +from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS +from ..test_pipelines_common import PipelineTesterMixin + + +enable_full_determinism() + + +class HiDreamImagePipelineFastTests(PipelineTesterMixin, unittest.TestCase): + pipeline_class = HiDreamImagePipeline + params = TEXT_TO_IMAGE_PARAMS - {"cross_attention_kwargs"} + batch_params = TEXT_TO_IMAGE_BATCH_PARAMS + image_params = TEXT_TO_IMAGE_IMAGE_PARAMS + image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS + + required_optional_params = PipelineTesterMixin.required_optional_params + test_layerwise_casting = True + supports_dduf = False + + def get_dummy_components(self): + torch.manual_seed(0) + transformer = HiDreamImageTransformer2DModel( + patch_size=2, + in_channels=4, + out_channels=4, + num_layers=1, + num_single_layers=1, + attention_head_dim=8, + num_attention_heads=4, + caption_channels=[32, 16], + text_emb_dim=64, + num_routed_experts=4, + num_activated_experts=2, + axes_dims_rope=(4, 2, 2), + max_resolution=(32, 32), + llama_layers=(0, 1), + ).eval() + torch.manual_seed(0) + vae = AutoencoderKL(scaling_factor=0.3611, shift_factor=0.1159) + clip_text_encoder_config = CLIPTextConfig( + bos_token_id=0, + eos_token_id=2, + hidden_size=32, + intermediate_size=37, + layer_norm_eps=1e-05, + num_attention_heads=4, + num_hidden_layers=5, + pad_token_id=1, + vocab_size=1000, + hidden_act="gelu", + projection_dim=32, + max_position_embeddings=128, + ) + + torch.manual_seed(0) + text_encoder = CLIPTextModelWithProjection(clip_text_encoder_config) + + torch.manual_seed(0) + text_encoder_2 = CLIPTextModelWithProjection(clip_text_encoder_config) + + torch.manual_seed(0) + text_encoder_3 = T5EncoderModel.from_pretrained("hf-internal-testing/tiny-random-t5") + + torch.manual_seed(0) + text_encoder_4 = LlamaForCausalLM.from_pretrained("hf-internal-testing/tiny-random-LlamaForCausalLM") + text_encoder_4.generation_config.pad_token_id = 1 + tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip") + tokenizer_2 = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip") + tokenizer_3 = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-t5") + tokenizer_4 = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-LlamaForCausalLM") + + scheduler = FlowMatchEulerDiscreteScheduler() + + components = { + "scheduler": scheduler, + "vae": vae, + "text_encoder": text_encoder, + "tokenizer": tokenizer, + "text_encoder_2": text_encoder_2, + "tokenizer_2": tokenizer_2, + "text_encoder_3": text_encoder_3, + "tokenizer_3": tokenizer_3, + "text_encoder_4": text_encoder_4, + "tokenizer_4": tokenizer_4, + "transformer": transformer, + } + return components + + def get_dummy_inputs(self, device, seed=0): + if str(device).startswith("mps"): + generator = torch.manual_seed(seed) + else: + generator = torch.Generator(device=device).manual_seed(seed) + inputs = { + "prompt": "A painting of a squirrel eating a burger", + "generator": generator, + "num_inference_steps": 2, + "guidance_scale": 5.0, + "output_type": "np", + } + return inputs + + def test_inference(self): + device = "cpu" + + components = self.get_dummy_components() + pipe = self.pipeline_class(**components) + pipe.to(device) + pipe.set_progress_bar_config(disable=None) + + inputs = self.get_dummy_inputs(device) + image = pipe(**inputs)[0] + generated_image = image[0] + + self.assertEqual(generated_image.shape, (128, 128, 3)) + expected_image = torch.randn(128, 128, 3).numpy() + max_diff = np.abs(generated_image - expected_image).max() + self.assertLessEqual(max_diff, 1e10) + + def test_inference_batch_single_identical(self): + super().test_inference_batch_single_identical(expected_max_diff=3e-4)