Skip to content

Commit 3b28306

Browse files
zRzRzRzRzRzRzRa-r-r-o-wstevhliu
authored
CogVideoX 1.5 (#9877)
* CogVideoX1_1PatchEmbed test * 1360 * 768 * refactor * make style * update docs * add modeling tests for cogvideox 1.5 * update * make fix-copies * add ofs embed(for convert) * add ofs embed(for convert) * more resolution for cogvideox1.5-5b-i2v * use even number of latent frames only * update pipeline implementations * make style * set patch_size_t as None by default * #skip frames 0 * refactor * make style * update docs * fix ofs_embed * update docs * invert_scale_latents * update * fix * Update docs/source/en/api/pipelines/cogvideox.md Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com> * Update docs/source/en/api/pipelines/cogvideox.md Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com> * Update docs/source/en/api/pipelines/cogvideox.md Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com> * Update docs/source/en/api/pipelines/cogvideox.md Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com> * Update src/diffusers/models/transformers/cogvideox_transformer_3d.py * update conversion script * remove copied from * fix test * Update docs/source/en/api/pipelines/cogvideox.md * Update docs/source/en/api/pipelines/cogvideox.md * Update docs/source/en/api/pipelines/cogvideox.md * Update docs/source/en/api/pipelines/cogvideox.md --------- Co-authored-by: Aryan <aryan@huggingface.co> Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com>
1 parent c3c94fe commit 3b28306

File tree

10 files changed

+405
-92
lines changed

10 files changed

+405
-92
lines changed

docs/source/en/api/pipelines/cogvideox.md

Lines changed: 23 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -29,16 +29,29 @@ Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers.m
2929

3030
This pipeline was contributed by [zRzRzRzRzRzRzR](https://github.com/zRzRzRzRzRzRzR). The original codebase can be found [here](https://huggingface.co/THUDM). The original weights can be found under [hf.co/THUDM](https://huggingface.co/THUDM).
3131

32-
There are two models available that can be used with the text-to-video and video-to-video CogVideoX pipelines:
33-
- [`THUDM/CogVideoX-2b`](https://huggingface.co/THUDM/CogVideoX-2b): The recommended dtype for running this model is `fp16`.
34-
- [`THUDM/CogVideoX-5b`](https://huggingface.co/THUDM/CogVideoX-5b): The recommended dtype for running this model is `bf16`.
35-
36-
There is one model available that can be used with the image-to-video CogVideoX pipeline:
37-
- [`THUDM/CogVideoX-5b-I2V`](https://huggingface.co/THUDM/CogVideoX-5b-I2V): The recommended dtype for running this model is `bf16`.
38-
39-
There are two models that support pose controllable generation (by the [Alibaba-PAI](https://huggingface.co/alibaba-pai) team):
40-
- [`alibaba-pai/CogVideoX-Fun-V1.1-2b-Pose`](https://huggingface.co/alibaba-pai/CogVideoX-Fun-V1.1-2b-Pose): The recommended dtype for running this model is `bf16`.
41-
- [`alibaba-pai/CogVideoX-Fun-V1.1-5b-Pose`](https://huggingface.co/alibaba-pai/CogVideoX-Fun-V1.1-5b-Pose): The recommended dtype for running this model is `bf16`.
32+
There are three official CogVideoX checkpoints for text-to-video and video-to-video.
33+
| checkpoints | recommended inference dtype |
34+
|---|---|
35+
| [`THUDM/CogVideoX-2b`](https://huggingface.co/THUDM/CogVideoX-2b) | torch.float16 |
36+
| [`THUDM/CogVideoX-5b`](https://huggingface.co/THUDM/CogVideoX-5b) | torch.bfloat16 |
37+
| [`THUDM/CogVideoX1.5-5b`](https://huggingface.co/THUDM/CogVideoX1.5-5b) | torch.bfloat16 |
38+
39+
There are two official CogVideoX checkpoints available for image-to-video.
40+
| checkpoints | recommended inference dtype |
41+
|---|---|
42+
| [`THUDM/CogVideoX-5b-I2V`](https://huggingface.co/THUDM/CogVideoX-5b-I2V) | torch.bfloat16 |
43+
| [`THUDM/CogVideoX-1.5-5b-I2V`](https://huggingface.co/THUDM/CogVideoX-1.5-5b-I2V) | torch.bfloat16 |
44+
45+
For the CogVideoX 1.5 series:
46+
- Text-to-video (T2V) works best at a resolution of 1360x768 because it was trained with that specific resolution.
47+
- Image-to-video (I2V) works for multiple resolutions. The width can vary from 768 to 1360, but the height must be 768. The height/width must be divisible by 16.
48+
- Both T2V and I2V models support generation with 81 and 161 frames and work best at this value. Exporting videos at 16 FPS is recommended.
49+
50+
There are two official CogVideoX checkpoints that support pose controllable generation (by the [Alibaba-PAI](https://huggingface.co/alibaba-pai) team).
51+
| checkpoints | recommended inference dtype |
52+
|---|---|
53+
| [`alibaba-pai/CogVideoX-Fun-V1.1-2b-Pose`](https://huggingface.co/alibaba-pai/CogVideoX-Fun-V1.1-2b-Pose) | torch.bfloat16 |
54+
| [`alibaba-pai/CogVideoX-Fun-V1.1-5b-Pose`](https://huggingface.co/alibaba-pai/CogVideoX-Fun-V1.1-5b-Pose) | torch.bfloat16 |
4255

4356
## Inference
4457

scripts/convert_cogvideox_to_diffusers.py

Lines changed: 66 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,8 @@ def replace_up_keys_inplace(key: str, state_dict: Dict[str, Any]):
8080
"post_attn1_layernorm": "norm2.norm",
8181
"time_embed.0": "time_embedding.linear_1",
8282
"time_embed.2": "time_embedding.linear_2",
83+
"ofs_embed.0": "ofs_embedding.linear_1",
84+
"ofs_embed.2": "ofs_embedding.linear_2",
8385
"mixins.patch_embed": "patch_embed",
8486
"mixins.final_layer.norm_final": "norm_out.norm",
8587
"mixins.final_layer.linear": "proj_out",
@@ -140,6 +142,7 @@ def convert_transformer(
140142
use_rotary_positional_embeddings: bool,
141143
i2v: bool,
142144
dtype: torch.dtype,
145+
init_kwargs: Dict[str, Any],
143146
):
144147
PREFIX_KEY = "model.diffusion_model."
145148

@@ -149,7 +152,9 @@ def convert_transformer(
149152
num_layers=num_layers,
150153
num_attention_heads=num_attention_heads,
151154
use_rotary_positional_embeddings=use_rotary_positional_embeddings,
152-
use_learned_positional_embeddings=i2v,
155+
ofs_embed_dim=512 if (i2v and init_kwargs["patch_size_t"] is not None) else None, # CogVideoX1.5-5B-I2V
156+
use_learned_positional_embeddings=i2v and init_kwargs["patch_size_t"] is None, # CogVideoX-5B-I2V
157+
**init_kwargs,
153158
).to(dtype=dtype)
154159

155160
for key in list(original_state_dict.keys()):
@@ -163,13 +168,18 @@ def convert_transformer(
163168
if special_key not in key:
164169
continue
165170
handler_fn_inplace(key, original_state_dict)
171+
166172
transformer.load_state_dict(original_state_dict, strict=True)
167173
return transformer
168174

169175

170-
def convert_vae(ckpt_path: str, scaling_factor: float, dtype: torch.dtype):
176+
def convert_vae(ckpt_path: str, scaling_factor: float, version: str, dtype: torch.dtype):
177+
init_kwargs = {"scaling_factor": scaling_factor}
178+
if version == "1.5":
179+
init_kwargs.update({"invert_scale_latents": True})
180+
171181
original_state_dict = get_state_dict(torch.load(ckpt_path, map_location="cpu", mmap=True))
172-
vae = AutoencoderKLCogVideoX(scaling_factor=scaling_factor).to(dtype=dtype)
182+
vae = AutoencoderKLCogVideoX(**init_kwargs).to(dtype=dtype)
173183

174184
for key in list(original_state_dict.keys()):
175185
new_key = key[:]
@@ -187,6 +197,34 @@ def convert_vae(ckpt_path: str, scaling_factor: float, dtype: torch.dtype):
187197
return vae
188198

189199

200+
def get_transformer_init_kwargs(version: str):
201+
if version == "1.0":
202+
vae_scale_factor_spatial = 8
203+
init_kwargs = {
204+
"patch_size": 2,
205+
"patch_size_t": None,
206+
"patch_bias": True,
207+
"sample_height": 480 // vae_scale_factor_spatial,
208+
"sample_width": 720 // vae_scale_factor_spatial,
209+
"sample_frames": 49,
210+
}
211+
212+
elif version == "1.5":
213+
vae_scale_factor_spatial = 8
214+
init_kwargs = {
215+
"patch_size": 2,
216+
"patch_size_t": 2,
217+
"patch_bias": False,
218+
"sample_height": 300,
219+
"sample_width": 300,
220+
"sample_frames": 81,
221+
}
222+
else:
223+
raise ValueError("Unsupported version of CogVideoX.")
224+
225+
return init_kwargs
226+
227+
190228
def get_args():
191229
parser = argparse.ArgumentParser()
192230
parser.add_argument(
@@ -202,6 +240,12 @@ def get_args():
202240
parser.add_argument(
203241
"--text_encoder_cache_dir", type=str, default=None, help="Path to text encoder cache directory"
204242
)
243+
parser.add_argument(
244+
"--typecast_text_encoder",
245+
action="store_true",
246+
default=False,
247+
help="Whether or not to apply fp16/bf16 precision to text_encoder",
248+
)
205249
# For CogVideoX-2B, num_layers is 30. For 5B, it is 42
206250
parser.add_argument("--num_layers", type=int, default=30, help="Number of transformer blocks")
207251
# For CogVideoX-2B, num_attention_heads is 30. For 5B, it is 48
@@ -214,7 +258,18 @@ def get_args():
214258
parser.add_argument("--scaling_factor", type=float, default=1.15258426, help="Scaling factor in the VAE")
215259
# For CogVideoX-2B, snr_shift_scale is 3.0. For 5B, it is 1.0
216260
parser.add_argument("--snr_shift_scale", type=float, default=3.0, help="Scaling factor in the VAE")
217-
parser.add_argument("--i2v", action="store_true", default=False, help="Whether to save the model weights in fp16")
261+
parser.add_argument(
262+
"--i2v",
263+
action="store_true",
264+
default=False,
265+
help="Whether the model to be converted is the Image-to-Video version of CogVideoX.",
266+
)
267+
parser.add_argument(
268+
"--version",
269+
choices=["1.0", "1.5"],
270+
default="1.0",
271+
help="Which version of CogVideoX to use for initializing default modeling parameters.",
272+
)
218273
return parser.parse_args()
219274

220275

@@ -230,21 +285,27 @@ def get_args():
230285
dtype = torch.float16 if args.fp16 else torch.bfloat16 if args.bf16 else torch.float32
231286

232287
if args.transformer_ckpt_path is not None:
288+
init_kwargs = get_transformer_init_kwargs(args.version)
233289
transformer = convert_transformer(
234290
args.transformer_ckpt_path,
235291
args.num_layers,
236292
args.num_attention_heads,
237293
args.use_rotary_positional_embeddings,
238294
args.i2v,
239295
dtype,
296+
init_kwargs,
240297
)
241298
if args.vae_ckpt_path is not None:
242-
vae = convert_vae(args.vae_ckpt_path, args.scaling_factor, dtype)
299+
# Keep VAE in float32 for better quality
300+
vae = convert_vae(args.vae_ckpt_path, args.scaling_factor, args.version, torch.float32)
243301

244302
text_encoder_id = "google/t5-v1_1-xxl"
245303
tokenizer = T5Tokenizer.from_pretrained(text_encoder_id, model_max_length=TOKENIZER_MAX_LENGTH)
246304
text_encoder = T5EncoderModel.from_pretrained(text_encoder_id, cache_dir=args.text_encoder_cache_dir)
247305

306+
if args.typecast_text_encoder:
307+
text_encoder = text_encoder.to(dtype=dtype)
308+
248309
# Apparently, the conversion does not work anymore without this :shrug:
249310
for param in text_encoder.parameters():
250311
param.data = param.data.contiguous()
@@ -276,11 +337,6 @@ def get_args():
276337
scheduler=scheduler,
277338
)
278339

279-
if args.fp16:
280-
pipe = pipe.to(dtype=torch.float16)
281-
if args.bf16:
282-
pipe = pipe.to(dtype=torch.bfloat16)
283-
284340
# We don't use variant here because the model must be run in fp16 (2B) or bf16 (5B). It would be weird
285341
# for users to specify variant when the default is not fp32 and they want to run with the correct default (which
286342
# is either fp16/bf16 here).

src/diffusers/models/autoencoders/autoencoder_kl_cogvideox.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1057,6 +1057,7 @@ def __init__(
10571057
force_upcast: float = True,
10581058
use_quant_conv: bool = False,
10591059
use_post_quant_conv: bool = False,
1060+
invert_scale_latents: bool = False,
10601061
):
10611062
super().__init__()
10621063

src/diffusers/models/embeddings.py

Lines changed: 61 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -338,6 +338,7 @@ class CogVideoXPatchEmbed(nn.Module):
338338
def __init__(
339339
self,
340340
patch_size: int = 2,
341+
patch_size_t: Optional[int] = None,
341342
in_channels: int = 16,
342343
embed_dim: int = 1920,
343344
text_embed_dim: int = 4096,
@@ -355,6 +356,7 @@ def __init__(
355356
super().__init__()
356357

357358
self.patch_size = patch_size
359+
self.patch_size_t = patch_size_t
358360
self.embed_dim = embed_dim
359361
self.sample_height = sample_height
360362
self.sample_width = sample_width
@@ -366,9 +368,15 @@ def __init__(
366368
self.use_positional_embeddings = use_positional_embeddings
367369
self.use_learned_positional_embeddings = use_learned_positional_embeddings
368370

369-
self.proj = nn.Conv2d(
370-
in_channels, embed_dim, kernel_size=(patch_size, patch_size), stride=patch_size, bias=bias
371-
)
371+
if patch_size_t is None:
372+
# CogVideoX 1.0 checkpoints
373+
self.proj = nn.Conv2d(
374+
in_channels, embed_dim, kernel_size=(patch_size, patch_size), stride=patch_size, bias=bias
375+
)
376+
else:
377+
# CogVideoX 1.5 checkpoints
378+
self.proj = nn.Linear(in_channels * patch_size * patch_size * patch_size_t, embed_dim)
379+
372380
self.text_proj = nn.Linear(text_embed_dim, embed_dim)
373381

374382
if use_positional_embeddings or use_learned_positional_embeddings:
@@ -407,12 +415,24 @@ def forward(self, text_embeds: torch.Tensor, image_embeds: torch.Tensor):
407415
"""
408416
text_embeds = self.text_proj(text_embeds)
409417

410-
batch, num_frames, channels, height, width = image_embeds.shape
411-
image_embeds = image_embeds.reshape(-1, channels, height, width)
412-
image_embeds = self.proj(image_embeds)
413-
image_embeds = image_embeds.view(batch, num_frames, *image_embeds.shape[1:])
414-
image_embeds = image_embeds.flatten(3).transpose(2, 3) # [batch, num_frames, height x width, channels]
415-
image_embeds = image_embeds.flatten(1, 2) # [batch, num_frames x height x width, channels]
418+
batch_size, num_frames, channels, height, width = image_embeds.shape
419+
420+
if self.patch_size_t is None:
421+
image_embeds = image_embeds.reshape(-1, channels, height, width)
422+
image_embeds = self.proj(image_embeds)
423+
image_embeds = image_embeds.view(batch_size, num_frames, *image_embeds.shape[1:])
424+
image_embeds = image_embeds.flatten(3).transpose(2, 3) # [batch, num_frames, height x width, channels]
425+
image_embeds = image_embeds.flatten(1, 2) # [batch, num_frames x height x width, channels]
426+
else:
427+
p = self.patch_size
428+
p_t = self.patch_size_t
429+
430+
image_embeds = image_embeds.permute(0, 1, 3, 4, 2)
431+
image_embeds = image_embeds.reshape(
432+
batch_size, num_frames // p_t, p_t, height // p, p, width // p, p, channels
433+
)
434+
image_embeds = image_embeds.permute(0, 1, 3, 5, 7, 2, 4, 6).flatten(4, 7).flatten(1, 3)
435+
image_embeds = self.proj(image_embeds)
416436

417437
embeds = torch.cat(
418438
[text_embeds, image_embeds], dim=1
@@ -497,7 +517,14 @@ def forward(self, hidden_states: torch.Tensor, encoder_hidden_states: torch.Tens
497517

498518

499519
def get_3d_rotary_pos_embed(
500-
embed_dim, crops_coords, grid_size, temporal_size, theta: int = 10000, use_real: bool = True
520+
embed_dim,
521+
crops_coords,
522+
grid_size,
523+
temporal_size,
524+
theta: int = 10000,
525+
use_real: bool = True,
526+
grid_type: str = "linspace",
527+
max_size: Optional[Tuple[int, int]] = None,
501528
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
502529
"""
503530
RoPE for video tokens with 3D structure.
@@ -513,17 +540,30 @@ def get_3d_rotary_pos_embed(
513540
The size of the temporal dimension.
514541
theta (`float`):
515542
Scaling factor for frequency computation.
543+
grid_type (`str`):
544+
Whether to use "linspace" or "slice" to compute grids.
516545
517546
Returns:
518547
`torch.Tensor`: positional embedding with shape `(temporal_size * grid_size[0] * grid_size[1], embed_dim/2)`.
519548
"""
520549
if use_real is not True:
521550
raise ValueError(" `use_real = False` is not currently supported for get_3d_rotary_pos_embed")
522-
start, stop = crops_coords
523-
grid_size_h, grid_size_w = grid_size
524-
grid_h = np.linspace(start[0], stop[0], grid_size_h, endpoint=False, dtype=np.float32)
525-
grid_w = np.linspace(start[1], stop[1], grid_size_w, endpoint=False, dtype=np.float32)
526-
grid_t = np.linspace(0, temporal_size, temporal_size, endpoint=False, dtype=np.float32)
551+
552+
if grid_type == "linspace":
553+
start, stop = crops_coords
554+
grid_size_h, grid_size_w = grid_size
555+
grid_h = np.linspace(start[0], stop[0], grid_size_h, endpoint=False, dtype=np.float32)
556+
grid_w = np.linspace(start[1], stop[1], grid_size_w, endpoint=False, dtype=np.float32)
557+
grid_t = np.arange(temporal_size, dtype=np.float32)
558+
grid_t = np.linspace(0, temporal_size, temporal_size, endpoint=False, dtype=np.float32)
559+
elif grid_type == "slice":
560+
max_h, max_w = max_size
561+
grid_size_h, grid_size_w = grid_size
562+
grid_h = np.arange(max_h, dtype=np.float32)
563+
grid_w = np.arange(max_w, dtype=np.float32)
564+
grid_t = np.arange(temporal_size, dtype=np.float32)
565+
else:
566+
raise ValueError("Invalid value passed for `grid_type`.")
527567

528568
# Compute dimensions for each axis
529569
dim_t = embed_dim // 4
@@ -559,6 +599,12 @@ def combine_time_height_width(freqs_t, freqs_h, freqs_w):
559599
t_cos, t_sin = freqs_t # both t_cos and t_sin has shape: temporal_size, dim_t
560600
h_cos, h_sin = freqs_h # both h_cos and h_sin has shape: grid_size_h, dim_h
561601
w_cos, w_sin = freqs_w # both w_cos and w_sin has shape: grid_size_w, dim_w
602+
603+
if grid_type == "slice":
604+
t_cos, t_sin = t_cos[:temporal_size], t_sin[:temporal_size]
605+
h_cos, h_sin = h_cos[:grid_size_h], h_sin[:grid_size_h]
606+
w_cos, w_sin = w_cos[:grid_size_w], w_sin[:grid_size_w]
607+
562608
cos = combine_time_height_width(t_cos, h_cos, w_cos)
563609
sin = combine_time_height_width(t_sin, h_sin, w_sin)
564610
return cos, sin

0 commit comments

Comments
 (0)