Skip to content

Commit f35b807

Browse files
committed
update
1 parent ea436c4 commit f35b807

File tree

5 files changed

+1058
-12
lines changed

5 files changed

+1058
-12
lines changed

scripts/convert_ltx_to_diffusers.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -134,12 +134,16 @@ def update_state_dict_inplace(state_dict: Dict[str, Any], old_key: str, new_key:
134134
def convert_transformer(
135135
ckpt_path: str,
136136
dtype: torch.dtype,
137+
version: str = "0.9.0",
137138
):
138139
PREFIX_KEY = "model.diffusion_model."
139140

140141
original_state_dict = get_state_dict(load_file(ckpt_path))
142+
config = {}
143+
if version == "0.9.5":
144+
config["_use_causal_rope_fix"] = True
141145
with init_empty_weights():
142-
transformer = LTXVideoTransformer3DModel()
146+
transformer = LTXVideoTransformer3DModel(**config)
143147

144148
for key in list(original_state_dict.keys()):
145149
new_key = key[:]

src/diffusers/models/transformers/transformer_ltx.py

Lines changed: 32 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -14,15 +14,15 @@
1414
# limitations under the License.
1515

1616
import math
17-
from typing import Any, Dict, Optional, Tuple
17+
from typing import Any, Dict, Optional, Tuple, Union
1818

1919
import torch
2020
import torch.nn as nn
2121
import torch.nn.functional as F
2222

2323
from ...configuration_utils import ConfigMixin, register_to_config
2424
from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
25-
from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
25+
from ...utils import USE_PEFT_BACKEND, deprecate, logging, scale_lora_layers, unscale_lora_layers
2626
from ...utils.torch_utils import maybe_allow_in_graph
2727
from ..attention import FeedForward
2828
from ..attention_processor import Attention
@@ -102,6 +102,7 @@ def __init__(
102102
patch_size: int = 1,
103103
patch_size_t: int = 1,
104104
theta: float = 10000.0,
105+
_causal_rope_fix: bool = False,
105106
) -> None:
106107
super().__init__()
107108

@@ -112,13 +113,15 @@ def __init__(
112113
self.patch_size = patch_size
113114
self.patch_size_t = patch_size_t
114115
self.theta = theta
116+
self._causal_rope_fix = _causal_rope_fix
115117

116118
def forward(
117119
self,
118120
hidden_states: torch.Tensor,
119121
num_frames: int,
120122
height: int,
121123
width: int,
124+
frame_rate: Optional[int] = None,
122125
rope_interpolation_scale: Optional[Tuple[torch.Tensor, float, float]] = None,
123126
) -> Tuple[torch.Tensor, torch.Tensor]:
124127
batch_size = hidden_states.size(0)
@@ -132,9 +135,24 @@ def forward(
132135
grid = grid.unsqueeze(0).repeat(batch_size, 1, 1, 1, 1)
133136

134137
if rope_interpolation_scale is not None:
135-
grid[:, 0:1] = grid[:, 0:1] * rope_interpolation_scale[0] * self.patch_size_t / self.base_num_frames
136-
grid[:, 1:2] = grid[:, 1:2] * rope_interpolation_scale[1] * self.patch_size / self.base_height
137-
grid[:, 2:3] = grid[:, 2:3] * rope_interpolation_scale[2] * self.patch_size / self.base_width
138+
if isinstance(rope_interpolation_scale, tuple):
139+
# This will be deprecated in v0.34.0
140+
grid[:, 0:1] = grid[:, 0:1] * rope_interpolation_scale[0] * self.patch_size_t / self.base_num_frames
141+
grid[:, 1:2] = grid[:, 1:2] * rope_interpolation_scale[1] * self.patch_size / self.base_height
142+
grid[:, 2:3] = grid[:, 2:3] * rope_interpolation_scale[2] * self.patch_size / self.base_width
143+
else:
144+
if not self._causal_rope_fix:
145+
grid[:, 0:1] = (
146+
grid[:, 0:1] * rope_interpolation_scale[0:1] * self.patch_size_t / self.base_num_frames
147+
)
148+
else:
149+
grid[:, 0:1] = (
150+
((grid[:, 0:1] - 1) * rope_interpolation_scale[0:1] + 1 / frame_rate).clamp(min=0)
151+
* self.patch_size_t
152+
/ self.base_num_frames
153+
)
154+
grid[:, 1:2] = grid[:, 1:2] * rope_interpolation_scale[1:2] * self.patch_size / self.base_height
155+
grid[:, 2:3] = grid[:, 2:3] * rope_interpolation_scale[2:3] * self.patch_size / self.base_width
138156

139157
grid = grid.flatten(2, 4).transpose(1, 2)
140158

@@ -315,6 +333,7 @@ def __init__(
315333
caption_channels: int = 4096,
316334
attention_bias: bool = True,
317335
attention_out_bias: bool = True,
336+
_causal_rope_fix: bool = False,
318337
) -> None:
319338
super().__init__()
320339

@@ -336,6 +355,7 @@ def __init__(
336355
patch_size=patch_size,
337356
patch_size_t=patch_size_t,
338357
theta=10000.0,
358+
_causal_rope_fix=_causal_rope_fix,
339359
)
340360

341361
self.transformer_blocks = nn.ModuleList(
@@ -370,7 +390,8 @@ def forward(
370390
num_frames: int,
371391
height: int,
372392
width: int,
373-
rope_interpolation_scale: Optional[Tuple[float, float, float]] = None,
393+
frame_rate: int,
394+
rope_interpolation_scale: Optional[Union[Tuple[float, float, float], torch.Tensor]] = None,
374395
attention_kwargs: Optional[Dict[str, Any]] = None,
375396
return_dict: bool = True,
376397
) -> torch.Tensor:
@@ -389,7 +410,11 @@ def forward(
389410
"Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective."
390411
)
391412

392-
image_rotary_emb = self.rope(hidden_states, num_frames, height, width, rope_interpolation_scale)
413+
if not isinstance(rope_interpolation_scale, torch.Tensor):
414+
msg = "Passing a tuple for `rope_interpolation_scale` is deprecated and will be removed in v0.34.0."
415+
deprecate("rope_interpolation_scale", "0.34.0", msg)
416+
417+
image_rotary_emb = self.rope(hidden_states, num_frames, height, width, frame_rate, rope_interpolation_scale)
393418

394419
# convert encoder_attention_mask to a bias the same way we do for attention_mask
395420
if encoder_attention_mask is not None and encoder_attention_mask.ndim == 2:

src/diffusers/pipelines/ltx/pipeline_ltx.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -694,9 +694,8 @@ def __call__(
694694
self._num_timesteps = len(timesteps)
695695

696696
# 6. Prepare micro-conditions
697-
latent_frame_rate = frame_rate / self.vae_temporal_compression_ratio
698697
rope_interpolation_scale = (
699-
1 / latent_frame_rate,
698+
self.vae_temporal_compression_ratio / frame_rate,
700699
self.vae_spatial_compression_ratio,
701700
self.vae_spatial_compression_ratio,
702701
)

0 commit comments

Comments
 (0)