Skip to content

Commit e98fea2

Browse files
yiyixuxuhlkya-r-r-o-w
authored
[WIP] test prepare_latents for ltx0.95 (#10976)
* up * Update src/diffusers/pipelines/ltx/pipeline_ltx_condition.py Co-authored-by: hlky <hlky@hlky.ac> * up * make it work * up * update conversion script * up * up * up * up * up more * up * Apply suggestions from code review Co-authored-by: Aryan <aryan@huggingface.co> * add docs tests + more refactor * up --------- Co-authored-by: hlky <hlky@hlky.ac> Co-authored-by: Aryan <aryan@huggingface.co>
1 parent 14a2282 commit e98fea2

File tree

11 files changed

+771
-263
lines changed

11 files changed

+771
-263
lines changed

docs/source/en/api/pipelines/ltx_video.md

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -196,6 +196,12 @@ export_to_video(video, "ship.mp4", fps=24)
196196
- all
197197
- __call__
198198

199+
## LTXConditionPipeline
200+
201+
[[autodoc]] LTXConditionPipeline
202+
- all
203+
- __call__
204+
199205
## LTXPipelineOutput
200206

201207
[[autodoc]] pipelines.ltx.pipeline_output.LTXPipelineOutput

scripts/convert_ltx_to_diffusers.py

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,7 @@ def remove_keys_(key: str, state_dict: Dict[str, Any]):
105105
"per_channel_statistics.mean-of-means": remove_keys_,
106106
"per_channel_statistics.mean-of-stds": remove_keys_,
107107
"model.diffusion_model": remove_keys_,
108+
"decoder.timestep_scale_multiplier": remove_keys_,
108109
}
109110

110111

@@ -268,6 +269,9 @@ def get_vae_config(version: str) -> Dict[str, Any]:
268269
"scaling_factor": 1.0,
269270
"encoder_causal": True,
270271
"decoder_causal": False,
272+
"spatial_compression_ratio": 32,
273+
"temporal_compression_ratio": 8,
274+
"timestep_scale_multiplier": 1000.0,
271275
}
272276
VAE_KEYS_RENAME_DICT.update(VAE_095_RENAME_DICT)
273277
return config
@@ -346,14 +350,17 @@ def get_args():
346350
for param in text_encoder.parameters():
347351
param.data = param.data.contiguous()
348352

349-
scheduler = FlowMatchEulerDiscreteScheduler(
350-
use_dynamic_shifting=True,
351-
base_shift=0.95,
352-
max_shift=2.05,
353-
base_image_seq_len=1024,
354-
max_image_seq_len=4096,
355-
shift_terminal=0.1,
356-
)
353+
if args.version == "0.9.5":
354+
scheduler = FlowMatchEulerDiscreteScheduler(use_dynamic_shifting=False)
355+
else:
356+
scheduler = FlowMatchEulerDiscreteScheduler(
357+
use_dynamic_shifting=True,
358+
base_shift=0.95,
359+
max_shift=2.05,
360+
base_image_seq_len=1024,
361+
max_image_seq_len=4096,
362+
shift_terminal=0.1,
363+
)
357364

