From ab2476bf095239dcbe9800c41c6c7972b3dfd365 Mon Sep 17 00:00:00 2001 From: Aryan Date: Thu, 6 Mar 2025 11:49:44 +0100 Subject: [PATCH 01/11] update --- scripts/convert_hunyuan_video_to_diffusers.py | 63 +++++++++++++++++-- .../transformers/transformer_hunyuan_video.py | 12 +++- 2 files changed, 67 insertions(+), 8 deletions(-) diff --git a/scripts/convert_hunyuan_video_to_diffusers.py b/scripts/convert_hunyuan_video_to_diffusers.py index 464c9e0fb954..4a7cc376e319 100644 --- a/scripts/convert_hunyuan_video_to_diffusers.py +++ b/scripts/convert_hunyuan_video_to_diffusers.py @@ -3,7 +3,7 @@ import torch from accelerate import init_empty_weights -from transformers import AutoModel, AutoTokenizer, CLIPTextModel, CLIPTokenizer +from transformers import AutoModel, AutoTokenizer, CLIPTextModel, CLIPTokenizer, LlavaForConditionalGeneration from diffusers import ( AutoencoderKLHunyuanVideo, @@ -134,6 +134,46 @@ def remap_single_transformer_blocks_(key, state_dict): VAE_SPECIAL_KEYS_REMAP = {} +TRANSFORMER_CONFIGS = { + "HYVideo-T/2-cfgdistill": { + "in_channels": 16, + "out_channels": 16, + "num_attention_heads": 24, + "attention_head_dim": 128, + "num_layers": 20, + "num_single_layers": 40, + "num_refiner_layers": 2, + "mlp_ratio": 4.0, + "patch_size": 2, + "patch_size_t": 1, + "qk_norm": "rms_norm", + "guidance_embeds": True, + "text_embed_dim": 4096, + "pooled_projection_dim": 768, + "rope_theta": 256.0, + "rope_axes_dim": (16, 56, 56), + }, + "HYVideo-T/2": { + "in_channels": 16 * 2 + 1, + "out_channels": 16, + "num_attention_heads": 24, + "attention_head_dim": 128, + "num_layers": 20, + "num_single_layers": 40, + "num_refiner_layers": 2, + "mlp_ratio": 4.0, + "patch_size": 2, + "patch_size_t": 1, + "qk_norm": "rms_norm", + "guidance_embeds": False, + "text_embed_dim": 4096, + "pooled_projection_dim": 768, + "rope_theta": 256.0, + "rope_axes_dim": (16, 56, 56), + }, +} + + def update_state_dict_(state_dict: Dict[str, Any], old_key: str, new_key: str) -> Dict[str, Any]: state_dict[new_key] = state_dict.pop(old_key) @@ -149,11 +189,12 @@ def get_state_dict(saved_dict: Dict[str, Any]) -> Dict[str, Any]: return state_dict -def convert_transformer(ckpt_path: str): +def convert_transformer(ckpt_path: str, transformer_type: str): original_state_dict = get_state_dict(torch.load(ckpt_path, map_location="cpu", weights_only=True)) + config = TRANSFORMER_CONFIGS[transformer_type] with init_empty_weights(): - transformer = HunyuanVideoTransformer3DModel() + transformer = HunyuanVideoTransformer3DModel(**config) for key in list(original_state_dict.keys()): new_key = key[:] @@ -205,6 +246,10 @@ def get_args(): parser.add_argument("--save_pipeline", action="store_true") parser.add_argument("--output_path", type=str, required=True, help="Path where converted model should be saved") parser.add_argument("--dtype", default="bf16", help="Torch dtype to save the transformer in.") + parser.add_argument( + "--transformer_type", type=str, default="HYVideo-T/2-cfgdistill", choices=list(TRANSFORMER_CONFIGS.keys()) + ) + parser.add_argument("--flow_shift", type=float, default=7.0) return parser.parse_args() @@ -228,7 +273,7 @@ def get_args(): assert args.text_encoder_2_path is not None if args.transformer_ckpt_path is not None: - transformer = convert_transformer(args.transformer_ckpt_path) + transformer = convert_transformer(args.transformer_ckpt_path, args.transformer_type) transformer = transformer.to(dtype=dtype) if not args.save_pipeline: transformer.save_pretrained(args.output_path, safe_serialization=True, max_shard_size="5GB") @@ -239,11 +284,17 @@ def get_args(): vae.save_pretrained(args.output_path, safe_serialization=True, max_shard_size="5GB") if args.save_pipeline: - text_encoder = AutoModel.from_pretrained(args.text_encoder_path, torch_dtype=torch.float16) + if args.transformer_type == "HYVideo-T/2-cfgdistill": + text_encoder = AutoModel.from_pretrained(args.text_encoder_path, torch_dtype=torch.float16) + else: + text_encoder = LlavaForConditionalGeneration.from_pretrained( + args.text_encoder_path, torch_dtype=torch.float16 + ) + tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_path, padding_side="right") text_encoder_2 = CLIPTextModel.from_pretrained(args.text_encoder_2_path, torch_dtype=torch.float16) tokenizer_2 = CLIPTokenizer.from_pretrained(args.text_encoder_2_path) - scheduler = FlowMatchEulerDiscreteScheduler(shift=7.0) + scheduler = FlowMatchEulerDiscreteScheduler(shift=args.flow_shift) pipe = HunyuanVideoPipeline( transformer=transformer, diff --git a/src/diffusers/models/transformers/transformer_hunyuan_video.py b/src/diffusers/models/transformers/transformer_hunyuan_video.py index c78d13344d81..bb0cef057992 100644 --- a/src/diffusers/models/transformers/transformer_hunyuan_video.py +++ b/src/diffusers/models/transformers/transformer_hunyuan_video.py @@ -581,7 +581,11 @@ def __init__( self.context_embedder = HunyuanVideoTokenRefiner( text_embed_dim, num_attention_heads, attention_head_dim, num_layers=num_refiner_layers ) - self.time_text_embed = CombinedTimestepGuidanceTextProjEmbeddings(inner_dim, pooled_projection_dim) + + if guidance_embeds: + self.time_text_embed = CombinedTimestepGuidanceTextProjEmbeddings(inner_dim, pooled_projection_dim) + else: + self.time_text_embed = CombinedTimestepTextProjEmbeddings(inner_dim, pooled_projection_dim) # 2. RoPE self.rope = HunyuanVideoRotaryPosEmbed(patch_size, patch_size_t, rope_axes_dim, rope_theta) @@ -708,7 +712,11 @@ def forward( image_rotary_emb = self.rope(hidden_states) # 2. Conditional embeddings - temb = self.time_text_embed(timestep, guidance, pooled_projections) + if self.config.guidance_embeds: + temb = self.time_text_embed(timestep, guidance, pooled_projections) + else: + temb = self.time_text_embed(timestep, pooled_projections) + hidden_states = self.x_embedder(hidden_states) encoder_hidden_states = self.context_embedder(encoder_hidden_states, timestep, encoder_attention_mask) From 77abad38f969764d26b4d4c65bc1b35f99c021e9 Mon Sep 17 00:00:00 2001 From: Aryan Date: Thu, 6 Mar 2025 12:28:43 +0100 Subject: [PATCH 02/11] update --- scripts/convert_hunyuan_video_to_diffusers.py | 60 +- src/diffusers/__init__.py | 2 + src/diffusers/pipelines/__init__.py | 12 +- .../pipelines/hunyuan_video/__init__.py | 2 + .../pipeline_hunyuan_video_image2video.py | 761 ++++++++++++++++++ .../dummy_torch_and_transformers_objects.py | 15 + 6 files changed, 832 insertions(+), 20 deletions(-) create mode 100644 src/diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video_image2video.py diff --git a/scripts/convert_hunyuan_video_to_diffusers.py b/scripts/convert_hunyuan_video_to_diffusers.py index 4a7cc376e319..1dce0490d358 100644 --- a/scripts/convert_hunyuan_video_to_diffusers.py +++ b/scripts/convert_hunyuan_video_to_diffusers.py @@ -3,11 +3,19 @@ import torch from accelerate import init_empty_weights -from transformers import AutoModel, AutoTokenizer, CLIPTextModel, CLIPTokenizer, LlavaForConditionalGeneration +from transformers import ( + AutoModel, + AutoTokenizer, + CLIPImageProcessor, + CLIPTextModel, + CLIPTokenizer, + LlavaForConditionalGeneration, +) from diffusers import ( AutoencoderKLHunyuanVideo, FlowMatchEulerDiscreteScheduler, + HunyuanVideoImageToVideoPipeline, HunyuanVideoPipeline, HunyuanVideoTransformer3DModel, ) @@ -153,7 +161,7 @@ def remap_single_transformer_blocks_(key, state_dict): "rope_theta": 256.0, "rope_axes_dim": (16, 56, 56), }, - "HYVideo-T/2": { + "HYVideo-T/2-I2V": { "in_channels": 16 * 2 + 1, "out_channels": 16, "num_attention_heads": 24, @@ -286,23 +294,39 @@ def get_args(): if args.save_pipeline: if args.transformer_type == "HYVideo-T/2-cfgdistill": text_encoder = AutoModel.from_pretrained(args.text_encoder_path, torch_dtype=torch.float16) + tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_path, padding_side="right") + text_encoder_2 = CLIPTextModel.from_pretrained(args.text_encoder_2_path, torch_dtype=torch.float16) + tokenizer_2 = CLIPTokenizer.from_pretrained(args.text_encoder_2_path) + scheduler = FlowMatchEulerDiscreteScheduler(shift=args.flow_shift) + + pipe = HunyuanVideoPipeline( + transformer=transformer, + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + text_encoder_2=text_encoder_2, + tokenizer_2=tokenizer_2, + scheduler=scheduler, + ) + pipe.save_pretrained(args.output_path, safe_serialization=True, max_shard_size="5GB") else: text_encoder = LlavaForConditionalGeneration.from_pretrained( args.text_encoder_path, torch_dtype=torch.float16 ) - - tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_path, padding_side="right") - text_encoder_2 = CLIPTextModel.from_pretrained(args.text_encoder_2_path, torch_dtype=torch.float16) - tokenizer_2 = CLIPTokenizer.from_pretrained(args.text_encoder_2_path) - scheduler = FlowMatchEulerDiscreteScheduler(shift=args.flow_shift) - - pipe = HunyuanVideoPipeline( - transformer=transformer, - vae=vae, - text_encoder=text_encoder, - tokenizer=tokenizer, - text_encoder_2=text_encoder_2, - tokenizer_2=tokenizer_2, - scheduler=scheduler, - ) - pipe.save_pretrained(args.output_path, safe_serialization=True, max_shard_size="5GB") + tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_path, padding_side="right") + text_encoder_2 = CLIPTextModel.from_pretrained(args.text_encoder_2_path, torch_dtype=torch.float16) + tokenizer_2 = CLIPTokenizer.from_pretrained(args.text_encoder_2_path) + scheduler = FlowMatchEulerDiscreteScheduler(shift=args.flow_shift) + image_processor = CLIPImageProcessor.from_pretrained(args.text_encoder_2_path) + + pipe = HunyuanVideoImageToVideoPipeline( + transformer=transformer, + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + text_encoder_2=text_encoder_2, + tokenizer_2=tokenizer_2, + scheduler=scheduler, + image_processor=image_processor, + ) + pipe.save_pretrained(args.output_path, safe_serialization=True, max_shard_size="5GB") diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index cfb0bd08f818..d5cfad915e3c 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -313,6 +313,7 @@ "HunyuanDiTPAGPipeline", "HunyuanDiTPipeline", "HunyuanSkyreelsImageToVideoPipeline", + "HunyuanVideoImageToVideoPipeline", "HunyuanVideoPipeline", "I2VGenXLPipeline", "IFImg2ImgPipeline", @@ -823,6 +824,7 @@ HunyuanDiTPAGPipeline, HunyuanDiTPipeline, HunyuanSkyreelsImageToVideoPipeline, + HunyuanVideoImageToVideoPipeline, HunyuanVideoPipeline, I2VGenXLPipeline, IFImg2ImgPipeline, diff --git a/src/diffusers/pipelines/__init__.py b/src/diffusers/pipelines/__init__.py index e99162e7a7fe..8b76e109e754 100644 --- a/src/diffusers/pipelines/__init__.py +++ b/src/diffusers/pipelines/__init__.py @@ -222,7 +222,11 @@ "EasyAnimateControlPipeline", ] _import_structure["hunyuandit"] = ["HunyuanDiTPipeline"] - _import_structure["hunyuan_video"] = ["HunyuanVideoPipeline", "HunyuanSkyreelsImageToVideoPipeline"] + _import_structure["hunyuan_video"] = [ + "HunyuanVideoPipeline", + "HunyuanSkyreelsImageToVideoPipeline", + "HunyuanVideoImageToVideoPipeline", + ] _import_structure["kandinsky"] = [ "KandinskyCombinedPipeline", "KandinskyImg2ImgCombinedPipeline", @@ -570,7 +574,11 @@ FluxPriorReduxPipeline, ReduxImageEncoder, ) - from .hunyuan_video import HunyuanSkyreelsImageToVideoPipeline, HunyuanVideoPipeline + from .hunyuan_video import ( + HunyuanSkyreelsImageToVideoPipeline, + HunyuanVideoImageToVideoPipeline, + HunyuanVideoPipeline, + ) from .hunyuandit import HunyuanDiTPipeline from .i2vgen_xl import I2VGenXLPipeline from .kandinsky import ( diff --git a/src/diffusers/pipelines/hunyuan_video/__init__.py b/src/diffusers/pipelines/hunyuan_video/__init__.py index cc9d4729e175..d9cacad24f17 100644 --- a/src/diffusers/pipelines/hunyuan_video/__init__.py +++ b/src/diffusers/pipelines/hunyuan_video/__init__.py @@ -24,6 +24,7 @@ else: _import_structure["pipeline_hunyuan_skyreels_image2video"] = ["HunyuanSkyreelsImageToVideoPipeline"] _import_structure["pipeline_hunyuan_video"] = ["HunyuanVideoPipeline"] + _import_structure["pipeline_hunyuan_video_image2video"] = ["HunyuanVideoImageToVideoPipeline"] if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: try: @@ -35,6 +36,7 @@ else: from .pipeline_hunyuan_skyreels_image2video import HunyuanSkyreelsImageToVideoPipeline from .pipeline_hunyuan_video import HunyuanVideoPipeline + from .pipeline_hunyuan_video_image2video import HunyuanVideoImageToVideoPipeline else: import sys diff --git a/src/diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video_image2video.py b/src/diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video_image2video.py new file mode 100644 index 000000000000..e97398d0fe4f --- /dev/null +++ b/src/diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video_image2video.py @@ -0,0 +1,761 @@ +# Copyright 2024 The HunyuanVideo Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from typing import Any, Callable, Dict, List, Optional, Tuple, Union + +import numpy as np +import torch +from transformers import ( + CLIPImageProcessor, + CLIPTextModel, + CLIPTokenizer, + LlamaTokenizerFast, + LlavaForConditionalGeneration, +) + +from ...callbacks import MultiPipelineCallbacks, PipelineCallback +from ...loaders import HunyuanVideoLoraLoaderMixin +from ...models import AutoencoderKLHunyuanVideo, HunyuanVideoTransformer3DModel +from ...schedulers import FlowMatchEulerDiscreteScheduler +from ...utils import is_torch_xla_available, logging, replace_example_docstring +from ...utils.torch_utils import randn_tensor +from ...video_processor import VideoProcessor +from ..pipeline_utils import DiffusionPipeline +from .pipeline_output import HunyuanVideoPipelineOutput + + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +EXAMPLE_DOC_STRING = """ + Examples: + ```python + >>> import torch + >>> from diffusers import HunyuanVideoPipeline, HunyuanVideoTransformer3DModel + >>> from diffusers.utils import export_to_video + + >>> model_id = "hunyuanvideo-community/HunyuanVideo" + >>> transformer = HunyuanVideoTransformer3DModel.from_pretrained( + ... model_id, subfolder="transformer", torch_dtype=torch.bfloat16 + ... ) + >>> pipe = HunyuanVideoPipeline.from_pretrained(model_id, transformer=transformer, torch_dtype=torch.float16) + >>> pipe.vae.enable_tiling() + >>> pipe.to("cuda") + + >>> output = pipe( + ... prompt="A cat walks on the grass, realistic", + ... height=320, + ... width=512, + ... num_frames=61, + ... num_inference_steps=30, + ... ).frames[0] + >>> export_to_video(output, "output.mp4", fps=15) + ``` +""" + + +DEFAULT_PROMPT_TEMPLATE = { + "template": ( + "<|start_header_id|>system<|end_header_id|>\n\nDescribe the video by detailing the following aspects: " + "1. The main content and theme of the video." + "2. The color, shape, size, texture, quantity, text, and spatial relationships of the objects." + "3. Actions, events, behaviors temporal relationships, physical movement changes of the objects." + "4. background environment, light, style and atmosphere." + "5. camera angles, movements, and transitions used in the video:<|eot_id|>" + "<|start_header_id|>user<|end_header_id|>\n\n{}<|eot_id|>" + ), + "crop_start": 95, +} + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: Optional[int] = None, + device: Optional[Union[str, torch.device]] = None, + timesteps: Optional[List[int]] = None, + sigmas: Optional[List[float]] = None, + **kwargs, +): + r""" + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`List[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`List[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +class HunyuanVideoImageToVideoPipeline(DiffusionPipeline, HunyuanVideoLoraLoaderMixin): + r""" + Pipeline for image-to-video generation using HunyuanVideo. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods + implemented for all pipelines (downloading, saving, running on a particular device, etc.). + + Args: + text_encoder ([`LlavaForConditionalGeneration`]): + [Llava Llama3-8B](https://huggingface.co/xtuner/llava-llama-3-8b-v1_1-transformers). + tokenizer (`LlamaTokenizer`): + Tokenizer from [Llava Llama3-8B](https://huggingface.co/xtuner/llava-llama-3-8b-v1_1-transformers). + transformer ([`HunyuanVideoTransformer3DModel`]): + Conditional Transformer to denoise the encoded image latents. + scheduler ([`FlowMatchEulerDiscreteScheduler`]): + A scheduler to be used in combination with `transformer` to denoise the encoded image latents. + vae ([`AutoencoderKLHunyuanVideo`]): + Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations. + text_encoder_2 ([`CLIPTextModel`]): + [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically + the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. + tokenizer_2 (`CLIPTokenizer`): + Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/en/model_doc/clip#transformers.CLIPTokenizer). + """ + + model_cpu_offload_seq = "text_encoder->text_encoder_2->transformer->vae" + _callback_tensor_inputs = ["latents", "prompt_embeds"] + + def __init__( + self, + text_encoder: LlavaForConditionalGeneration, + tokenizer: LlamaTokenizerFast, + transformer: HunyuanVideoTransformer3DModel, + vae: AutoencoderKLHunyuanVideo, + scheduler: FlowMatchEulerDiscreteScheduler, + text_encoder_2: CLIPTextModel, + tokenizer_2: CLIPTokenizer, + image_processor: CLIPImageProcessor, + ): + super().__init__() + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + transformer=transformer, + scheduler=scheduler, + text_encoder_2=text_encoder_2, + tokenizer_2=tokenizer_2, + ) + + self.vae_scale_factor_temporal = self.vae.temporal_compression_ratio if getattr(self, "vae", None) else 4 + self.vae_scale_factor_spatial = self.vae.spatial_compression_ratio if getattr(self, "vae", None) else 8 + self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial) + + def _get_llama_prompt_embeds( + self, + prompt: Union[str, List[str]], + prompt_template: Dict[str, Any], + num_videos_per_prompt: int = 1, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + max_sequence_length: int = 256, + num_hidden_layers_to_skip: int = 2, + ) -> Tuple[torch.Tensor, torch.Tensor]: + device = device or self._execution_device + dtype = dtype or self.text_encoder.dtype + + prompt = [prompt] if isinstance(prompt, str) else prompt + batch_size = len(prompt) + + prompt = [prompt_template["template"].format(p) for p in prompt] + + crop_start = prompt_template.get("crop_start", None) + if crop_start is None: + prompt_template_input = self.tokenizer( + prompt_template["template"], + padding="max_length", + return_tensors="pt", + return_length=False, + return_overflowing_tokens=False, + return_attention_mask=False, + ) + crop_start = prompt_template_input["input_ids"].shape[-1] + # Remove <|eot_id|> token and placeholder {} + crop_start -= 2 + + max_sequence_length += crop_start + text_inputs = self.tokenizer( + prompt, + max_length=max_sequence_length, + padding="max_length", + truncation=True, + return_tensors="pt", + return_length=False, + return_overflowing_tokens=False, + return_attention_mask=True, + ) + text_input_ids = text_inputs.input_ids.to(device=device) + prompt_attention_mask = text_inputs.attention_mask.to(device=device) + + prompt_embeds = self.text_encoder( + input_ids=text_input_ids, + attention_mask=prompt_attention_mask, + output_hidden_states=True, + ).hidden_states[-(num_hidden_layers_to_skip + 1)] + prompt_embeds = prompt_embeds.to(dtype=dtype) + + if crop_start is not None and crop_start > 0: + prompt_embeds = prompt_embeds[:, crop_start:] + prompt_attention_mask = prompt_attention_mask[:, crop_start:] + + # duplicate text embeddings for each generation per prompt, using mps friendly method + _, seq_len, _ = prompt_embeds.shape + prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1) + prompt_attention_mask = prompt_attention_mask.repeat(1, num_videos_per_prompt) + prompt_attention_mask = prompt_attention_mask.view(batch_size * num_videos_per_prompt, seq_len) + + return prompt_embeds, prompt_attention_mask + + def _get_clip_prompt_embeds( + self, + prompt: Union[str, List[str]], + num_videos_per_prompt: int = 1, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + max_sequence_length: int = 77, + ) -> torch.Tensor: + device = device or self._execution_device + dtype = dtype or self.text_encoder_2.dtype + + prompt = [prompt] if isinstance(prompt, str) else prompt + batch_size = len(prompt) + + text_inputs = self.tokenizer_2( + prompt, + padding="max_length", + max_length=max_sequence_length, + truncation=True, + return_tensors="pt", + ) + + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer_2(prompt, padding="longest", return_tensors="pt").input_ids + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): + removed_text = self.tokenizer_2.batch_decode(untruncated_ids[:, max_sequence_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {max_sequence_length} tokens: {removed_text}" + ) + + prompt_embeds = self.text_encoder_2(text_input_ids.to(device), output_hidden_states=False).pooler_output + + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt) + prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, -1) + + return prompt_embeds + + def encode_prompt( + self, + prompt: Union[str, List[str]], + prompt_2: Union[str, List[str]] = None, + prompt_template: Dict[str, Any] = DEFAULT_PROMPT_TEMPLATE, + num_videos_per_prompt: int = 1, + prompt_embeds: Optional[torch.Tensor] = None, + pooled_prompt_embeds: Optional[torch.Tensor] = None, + prompt_attention_mask: Optional[torch.Tensor] = None, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + max_sequence_length: int = 256, + ): + if prompt_embeds is None: + prompt_embeds, prompt_attention_mask = self._get_llama_prompt_embeds( + prompt, + prompt_template, + num_videos_per_prompt, + device=device, + dtype=dtype, + max_sequence_length=max_sequence_length, + ) + + if pooled_prompt_embeds is None: + if prompt_2 is None: + prompt_2 = prompt + pooled_prompt_embeds = self._get_clip_prompt_embeds( + prompt, + num_videos_per_prompt, + device=device, + dtype=dtype, + max_sequence_length=77, + ) + + return prompt_embeds, pooled_prompt_embeds, prompt_attention_mask + + def check_inputs( + self, + prompt, + prompt_2, + height, + width, + prompt_embeds=None, + callback_on_step_end_tensor_inputs=None, + prompt_template=None, + ): + if height % 16 != 0 or width % 16 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 16 but are {height} and {width}.") + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt_2 is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)): + raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}") + + if prompt_template is not None: + if not isinstance(prompt_template, dict): + raise ValueError(f"`prompt_template` has to be of type `dict` but is {type(prompt_template)}") + if "template" not in prompt_template: + raise ValueError( + f"`prompt_template` has to contain a key `template` but only found {prompt_template.keys()}" + ) + + def prepare_latents( + self, + batch_size: int, + num_channels_latents: int = 32, + height: int = 720, + width: int = 1280, + num_frames: int = 129, + dtype: Optional[torch.dtype] = None, + device: Optional[torch.device] = None, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + if latents is not None: + return latents.to(device=device, dtype=dtype) + + shape = ( + batch_size, + num_channels_latents, + (num_frames - 1) // self.vae_scale_factor_temporal + 1, + int(height) // self.vae_scale_factor_spatial, + int(width) // self.vae_scale_factor_spatial, + ) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + return latents + + def enable_vae_slicing(self): + r""" + Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to + compute decoding in several steps. This is useful to save some memory and allow larger batch sizes. + """ + self.vae.enable_slicing() + + def disable_vae_slicing(self): + r""" + Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to + computing decoding in one step. + """ + self.vae.disable_slicing() + + def enable_vae_tiling(self): + r""" + Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to + compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow + processing larger images. + """ + self.vae.enable_tiling() + + def disable_vae_tiling(self): + r""" + Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to + computing decoding in one step. + """ + self.vae.disable_tiling() + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def attention_kwargs(self): + return self._attention_kwargs + + @property + def current_timestep(self): + return self._current_timestep + + @property + def interrupt(self): + return self._interrupt + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: Union[str, List[str]] = None, + prompt_2: Union[str, List[str]] = None, + negative_prompt: Union[str, List[str]] = None, + negative_prompt_2: Union[str, List[str]] = None, + height: int = 720, + width: int = 1280, + num_frames: int = 129, + num_inference_steps: int = 50, + sigmas: List[float] = None, + true_cfg_scale: float = 1.0, + guidance_scale: float = 6.0, + num_videos_per_prompt: Optional[int] = 1, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.Tensor] = None, + prompt_embeds: Optional[torch.Tensor] = None, + pooled_prompt_embeds: Optional[torch.Tensor] = None, + prompt_attention_mask: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + negative_pooled_prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_attention_mask: Optional[torch.Tensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + attention_kwargs: Optional[Dict[str, Any]] = None, + callback_on_step_end: Optional[ + Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks] + ] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + prompt_template: Dict[str, Any] = DEFAULT_PROMPT_TEMPLATE, + max_sequence_length: int = 256, + ): + r""" + The call function to the pipeline for generation. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + prompt_2 (`str` or `List[str]`, *optional*): + The prompt or prompts to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is + will be used instead. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `true_cfg_scale` is + not greater than `1`). + negative_prompt_2 (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and + `text_encoder_2`. If not defined, `negative_prompt` is used in all the text-encoders. + height (`int`, defaults to `720`): + The height in pixels of the generated image. + width (`int`, defaults to `1280`): + The width in pixels of the generated image. + num_frames (`int`, defaults to `129`): + The number of frames in the generated video. + num_inference_steps (`int`, defaults to `50`): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + sigmas (`List[float]`, *optional*): + Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in + their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed + will be used. + true_cfg_scale (`float`, *optional*, defaults to 1.0): + When > 1.0 and a provided `negative_prompt`, enables true classifier-free guidance. + guidance_scale (`float`, defaults to `6.0`): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. Note that the only available HunyuanVideo model is + CFG-distilled, which means that traditional guidance between unconditional and conditional latent is + not applied. + num_videos_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make + generation deterministic. + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor is generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not + provided, text embeddings are generated from the `prompt` input argument. + pooled_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. + If not provided, pooled text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt` + input argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generated image. Choose between `PIL.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`HunyuanVideoPipelineOutput`] instead of a plain tuple. + attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + clip_skip (`int`, *optional*): + Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that + the output of the pre-final layer will be used for computing the prompt embeddings. + callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*): + A function or a subclass of `PipelineCallback` or `MultiPipelineCallbacks` that is called at the end of + each denoising step during the inference. with the following arguments: `callback_on_step_end(self: + DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a + list of all tensors as specified by `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`List`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + + Examples: + + Returns: + [`~HunyuanVideoPipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`HunyuanVideoPipelineOutput`] is returned, otherwise a `tuple` is returned + where the first element is a list with the generated images and the second element is a list of `bool`s + indicating whether the corresponding generated image contains "not-safe-for-work" (nsfw) content. + """ + + if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): + callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + prompt_2, + height, + width, + prompt_embeds, + callback_on_step_end_tensor_inputs, + prompt_template, + ) + + has_neg_prompt = negative_prompt is not None or ( + negative_prompt_embeds is not None and negative_pooled_prompt_embeds is not None + ) + do_true_cfg = true_cfg_scale > 1 and has_neg_prompt + + self._guidance_scale = guidance_scale + self._attention_kwargs = attention_kwargs + self._current_timestep = None + self._interrupt = False + + device = self._execution_device + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + # 3. Encode input prompt + transformer_dtype = self.transformer.dtype + prompt_embeds, pooled_prompt_embeds, prompt_attention_mask = self.encode_prompt( + prompt=prompt, + prompt_2=prompt_2, + prompt_template=prompt_template, + num_videos_per_prompt=num_videos_per_prompt, + prompt_embeds=prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds, + prompt_attention_mask=prompt_attention_mask, + device=device, + max_sequence_length=max_sequence_length, + ) + prompt_embeds = prompt_embeds.to(transformer_dtype) + prompt_attention_mask = prompt_attention_mask.to(transformer_dtype) + pooled_prompt_embeds = pooled_prompt_embeds.to(transformer_dtype) + + if do_true_cfg: + negative_prompt_embeds, negative_pooled_prompt_embeds, negative_prompt_attention_mask = self.encode_prompt( + prompt=negative_prompt, + prompt_2=negative_prompt_2, + prompt_template=prompt_template, + num_videos_per_prompt=num_videos_per_prompt, + prompt_embeds=negative_prompt_embeds, + pooled_prompt_embeds=negative_pooled_prompt_embeds, + prompt_attention_mask=negative_prompt_attention_mask, + device=device, + max_sequence_length=max_sequence_length, + ) + negative_prompt_embeds = negative_prompt_embeds.to(transformer_dtype) + negative_prompt_attention_mask = negative_prompt_attention_mask.to(transformer_dtype) + negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.to(transformer_dtype) + + # 4. Prepare timesteps + sigmas = np.linspace(1.0, 0.0, num_inference_steps + 1)[:-1] if sigmas is None else sigmas + timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, sigmas=sigmas) + + # 5. Prepare latent variables + num_channels_latents = self.transformer.config.in_channels + latents = self.prepare_latents( + batch_size * num_videos_per_prompt, + num_channels_latents, + height, + width, + num_frames, + torch.float32, + device, + generator, + latents, + ) + + # 6. Prepare guidance condition + guidance = torch.tensor([guidance_scale] * latents.shape[0], dtype=transformer_dtype, device=device) * 1000.0 + + # 7. Denoising loop + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + self._num_timesteps = len(timesteps) + + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + self._current_timestep = t + latent_model_input = latents.to(transformer_dtype) + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timestep = t.expand(latents.shape[0]).to(latents.dtype) + + noise_pred = self.transformer( + hidden_states=latent_model_input, + timestep=timestep, + encoder_hidden_states=prompt_embeds, + encoder_attention_mask=prompt_attention_mask, + pooled_projections=pooled_prompt_embeds, + guidance=guidance, + attention_kwargs=attention_kwargs, + return_dict=False, + )[0] + + if do_true_cfg: + neg_noise_pred = self.transformer( + hidden_states=latent_model_input, + timestep=timestep, + encoder_hidden_states=negative_prompt_embeds, + encoder_attention_mask=negative_prompt_attention_mask, + pooled_projections=negative_pooled_prompt_embeds, + guidance=guidance, + attention_kwargs=attention_kwargs, + return_dict=False, + )[0] + noise_pred = neg_noise_pred + true_cfg_scale * (noise_pred - neg_noise_pred) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + if XLA_AVAILABLE: + xm.mark_step() + + self._current_timestep = None + + if not output_type == "latent": + latents = latents.to(self.vae.dtype) / self.vae.config.scaling_factor + video = self.vae.decode(latents, return_dict=False)[0] + video = self.video_processor.postprocess_video(video, output_type=output_type) + else: + video = latents + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (video,) + + return HunyuanVideoPipelineOutput(frames=video) diff --git a/src/diffusers/utils/dummy_torch_and_transformers_objects.py b/src/diffusers/utils/dummy_torch_and_transformers_objects.py index 5a2818c2e245..ded30d16cf93 100644 --- a/src/diffusers/utils/dummy_torch_and_transformers_objects.py +++ b/src/diffusers/utils/dummy_torch_and_transformers_objects.py @@ -677,6 +677,21 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) +class HunyuanVideoImageToVideoPipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + class HunyuanVideoPipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] From 655dcdada116bd233b000cc9f82aae59e6b7cf64 Mon Sep 17 00:00:00 2001 From: Aryan Date: Thu, 6 Mar 2025 13:43:43 +0100 Subject: [PATCH 03/11] update --- .../pipeline_hunyuan_video_image2video.py | 190 ++++++++++++++---- 1 file changed, 150 insertions(+), 40 deletions(-) diff --git a/src/diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video_image2video.py b/src/diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video_image2video.py index e97398d0fe4f..67de868215ab 100644 --- a/src/diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video_image2video.py +++ b/src/diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video_image2video.py @@ -26,6 +26,7 @@ ) from ...callbacks import MultiPipelineCallbacks, PipelineCallback +from ...image_processor import PipelineImageInput from ...loaders import HunyuanVideoLoraLoaderMixin from ...models import AutoencoderKLHunyuanVideo, HunyuanVideoTransformer3DModel from ...schedulers import FlowMatchEulerDiscreteScheduler @@ -75,15 +76,20 @@ DEFAULT_PROMPT_TEMPLATE = { "template": ( - "<|start_header_id|>system<|end_header_id|>\n\nDescribe the video by detailing the following aspects: " + "<|start_header_id|>system<|end_header_id|>\n\n\nDescribe the video by detailing the following aspects according to the reference image: " "1. The main content and theme of the video." "2. The color, shape, size, texture, quantity, text, and spatial relationships of the objects." "3. Actions, events, behaviors temporal relationships, physical movement changes of the objects." "4. background environment, light, style and atmosphere." - "5. camera angles, movements, and transitions used in the video:<|eot_id|>" + "5. camera angles, movements, and transitions used in the video:<|eot_id|>\n\n" "<|start_header_id|>user<|end_header_id|>\n\n{}<|eot_id|>" + "<|start_header_id|>assistant<|end_header_id|>\n\n" ), - "crop_start": 95, + "crop_start": 103, + "image_emb_start": 5, + "image_emb_end": 581, + "image_emb_len": 576, + "double_return_token_id": 271, } @@ -147,6 +153,20 @@ def retrieve_timesteps( return timesteps, num_inference_steps +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents +def retrieve_latents( + encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample" +): + if hasattr(encoder_output, "latent_dist") and sample_mode == "sample": + return encoder_output.latent_dist.sample(generator) + elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax": + return encoder_output.latent_dist.mode() + elif hasattr(encoder_output, "latents"): + return encoder_output.latents + else: + raise AttributeError("Could not access latents of provided encoder_output") + + class HunyuanVideoImageToVideoPipeline(DiffusionPipeline, HunyuanVideoLoraLoaderMixin): r""" Pipeline for image-to-video generation using HunyuanVideo. @@ -197,6 +217,7 @@ def __init__( scheduler=scheduler, text_encoder_2=text_encoder_2, tokenizer_2=tokenizer_2, + image_processor=image_processor, ) self.vae_scale_factor_temporal = self.vae.temporal_compression_ratio if getattr(self, "vae", None) else 4 @@ -205,6 +226,7 @@ def __init__( def _get_llama_prompt_embeds( self, + image: torch.Tensor, prompt: Union[str, List[str]], prompt_template: Dict[str, Any], num_videos_per_prompt: int = 1, @@ -212,6 +234,7 @@ def _get_llama_prompt_embeds( dtype: Optional[torch.dtype] = None, max_sequence_length: int = 256, num_hidden_layers_to_skip: int = 2, + image_embed_interleave: int = 2, ) -> Tuple[torch.Tensor, torch.Tensor]: device = device or self._execution_device dtype = dtype or self.text_encoder.dtype @@ -232,8 +255,8 @@ def _get_llama_prompt_embeds( return_attention_mask=False, ) crop_start = prompt_template_input["input_ids"].shape[-1] - # Remove <|eot_id|> token and placeholder {} - crop_start -= 2 + # Remove <|start_header_id|>, <|end_header_id|>, assistant, <|eot_id|>, and placeholder {} + crop_start -= 5 max_sequence_length += crop_start text_inputs = self.tokenizer( @@ -249,16 +272,84 @@ def _get_llama_prompt_embeds( text_input_ids = text_inputs.input_ids.to(device=device) prompt_attention_mask = text_inputs.attention_mask.to(device=device) + image_embeds = self.image_processor(image, return_tensors="pt").pixel_values.to(device) + prompt_embeds = self.text_encoder( input_ids=text_input_ids, attention_mask=prompt_attention_mask, + pixel_values=image_embeds, output_hidden_states=True, ).hidden_states[-(num_hidden_layers_to_skip + 1)] prompt_embeds = prompt_embeds.to(dtype=dtype) + image_emb_len = prompt_template.get("image_emb_len", 576) + image_emb_start = prompt_template.get("image_emb_start", 5) + image_emb_end = prompt_template.get("image_emb_end", 581) + double_return_token_id = prompt_template.get("double_return_token_id", 271) + if crop_start is not None and crop_start > 0: - prompt_embeds = prompt_embeds[:, crop_start:] - prompt_attention_mask = prompt_attention_mask[:, crop_start:] + text_crop_start = crop_start - 1 + image_emb_len + batch_indices, last_double_return_token_indices = torch.where(text_input_ids == double_return_token_id) + + if last_double_return_token_indices.shape[0] == 3: + # in case the prompt is too long + last_double_return_token_indices = torch.cat( + (last_double_return_token_indices, torch.tensor([text_input_ids.shape[-1]])) + ) + batch_indices = torch.cat((batch_indices, torch.tensor([0]))) + + last_double_return_token_indices = last_double_return_token_indices.reshape(text_input_ids.shape[0], -1)[ + :, -1 + ] + batch_indices = batch_indices.reshape(text_input_ids.shape[0], -1)[:, -1] + assistant_crop_start = last_double_return_token_indices - 1 + image_emb_len - 4 + assistant_crop_end = last_double_return_token_indices - 1 + image_emb_len + attention_mask_assistant_crop_start = last_double_return_token_indices - 4 + attention_mask_assistant_crop_end = last_double_return_token_indices + + prompt_embed_list = [] + prompt_attention_mask_list = [] + image_embed_list = [] + image_attention_mask_list = [] + + for i in range(text_input_ids.shape[0]): + prompt_embed_list.append( + torch.cat( + [ + prompt_embeds[i, text_crop_start : assistant_crop_start[i].item()], + prompt_embeds[i, assistant_crop_end[i].item() :], + ] + ) + ) + prompt_attention_mask_list.append( + torch.cat( + [ + prompt_attention_mask[i, crop_start : attention_mask_assistant_crop_start[i].item()], + prompt_attention_mask[i, attention_mask_assistant_crop_end[i].item() :], + ] + ) + ) + image_embed_list.append(prompt_embeds[i, image_emb_start:image_emb_end]) + image_attention_mask_list.append( + torch.ones(image_embed_list[-1].shape[0]).to(prompt_embeds.device).to(prompt_attention_mask.dtype) + ) + + prompt_embed_list = torch.stack(prompt_embed_list) + prompt_attention_mask_list = torch.stack(prompt_attention_mask_list) + image_embed_list = torch.stack(image_embed_list) + image_attention_mask_list = torch.stack(image_attention_mask_list) + + if image_embed_interleave < 6: + image_embed_list = image_embed_list[:, ::image_embed_interleave, :] + image_attention_mask_list = image_attention_mask_list[:, ::image_embed_interleave] + + assert ( + prompt_embed_list.shape[0] == prompt_attention_mask_list.shape[0] + and image_embed_list.shape[0] == image_attention_mask_list.shape[0] + ) + + prompt_embeds = torch.cat([image_embed_list, prompt_embed_list], dim=1) + prompt_attention_mask = torch.cat([image_attention_mask_list, prompt_attention_mask_list], dim=1) # duplicate text embeddings for each generation per prompt, using mps friendly method _, seq_len, _ = prompt_embeds.shape @@ -310,6 +401,7 @@ def _get_clip_prompt_embeds( def encode_prompt( self, + image: torch.Tensor, prompt: Union[str, List[str]], prompt_2: Union[str, List[str]] = None, prompt_template: Dict[str, Any] = DEFAULT_PROMPT_TEMPLATE, @@ -323,6 +415,7 @@ def encode_prompt( ): if prompt_embeds is None: prompt_embeds, prompt_attention_mask = self._get_llama_prompt_embeds( + image, prompt, prompt_template, num_videos_per_prompt, @@ -393,6 +486,7 @@ def check_inputs( def prepare_latents( self, + image: torch.Tensor, batch_size: int, num_channels_latents: int = 32, height: int = 720, @@ -403,24 +497,36 @@ def prepare_latents( generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, latents: Optional[torch.Tensor] = None, ) -> torch.Tensor: - if latents is not None: - return latents.to(device=device, dtype=dtype) - - shape = ( - batch_size, - num_channels_latents, - (num_frames - 1) // self.vae_scale_factor_temporal + 1, - int(height) // self.vae_scale_factor_spatial, - int(width) // self.vae_scale_factor_spatial, - ) if isinstance(generator, list) and len(generator) != batch_size: raise ValueError( f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" f" size of {batch_size}. Make sure the batch size matches the length of the generators." ) - latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) - return latents + num_latent_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1 + latent_height, latent_width = height // self.vae_scale_factor_spatial, width // self.vae_scale_factor_spatial + shape = (batch_size, num_channels_latents, num_latent_frames, latent_height, latent_width) + + image = image.unsqueeze(2) # [B, C, 1, H, W] + if isinstance(generator, list): + image_latents = [ + retrieve_latents(self.vae.encode(image[i].unsqueeze(0)), generator[i]) for i in range(batch_size) + ] + else: + image_latents = [retrieve_latents(self.vae.encode(img.unsqueeze(0)), generator) for img in image] + + image_latents = torch.cat(image_latents, dim=0).to(dtype) * self.vae_scaling_factor + image_latents = image_latents.repeat(1, 1, num_latent_frames, 1, 1) + + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + latents = latents.to(device=device, dtype=dtype) + + t = torch.tensor([0.999]).to(device=device) + latents = latents * t + image_latents * (1 - t) + + return latents, image_latents def enable_vae_slicing(self): r""" @@ -475,6 +581,7 @@ def interrupt(self): @replace_example_docstring(EXAMPLE_DOC_STRING) def __call__( self, + image: PipelineImageInput, prompt: Union[str, List[str]] = None, prompt_2: Union[str, List[str]] = None, negative_prompt: Union[str, List[str]] = None, @@ -632,9 +739,30 @@ def __call__( else: batch_size = prompt_embeds.shape[0] - # 3. Encode input prompt + # 3. Prepare latent variables + vae_dtype = self.vae.dtype + image = self.video_processor.preprocess(image, height, width).to(device, vae_dtype) + num_channels_latents = (self.transformer.config.in_channels - 1) // 2 + latents, image_latents = self.prepare_latents( + image, + batch_size * num_videos_per_prompt, + num_channels_latents, + height, + width, + num_frames, + torch.float32, + device, + generator, + latents, + ) + image_latents[:, :, 1:] = 0 + mask = image_latents.new_ones(image_latents.shape[0], 1, *image_latents.shape[2:]) + mask[:, :, 1:] = 0 + + # 4. Encode input prompt transformer_dtype = self.transformer.dtype prompt_embeds, pooled_prompt_embeds, prompt_attention_mask = self.encode_prompt( + image=image, prompt=prompt, prompt_2=prompt_2, prompt_template=prompt_template, @@ -651,6 +779,7 @@ def __call__( if do_true_cfg: negative_prompt_embeds, negative_pooled_prompt_embeds, negative_prompt_attention_mask = self.encode_prompt( + image=torch.full_like(image, fill_value=-1), prompt=negative_prompt, prompt_2=negative_prompt_2, prompt_template=prompt_template, @@ -669,23 +798,6 @@ def __call__( sigmas = np.linspace(1.0, 0.0, num_inference_steps + 1)[:-1] if sigmas is None else sigmas timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, sigmas=sigmas) - # 5. Prepare latent variables - num_channels_latents = self.transformer.config.in_channels - latents = self.prepare_latents( - batch_size * num_videos_per_prompt, - num_channels_latents, - height, - width, - num_frames, - torch.float32, - device, - generator, - latents, - ) - - # 6. Prepare guidance condition - guidance = torch.tensor([guidance_scale] * latents.shape[0], dtype=transformer_dtype, device=device) * 1000.0 - # 7. Denoising loop num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order self._num_timesteps = len(timesteps) @@ -696,7 +808,7 @@ def __call__( continue self._current_timestep = t - latent_model_input = latents.to(transformer_dtype) + latent_model_input = torch.cat([latents, image_latents, mask], dim=1) # broadcast to batch dimension in a way that's compatible with ONNX/Core ML timestep = t.expand(latents.shape[0]).to(latents.dtype) @@ -706,7 +818,6 @@ def __call__( encoder_hidden_states=prompt_embeds, encoder_attention_mask=prompt_attention_mask, pooled_projections=pooled_prompt_embeds, - guidance=guidance, attention_kwargs=attention_kwargs, return_dict=False, )[0] @@ -718,7 +829,6 @@ def __call__( encoder_hidden_states=negative_prompt_embeds, encoder_attention_mask=negative_prompt_attention_mask, pooled_projections=negative_pooled_prompt_embeds, - guidance=guidance, attention_kwargs=attention_kwargs, return_dict=False, )[0] From 1e6ada6e7e04e757e87b55240c0f67def316d8f5 Mon Sep 17 00:00:00 2001 From: Aryan Date: Thu, 6 Mar 2025 19:53:55 +0100 Subject: [PATCH 04/11] add tests --- .../hunyuan_video/test_hunyuan_image2video.py | 364 ++++++++++++++++++ 1 file changed, 364 insertions(+) create mode 100644 tests/pipelines/hunyuan_video/test_hunyuan_image2video.py diff --git a/tests/pipelines/hunyuan_video/test_hunyuan_image2video.py b/tests/pipelines/hunyuan_video/test_hunyuan_image2video.py new file mode 100644 index 000000000000..48a26f6855e5 --- /dev/null +++ b/tests/pipelines/hunyuan_video/test_hunyuan_image2video.py @@ -0,0 +1,364 @@ +# Copyright 2024 The HuggingFace Team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +import unittest + +import numpy as np +import torch +from PIL import Image +from transformers import ( + CLIPImageProcessor, + CLIPTextConfig, + CLIPTextModel, + CLIPTokenizer, + LlamaConfig, + LlamaModel, + LlamaTokenizer, +) + +from diffusers import ( + AutoencoderKLHunyuanVideo, + FlowMatchEulerDiscreteScheduler, + HunyuanVideoImageToVideoPipeline, + HunyuanVideoTransformer3DModel, +) +from diffusers.utils.testing_utils import enable_full_determinism, torch_device + +from ..test_pipelines_common import PipelineTesterMixin, PyramidAttentionBroadcastTesterMixin, to_np + + +enable_full_determinism() + + +class HunyuanVideoImageToVideoPipelineFastTests( + PipelineTesterMixin, PyramidAttentionBroadcastTesterMixin, unittest.TestCase +): + pipeline_class = HunyuanVideoImageToVideoPipeline + params = frozenset( + ["image", "prompt", "height", "width", "guidance_scale", "prompt_embeds", "pooled_prompt_embeds"] + ) + batch_params = frozenset(["prompt", "image"]) + required_optional_params = frozenset( + [ + "num_inference_steps", + "generator", + "latents", + "return_dict", + "callback_on_step_end", + "callback_on_step_end_tensor_inputs", + ] + ) + supports_dduf = False + + # there is no xformers processor for Flux + test_xformers_attention = False + test_layerwise_casting = True + test_group_offloading = True + + def get_dummy_components(self, num_layers: int = 1, num_single_layers: int = 1): + torch.manual_seed(0) + transformer = HunyuanVideoTransformer3DModel( + in_channels=2 * 4 + 1, + out_channels=4, + num_attention_heads=2, + attention_head_dim=10, + num_layers=num_layers, + num_single_layers=num_single_layers, + num_refiner_layers=1, + patch_size=1, + patch_size_t=1, + guidance_embeds=False, + text_embed_dim=16, + pooled_projection_dim=8, + rope_axes_dim=(2, 4, 4), + ) + + torch.manual_seed(0) + vae = AutoencoderKLHunyuanVideo( + in_channels=3, + out_channels=3, + latent_channels=4, + down_block_types=( + "HunyuanVideoDownBlock3D", + "HunyuanVideoDownBlock3D", + "HunyuanVideoDownBlock3D", + "HunyuanVideoDownBlock3D", + ), + up_block_types=( + "HunyuanVideoUpBlock3D", + "HunyuanVideoUpBlock3D", + "HunyuanVideoUpBlock3D", + "HunyuanVideoUpBlock3D", + ), + block_out_channels=(8, 8, 8, 8), + layers_per_block=1, + act_fn="silu", + norm_num_groups=4, + scaling_factor=0.476986, + spatial_compression_ratio=8, + temporal_compression_ratio=4, + mid_block_add_attention=True, + ) + + torch.manual_seed(0) + scheduler = FlowMatchEulerDiscreteScheduler(shift=7.0) + + llama_text_encoder_config = LlamaConfig( + bos_token_id=0, + eos_token_id=2, + hidden_size=16, + intermediate_size=37, + layer_norm_eps=1e-05, + num_attention_heads=4, + num_hidden_layers=2, + pad_token_id=1, + vocab_size=1000, + hidden_act="gelu", + projection_dim=32, + ) + clip_text_encoder_config = CLIPTextConfig( + bos_token_id=0, + eos_token_id=2, + hidden_size=8, + intermediate_size=37, + layer_norm_eps=1e-05, + num_attention_heads=4, + num_hidden_layers=2, + pad_token_id=1, + vocab_size=1000, + hidden_act="gelu", + projection_dim=32, + ) + + torch.manual_seed(0) + text_encoder = LlamaModel(llama_text_encoder_config) + tokenizer = LlamaTokenizer.from_pretrained("finetrainers/dummy-hunyaunvideo", subfolder="tokenizer") + + torch.manual_seed(0) + text_encoder_2 = CLIPTextModel(clip_text_encoder_config) + tokenizer_2 = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip") + + torch.manual_seed(0) + image_processor = CLIPImageProcessor( + crop_size=224, + do_center_crop=True, + do_normalize=True, + do_resize=True, + image_mean=[0.48145466, 0.4578275, 0.40821073], + image_std=[0.26862954, 0.26130258, 0.27577711], + resample=3, + size=224, + ) + + components = { + "transformer": transformer, + "vae": vae, + "scheduler": scheduler, + "text_encoder": text_encoder, + "text_encoder_2": text_encoder_2, + "tokenizer": tokenizer, + "tokenizer_2": tokenizer_2, + "image_processor": image_processor, + } + return components + + def get_dummy_inputs(self, device, seed=0): + if str(device).startswith("mps"): + generator = torch.manual_seed(seed) + else: + generator = torch.Generator(device=device).manual_seed(seed) + + image_height = 16 + image_width = 16 + image = Image.new("RGB", (image_width, image_height)) + inputs = { + "image": image, + "prompt": "dance monkey", + "prompt_template": { + "template": "{}", + "crop_start": 0, + }, + "generator": generator, + "num_inference_steps": 2, + "guidance_scale": 4.5, + "height": image_height, + "width": image_width, + "num_frames": 9, + "max_sequence_length": 16, + "output_type": "pt", + } + return inputs + + def test_inference(self): + device = "cpu" + + components = self.get_dummy_components() + pipe = self.pipeline_class(**components) + pipe.to(device) + pipe.set_progress_bar_config(disable=None) + + inputs = self.get_dummy_inputs(device) + video = pipe(**inputs).frames + generated_video = video[0] + + self.assertEqual(generated_video.shape, (9, 3, 16, 16)) + expected_video = torch.randn(9, 3, 16, 16) + max_diff = np.abs(generated_video - expected_video).max() + self.assertLessEqual(max_diff, 1e10) + + def test_callback_inputs(self): + sig = inspect.signature(self.pipeline_class.__call__) + has_callback_tensor_inputs = "callback_on_step_end_tensor_inputs" in sig.parameters + has_callback_step_end = "callback_on_step_end" in sig.parameters + + if not (has_callback_tensor_inputs and has_callback_step_end): + return + + components = self.get_dummy_components() + pipe = self.pipeline_class(**components) + pipe = pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + self.assertTrue( + hasattr(pipe, "_callback_tensor_inputs"), + f" {self.pipeline_class} should have `_callback_tensor_inputs` that defines a list of tensor variables its callback function can use as inputs", + ) + + def callback_inputs_subset(pipe, i, t, callback_kwargs): + # iterate over callback args + for tensor_name, tensor_value in callback_kwargs.items(): + # check that we're only passing in allowed tensor inputs + assert tensor_name in pipe._callback_tensor_inputs + + return callback_kwargs + + def callback_inputs_all(pipe, i, t, callback_kwargs): + for tensor_name in pipe._callback_tensor_inputs: + assert tensor_name in callback_kwargs + + # iterate over callback args + for tensor_name, tensor_value in callback_kwargs.items(): + # check that we're only passing in allowed tensor inputs + assert tensor_name in pipe._callback_tensor_inputs + + return callback_kwargs + + inputs = self.get_dummy_inputs(torch_device) + + # Test passing in a subset + inputs["callback_on_step_end"] = callback_inputs_subset + inputs["callback_on_step_end_tensor_inputs"] = ["latents"] + output = pipe(**inputs)[0] + + # Test passing in a everything + inputs["callback_on_step_end"] = callback_inputs_all + inputs["callback_on_step_end_tensor_inputs"] = pipe._callback_tensor_inputs + output = pipe(**inputs)[0] + + def callback_inputs_change_tensor(pipe, i, t, callback_kwargs): + is_last = i == (pipe.num_timesteps - 1) + if is_last: + callback_kwargs["latents"] = torch.zeros_like(callback_kwargs["latents"]) + return callback_kwargs + + inputs["callback_on_step_end"] = callback_inputs_change_tensor + inputs["callback_on_step_end_tensor_inputs"] = pipe._callback_tensor_inputs + output = pipe(**inputs)[0] + assert output.abs().sum() < 1e10 + + def test_attention_slicing_forward_pass( + self, test_max_difference=True, test_mean_pixel_difference=True, expected_max_diff=1e-3 + ): + if not self.test_attention_slicing: + return + + components = self.get_dummy_components() + pipe = self.pipeline_class(**components) + for component in pipe.components.values(): + if hasattr(component, "set_default_attn_processor"): + component.set_default_attn_processor() + pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + + generator_device = "cpu" + inputs = self.get_dummy_inputs(generator_device) + output_without_slicing = pipe(**inputs)[0] + + pipe.enable_attention_slicing(slice_size=1) + inputs = self.get_dummy_inputs(generator_device) + output_with_slicing1 = pipe(**inputs)[0] + + pipe.enable_attention_slicing(slice_size=2) + inputs = self.get_dummy_inputs(generator_device) + output_with_slicing2 = pipe(**inputs)[0] + + if test_max_difference: + max_diff1 = np.abs(to_np(output_with_slicing1) - to_np(output_without_slicing)).max() + max_diff2 = np.abs(to_np(output_with_slicing2) - to_np(output_without_slicing)).max() + self.assertLess( + max(max_diff1, max_diff2), + expected_max_diff, + "Attention slicing should not affect the inference results", + ) + + def test_vae_tiling(self, expected_diff_max: float = 0.2): + # Seems to require higher tolerance than the other tests + expected_diff_max = 0.6 + generator_device = "cpu" + components = self.get_dummy_components() + + pipe = self.pipeline_class(**components) + pipe.to("cpu") + pipe.set_progress_bar_config(disable=None) + + # Without tiling + inputs = self.get_dummy_inputs(generator_device) + inputs["height"] = inputs["width"] = 128 + output_without_tiling = pipe(**inputs)[0] + + # With tiling + pipe.vae.enable_tiling( + tile_sample_min_height=96, + tile_sample_min_width=96, + tile_sample_stride_height=64, + tile_sample_stride_width=64, + ) + inputs = self.get_dummy_inputs(generator_device) + inputs["height"] = inputs["width"] = 128 + output_with_tiling = pipe(**inputs)[0] + + self.assertLess( + (to_np(output_without_tiling) - to_np(output_with_tiling)).max(), + expected_diff_max, + "VAE tiling should not affect the inference results", + ) + + # TODO(aryan): Create a dummy gemma model with smol vocab size + @unittest.skip( + "A very small vocab size is used for fast tests. So, any kind of prompt other than the empty default used in other tests will lead to a embedding lookup error. This test uses a long prompt that causes the error." + ) + def test_inference_batch_consistent(self): + pass + + @unittest.skip( + "A very small vocab size is used for fast tests. So, any kind of prompt other than the empty default used in other tests will lead to a embedding lookup error. This test uses a long prompt that causes the error." + ) + def test_inference_batch_single_identical(self): + pass + + @unittest.skip( + "Encode prompt currently does not work in isolation because of requiring image embeddings from image processor. The test does not handle this case, or we need to rewrite encode_prompt." + ) + def test_encode_prompt_works_in_isolation(self): + pass From e978876bbea98275ac02901ef53b77f52134a4fd Mon Sep 17 00:00:00 2001 From: Aryan Date: Thu, 6 Mar 2025 19:54:01 +0100 Subject: [PATCH 05/11] update --- scripts/convert_hunyuan_video_to_diffusers.py | 2 +- .../pipeline_hunyuan_video_image2video.py | 29 +++++-------------- 2 files changed, 9 insertions(+), 22 deletions(-) diff --git a/scripts/convert_hunyuan_video_to_diffusers.py b/scripts/convert_hunyuan_video_to_diffusers.py index 1dce0490d358..ca6ec152f66f 100644 --- a/scripts/convert_hunyuan_video_to_diffusers.py +++ b/scripts/convert_hunyuan_video_to_diffusers.py @@ -317,7 +317,7 @@ def get_args(): text_encoder_2 = CLIPTextModel.from_pretrained(args.text_encoder_2_path, torch_dtype=torch.float16) tokenizer_2 = CLIPTokenizer.from_pretrained(args.text_encoder_2_path) scheduler = FlowMatchEulerDiscreteScheduler(shift=args.flow_shift) - image_processor = CLIPImageProcessor.from_pretrained(args.text_encoder_2_path) + image_processor = CLIPImageProcessor.from_pretrained(args.text_encoder_path) pipe = HunyuanVideoImageToVideoPipeline( transformer=transformer, diff --git a/src/diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video_image2video.py b/src/diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video_image2video.py index 67de868215ab..29bb627fef2a 100644 --- a/src/diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video_image2video.py +++ b/src/diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video_image2video.py @@ -16,6 +16,7 @@ from typing import Any, Callable, Dict, List, Optional, Tuple, Union import numpy as np +import PIL.Image import torch from transformers import ( CLIPImageProcessor, @@ -26,7 +27,6 @@ ) from ...callbacks import MultiPipelineCallbacks, PipelineCallback -from ...image_processor import PipelineImageInput from ...loaders import HunyuanVideoLoraLoaderMixin from ...models import AutoencoderKLHunyuanVideo, HunyuanVideoTransformer3DModel from ...schedulers import FlowMatchEulerDiscreteScheduler @@ -220,6 +220,7 @@ def __init__( image_processor=image_processor, ) + self.vae_scaling_factor = self.vae.config.scaling_factor if getattr(self, "vae", None) else 0.476986 self.vae_scale_factor_temporal = self.vae.temporal_compression_ratio if getattr(self, "vae", None) else 4 self.vae_scale_factor_spatial = self.vae.spatial_compression_ratio if getattr(self, "vae", None) else 8 self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial) @@ -240,8 +241,6 @@ def _get_llama_prompt_embeds( dtype = dtype or self.text_encoder.dtype prompt = [prompt] if isinstance(prompt, str) else prompt - batch_size = len(prompt) - prompt = [prompt_template["template"].format(p) for p in prompt] crop_start = prompt_template.get("crop_start", None) @@ -351,13 +350,6 @@ def _get_llama_prompt_embeds( prompt_embeds = torch.cat([image_embed_list, prompt_embed_list], dim=1) prompt_attention_mask = torch.cat([image_attention_mask_list, prompt_attention_mask_list], dim=1) - # duplicate text embeddings for each generation per prompt, using mps friendly method - _, seq_len, _ = prompt_embeds.shape - prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1) - prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1) - prompt_attention_mask = prompt_attention_mask.repeat(1, num_videos_per_prompt) - prompt_attention_mask = prompt_attention_mask.view(batch_size * num_videos_per_prompt, seq_len) - return prompt_embeds, prompt_attention_mask def _get_clip_prompt_embeds( @@ -372,7 +364,6 @@ def _get_clip_prompt_embeds( dtype = dtype or self.text_encoder_2.dtype prompt = [prompt] if isinstance(prompt, str) else prompt - batch_size = len(prompt) text_inputs = self.tokenizer_2( prompt, @@ -392,11 +383,6 @@ def _get_clip_prompt_embeds( ) prompt_embeds = self.text_encoder_2(text_input_ids.to(device), output_hidden_states=False).pooler_output - - # duplicate text embeddings for each generation per prompt, using mps friendly method - prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt) - prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, -1) - return prompt_embeds def encode_prompt( @@ -581,7 +567,7 @@ def interrupt(self): @replace_example_docstring(EXAMPLE_DOC_STRING) def __call__( self, - image: PipelineImageInput, + image: PIL.Image.Image, prompt: Union[str, List[str]] = None, prompt_2: Union[str, List[str]] = None, negative_prompt: Union[str, List[str]] = None, @@ -741,10 +727,10 @@ def __call__( # 3. Prepare latent variables vae_dtype = self.vae.dtype - image = self.video_processor.preprocess(image, height, width).to(device, vae_dtype) + image_tensor = self.video_processor.preprocess(image, height, width).to(device, vae_dtype) num_channels_latents = (self.transformer.config.in_channels - 1) // 2 latents, image_latents = self.prepare_latents( - image, + image_tensor, batch_size * num_videos_per_prompt, num_channels_latents, height, @@ -778,8 +764,9 @@ def __call__( pooled_prompt_embeds = pooled_prompt_embeds.to(transformer_dtype) if do_true_cfg: + black_image = PIL.Image.new("RGB", (width, height), 0) negative_prompt_embeds, negative_pooled_prompt_embeds, negative_prompt_attention_mask = self.encode_prompt( - image=torch.full_like(image, fill_value=-1), + image=black_image, prompt=negative_prompt, prompt_2=negative_prompt_2, prompt_template=prompt_template, @@ -808,7 +795,7 @@ def __call__( continue self._current_timestep = t - latent_model_input = torch.cat([latents, image_latents, mask], dim=1) + latent_model_input = torch.cat([latents, image_latents, mask], dim=1).to(transformer_dtype) # broadcast to batch dimension in a way that's compatible with ONNX/Core ML timestep = t.expand(latents.shape[0]).to(latents.dtype) From e13231c41bc6f7b71a515b5fe84c1c7a49f15828 Mon Sep 17 00:00:00 2001 From: Aryan Date: Thu, 6 Mar 2025 19:56:12 +0100 Subject: [PATCH 06/11] add model tests --- .../test_models_transformer_hunyuan_video.py | 65 +++++++++++++++++++ 1 file changed, 65 insertions(+) diff --git a/tests/models/transformers/test_models_transformer_hunyuan_video.py b/tests/models/transformers/test_models_transformer_hunyuan_video.py index ac95fe6f4544..2b81dc876433 100644 --- a/tests/models/transformers/test_models_transformer_hunyuan_video.py +++ b/tests/models/transformers/test_models_transformer_hunyuan_video.py @@ -154,3 +154,68 @@ def test_output(self): def test_gradient_checkpointing_is_applied(self): expected_set = {"HunyuanVideoTransformer3DModel"} super().test_gradient_checkpointing_is_applied(expected_set=expected_set) + + +class HunyuanVideoImageToVideoTransformer3DTests(ModelTesterMixin, unittest.TestCase): + model_class = HunyuanVideoTransformer3DModel + main_input_name = "hidden_states" + uses_custom_attn_processor = True + + @property + def dummy_input(self): + batch_size = 1 + num_channels = 2 * 4 + 1 + num_frames = 1 + height = 16 + width = 16 + text_encoder_embedding_dim = 16 + pooled_projection_dim = 8 + sequence_length = 12 + + hidden_states = torch.randn((batch_size, num_channels, num_frames, height, width)).to(torch_device) + timestep = torch.randint(0, 1000, size=(batch_size,)).to(torch_device) + encoder_hidden_states = torch.randn((batch_size, sequence_length, text_encoder_embedding_dim)).to(torch_device) + pooled_projections = torch.randn((batch_size, pooled_projection_dim)).to(torch_device) + encoder_attention_mask = torch.ones((batch_size, sequence_length)).to(torch_device) + + return { + "hidden_states": hidden_states, + "timestep": timestep, + "encoder_hidden_states": encoder_hidden_states, + "pooled_projections": pooled_projections, + "encoder_attention_mask": encoder_attention_mask, + } + + @property + def input_shape(self): + return (8, 1, 16, 16) + + @property + def output_shape(self): + return (4, 1, 16, 16) + + def prepare_init_args_and_inputs_for_common(self): + init_dict = { + "in_channels": 2 * 4 + 1, + "out_channels": 4, + "num_attention_heads": 2, + "attention_head_dim": 10, + "num_layers": 1, + "num_single_layers": 1, + "num_refiner_layers": 1, + "patch_size": 1, + "patch_size_t": 1, + "guidance_embeds": False, + "text_embed_dim": 16, + "pooled_projection_dim": 8, + "rope_axes_dim": (2, 4, 4), + } + inputs_dict = self.dummy_input + return init_dict, inputs_dict + + def test_output(self): + super().test_output(expected_output_shape=(1, *self.output_shape)) + + def test_gradient_checkpointing_is_applied(self): + expected_set = {"HunyuanVideoTransformer3DModel"} + super().test_gradient_checkpointing_is_applied(expected_set=expected_set) From a879a2284b1746f07d729ababfbc239c4d8ac154 Mon Sep 17 00:00:00 2001 From: Aryan Date: Thu, 6 Mar 2025 20:02:50 +0100 Subject: [PATCH 07/11] update docs --- docs/source/en/api/pipelines/hunyuan_video.md | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/docs/source/en/api/pipelines/hunyuan_video.md b/docs/source/en/api/pipelines/hunyuan_video.md index e16b5a4b250c..f8039902976e 100644 --- a/docs/source/en/api/pipelines/hunyuan_video.md +++ b/docs/source/en/api/pipelines/hunyuan_video.md @@ -49,7 +49,8 @@ The following models are available for the image-to-video pipeline: | Model name | Description | |:---|:---| -| [`https://huggingface.co/Skywork/SkyReels-V1-Hunyuan-I2V`](https://huggingface.co/Skywork/SkyReels-V1-Hunyuan-I2V) | Skywork's custom finetune of HunyuanVideo (de-distilled). Performs best with `97x544x960` resolution. Performs best at `97x544x960` resolution, `guidance_scale=1.0`, `true_cfg_scale=6.0` and a negative prompt. | +| [`Skywork/SkyReels-V1-Hunyuan-I2V`](https://huggingface.co/Skywork/SkyReels-V1-Hunyuan-I2V) | Skywork's custom finetune of HunyuanVideo (de-distilled). Performs best with `97x544x960` resolution. Performs best at `97x544x960` resolution, `guidance_scale=1.0`, `true_cfg_scale=6.0` and a negative prompt. | +| [`hunyuanvideo-community/HunyuanVideo-I2V`](https://huggingface.co/hunyuanvideo-community/HunyuanVideo-I2V) | Tecent's official HunyuanVideo I2V model. Performs best at resolutions of 480, 720, 960, 1280. A higher `shift` value when initializing the scheduler is recommended (good values are between 7 and 20) | ## Quantization From 0a5a820b75102fa5b741ca8fce4442d1484de982 Mon Sep 17 00:00:00 2001 From: Aryan Date: Thu, 6 Mar 2025 22:47:16 +0100 Subject: [PATCH 08/11] update --- .../hunyuan_video/pipeline_hunyuan_video_image2video.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video_image2video.py b/src/diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video_image2video.py index 29bb627fef2a..7746c3b761d8 100644 --- a/src/diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video_image2video.py +++ b/src/diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video_image2video.py @@ -338,7 +338,7 @@ def _get_llama_prompt_embeds( image_embed_list = torch.stack(image_embed_list) image_attention_mask_list = torch.stack(image_attention_mask_list) - if image_embed_interleave < 6: + if 0 < image_embed_interleave < 6: image_embed_list = image_embed_list[:, ::image_embed_interleave, :] image_attention_mask_list = image_attention_mask_list[:, ::image_embed_interleave] @@ -846,8 +846,9 @@ def __call__( latents = latents.to(self.vae.dtype) / self.vae.config.scaling_factor video = self.vae.decode(latents, return_dict=False)[0] video = self.video_processor.postprocess_video(video, output_type=output_type) + video = video[:, :, 4:, :, :] else: - video = latents + video = latents[:, :, 1:, :, :] # Offload all models self.maybe_free_model_hooks() From f6a07e5b38522eec8a2614f77e2e504799052243 Mon Sep 17 00:00:00 2001 From: Aryan Date: Thu, 6 Mar 2025 22:55:17 +0100 Subject: [PATCH 09/11] update example --- .../pipeline_hunyuan_video_image2video.py | 27 ++++++++++++------- 1 file changed, 18 insertions(+), 9 deletions(-) diff --git a/src/diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video_image2video.py b/src/diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video_image2video.py index 7746c3b761d8..f421a8b77abe 100644 --- a/src/diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video_image2video.py +++ b/src/diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video_image2video.py @@ -51,23 +51,32 @@ Examples: ```python >>> import torch - >>> from diffusers import HunyuanVideoPipeline, HunyuanVideoTransformer3DModel - >>> from diffusers.utils import export_to_video + >>> from diffusers import HunyuanVideoImageToVideoPipeline, HunyuanVideoTransformer3DModel + >>> from diffusers.utils import load_image, export_to_video - >>> model_id = "hunyuanvideo-community/HunyuanVideo" + >>> model_id = "hunyuanvideo-community/HunyuanVideo-I2V" >>> transformer = HunyuanVideoTransformer3DModel.from_pretrained( ... model_id, subfolder="transformer", torch_dtype=torch.bfloat16 ... ) - >>> pipe = HunyuanVideoPipeline.from_pretrained(model_id, transformer=transformer, torch_dtype=torch.float16) + >>> pipe = HunyuanVideoImageToVideoPipeline.from_pretrained( + ... model_id, transformer=transformer, torch_dtype=torch.float16 + ... ) >>> pipe.vae.enable_tiling() >>> pipe.to("cuda") + >>> prompt = "A man with short gray hair plays a red electric guitar." + >>> image = load_image( + ... "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/guitar-man.png" + ... ) + >>> output = pipe( - ... prompt="A cat walks on the grass, realistic", - ... height=320, - ... width=512, - ... num_frames=61, - ... num_inference_steps=30, + ... image=image, + ... height=720, + ... width=1280, + ... num_frames=129, + ... prompt=prompt, + ... true_cfg_scale=1.0, + ... guidance_scale=1.0, ... ).frames[0] >>> export_to_video(output, "output.mp4", fps=15) ``` From 39a1ce8747dfa5060a705c382d13b34057ea4bb7 Mon Sep 17 00:00:00 2001 From: Aryan Date: Thu, 6 Mar 2025 22:57:03 +0100 Subject: [PATCH 10/11] fix defaults --- .../pipeline_hunyuan_video_image2video.py | 14 +++----------- 1 file changed, 3 insertions(+), 11 deletions(-) diff --git a/src/diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video_image2video.py b/src/diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video_image2video.py index f421a8b77abe..99046661ea90 100644 --- a/src/diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video_image2video.py +++ b/src/diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video_image2video.py @@ -69,15 +69,7 @@ ... "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/guitar-man.png" ... ) - >>> output = pipe( - ... image=image, - ... height=720, - ... width=1280, - ... num_frames=129, - ... prompt=prompt, - ... true_cfg_scale=1.0, - ... guidance_scale=1.0, - ... ).frames[0] + >>> output = pipe(image=image, prompt=prompt).frames[0] >>> export_to_video(output, "output.mp4", fps=15) ``` """ @@ -587,7 +579,7 @@ def __call__( num_inference_steps: int = 50, sigmas: List[float] = None, true_cfg_scale: float = 1.0, - guidance_scale: float = 6.0, + guidance_scale: float = 1.0, num_videos_per_prompt: Optional[int] = 1, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, latents: Optional[torch.Tensor] = None, @@ -639,7 +631,7 @@ def __call__( will be used. true_cfg_scale (`float`, *optional*, defaults to 1.0): When > 1.0 and a provided `negative_prompt`, enables true classifier-free guidance. - guidance_scale (`float`, defaults to `6.0`): + guidance_scale (`float`, defaults to `1.0`): Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). `guidance_scale` is defined as `w` of equation 2. of [Imagen Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > From ab6c463a523ca8745574a382fe7141ca5dfdf327 Mon Sep 17 00:00:00 2001 From: Aryan Date: Fri, 7 Mar 2025 07:27:56 +0100 Subject: [PATCH 11/11] update --- .../hunyuan_video/pipeline_hunyuan_video_image2video.py | 2 +- .../pipelines/hunyuan_video/test_hunyuan_image2video.py | 9 +++++---- 2 files changed, 6 insertions(+), 5 deletions(-) diff --git a/src/diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video_image2video.py b/src/diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video_image2video.py index 99046661ea90..5a600dda4326 100644 --- a/src/diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video_image2video.py +++ b/src/diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video_image2video.py @@ -846,8 +846,8 @@ def __call__( if not output_type == "latent": latents = latents.to(self.vae.dtype) / self.vae.config.scaling_factor video = self.vae.decode(latents, return_dict=False)[0] - video = self.video_processor.postprocess_video(video, output_type=output_type) video = video[:, :, 4:, :, :] + video = self.video_processor.postprocess_video(video, output_type=output_type) else: video = latents[:, :, 1:, :, :] diff --git a/tests/pipelines/hunyuan_video/test_hunyuan_image2video.py b/tests/pipelines/hunyuan_video/test_hunyuan_image2video.py index 48a26f6855e5..c18e5c0ad8fb 100644 --- a/tests/pipelines/hunyuan_video/test_hunyuan_image2video.py +++ b/tests/pipelines/hunyuan_video/test_hunyuan_image2video.py @@ -152,14 +152,14 @@ def get_dummy_components(self, num_layers: int = 1, num_single_layers: int = 1): torch.manual_seed(0) image_processor = CLIPImageProcessor( - crop_size=224, + crop_size=336, do_center_crop=True, do_normalize=True, do_resize=True, image_mean=[0.48145466, 0.4578275, 0.40821073], image_std=[0.26862954, 0.26130258, 0.27577711], resample=3, - size=224, + size=336, ) components = { @@ -213,8 +213,9 @@ def test_inference(self): video = pipe(**inputs).frames generated_video = video[0] - self.assertEqual(generated_video.shape, (9, 3, 16, 16)) - expected_video = torch.randn(9, 3, 16, 16) + # NOTE: The expected video has 4 lesser frames because they are dropped in the pipeline + self.assertEqual(generated_video.shape, (5, 3, 16, 16)) + expected_video = torch.randn(5, 3, 16, 16) max_diff = np.abs(generated_video - expected_video).max() self.assertLessEqual(max_diff, 1e10)