Skip to content

Hunyuan I2V #10983

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 11 commits into from
Mar 7, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion docs/source/en/api/pipelines/hunyuan_video.md
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,8 @@ The following models are available for the image-to-video pipeline:

| Model name | Description |
|:---|:---|
| [`https://huggingface.co/Skywork/SkyReels-V1-Hunyuan-I2V`](https://huggingface.co/Skywork/SkyReels-V1-Hunyuan-I2V) | Skywork's custom finetune of HunyuanVideo (de-distilled). Performs best with `97x544x960` resolution. Performs best at `97x544x960` resolution, `guidance_scale=1.0`, `true_cfg_scale=6.0` and a negative prompt. |
| [`Skywork/SkyReels-V1-Hunyuan-I2V`](https://huggingface.co/Skywork/SkyReels-V1-Hunyuan-I2V) | Skywork's custom finetune of HunyuanVideo (de-distilled). Performs best with `97x544x960` resolution. Performs best at `97x544x960` resolution, `guidance_scale=1.0`, `true_cfg_scale=6.0` and a negative prompt. |
| [`hunyuanvideo-community/HunyuanVideo-I2V`](https://huggingface.co/hunyuanvideo-community/HunyuanVideo-I2V) | Tecent's official HunyuanVideo I2V model. Performs best at resolutions of 480, 720, 960, 1280. A higher `shift` value when initializing the scheduler is recommended (good values are between 7 and 20) |

## Quantization

Expand Down
115 changes: 95 additions & 20 deletions scripts/convert_hunyuan_video_to_diffusers.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,19 @@

import torch
from accelerate import init_empty_weights
from transformers import AutoModel, AutoTokenizer, CLIPTextModel, CLIPTokenizer
from transformers import (
AutoModel,
AutoTokenizer,
CLIPImageProcessor,
CLIPTextModel,
CLIPTokenizer,
LlavaForConditionalGeneration,
)

from diffusers import (
AutoencoderKLHunyuanVideo,
FlowMatchEulerDiscreteScheduler,
HunyuanVideoImageToVideoPipeline,
HunyuanVideoPipeline,
HunyuanVideoTransformer3DModel,
)
Expand Down Expand Up @@ -134,6 +142,46 @@ def remap_single_transformer_blocks_(key, state_dict):
VAE_SPECIAL_KEYS_REMAP = {}


TRANSFORMER_CONFIGS = {
"HYVideo-T/2-cfgdistill": {
"in_channels": 16,
"out_channels": 16,
"num_attention_heads": 24,
"attention_head_dim": 128,
"num_layers": 20,
"num_single_layers": 40,
"num_refiner_layers": 2,
"mlp_ratio": 4.0,
"patch_size": 2,
"patch_size_t": 1,
"qk_norm": "rms_norm",
"guidance_embeds": True,
"text_embed_dim": 4096,
"pooled_projection_dim": 768,
"rope_theta": 256.0,
"rope_axes_dim": (16, 56, 56),
},
"HYVideo-T/2-I2V": {
"in_channels": 16 * 2 + 1,
"out_channels": 16,
"num_attention_heads": 24,
"attention_head_dim": 128,
"num_layers": 20,
"num_single_layers": 40,
"num_refiner_layers": 2,
"mlp_ratio": 4.0,
"patch_size": 2,
"patch_size_t": 1,
"qk_norm": "rms_norm",
"guidance_embeds": False,
"text_embed_dim": 4096,
"pooled_projection_dim": 768,
"rope_theta": 256.0,
"rope_axes_dim": (16, 56, 56),
},
}


def update_state_dict_(state_dict: Dict[str, Any], old_key: str, new_key: str) -> Dict[str, Any]:
state_dict[new_key] = state_dict.pop(old_key)

Expand All @@ -149,11 +197,12 @@ def get_state_dict(saved_dict: Dict[str, Any]) -> Dict[str, Any]:
return state_dict


def convert_transformer(ckpt_path: str):
def convert_transformer(ckpt_path: str, transformer_type: str):
original_state_dict = get_state_dict(torch.load(ckpt_path, map_location="cpu", weights_only=True))
config = TRANSFORMER_CONFIGS[transformer_type]

with init_empty_weights():
transformer = HunyuanVideoTransformer3DModel()
transformer = HunyuanVideoTransformer3DModel(**config)

for key in list(original_state_dict.keys()):
new_key = key[:]
Expand Down Expand Up @@ -205,6 +254,10 @@ def get_args():
parser.add_argument("--save_pipeline", action="store_true")
parser.add_argument("--output_path", type=str, required=True, help="Path where converted model should be saved")
parser.add_argument("--dtype", default="bf16", help="Torch dtype to save the transformer in.")
parser.add_argument(
"--transformer_type", type=str, default="HYVideo-T/2-cfgdistill", choices=list(TRANSFORMER_CONFIGS.keys())
)
parser.add_argument("--flow_shift", type=float, default=7.0)
return parser.parse_args()


Expand All @@ -228,7 +281,7 @@ def get_args():
assert args.text_encoder_2_path is not None

if args.transformer_ckpt_path is not None:
transformer = convert_transformer(args.transformer_ckpt_path)
transformer = convert_transformer(args.transformer_ckpt_path, args.transformer_type)
transformer = transformer.to(dtype=dtype)
if not args.save_pipeline:
transformer.save_pretrained(args.output_path, safe_serialization=True, max_shard_size="5GB")
Expand All @@ -239,19 +292,41 @@ def get_args():
vae.save_pretrained(args.output_path, safe_serialization=True, max_shard_size="5GB")

if args.save_pipeline:
text_encoder = AutoModel.from_pretrained(args.text_encoder_path, torch_dtype=torch.float16)
tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_path, padding_side="right")
text_encoder_2 = CLIPTextModel.from_pretrained(args.text_encoder_2_path, torch_dtype=torch.float16)
tokenizer_2 = CLIPTokenizer.from_pretrained(args.text_encoder_2_path)
scheduler = FlowMatchEulerDiscreteScheduler(shift=7.0)

pipe = HunyuanVideoPipeline(
transformer=transformer,
vae=vae,
text_encoder=text_encoder,
tokenizer=tokenizer,
text_encoder_2=text_encoder_2,
tokenizer_2=tokenizer_2,
scheduler=scheduler,
)
pipe.save_pretrained(args.output_path, safe_serialization=True, max_shard_size="5GB")
if args.transformer_type == "HYVideo-T/2-cfgdistill":
text_encoder = AutoModel.from_pretrained(args.text_encoder_path, torch_dtype=torch.float16)
tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_path, padding_side="right")
text_encoder_2 = CLIPTextModel.from_pretrained(args.text_encoder_2_path, torch_dtype=torch.float16)
tokenizer_2 = CLIPTokenizer.from_pretrained(args.text_encoder_2_path)
scheduler = FlowMatchEulerDiscreteScheduler(shift=args.flow_shift)

pipe = HunyuanVideoPipeline(
transformer=transformer,
vae=vae,
text_encoder=text_encoder,
tokenizer=tokenizer,
text_encoder_2=text_encoder_2,
tokenizer_2=tokenizer_2,
scheduler=scheduler,
)
pipe.save_pretrained(args.output_path, safe_serialization=True, max_shard_size="5GB")
else:
text_encoder = LlavaForConditionalGeneration.from_pretrained(
args.text_encoder_path, torch_dtype=torch.float16
)
tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_path, padding_side="right")
text_encoder_2 = CLIPTextModel.from_pretrained(args.text_encoder_2_path, torch_dtype=torch.float16)
tokenizer_2 = CLIPTokenizer.from_pretrained(args.text_encoder_2_path)
scheduler = FlowMatchEulerDiscreteScheduler(shift=args.flow_shift)
image_processor = CLIPImageProcessor.from_pretrained(args.text_encoder_path)