358365
pipe = LTXPipeline(
359366
scheduler=scheduler,

src/diffusers/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -347,6 +347,7 @@
347347
"LDMTextToImagePipeline",
348348
"LEditsPPPipelineStableDiffusion",
349349
"LEditsPPPipelineStableDiffusionXL",
350+
"LTXConditionPipeline",
350351
"LTXImageToVideoPipeline",
351352
"LTXPipeline",
352353
"Lumina2Text2ImgPipeline",
@@ -857,6 +858,7 @@
857858
LDMTextToImagePipeline,
858859
LEditsPPPipelineStableDiffusion,
859860
LEditsPPPipelineStableDiffusionXL,
861+
LTXConditionPipeline,
860862
LTXImageToVideoPipeline,
861863
LTXPipeline,
862864
Lumina2Text2ImgPipeline,

src/diffusers/models/autoencoders/autoencoder_kl_ltx.py

Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -921,12 +921,14 @@ def __init__(
921921
timestep_conditioning: bool = False,
922922
upsample_residual: Tuple[bool, ...] = (False, False, False, False),
923923
upsample_factor: Tuple[bool, ...] = (1, 1, 1, 1),
924+
timestep_scale_multiplier: float = 1.0,
924925
) -> None:
925926
super().__init__()
926927

927928
self.patch_size = patch_size
928929
self.patch_size_t = patch_size_t
929930
self.out_channels = out_channels * patch_size**2
931+
self.timestep_scale_multiplier = timestep_scale_multiplier
930932

931933
block_out_channels = tuple(reversed(block_out_channels))
932934
spatio_temporal_scaling = tuple(reversed(spatio_temporal_scaling))
@@ -981,9 +983,7 @@ def __init__(
981983
# timestep embedding
982984
self.time_embedder = None
983985
self.scale_shift_table = None
984-
self.timestep_scale_multiplier = None
985986
if timestep_conditioning:
986-
self.timestep_scale_multiplier = nn.Parameter(torch.tensor(1000.0, dtype=torch.float32))
987987
self.time_embedder = PixArtAlphaCombinedTimestepSizeEmbeddings(output_channel * 2, 0)
988988
self.scale_shift_table = nn.Parameter(torch.randn(2, output_channel) / output_channel**0.5)
989989

@@ -992,7 +992,7 @@ def __init__(
992992
def forward(self, hidden_states: torch.Tensor, temb: Optional[torch.Tensor] = None) -> torch.Tensor:
993993
hidden_states = self.conv_in(hidden_states)
994994

995-
if self.timestep_scale_multiplier is not None:
995+
if temb is not None:
996996
temb = temb * self.timestep_scale_multiplier
997997

998998
if torch.is_grad_enabled() and self.gradient_checkpointing:
@@ -1105,6 +1105,9 @@ def __init__(
11051105
scaling_factor: float = 1.0,
11061106
encoder_causal: bool = True,
11071107
decoder_causal: bool = False,
1108+
spatial_compression_ratio: int = None,
1109+
temporal_compression_ratio: int = None,
1110+
timestep_scale_multiplier: float = 1.0,
11081111
) -> None:
11091112
super().__init__()
11101113

@@ -1135,15 +1138,24 @@ def __init__(
11351138
inject_noise=decoder_inject_noise,
11361139
upsample_residual=upsample_residual,
11371140
upsample_factor=upsample_factor,
1141+
timestep_scale_multiplier=timestep_scale_multiplier,
11381142
)
11391143

11401144
latents_mean = torch.zeros((latent_channels,), requires_grad=False)
11411145
latents_std = torch.ones((latent_channels,), requires_grad=False)
11421146
self.register_buffer("latents_mean", latents_mean, persistent=True)
11431147
self.register_buffer("latents_std", latents_std, persistent=True)
11441148

1145-
self.spatial_compression_ratio = patch_size * 2 ** sum(spatio_temporal_scaling)
1146-
self.temporal_compression_ratio = patch_size_t * 2 ** sum(spatio_temporal_scaling)
1149+
self.spatial_compression_ratio = (
1150+
patch_size * 2 ** sum(spatio_temporal_scaling)
1151+
if spatial_compression_ratio is None
1152+
else spatial_compression_ratio
1153+
)
1154+
self.temporal_compression_ratio = (
1155+
patch_size_t * 2 ** sum(spatio_temporal_scaling)
1156+
if temporal_compression_ratio is None
1157+
else temporal_compression_ratio
1158+
)
11471159

11481160
# When decoding a batch of video latents at a time, one can save memory by slicing across the batch dimension
11491161
# to perform decoding of a single video latent at a time.

src/diffusers/models/transformers/transformer_ltx.py

Lines changed: 66 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -115,47 +115,77 @@ def __init__(
115115
self.theta = theta
116116
self._causal_rope_fix = _causal_rope_fix
117117

118-
def forward(
118+
def _prepare_video_coords(
119119
self,
120-
hidden_states: torch.Tensor,
120+
batch_size: int,
121121
num_frames: int,
122122
height: int,
123123
width: int,
124-
frame_rate: Optional[int] = None,
125-
rope_interpolation_scale: Optional[Tuple[torch.Tensor, float, float]] = None,
126-
) -> Tuple[torch.Tensor, torch.Tensor]:
127-
batch_size = hidden_states.size(0)
128-
124+
rope_interpolation_scale: Tuple[torch.Tensor, float, float],
125+
frame_rate: float,
126+
device: torch.device,
127+
) -> torch.Tensor:
129128
# Always compute rope in fp32
130-
grid_h = torch.arange(height, dtype=torch.float32, device=hidden_states.device)
131-
grid_w = torch.arange(width, dtype=torch.float32, device=hidden_states.device)
132-
grid_f = torch.arange(num_frames, dtype=torch.float32, device=hidden_states.device)
129+
grid_h = torch.arange(height, dtype=torch.float32, device=device)
130+
grid_w = torch.arange(width, dtype=torch.float32, device=device)
131+
grid_f = torch.arange(num_frames, dtype=torch.float32, device=device)
133132
grid = torch.meshgrid(grid_f, grid_h, grid_w, indexing="ij")
134133
grid = torch.stack(grid, dim=0)
135134
grid = grid.unsqueeze(0).repeat(batch_size, 1, 1, 1, 1)
136135

137-
if rope_interpolation_scale is not None:
138-
if isinstance(rope_interpolation_scale, tuple):
139-
# This will be deprecated in v0.34.0
140-
grid[:, 0:1] = grid[:, 0:1] * rope_interpolation_scale[0] * self.patch_size_t / self.base_num_frames
141-
grid[:, 1:2] = grid[:, 1:2] * rope_interpolation_scale[1] * self.patch_size / self.base_height
142-
grid[:, 2:3] = grid[:, 2:3] * rope_interpolation_scale[2] * self.patch_size / self.base_width
136+
if isinstance(rope_interpolation_scale, tuple):
137+
# This will be deprecated in v0.34.0
138+
grid[:, 0:1] = grid[:, 0:1] * rope_interpolation_scale[0] * self.patch_size_t / self.base_num_frames
139+
grid[:, 1:2] = grid[:, 1:2] * rope_interpolation_scale[1] * self.patch_size / self.base_height
140+
grid[:, 2:3] = grid[:, 2:3] * rope_interpolation_scale[2] * self.patch_size / self.base_width
141+
else:
142+
if not self._causal_rope_fix:
143+
grid[:, 0:1] = grid[:, 0:1] * rope_interpolation_scale[0:1] * self.patch_size_t / self.base_num_frames
143144
else:
144-
if not self._causal_rope_fix:
145-
grid[:, 0:1] = (
146-
grid[:, 0:1] * rope_interpolation_scale[0:1] * self.patch_size_t / self.base_num_frames
147-
)
148-
else:
149-
grid[:, 0:1] = (
150-
((grid[:, 0:1] - 1) * rope_interpolation_scale[0:1] + 1 / frame_rate).clamp(min=0)
151-
* self.patch_size_t
152-
/ self.base_num_frames
153-
)
154-
grid[:, 1:2] = grid[:, 1:2] * rope_interpolation_scale[1:2] * self.patch_size / self.base_height
155-
grid[:, 2:3] = grid[:, 2:3] * rope_interpolation_scale[2:3] * self.patch_size / self.base_width
145+
grid[:, 0:1] = (
146+
((grid[:, 0:1] - 1) * rope_interpolation_scale[0:1] + 1 / frame_rate).clamp(min=0)
147+
* self.patch_size_t
148+
/ self.base_num_frames
149+
)
150+
grid[:, 1:2] = grid[:, 1:2] * rope_interpolation_scale[1:2] * self.patch_size / self.base_height
151+
grid[:, 2:3] = grid[:, 2:3] * rope_interpolation_scale[2:3] * self.patch_size / self.base_width
156152

157153
grid = grid.flatten(2, 4).transpose(1, 2)
158154

155+
return grid
156+
157+
def forward(
158+
self,
159+
hidden_states: torch.Tensor,
160+
num_frames: Optional[int] = None,
161+
height: Optional[int] = None,
162+
width: Optional[int] = None,
163+
frame_rate: Optional[int] = None,
164+
rope_interpolation_scale: Optional[Tuple[torch.Tensor, float, float]] = None,
165+
video_coords: Optional[torch.Tensor] = None,
166+
) -> Tuple[torch.Tensor, torch.Tensor]:
167+
batch_size = hidden_states.size(0)
168+
169+
if video_coords is None:
170+
grid = self._prepare_video_coords(
171+
batch_size,
172+
num_frames,
173+
height,
174+
width,
175+
rope_interpolation_scale=rope_interpolation_scale,
176+
frame_rate=frame_rate,
177+
device=hidden_states.device,
178+
)
179+
else:
180+
grid = torch.stack(
181+
[
182+
video_coords[:, 0] / self.base_num_frames,
183+
video_coords[:, 1] / self.base_height,
184+
video_coords[:, 2] / self.base_width,
185+
],
186+
dim=-1,
187+
)
188+
159189
start = 1.0
160190
end = self.theta
161191
freqs = self.theta ** torch.linspace(
@@ -387,11 +417,12 @@ def forward(
387417
encoder_hidden_states: torch.Tensor,
388418
timestep: torch.LongTensor,
389419
encoder_attention_mask: torch.Tensor,
390-
num_frames: int,
391-
height: int,
392-
width: int,
393-
frame_rate: int,
420+
num_frames: Optional[int] = None,
421+
height: Optional[int] = None,
422+
width: Optional[int] = None,
423+
frame_rate: Optional[int] = None,
394424
rope_interpolation_scale: Optional[Union[Tuple[float, float, float], torch.Tensor]] = None,
425+
video_coords: Optional[torch.Tensor] = None,
395426
attention_kwargs: Optional[Dict[str, Any]] = None,
396427
return_dict: bool = True,
397428
) -> torch.Tensor:
@@ -414,7 +445,9 @@ def forward(
414445
msg = "Passing a tuple for `rope_interpolation_scale` is deprecated and will be removed in v0.34.0."
415446
deprecate("rope_interpolation_scale", "0.34.0", msg)
416447

417-
image_rotary_emb = self.rope(hidden_states, num_frames, height, width, frame_rate, rope_interpolation_scale)
448+
image_rotary_emb = self.rope(
449+
hidden_states, num_frames, height, width, frame_rate, rope_interpolation_scale, video_coords
450+
)
418451

419452
# convert encoder_attention_mask to a bias the same way we do for attention_mask
420453
if encoder_attention_mask is not None and encoder_attention_mask.ndim == 2:

src/diffusers/pipelines/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -260,7 +260,7 @@
260260
]
261261
)
262262
_import_structure["latte"] = ["LattePipeline"]
263-
_import_structure["ltx"] = ["LTXPipeline", "LTXImageToVideoPipeline"]
263+
_import_structure["ltx"] = ["LTXPipeline", "LTXImageToVideoPipeline", "LTXConditionPipeline"]
264264
_import_structure["lumina"] = ["LuminaText2ImgPipeline"]
265265
_import_structure["lumina2"] = ["Lumina2Text2ImgPipeline"]
266266
_import_structure["marigold"].extend(
@@ -610,7 +610,7 @@
610610
LEditsPPPipelineStableDiffusion,
611611
LEditsPPPipelineStableDiffusionXL,
612612
)
613-
from .ltx import LTXImageToVideoPipeline, LTXPipeline
613+
from .ltx import LTXConditionPipeline, LTXImageToVideoPipeline, LTXPipeline
614614
from .lumina import LuminaText2ImgPipeline
615615
from .lumina2 import Lumina2Text2ImgPipeline
616616
from .marigold import (

src/diffusers/pipelines/ltx/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
2424
else:
2525
_import_structure["pipeline_ltx"] = ["LTXPipeline"]
26+
_import_structure["pipeline_ltx_condition"] = ["LTXConditionPipeline"]
2627
_import_structure["pipeline_ltx_image2video"] = ["LTXImageToVideoPipeline"]
2728

2829
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
@@ -34,6 +35,7 @@
3435
from ...utils.dummy_torch_and_transformers_objects import *
3536
else:
3637
from .pipeline_ltx import LTXPipeline
38+
from .pipeline_ltx_condition import LTXConditionPipeline
3739
from .pipeline_ltx_image2video import LTXImageToVideoPipeline
3840

3941
else:

0 commit comments

Comments
 (0)