diff --git a/docs/source/en/api/pipelines/ltx_video.md b/docs/source/en/api/pipelines/ltx_video.md index f31c621293fc..4bc22c0f9f6c 100644 --- a/docs/source/en/api/pipelines/ltx_video.md +++ b/docs/source/en/api/pipelines/ltx_video.md @@ -196,6 +196,12 @@ export_to_video(video, "ship.mp4", fps=24) - all - __call__ +## LTXConditionPipeline + +[[autodoc]] LTXConditionPipeline + - all + - __call__ + ## LTXPipelineOutput [[autodoc]] pipelines.ltx.pipeline_output.LTXPipelineOutput diff --git a/scripts/convert_ltx_to_diffusers.py b/scripts/convert_ltx_to_diffusers.py index ce980b304f1b..52a7791bb2fc 100644 --- a/scripts/convert_ltx_to_diffusers.py +++ b/scripts/convert_ltx_to_diffusers.py @@ -105,6 +105,7 @@ def remove_keys_(key: str, state_dict: Dict[str, Any]): "per_channel_statistics.mean-of-means": remove_keys_, "per_channel_statistics.mean-of-stds": remove_keys_, "model.diffusion_model": remove_keys_, + "decoder.timestep_scale_multiplier": remove_keys_, } @@ -268,6 +269,9 @@ def get_vae_config(version: str) -> Dict[str, Any]: "scaling_factor": 1.0, "encoder_causal": True, "decoder_causal": False, + "spatial_compression_ratio": 32, + "temporal_compression_ratio": 8, + "timestep_scale_multiplier": 1000.0, } VAE_KEYS_RENAME_DICT.update(VAE_095_RENAME_DICT) return config @@ -346,14 +350,17 @@ def get_args(): for param in text_encoder.parameters(): param.data = param.data.contiguous() - scheduler = FlowMatchEulerDiscreteScheduler( - use_dynamic_shifting=True, - base_shift=0.95, - max_shift=2.05, - base_image_seq_len=1024, - max_image_seq_len=4096, - shift_terminal=0.1, - ) + if args.version == "0.9.5": + scheduler = FlowMatchEulerDiscreteScheduler(use_dynamic_shifting=False) + else: + scheduler = FlowMatchEulerDiscreteScheduler( + use_dynamic_shifting=True, + base_shift=0.95, + max_shift=2.05, + base_image_seq_len=1024, + max_image_seq_len=4096, + shift_terminal=0.1, + ) pipe = LTXPipeline( scheduler=scheduler, diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index cfb0bd08f818..704dee331f7e 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -347,6 +347,7 @@ "LDMTextToImagePipeline", "LEditsPPPipelineStableDiffusion", "LEditsPPPipelineStableDiffusionXL", + "LTXConditionPipeline", "LTXImageToVideoPipeline", "LTXPipeline", "Lumina2Text2ImgPipeline", @@ -857,6 +858,7 @@ LDMTextToImagePipeline, LEditsPPPipelineStableDiffusion, LEditsPPPipelineStableDiffusionXL, + LTXConditionPipeline, LTXImageToVideoPipeline, LTXPipeline, Lumina2Text2ImgPipeline, diff --git a/src/diffusers/models/autoencoders/autoencoder_kl_ltx.py b/src/diffusers/models/autoencoders/autoencoder_kl_ltx.py index 5967a6e44f7d..9384f8863b16 100644 --- a/src/diffusers/models/autoencoders/autoencoder_kl_ltx.py +++ b/src/diffusers/models/autoencoders/autoencoder_kl_ltx.py @@ -921,12 +921,14 @@ def __init__( timestep_conditioning: bool = False, upsample_residual: Tuple[bool, ...] = (False, False, False, False), upsample_factor: Tuple[bool, ...] = (1, 1, 1, 1), + timestep_scale_multiplier: float = 1.0, ) -> None: super().__init__() self.patch_size = patch_size self.patch_size_t = patch_size_t self.out_channels = out_channels * patch_size**2 + self.timestep_scale_multiplier = timestep_scale_multiplier block_out_channels = tuple(reversed(block_out_channels)) spatio_temporal_scaling = tuple(reversed(spatio_temporal_scaling)) @@ -981,9 +983,7 @@ def __init__( # timestep embedding self.time_embedder = None self.scale_shift_table = None - self.timestep_scale_multiplier = None if timestep_conditioning: - self.timestep_scale_multiplier = nn.Parameter(torch.tensor(1000.0, dtype=torch.float32)) self.time_embedder = PixArtAlphaCombinedTimestepSizeEmbeddings(output_channel * 2, 0) self.scale_shift_table = nn.Parameter(torch.randn(2, output_channel) / output_channel**0.5) @@ -992,7 +992,7 @@ def __init__( def forward(self, hidden_states: torch.Tensor, temb: Optional[torch.Tensor] = None) -> torch.Tensor: hidden_states = self.conv_in(hidden_states) - if self.timestep_scale_multiplier is not None: + if temb is not None: temb = temb * self.timestep_scale_multiplier if torch.is_grad_enabled() and self.gradient_checkpointing: @@ -1105,6 +1105,9 @@ def __init__( scaling_factor: float = 1.0, encoder_causal: bool = True, decoder_causal: bool = False, + spatial_compression_ratio: int = None, + temporal_compression_ratio: int = None, + timestep_scale_multiplier: float = 1.0, ) -> None: super().__init__() @@ -1135,6 +1138,7 @@ def __init__( inject_noise=decoder_inject_noise, upsample_residual=upsample_residual, upsample_factor=upsample_factor, + timestep_scale_multiplier=timestep_scale_multiplier, ) latents_mean = torch.zeros((latent_channels,), requires_grad=False) @@ -1142,8 +1146,16 @@ def __init__( self.register_buffer("latents_mean", latents_mean, persistent=True) self.register_buffer("latents_std", latents_std, persistent=True) - self.spatial_compression_ratio = patch_size * 2 ** sum(spatio_temporal_scaling) - self.temporal_compression_ratio = patch_size_t * 2 ** sum(spatio_temporal_scaling) + self.spatial_compression_ratio = ( + patch_size * 2 ** sum(spatio_temporal_scaling) + if spatial_compression_ratio is None + else spatial_compression_ratio + ) + self.temporal_compression_ratio = ( + patch_size_t * 2 ** sum(spatio_temporal_scaling) + if temporal_compression_ratio is None + else temporal_compression_ratio + ) # When decoding a batch of video latents at a time, one can save memory by slicing across the batch dimension # to perform decoding of a single video latent at a time. diff --git a/src/diffusers/models/transformers/transformer_ltx.py b/src/diffusers/models/transformers/transformer_ltx.py index e7930b333ff6..3c79921d85c5 100644 --- a/src/diffusers/models/transformers/transformer_ltx.py +++ b/src/diffusers/models/transformers/transformer_ltx.py @@ -115,47 +115,77 @@ def __init__( self.theta = theta self._causal_rope_fix = _causal_rope_fix - def forward( + def _prepare_video_coords( self, - hidden_states: torch.Tensor, + batch_size: int, num_frames: int, height: int, width: int, - frame_rate: Optional[int] = None, - rope_interpolation_scale: Optional[Tuple[torch.Tensor, float, float]] = None, - ) -> Tuple[torch.Tensor, torch.Tensor]: - batch_size = hidden_states.size(0) - + rope_interpolation_scale: Tuple[torch.Tensor, float, float], + frame_rate: float, + device: torch.device, + ) -> torch.Tensor: # Always compute rope in fp32 - grid_h = torch.arange(height, dtype=torch.float32, device=hidden_states.device) - grid_w = torch.arange(width, dtype=torch.float32, device=hidden_states.device) - grid_f = torch.arange(num_frames, dtype=torch.float32, device=hidden_states.device) + grid_h = torch.arange(height, dtype=torch.float32, device=device) + grid_w = torch.arange(width, dtype=torch.float32, device=device) + grid_f = torch.arange(num_frames, dtype=torch.float32, device=device) grid = torch.meshgrid(grid_f, grid_h, grid_w, indexing="ij") grid = torch.stack(grid, dim=0) grid = grid.unsqueeze(0).repeat(batch_size, 1, 1, 1, 1) - if rope_interpolation_scale is not None: - if isinstance(rope_interpolation_scale, tuple): - # This will be deprecated in v0.34.0 - grid[:, 0:1] = grid[:, 0:1] * rope_interpolation_scale[0] * self.patch_size_t / self.base_num_frames - grid[:, 1:2] = grid[:, 1:2] * rope_interpolation_scale[1] * self.patch_size / self.base_height - grid[:, 2:3] = grid[:, 2:3] * rope_interpolation_scale[2] * self.patch_size / self.base_width + if isinstance(rope_interpolation_scale, tuple): + # This will be deprecated in v0.34.0 + grid[:, 0:1] = grid[:, 0:1] * rope_interpolation_scale[0] * self.patch_size_t / self.base_num_frames + grid[:, 1:2] = grid[:, 1:2] * rope_interpolation_scale[1] * self.patch_size / self.base_height + grid[:, 2:3] = grid[:, 2:3] * rope_interpolation_scale[2] * self.patch_size / self.base_width + else: + if not self._causal_rope_fix: + grid[:, 0:1] = grid[:, 0:1] * rope_interpolation_scale[0:1] * self.patch_size_t / self.base_num_frames else: - if not self._causal_rope_fix: - grid[:, 0:1] = ( - grid[:, 0:1] * rope_interpolation_scale[0:1] * self.patch_size_t / self.base_num_frames - ) - else: - grid[:, 0:1] = ( - ((grid[:, 0:1] - 1) * rope_interpolation_scale[0:1] + 1 / frame_rate).clamp(min=0) - * self.patch_size_t - / self.base_num_frames - ) - grid[:, 1:2] = grid[:, 1:2] * rope_interpolation_scale[1:2] * self.patch_size / self.base_height - grid[:, 2:3] = grid[:, 2:3] * rope_interpolation_scale[2:3] * self.patch_size / self.base_width + grid[:, 0:1] = ( + ((grid[:, 0:1] - 1) * rope_interpolation_scale[0:1] + 1 / frame_rate).clamp(min=0) + * self.patch_size_t + / self.base_num_frames + ) + grid[:, 1:2] = grid[:, 1:2] * rope_interpolation_scale[1:2] * self.patch_size / self.base_height + grid[:, 2:3] = grid[:, 2:3] * rope_interpolation_scale[2:3] * self.patch_size / self.base_width grid = grid.flatten(2, 4).transpose(1, 2) + return grid + + def forward( + self, + hidden_states: torch.Tensor, + num_frames: Optional[int] = None, + height: Optional[int] = None, + width: Optional[int] = None, + frame_rate: Optional[int] = None, + rope_interpolation_scale: Optional[Tuple[torch.Tensor, float, float]] = None, + video_coords: Optional[torch.Tensor] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + batch_size = hidden_states.size(0) + + if video_coords is None: + grid = self._prepare_video_coords( + batch_size, + num_frames, + height, + width, + rope_interpolation_scale=rope_interpolation_scale, + frame_rate=frame_rate, + device=hidden_states.device, + ) + else: + grid = torch.stack( + [ + video_coords[:, 0] / self.base_num_frames, + video_coords[:, 1] / self.base_height, + video_coords[:, 2] / self.base_width, + ], + dim=-1, + ) + start = 1.0 end = self.theta freqs = self.theta ** torch.linspace( @@ -387,11 +417,12 @@ def forward( encoder_hidden_states: torch.Tensor, timestep: torch.LongTensor, encoder_attention_mask: torch.Tensor, - num_frames: int, - height: int, - width: int, - frame_rate: int, + num_frames: Optional[int] = None, + height: Optional[int] = None, + width: Optional[int] = None, + frame_rate: Optional[int] = None, rope_interpolation_scale: Optional[Union[Tuple[float, float, float], torch.Tensor]] = None, + video_coords: Optional[torch.Tensor] = None, attention_kwargs: Optional[Dict[str, Any]] = None, return_dict: bool = True, ) -> torch.Tensor: @@ -414,7 +445,9 @@ def forward( msg = "Passing a tuple for `rope_interpolation_scale` is deprecated and will be removed in v0.34.0." deprecate("rope_interpolation_scale", "0.34.0", msg) - image_rotary_emb = self.rope(hidden_states, num_frames, height, width, frame_rate, rope_interpolation_scale) + image_rotary_emb = self.rope( + hidden_states, num_frames, height, width, frame_rate, rope_interpolation_scale, video_coords + ) # convert encoder_attention_mask to a bias the same way we do for attention_mask if encoder_attention_mask is not None and encoder_attention_mask.ndim == 2: diff --git a/src/diffusers/pipelines/__init__.py b/src/diffusers/pipelines/__init__.py index e99162e7a7fe..af5ffdca2152 100644 --- a/src/diffusers/pipelines/__init__.py +++ b/src/diffusers/pipelines/__init__.py @@ -260,7 +260,7 @@ ] ) _import_structure["latte"] = ["LattePipeline"] - _import_structure["ltx"] = ["LTXPipeline", "LTXImageToVideoPipeline"] + _import_structure["ltx"] = ["LTXPipeline", "LTXImageToVideoPipeline", "LTXConditionPipeline"] _import_structure["lumina"] = ["LuminaText2ImgPipeline"] _import_structure["lumina2"] = ["Lumina2Text2ImgPipeline"] _import_structure["marigold"].extend( @@ -610,7 +610,7 @@ LEditsPPPipelineStableDiffusion, LEditsPPPipelineStableDiffusionXL, ) - from .ltx import LTXImageToVideoPipeline, LTXPipeline + from .ltx import LTXConditionPipeline, LTXImageToVideoPipeline, LTXPipeline from .lumina import LuminaText2ImgPipeline from .lumina2 import Lumina2Text2ImgPipeline from .marigold import ( diff --git a/src/diffusers/pipelines/ltx/__init__.py b/src/diffusers/pipelines/ltx/__init__.py index 20cc1c216522..199e730d9b4d 100644 --- a/src/diffusers/pipelines/ltx/__init__.py +++ b/src/diffusers/pipelines/ltx/__init__.py @@ -23,6 +23,7 @@ _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) else: _import_structure["pipeline_ltx"] = ["LTXPipeline"] + _import_structure["pipeline_ltx_condition"] = ["LTXConditionPipeline"] _import_structure["pipeline_ltx_image2video"] = ["LTXImageToVideoPipeline"] if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: @@ -34,6 +35,7 @@ from ...utils.dummy_torch_and_transformers_objects import * else: from .pipeline_ltx import LTXPipeline + from .pipeline_ltx_condition import LTXConditionPipeline from .pipeline_ltx_image2video import LTXImageToVideoPipeline else: diff --git a/src/diffusers/pipelines/ltx/pipeline_ltx_condition.py b/src/diffusers/pipelines/ltx/pipeline_ltx_condition.py index a37b9b5122f2..515950de507e 100644 --- a/src/diffusers/pipelines/ltx/pipeline_ltx_condition.py +++ b/src/diffusers/pipelines/ltx/pipeline_ltx_condition.py @@ -14,14 +14,14 @@ import inspect from dataclasses import dataclass -from typing import Any, Callable, Dict, List, Optional, Tuple, Union +from typing import Any, Callable, Dict, List, Optional, Union -import numpy as np import PIL.Image import torch from transformers import T5EncoderModel, T5TokenizerFast from ...callbacks import MultiPipelineCallbacks, PipelineCallback +from ...image_processor import PipelineImageInput from ...loaders import FromSingleFileMixin, LTXVideoLoraLoaderMixin from ...models.autoencoders import AutoencoderKLLTXVideo from ...models.transformers import LTXVideoTransformer3DModel @@ -46,12 +46,11 @@ Examples: ```py >>> import torch - >>> from diffusers import LTXImageToVideoPipeline + >>> from diffusers import LTXConditionPipeline >>> from diffusers.utils import export_to_video, load_image - >>> pipe = LTXImageToVideoPipeline.from_pretrained("Lightricks/LTX-Video", torch_dtype=torch.bfloat16) + >>> pipe = LTXConditionPipeline.from_pretrained("YiYiXu/ltx-95", torch_dtype=torch.bfloat16) >>> pipe.to("cuda") - >>> image = load_image( ... "https://huggingface.co/datasets/a-r-r-o-w/tiny-meme-dataset-captioned/resolve/main/images/8.png" ... ) @@ -94,6 +93,26 @@ class LTXVideoCondition: strength: float = 1.0 +# from LTX-Video/ltx_video/schedulers/rf.py +def linear_quadratic_schedule(num_steps, threshold_noise=0.025, linear_steps=None): + if linear_steps is None: + linear_steps = num_steps // 2 + if num_steps < 2: + return torch.tensor([1.0]) + linear_sigma_schedule = [i * threshold_noise / linear_steps for i in range(linear_steps)] + threshold_noise_step_diff = linear_steps - threshold_noise * num_steps + quadratic_steps = num_steps - linear_steps + quadratic_coef = threshold_noise_step_diff / (linear_steps * quadratic_steps**2) + linear_coef = threshold_noise / linear_steps - 2 * threshold_noise_step_diff / (quadratic_steps**2) + const = quadratic_coef * (linear_steps**2) + quadratic_sigma_schedule = [ + quadratic_coef * (i**2) + linear_coef * i + const for i in range(linear_steps, num_steps) + ] + sigma_schedule = linear_sigma_schedule + quadratic_sigma_schedule + [1.0] + sigma_schedule = [1.0 - x for x in sigma_schedule] + return torch.tensor(sigma_schedule[:-1]) + + # Copied from diffusers.pipelines.flux.pipeline_flux.calculate_shift def calculate_shift( image_seq_len, @@ -285,7 +304,7 @@ def _get_t5_prompt_embeds( f" {max_sequence_length} tokens: {removed_text}" ) - prompt_embeds = self.text_encoder(text_input_ids.to(device))[0] + prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=prompt_attention_mask)[0] prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) # duplicate text embeddings for each generation per prompt, using mps friendly method @@ -382,10 +401,14 @@ def encode_prompt( return prompt_embeds, prompt_attention_mask, negative_prompt_embeds, negative_prompt_attention_mask - # Copied from diffusers.pipelines.ltx.pipeline_ltx.LTXPipeline.check_inputs def check_inputs( self, prompt, + conditions, + image, + video, + frame_index, + strength, height, width, callback_on_step_end_tensor_inputs=None, @@ -436,6 +459,65 @@ def check_inputs( f" {negative_prompt_attention_mask.shape}." ) + if conditions is not None and (image is not None or video is not None): + raise ValueError("If `conditions` is provided, `image` and `video` must not be provided.") + + if conditions is None and (image is None and video is None): + raise ValueError("If `conditions` is not provided, `image` or `video` must be provided.") + + if conditions is None: + if isinstance(image, list) and isinstance(frame_index, list) and len(image) != len(frame_index): + raise ValueError( + "If `conditions` is not provided, `image` and `frame_index` must be of the same length." + ) + elif isinstance(image, list) and isinstance(strength, list) and len(image) != len(strength): + raise ValueError("If `conditions` is not provided, `image` and `strength` must be of the same length.") + elif isinstance(video, list) and isinstance(frame_index, list) and len(video) != len(frame_index): + raise ValueError( + "If `conditions` is not provided, `video` and `frame_index` must be of the same length." + ) + elif isinstance(video, list) and isinstance(strength, list) and len(video) != len(strength): + raise ValueError("If `conditions` is not provided, `video` and `strength` must be of the same length.") + + @staticmethod + def _prepare_video_ids( + batch_size: int, + num_frames: int, + height: int, + width: int, + patch_size: int = 1, + patch_size_t: int = 1, + device: torch.device = None, + ) -> torch.Tensor: + latent_sample_coords = torch.meshgrid( + torch.arange(0, num_frames, patch_size_t, device=device), + torch.arange(0, height, patch_size, device=device), + torch.arange(0, width, patch_size, device=device), + indexing="ij", + ) + latent_sample_coords = torch.stack(latent_sample_coords, dim=0) + latent_coords = latent_sample_coords.unsqueeze(0).repeat(batch_size, 1, 1, 1, 1) + latent_coords = latent_coords.reshape(batch_size, -1, num_frames * height * width) + + return latent_coords + + @staticmethod + def _scale_video_ids( + video_ids: torch.Tensor, + scale_factor: int = 32, + scale_factor_t: int = 8, + frame_index: int = 0, + device: torch.device = None, + ) -> torch.Tensor: + scaled_latent_coords = ( + video_ids + * torch.tensor([scale_factor_t, scale_factor, scale_factor], device=video_ids.device)[None, :, None] + ) + scaled_latent_coords[:, 0] = (scaled_latent_coords[:, 0] + 1 - scale_factor_t).clamp(min=0) + scaled_latent_coords[:, 0] += frame_index + + return scaled_latent_coords + @staticmethod # Copied from diffusers.pipelines.ltx.pipeline_ltx.LTXPipeline._pack_latents def _pack_latents(latents: torch.Tensor, patch_size: int = 1, patch_size_t: int = 1) -> torch.Tensor: @@ -495,71 +577,64 @@ def _denormalize_latents( latents = latents * latents_std / scaling_factor + latents_mean return latents - def _prepare_non_first_frame_conditioning( - self, - latents: torch.Tensor, - condition_latents: torch.Tensor, - condition_latent_frames_mask: torch.Tensor, - frame_index: int, - strength: float, - num_prefix_latent_frames: int = 2, - prefix_latents_mode: str = "soft", - prefix_soft_conditioning_strength: float = 0.15, - ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - num_latent_frames = latents.size(2) - - if num_latent_frames < num_prefix_latent_frames: - raise ValueError( - f"Number of latent frames must be at least {num_prefix_latent_frames} but got {num_latent_frames}." - ) - if frame_index % self.vae_temporal_compression_ratio != 0: - raise ValueError( - f"Frame index must be divisible by {self.vae_temporal_compression_ratio} but got {frame_index}." - ) + def trim_conditioning_sequence(self, start_frame: int, sequence_num_frames: int, target_num_frames: int): + """ + Trim a conditioning sequence to the allowed number of frames. - if num_latent_frames > num_prefix_latent_frames: - start_frame = frame_index // self.vae_temporal_compression_ratio + num_prefix_latent_frames - end_frame = start_frame + num_latent_frames - num_prefix_latent_frames - latents[:, :, start_frame:end_frame] = torch.lerp( - latents[:, :, start_frame:end_frame], condition_latents[:, :, num_prefix_latent_frames:], strength - ) - condition_latent_frames_mask[:, start_frame:end_frame] = strength - - if prefix_latents_mode == "soft": - if num_prefix_latent_frames > 1: - start_frame = frame_index // self.vae_temporal_compression_ratio + 1 - end_frame = start_frame + num_prefix_latent_frames - 1 - strength = min(prefix_soft_conditioning_strength, strength) - latents[:, :, start_frame:end_frame] = torch.lerp( - latents[:, :, start_frame:end_frame], condition_latents[:, :, 1:num_prefix_latent_frames], strength - ) - condition_latent_frames_mask[:, start_frame:end_frame] = strength - condition_latents = None - elif prefix_latents_mode == "drop": - condition_latents = None - elif prefix_latents_mode == "concat": - condition_latents = condition_latents[:, :, :num_prefix_latent_frames] - else: - raise ValueError(f"Invalid prefix_latents_mode: {prefix_latents_mode}") + Args: + start_frame (int): The target frame number of the first frame in the sequence. + sequence_num_frames (int): The number of frames in the sequence. + target_num_frames (int): The target number of frames in the generated video. + Returns: + int: updated sequence length + """ + scale_factor = self.vae_temporal_compression_ratio + num_frames = min(sequence_num_frames, target_num_frames - start_frame) + # Trim down to a multiple of temporal_scale_factor frames plus 1 + num_frames = (num_frames - 1) // scale_factor * scale_factor + 1 + return num_frames - return latents, condition_latents, condition_latent_frames_mask + @staticmethod + def add_noise_to_image_conditioning_latents( + t: float, + init_latents: torch.Tensor, + latents: torch.Tensor, + noise_scale: float, + conditioning_mask: torch.Tensor, + generator, + eps=1e-6, + ): + """ + Add timestep-dependent noise to the hard-conditioning latents. This helps with motion continuity, especially + when conditioned on a single frame. + """ + noise = randn_tensor( + latents.shape, + generator=generator, + device=latents.device, + dtype=latents.dtype, + ) + # Add noise only to hard-conditioning latents (conditioning_mask = 1.0) + need_to_noise = (conditioning_mask > 1.0 - eps).unsqueeze(-1) + noised_latents = init_latents + noise_scale * noise * (t**2) + latents = torch.where(need_to_noise, noised_latents, latents) + return latents def prepare_latents( self, - conditions: Union[LTXVideoCondition, List[LTXVideoCondition]], + conditions: List[torch.Tensor], + condition_strength: List[float], + condition_frame_index: List[int], batch_size: int = 1, num_channels_latents: int = 128, height: int = 512, width: int = 704, num_frames: int = 161, - frame_rate: int = 25, + num_prefix_latent_frames: int = 2, generator: Optional[torch.Generator] = None, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None, ) -> None: - if not isinstance(conditions, list): - conditions = [conditions] - num_latent_frames = (num_frames - 1) // self.vae_temporal_compression_ratio + 1 latent_height = height // self.vae_spatial_compression_ratio latent_width = width // self.vae_spatial_compression_ratio @@ -567,106 +642,105 @@ def prepare_latents( shape = (batch_size, num_channels_latents, num_latent_frames, latent_height, latent_width) latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + condition_latent_frames_mask = torch.zeros((batch_size, num_latent_frames), device=device, dtype=torch.float32) + extra_conditioning_latents = [] - extra_conditioning_rope_interpolation_scales = [] + extra_conditioning_video_ids = [] extra_conditioning_mask = [] - extra_conditioning_num_latents = ( - 0 # Number of extra conditioning latents added (should be removed before decoding) - ) - condition_latent_frames_mask = torch.zeros((batch_size, num_latent_frames), device=device, dtype=dtype) - - for condition in conditions: - if condition.image is not None: - data = self.video_processor.preprocess(condition.image, height, width).unsqueeze(2) - elif condition.video is not None: - data = self.video_processor.preprocess_video(condition.vide, height, width) - else: - raise ValueError("Either `image` or `video` must be provided in the `LTXVideoCondition`.") - - if data.size(2) % self.vae_temporal_compression_ratio != 1: - raise ValueError( - f"Number of frames in the video must be of the form (k * {self.vae_temporal_compression_ratio} + 1) " - f"but got {data.size(2)} frames." - ) - + extra_conditioning_num_latents = 0 + for data, strength, frame_index in zip(conditions, condition_strength, condition_frame_index): condition_latents = retrieve_latents(self.vae.encode(data), generator=generator) - condition_latents = self._normalize_latents(condition_latents, self.vae.latents_mean, self.vae.latents_std) + condition_latents = self._normalize_latents( + condition_latents, self.vae.latents_mean, self.vae.latents_std + ).to(device, dtype=dtype) + num_data_frames = data.size(2) num_cond_frames = condition_latents.size(2) - if condition.frame_index == 0: + if frame_index == 0: latents[:, :, :num_cond_frames] = torch.lerp( - latents[:, :, :num_cond_frames], condition_latents, condition.strength + latents[:, :, :num_cond_frames], condition_latents, strength ) - condition_latent_frames_mask[:, :num_cond_frames] = condition.strength + condition_latent_frames_mask[:, :num_cond_frames] = strength + else: if num_data_frames > 1: - ( - latents, - condition_latents, - condition_latent_frames_mask, - ) = self._prepare_non_first_frame_conditioning( - latents, - condition_latents, - condition_latent_frames_mask, - condition.frame_index, - condition.strength, - ) - - if condition_latents is not None: - noise = randn_tensor(condition_latents.shape, generator=generator, device=device, dtype=dtype) - condition_latents = torch.lerp(noise, condition_latents, condition.strength) - c_nlf = condition_latents.shape[2] - condition_latents = self._pack_latents( - condition_latents, self.transformer_spatial_patch_size, self.transformer_temporal_patch_size - ) - conditioning_mask = torch.full( - condition_latents.shape[:2], condition.strength, device=device, dtype=dtype - ) - - rope_interpolation_scale = [ - # TODO!!! This is incorrect: the frame index needs to added AFTER multiplying the interpolation - # scale with the grid. - (self.vae_temporal_compression_ratio + condition.frame_index) / frame_rate, - self.vae_spatial_compression_ratio, - self.vae_spatial_compression_ratio, - ] - rope_interpolation_scale = ( - torch.tensor(rope_interpolation_scale, device=device, dtype=dtype) - .view(-1, 1, 1, 1, 1) - .repeat(1, 1, c_nlf, latent_height, latent_width) - ) - extra_conditioning_num_latents += condition_latents.size(1) + if num_cond_frames < num_prefix_latent_frames: + raise ValueError( + f"Number of latent frames must be at least {num_prefix_latent_frames} but got {num_data_frames}." + ) + + if num_cond_frames > num_prefix_latent_frames: + start_frame = frame_index // self.vae_temporal_compression_ratio + num_prefix_latent_frames + end_frame = start_frame + num_cond_frames - num_prefix_latent_frames + latents[:, :, start_frame:end_frame] = torch.lerp( + latents[:, :, start_frame:end_frame], + condition_latents[:, :, num_prefix_latent_frames:], + strength, + ) + condition_latent_frames_mask[:, start_frame:end_frame] = strength + condition_latents = condition_latents[:, :, :num_prefix_latent_frames] + + noise = randn_tensor(condition_latents.shape, generator=generator, device=device, dtype=dtype) + condition_latents = torch.lerp(noise, condition_latents, strength) + + condition_video_ids = self._prepare_video_ids( + batch_size, + condition_latents.size(2), + latent_height, + latent_width, + patch_size=self.transformer_spatial_patch_size, + patch_size_t=self.transformer_temporal_patch_size, + device=device, + ) + condition_video_ids = self._scale_video_ids( + condition_video_ids, + scale_factor=self.vae_spatial_compression_ratio, + scale_factor_t=self.vae_temporal_compression_ratio, + frame_index=frame_index, + device=device, + ) + condition_latents = self._pack_latents( + condition_latents, + self.transformer_spatial_patch_size, + self.transformer_temporal_patch_size, + ) + condition_conditioning_mask = torch.full( + condition_latents.shape[:2], strength, device=device, dtype=dtype + ) - extra_conditioning_latents.append(condition_latents) - extra_conditioning_rope_interpolation_scales.append(rope_interpolation_scale) - extra_conditioning_mask.append(conditioning_mask) + extra_conditioning_latents.append(condition_latents) + extra_conditioning_video_ids.append(condition_video_ids) + extra_conditioning_mask.append(condition_conditioning_mask) + extra_conditioning_num_latents += condition_latents.size(1) - latents = self._pack_latents( - latents, self.transformer_spatial_patch_size, self.transformer_temporal_patch_size + video_ids = self._prepare_video_ids( + batch_size, + num_latent_frames, + latent_height, + latent_width, + patch_size_t=self.transformer_temporal_patch_size, + patch_size=self.transformer_spatial_patch_size, + device=device, ) - rope_interpolation_scale = [ - self.vae_temporal_compression_ratio / frame_rate, - self.vae_spatial_compression_ratio, - self.vae_spatial_compression_ratio, - ] - rope_interpolation_scale = ( - torch.tensor(rope_interpolation_scale, device=device, dtype=dtype) - .view(-1, 1, 1, 1, 1) - .repeat(1, 1, num_latent_frames, latent_height, latent_width) + conditioning_mask = condition_latent_frames_mask.gather(1, video_ids[:, 0]) + video_ids = self._scale_video_ids( + video_ids, + scale_factor=self.vae_spatial_compression_ratio, + scale_factor_t=self.vae_temporal_compression_ratio, + frame_index=0, + device=device, ) - conditioning_mask = self._pack_latents( - conditioning_mask, self.transformer_spatial_patch_size, self.transformer_temporal_patch_size + latents = self._pack_latents( + latents, self.transformer_spatial_patch_size, self.transformer_temporal_patch_size ) if len(extra_conditioning_latents) > 0: latents = torch.cat([*extra_conditioning_latents, latents], dim=1) - rope_interpolation_scale = torch.cat( - [*extra_conditioning_rope_interpolation_scales, rope_interpolation_scale], dim=2 - ) + video_ids = torch.cat([*extra_conditioning_video_ids, video_ids], dim=2) conditioning_mask = torch.cat([*extra_conditioning_mask, conditioning_mask], dim=1) - return latents, conditioning_mask, rope_interpolation_scale, extra_conditioning_num_latents + return latents, conditioning_mask, video_ids, extra_conditioning_num_latents @property def guidance_scale(self): @@ -692,7 +766,11 @@ def interrupt(self): @replace_example_docstring(EXAMPLE_DOC_STRING) def __call__( self, - conditions: Union[LTXVideoCondition, List[LTXVideoCondition]], + conditions: Union[LTXVideoCondition, List[LTXVideoCondition]] = None, + image: Union[PipelineImageInput, List[PipelineImageInput]] = None, + video: List[PipelineImageInput] = None, + frame_index: Union[int, List[int]] = 0, + strength: Union[float, List[float]] = 1.0, prompt: Union[str, List[str]] = None, negative_prompt: Optional[Union[str, List[str]]] = None, height: int = 512, @@ -717,14 +795,25 @@ def __call__( attention_kwargs: Optional[Dict[str, Any]] = None, callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, callback_on_step_end_tensor_inputs: List[str] = ["latents"], - max_sequence_length: int = 128, + max_sequence_length: int = 256, ): r""" Function invoked when calling the pipeline for generation. Args: - conditions (`List[LTXVideoCondition]`): - The list of frame-conditioning items for the video generation. + conditions (`List[LTXVideoCondition], *optional*`): + The list of frame-conditioning items for the video generation.If not provided, conditions will be + created using `image`, `video`, `frame_index` and `strength`. + image (`PipelineImageInput` or `List[PipelineImageInput]`, *optional*): + The image or images to condition the video generation. If not provided, one has to pass `video` or + `conditions`. + video (`List[PipelineImageInput]`, *optional*): + The video to condition the video generation. If not provided, one has to pass `image` or `conditions`. + frame_index (`int` or `List[int]`, *optional*): + The frame index or frame indices at which the image or video will conditionally effect the video + generation. If not provided, one has to pass `conditions`. + strength (`float` or `List[float]`, *optional*): + The strength or strengths of the conditioning effect. If not provided, one has to pass `conditions`. prompt (`str` or `List[str]`, *optional*): The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. instead. @@ -807,6 +896,11 @@ def __call__( # 1. Check inputs. Raise error if not correct self.check_inputs( prompt=prompt, + conditions=conditions, + image=image, + video=video, + frame_index=frame_index, + strength=strength, height=height, width=width, callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, @@ -828,6 +922,31 @@ def __call__( else: batch_size = prompt_embeds.shape[0] + if conditions is not None: + if not isinstance(conditions, list): + conditions = [conditions] + + strength = [condition.strength for condition in conditions] + frame_index = [condition.frame_index for condition in conditions] + image = [condition.image for condition in conditions] + video = [condition.video for condition in conditions] + else: + if not isinstance(image, list): + image = [image] + num_conditions = 1 + elif isinstance(image, list): + num_conditions = len(image) + if not isinstance(video, list): + video = [video] + num_conditions = 1 + elif isinstance(video, list): + num_conditions = len(video) + + if not isinstance(frame_index, list): + frame_index = [frame_index] * num_conditions + if not isinstance(strength, list): + strength = [strength] * num_conditions + device = self._execution_device # 3. Prepare text embeddings @@ -852,45 +971,71 @@ def __call__( prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) prompt_attention_mask = torch.cat([negative_prompt_attention_mask, prompt_attention_mask], dim=0) + vae_dtype = self.vae.dtype + + conditioning_tensors = [] + for condition_image, condition_video, condition_frame_index, condition_strength in zip( + image, video, frame_index, strength + ): + if condition_image is not None: + condition_tensor = ( + self.video_processor.preprocess(condition_image, height, width) + .unsqueeze(2) + .to(device, dtype=vae_dtype) + ) + elif condition_video is not None: + condition_tensor = self.video_processor.preprocess_video(condition_video, height, width) + num_frames_input = condition_tensor.size(2) + num_frames_output = self.trim_conditioning_sequence( + condition_frame_index, num_frames_input, num_frames + ) + condition_tensor = condition_tensor[:, :, :num_frames_output] + condition_tensor = condition_tensor.to(device, dtype=vae_dtype) + else: + raise ValueError("Either `image` or `video` must be provided in the `LTXVideoCondition`.") + + if condition_tensor.size(2) % self.vae_temporal_compression_ratio != 1: + raise ValueError( + f"Number of frames in the video must be of the form (k * {self.vae_temporal_compression_ratio} + 1) " + f"but got {condition_tensor.size(2)} frames." + ) + conditioning_tensors.append(condition_tensor) + # 4. Prepare latent variables num_channels_latents = self.transformer.config.in_channels - latents, conditioning_mask, rope_interpolation_scale, extra_conditioning_num_latents = self.prepare_latents( - conditions, - batch_size * num_videos_per_prompt, - num_channels_latents, - height, - width, - num_frames, - frame_rate, - generator, - device, - torch.float32, + latents, conditioning_mask, video_coords, extra_conditioning_num_latents = self.prepare_latents( + conditioning_tensors, + strength, + frame_index, + batch_size=batch_size * num_videos_per_prompt, + num_channels_latents=num_channels_latents, + height=height, + width=width, + num_frames=num_frames, + generator=generator, + device=device, + dtype=torch.float32, ) + + video_coords = video_coords.float() + video_coords[:, 0] = video_coords[:, 0] * (1.0 / frame_rate) + init_latents = latents.clone() if self.do_classifier_free_guidance: - conditioning_mask = torch.cat([conditioning_mask, conditioning_mask]) + video_coords = torch.cat([video_coords, video_coords], dim=0) # 5. Prepare timesteps latent_num_frames = (num_frames - 1) // self.vae_temporal_compression_ratio + 1 latent_height = height // self.vae_spatial_compression_ratio latent_width = width // self.vae_spatial_compression_ratio - video_sequence_length = latent_num_frames * latent_height * latent_width - sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) - mu = calculate_shift( - video_sequence_length, - self.scheduler.config.get("base_image_seq_len", 256), - self.scheduler.config.get("max_image_seq_len", 4096), - self.scheduler.config.get("base_shift", 0.5), - self.scheduler.config.get("max_shift", 1.15), - ) + sigmas = linear_quadratic_schedule(num_inference_steps) + timesteps = sigmas * 1000 timesteps, num_inference_steps = retrieve_timesteps( self.scheduler, num_inference_steps, device, - timesteps, - sigmas=sigmas, - mu=mu, + timesteps=timesteps, ) num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) self._num_timesteps = len(timesteps) @@ -902,62 +1047,49 @@ def __call__( continue if image_cond_noise_scale > 0: - latents = latents - # TODO(aryan): implement this + # Add timestep-dependent noise to the hard-conditioning latents + # This helps with motion continuity, especially when conditioned on a single frame + latents = self.add_noise_to_image_conditioning_latents( + t / 1000.0, + init_latents, + latents, + image_cond_noise_scale, + conditioning_mask, + generator, + ) latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents + conditioning_mask_model_input = ( + torch.cat([conditioning_mask, conditioning_mask]) + if self.do_classifier_free_guidance + else conditioning_mask + ) latent_model_input = latent_model_input.to(prompt_embeds.dtype) # broadcast to batch dimension in a way that's compatible with ONNX/Core ML - timestep = t.expand(latent_model_input.shape[0]) - timestep = timestep.unsqueeze(-1) * (1 - conditioning_mask) + timestep = t.expand(latent_model_input.shape[0]).unsqueeze(-1).float() + timestep = torch.min(timestep, (1 - conditioning_mask_model_input) * 1000.0) noise_pred = self.transformer( hidden_states=latent_model_input, encoder_hidden_states=prompt_embeds, timestep=timestep, encoder_attention_mask=prompt_attention_mask, - num_frames=latent_num_frames, - height=latent_height, - width=latent_width, - frame_rate=frame_rate, - rope_interpolation_scale=rope_interpolation_scale, + video_coords=video_coords, attention_kwargs=attention_kwargs, return_dict=False, )[0] - noise_pred = noise_pred.float() if self.do_classifier_free_guidance: noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) timestep, _ = timestep.chunk(2) - # compute the previous noisy sample x_t -> x_t-1 - noise_pred = self._unpack_latents( - noise_pred, - latent_num_frames, - latent_height, - latent_width, - self.transformer_spatial_patch_size, - self.transformer_temporal_patch_size, - ) - latents = self._unpack_latents( - latents, - latent_num_frames, - latent_height, - latent_width, - self.transformer_spatial_patch_size, - self.transformer_temporal_patch_size, - ) - - noise_pred = noise_pred[:, :, 1:] - noise_latents = latents[:, :, 1:] - pred_latents = self.scheduler.step(noise_pred, t, noise_latents, return_dict=False)[0] - - latents = torch.cat([latents[:, :, :1], pred_latents], dim=2) - latents = self._pack_latents( - latents, self.transformer_spatial_patch_size, self.transformer_temporal_patch_size - ) + denoised_latents = self.scheduler.step( + -noise_pred, t, latents, per_token_timesteps=timestep, return_dict=False + )[0] + tokens_to_denoise_mask = (t / 1000 - 1e-6 < (1.0 - conditioning_mask)).unsqueeze(-1) + latents = torch.where(tokens_to_denoise_mask, denoised_latents, latents) if callback_on_step_end is not None: callback_kwargs = {} @@ -975,17 +1107,19 @@ def __call__( if XLA_AVAILABLE: xm.mark_step() + latents = latents[:, extra_conditioning_num_latents:] + latents = self._unpack_latents( + latents, + latent_num_frames, + latent_height, + latent_width, + self.transformer_spatial_patch_size, + self.transformer_temporal_patch_size, + ) + if output_type == "latent": video = latents else: - latents = self._unpack_latents( - latents, - latent_num_frames, - latent_height, - latent_width, - self.transformer_spatial_patch_size, - self.transformer_temporal_patch_size, - ) latents = self._denormalize_latents( latents, self.vae.latents_mean, self.vae.latents_std, self.vae.config.scaling_factor ) diff --git a/src/diffusers/schedulers/scheduling_flow_match_euler_discrete.py b/src/diffusers/schedulers/scheduling_flow_match_euler_discrete.py index e3bff7582cd9..cbb27e5fad63 100644 --- a/src/diffusers/schedulers/scheduling_flow_match_euler_discrete.py +++ b/src/diffusers/schedulers/scheduling_flow_match_euler_discrete.py @@ -377,6 +377,7 @@ def step( s_tmax: float = float("inf"), s_noise: float = 1.0, generator: Optional[torch.Generator] = None, + per_token_timesteps: Optional[torch.Tensor] = None, return_dict: bool = True, ) -> Union[FlowMatchEulerDiscreteSchedulerOutput, Tuple]: """ @@ -397,6 +398,8 @@ def step( Scaling factor for noise added to the sample. generator (`torch.Generator`, *optional*): A random number generator. + per_token_timesteps (`torch.Tensor`, *optional*): + The timesteps for each token in the sample. return_dict (`bool`): Whether or not to return a [`~schedulers.scheduling_flow_match_euler_discrete.FlowMatchEulerDiscreteSchedulerOutput`] or tuple. @@ -427,16 +430,26 @@ def step( # Upcast to avoid precision issues when computing prev_sample sample = sample.to(torch.float32) - sigma = self.sigmas[self.step_index] - sigma_next = self.sigmas[self.step_index + 1] + if per_token_timesteps is not None: + per_token_sigmas = per_token_timesteps / self.config.num_train_timesteps - prev_sample = sample + (sigma_next - sigma) * model_output + sigmas = self.sigmas[:, None, None] + lower_mask = sigmas < per_token_sigmas[None] - 1e-6 + lower_sigmas = lower_mask * sigmas + lower_sigmas, _ = lower_sigmas.max(dim=0) + dt = (per_token_sigmas - lower_sigmas)[..., None] + else: + sigma = self.sigmas[self.step_index] + sigma_next = self.sigmas[self.step_index + 1] + dt = sigma_next - sigma - # Cast sample back to model compatible dtype - prev_sample = prev_sample.to(model_output.dtype) + prev_sample = sample + dt * model_output # upon completion increase step index by one self._step_index += 1 + if per_token_timesteps is None: + # Cast sample back to model compatible dtype + prev_sample = prev_sample.to(model_output.dtype) if not return_dict: return (prev_sample,) diff --git a/src/diffusers/utils/dummy_torch_and_transformers_objects.py b/src/diffusers/utils/dummy_torch_and_transformers_objects.py index 5a2818c2e245..9159c30eefb1 100644 --- a/src/diffusers/utils/dummy_torch_and_transformers_objects.py +++ b/src/diffusers/utils/dummy_torch_and_transformers_objects.py @@ -1187,6 +1187,21 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) +class LTXConditionPipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + class LTXImageToVideoPipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] diff --git a/tests/pipelines/ltx/test_ltx_condition.py b/tests/pipelines/ltx/test_ltx_condition.py new file mode 100644 index 000000000000..dbb9a740b433 --- /dev/null +++ b/tests/pipelines/ltx/test_ltx_condition.py @@ -0,0 +1,284 @@ +# Copyright 2024 The HuggingFace Team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +import unittest + +import numpy as np +import torch +from transformers import AutoTokenizer, T5EncoderModel + +from diffusers import ( + AutoencoderKLLTXVideo, + FlowMatchEulerDiscreteScheduler, + LTXConditionPipeline, + LTXVideoTransformer3DModel, +) +from diffusers.pipelines.ltx.pipeline_ltx_condition import LTXVideoCondition +from diffusers.utils.testing_utils import enable_full_determinism, torch_device + +from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS +from ..test_pipelines_common import PipelineTesterMixin, to_np + + +enable_full_determinism() + + +class LTXConditionPipelineFastTests(PipelineTesterMixin, unittest.TestCase): + pipeline_class = LTXConditionPipeline + params = TEXT_TO_IMAGE_PARAMS - {"cross_attention_kwargs"} + batch_params = TEXT_TO_IMAGE_BATCH_PARAMS.union({"image"}) + image_params = TEXT_TO_IMAGE_IMAGE_PARAMS + image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS + required_optional_params = frozenset( + [ + "num_inference_steps", + "generator", + "latents", + "return_dict", + "callback_on_step_end", + "callback_on_step_end_tensor_inputs", + ] + ) + test_xformers_attention = False + + def get_dummy_components(self): + torch.manual_seed(0) + transformer = LTXVideoTransformer3DModel( + in_channels=8, + out_channels=8, + patch_size=1, + patch_size_t=1, + num_attention_heads=4, + attention_head_dim=8, + cross_attention_dim=32, + num_layers=1, + caption_channels=32, + ) + + torch.manual_seed(0) + vae = AutoencoderKLLTXVideo( + in_channels=3, + out_channels=3, + latent_channels=8, + block_out_channels=(8, 8, 8, 8), + decoder_block_out_channels=(8, 8, 8, 8), + layers_per_block=(1, 1, 1, 1, 1), + decoder_layers_per_block=(1, 1, 1, 1, 1), + spatio_temporal_scaling=(True, True, False, False), + decoder_spatio_temporal_scaling=(True, True, False, False), + decoder_inject_noise=(False, False, False, False, False), + upsample_residual=(False, False, False, False), + upsample_factor=(1, 1, 1, 1), + timestep_conditioning=False, + patch_size=1, + patch_size_t=1, + encoder_causal=True, + decoder_causal=False, + ) + vae.use_framewise_encoding = False + vae.use_framewise_decoding = False + + torch.manual_seed(0) + scheduler = FlowMatchEulerDiscreteScheduler() + text_encoder = T5EncoderModel.from_pretrained("hf-internal-testing/tiny-random-t5") + tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-t5") + + components = { + "transformer": transformer, + "vae": vae, + "scheduler": scheduler, + "text_encoder": text_encoder, + "tokenizer": tokenizer, + } + return components + + def get_dummy_inputs(self, device, seed=0, use_conditions=False): + if str(device).startswith("mps"): + generator = torch.manual_seed(seed) + else: + generator = torch.Generator(device=device).manual_seed(seed) + + image = torch.randn((1, 3, 32, 32), generator=generator, device=device) + if use_conditions: + conditions = LTXVideoCondition( + image=image, + ) + else: + conditions = None + + inputs = { + "conditions": conditions, + "image": None if use_conditions else image, + "prompt": "dance monkey", + "negative_prompt": "", + "generator": generator, + "num_inference_steps": 2, + "guidance_scale": 3.0, + "height": 32, + "width": 32, + # 8 * k + 1 is the recommendation + "num_frames": 9, + "max_sequence_length": 16, + "output_type": "pt", + } + + return inputs + + def test_inference(self): + device = "cpu" + + components = self.get_dummy_components() + pipe = self.pipeline_class(**components) + pipe.to(device) + pipe.set_progress_bar_config(disable=None) + + inputs = self.get_dummy_inputs(device) + inputs2 = self.get_dummy_inputs(device, use_conditions=True) + video = pipe(**inputs).frames + generated_video = video[0] + video2 = pipe(**inputs2).frames + generated_video2 = video2[0] + + self.assertEqual(generated_video.shape, (9, 3, 32, 32)) + + max_diff = np.abs(generated_video - generated_video2).max() + self.assertLessEqual(max_diff, 1e-3) + + def test_callback_inputs(self): + sig = inspect.signature(self.pipeline_class.__call__) + has_callback_tensor_inputs = "callback_on_step_end_tensor_inputs" in sig.parameters + has_callback_step_end = "callback_on_step_end" in sig.parameters + + if not (has_callback_tensor_inputs and has_callback_step_end): + return + + components = self.get_dummy_components() + pipe = self.pipeline_class(**components) + pipe = pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + self.assertTrue( + hasattr(pipe, "_callback_tensor_inputs"), + f" {self.pipeline_class} should have `_callback_tensor_inputs` that defines a list of tensor variables its callback function can use as inputs", + ) + + def callback_inputs_subset(pipe, i, t, callback_kwargs): + # iterate over callback args + for tensor_name, tensor_value in callback_kwargs.items(): + # check that we're only passing in allowed tensor inputs + assert tensor_name in pipe._callback_tensor_inputs + + return callback_kwargs + + def callback_inputs_all(pipe, i, t, callback_kwargs): + for tensor_name in pipe._callback_tensor_inputs: + assert tensor_name in callback_kwargs + + # iterate over callback args + for tensor_name, tensor_value in callback_kwargs.items(): + # check that we're only passing in allowed tensor inputs + assert tensor_name in pipe._callback_tensor_inputs + + return callback_kwargs + + inputs = self.get_dummy_inputs(torch_device) + + # Test passing in a subset + inputs["callback_on_step_end"] = callback_inputs_subset + inputs["callback_on_step_end_tensor_inputs"] = ["latents"] + output = pipe(**inputs)[0] + + # Test passing in a everything + inputs["callback_on_step_end"] = callback_inputs_all + inputs["callback_on_step_end_tensor_inputs"] = pipe._callback_tensor_inputs + output = pipe(**inputs)[0] + + def callback_inputs_change_tensor(pipe, i, t, callback_kwargs): + is_last = i == (pipe.num_timesteps - 1) + if is_last: + callback_kwargs["latents"] = torch.zeros_like(callback_kwargs["latents"]) + return callback_kwargs + + inputs["callback_on_step_end"] = callback_inputs_change_tensor + inputs["callback_on_step_end_tensor_inputs"] = pipe._callback_tensor_inputs + output = pipe(**inputs)[0] + assert output.abs().sum() < 1e10 + + def test_inference_batch_single_identical(self): + self._test_inference_batch_single_identical(batch_size=3, expected_max_diff=1e-3) + + def test_attention_slicing_forward_pass( + self, test_max_difference=True, test_mean_pixel_difference=True, expected_max_diff=1e-3 + ): + if not self.test_attention_slicing: + return + + components = self.get_dummy_components() + pipe = self.pipeline_class(**components) + for component in pipe.components.values(): + if hasattr(component, "set_default_attn_processor"): + component.set_default_attn_processor() + pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + + generator_device = "cpu" + inputs = self.get_dummy_inputs(generator_device) + output_without_slicing = pipe(**inputs)[0] + + pipe.enable_attention_slicing(slice_size=1) + inputs = self.get_dummy_inputs(generator_device) + output_with_slicing1 = pipe(**inputs)[0] + + pipe.enable_attention_slicing(slice_size=2) + inputs = self.get_dummy_inputs(generator_device) + output_with_slicing2 = pipe(**inputs)[0] + + if test_max_difference: + max_diff1 = np.abs(to_np(output_with_slicing1) - to_np(output_without_slicing)).max() + max_diff2 = np.abs(to_np(output_with_slicing2) - to_np(output_without_slicing)).max() + self.assertLess( + max(max_diff1, max_diff2), + expected_max_diff, + "Attention slicing should not affect the inference results", + ) + + def test_vae_tiling(self, expected_diff_max: float = 0.2): + generator_device = "cpu" + components = self.get_dummy_components() + + pipe = self.pipeline_class(**components) + pipe.to("cpu") + pipe.set_progress_bar_config(disable=None) + + # Without tiling + inputs = self.get_dummy_inputs(generator_device) + inputs["height"] = inputs["width"] = 128 + output_without_tiling = pipe(**inputs)[0] + + # With tiling + pipe.vae.enable_tiling( + tile_sample_min_height=96, + tile_sample_min_width=96, + tile_sample_stride_height=64, + tile_sample_stride_width=64, + ) + inputs = self.get_dummy_inputs(generator_device) + inputs["height"] = inputs["width"] = 128 + output_with_tiling = pipe(**inputs)[0] + + self.assertLess( + (to_np(output_without_tiling) - to_np(output_with_tiling)).max(), + expected_diff_max, + "VAE tiling should not affect the inference results", + )