pipe = HunyuanVideoImageToVideoPipeline(
transformer=transformer,
vae=vae,
text_encoder=text_encoder,
tokenizer=tokenizer,
text_encoder_2=text_encoder_2,
tokenizer_2=tokenizer_2,
scheduler=scheduler,
image_processor=image_processor,
)
pipe.save_pretrained(args.output_path, safe_serialization=True, max_shard_size="5GB")
2 changes: 2 additions & 0 deletions src/diffusers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -313,6 +313,7 @@
"HunyuanDiTPAGPipeline",
"HunyuanDiTPipeline",
"HunyuanSkyreelsImageToVideoPipeline",
"HunyuanVideoImageToVideoPipeline",
"HunyuanVideoPipeline",
"I2VGenXLPipeline",
"IFImg2ImgPipeline",
Expand Down Expand Up @@ -823,6 +824,7 @@
HunyuanDiTPAGPipeline,
HunyuanDiTPipeline,
HunyuanSkyreelsImageToVideoPipeline,
HunyuanVideoImageToVideoPipeline,
HunyuanVideoPipeline,
I2VGenXLPipeline,
IFImg2ImgPipeline,
Expand Down
12 changes: 10 additions & 2 deletions src/diffusers/models/transformers/transformer_hunyuan_video.py
Original file line number Diff line number Diff line change
Expand Up @@ -581,7 +581,11 @@ def __init__(
self.context_embedder = HunyuanVideoTokenRefiner(
text_embed_dim, num_attention_heads, attention_head_dim, num_layers=num_refiner_layers
)
self.time_text_embed = CombinedTimestepGuidanceTextProjEmbeddings(inner_dim, pooled_projection_dim)

if guidance_embeds:
self.time_text_embed = CombinedTimestepGuidanceTextProjEmbeddings(inner_dim, pooled_projection_dim)
else:
self.time_text_embed = CombinedTimestepTextProjEmbeddings(inner_dim, pooled_projection_dim)

# 2. RoPE
self.rope = HunyuanVideoRotaryPosEmbed(patch_size, patch_size_t, rope_axes_dim, rope_theta)
Expand Down Expand Up @@ -708,7 +712,11 @@ def forward(
image_rotary_emb = self.rope(hidden_states)

# 2. Conditional embeddings
temb = self.time_text_embed(timestep, guidance, pooled_projections)
if self.config.guidance_embeds:
temb = self.time_text_embed(timestep, guidance, pooled_projections)
else:
temb = self.time_text_embed(timestep, pooled_projections)

hidden_states = self.x_embedder(hidden_states)
encoder_hidden_states = self.context_embedder(encoder_hidden_states, timestep, encoder_attention_mask)

Expand Down
12 changes: 10 additions & 2 deletions src/diffusers/pipelines/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,7 +222,11 @@
"EasyAnimateControlPipeline",
]
_import_structure["hunyuandit"] = ["HunyuanDiTPipeline"]
_import_structure["hunyuan_video"] = ["HunyuanVideoPipeline", "HunyuanSkyreelsImageToVideoPipeline"]
_import_structure["hunyuan_video"] = [
"HunyuanVideoPipeline",
"HunyuanSkyreelsImageToVideoPipeline",
"HunyuanVideoImageToVideoPipeline",
]
_import_structure["kandinsky"] = [
"KandinskyCombinedPipeline",
"KandinskyImg2ImgCombinedPipeline",
Expand Down Expand Up @@ -570,7 +574,11 @@
FluxPriorReduxPipeline,
ReduxImageEncoder,
)
from .hunyuan_video import HunyuanSkyreelsImageToVideoPipeline, HunyuanVideoPipeline
from .hunyuan_video import (
HunyuanSkyreelsImageToVideoPipeline,
HunyuanVideoImageToVideoPipeline,
HunyuanVideoPipeline,
)
from .hunyuandit import HunyuanDiTPipeline
from .i2vgen_xl import I2VGenXLPipeline
from .kandinsky import (
Expand Down
2 changes: 2 additions & 0 deletions src/diffusers/pipelines/hunyuan_video/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
else:
_import_structure["pipeline_hunyuan_skyreels_image2video"] = ["HunyuanSkyreelsImageToVideoPipeline"]
_import_structure["pipeline_hunyuan_video"] = ["HunyuanVideoPipeline"]
_import_structure["pipeline_hunyuan_video_image2video"] = ["HunyuanVideoImageToVideoPipeline"]

if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
try:
Expand All @@ -35,6 +36,7 @@
else:
from .pipeline_hunyuan_skyreels_image2video import HunyuanSkyreelsImageToVideoPipeline
from .pipeline_hunyuan_video import HunyuanVideoPipeline
from .pipeline_hunyuan_video_image2video import HunyuanVideoImageToVideoPipeline

else:
import sys
Expand Down
Loading
Loading