diff --git a/docs/source/en/api/pipelines/ltx_video.md b/docs/source/en/api/pipelines/ltx_video.md index f31c621293fc..4bc22c0f9f6c 100644 --- a/docs/source/en/api/pipelines/ltx_video.md +++ b/docs/source/en/api/pipelines/ltx_video.md @@ -196,6 +196,12 @@ export_to_video(video, "ship.mp4", fps=24) - all - __call__ +## LTXConditionPipeline + +[[autodoc]] LTXConditionPipeline + - all + - __call__ + ## LTXPipelineOutput [[autodoc]] pipelines.ltx.pipeline_output.LTXPipelineOutput diff --git a/scripts/convert_ltx_to_diffusers.py b/scripts/convert_ltx_to_diffusers.py index 7df0745fd98c..2e966d5d110b 100644 --- a/scripts/convert_ltx_to_diffusers.py +++ b/scripts/convert_ltx_to_diffusers.py @@ -74,6 +74,32 @@ def remove_keys_(key: str, state_dict: Dict[str, Any]): "last_scale_shift_table": "scale_shift_table", } +VAE_095_RENAME_DICT = { + # decoder + "up_blocks.0": "mid_block", + "up_blocks.1": "up_blocks.0.upsamplers.0", + "up_blocks.2": "up_blocks.0", + "up_blocks.3": "up_blocks.1.upsamplers.0", + "up_blocks.4": "up_blocks.1", + "up_blocks.5": "up_blocks.2.upsamplers.0", + "up_blocks.6": "up_blocks.2", + "up_blocks.7": "up_blocks.3.upsamplers.0", + "up_blocks.8": "up_blocks.3", + # encoder + "down_blocks.0": "down_blocks.0", + "down_blocks.1": "down_blocks.0.downsamplers.0", + "down_blocks.2": "down_blocks.1", + "down_blocks.3": "down_blocks.1.downsamplers.0", + "down_blocks.4": "down_blocks.2", + "down_blocks.5": "down_blocks.2.downsamplers.0", + "down_blocks.6": "down_blocks.3", + "down_blocks.7": "down_blocks.3.downsamplers.0", + "down_blocks.8": "mid_block", + # common + "last_time_embedder": "time_embedder", + "last_scale_shift_table": "scale_shift_table", +} + VAE_SPECIAL_KEYS_REMAP = { "per_channel_statistics.channel": remove_keys_, "per_channel_statistics.mean-of-means": remove_keys_, @@ -81,10 +107,6 @@ def remove_keys_(key: str, state_dict: Dict[str, Any]): "model.diffusion_model": remove_keys_, } -VAE_091_SPECIAL_KEYS_REMAP = { - "timestep_scale_multiplier": remove_keys_, -} - def get_state_dict(saved_dict: Dict[str, Any]) -> Dict[str, Any]: state_dict = saved_dict @@ -104,12 +126,16 @@ def update_state_dict_inplace(state_dict: Dict[str, Any], old_key: str, new_key: def convert_transformer( ckpt_path: str, dtype: torch.dtype, + version: str = "0.9.0", ): PREFIX_KEY = "model.diffusion_model." original_state_dict = get_state_dict(load_file(ckpt_path)) + config = {} + if version == "0.9.5": + config["_use_causal_rope_fix"] = True with init_empty_weights(): - transformer = LTXVideoTransformer3DModel() + transformer = LTXVideoTransformer3DModel(**config) for key in list(original_state_dict.keys()): new_key = key[:] @@ -161,12 +187,19 @@ def get_vae_config(version: str) -> Dict[str, Any]: "out_channels": 3, "latent_channels": 128, "block_out_channels": (128, 256, 512, 512), + "down_block_types": ( + "LTXVideoDownBlock3D", + "LTXVideoDownBlock3D", + "LTXVideoDownBlock3D", + "LTXVideoDownBlock3D", + ), "decoder_block_out_channels": (128, 256, 512, 512), "layers_per_block": (4, 3, 3, 3, 4), "decoder_layers_per_block": (4, 3, 3, 3, 4), "spatio_temporal_scaling": (True, True, True, False), "decoder_spatio_temporal_scaling": (True, True, True, False), "decoder_inject_noise": (False, False, False, False, False), + "downsample_type": ("conv", "conv", "conv", "conv"), "upsample_residual": (False, False, False, False), "upsample_factor": (1, 1, 1, 1), "patch_size": 4, @@ -183,12 +216,19 @@ def get_vae_config(version: str) -> Dict[str, Any]: "out_channels": 3, "latent_channels": 128, "block_out_channels": (128, 256, 512, 512), + "down_block_types": ( + "LTXVideoDownBlock3D", + "LTXVideoDownBlock3D", + "LTXVideoDownBlock3D", + "LTXVideoDownBlock3D", + ), "decoder_block_out_channels": (256, 512, 1024), "layers_per_block": (4, 3, 3, 3, 4), "decoder_layers_per_block": (5, 6, 7, 8), "spatio_temporal_scaling": (True, True, True, False), "decoder_spatio_temporal_scaling": (True, True, True), "decoder_inject_noise": (True, True, True, False), + "downsample_type": ("conv", "conv", "conv", "conv"), "upsample_residual": (True, True, True), "upsample_factor": (2, 2, 2), "timestep_conditioning": True, @@ -200,7 +240,38 @@ def get_vae_config(version: str) -> Dict[str, Any]: "decoder_causal": False, } VAE_KEYS_RENAME_DICT.update(VAE_091_RENAME_DICT) - VAE_SPECIAL_KEYS_REMAP.update(VAE_091_SPECIAL_KEYS_REMAP) + elif version == "0.9.5": + config = { + "in_channels": 3, + "out_channels": 3, + "latent_channels": 128, + "block_out_channels": (128, 256, 512, 1024, 2048), + "down_block_types": ( + "LTXVideo095DownBlock3D", + "LTXVideo095DownBlock3D", + "LTXVideo095DownBlock3D", + "LTXVideo095DownBlock3D", + ), + "decoder_block_out_channels": (256, 512, 1024), + "layers_per_block": (4, 6, 6, 2, 2), + "decoder_layers_per_block": (5, 5, 5, 5), + "spatio_temporal_scaling": (True, True, True, True), + "decoder_spatio_temporal_scaling": (True, True, True), + "decoder_inject_noise": (False, False, False, False), + "downsample_type": ("spatial", "temporal", "spatiotemporal", "spatiotemporal"), + "upsample_residual": (True, True, True), + "upsample_factor": (2, 2, 2), + "timestep_conditioning": True, + "patch_size": 4, + "patch_size_t": 1, + "resnet_norm_eps": 1e-6, + "scaling_factor": 1.0, + "encoder_causal": True, + "decoder_causal": False, + "spatial_compression_ratio": 32, + "temporal_compression_ratio": 8, + } + VAE_KEYS_RENAME_DICT.update(VAE_095_RENAME_DICT) return config @@ -223,7 +294,7 @@ def get_args(): parser.add_argument("--output_path", type=str, required=True, help="Path where converted model should be saved") parser.add_argument("--dtype", default="fp32", help="Torch dtype to save the model in.") parser.add_argument( - "--version", type=str, default="0.9.0", choices=["0.9.0", "0.9.1"], help="Version of the LTX model" + "--version", type=str, default="0.9.0", choices=["0.9.0", "0.9.1", "0.9.5"], help="Version of the LTX model" ) return parser.parse_args() @@ -277,14 +348,17 @@ def get_args(): for param in text_encoder.parameters(): param.data = param.data.contiguous() - scheduler = FlowMatchEulerDiscreteScheduler( - use_dynamic_shifting=True, - base_shift=0.95, - max_shift=2.05, - base_image_seq_len=1024, - max_image_seq_len=4096, - shift_terminal=0.1, - ) + if args.version == "0.9.5": + scheduler = FlowMatchEulerDiscreteScheduler(use_dynamic_shifting=False) + else: + scheduler = FlowMatchEulerDiscreteScheduler( + use_dynamic_shifting=True, + base_shift=0.95, + max_shift=2.05, + base_image_seq_len=1024, + max_image_seq_len=4096, + shift_terminal=0.1, + ) pipe = LTXPipeline( scheduler=scheduler, diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index 913816ec9a93..7754d538cfef 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -401,6 +401,7 @@ "LDMTextToImagePipeline", "LEditsPPPipelineStableDiffusion", "LEditsPPPipelineStableDiffusionXL", + "LTXConditionPipeline", "LTXImageToVideoPipeline", "LTXPipeline", "Lumina2Pipeline", @@ -945,6 +946,7 @@ LDMTextToImagePipeline, LEditsPPPipelineStableDiffusion, LEditsPPPipelineStableDiffusionXL, + LTXConditionPipeline, LTXImageToVideoPipeline, LTXPipeline, Lumina2Pipeline, diff --git a/src/diffusers/models/autoencoders/autoencoder_kl_ltx.py b/src/diffusers/models/autoencoders/autoencoder_kl_ltx.py index 75709ca10dfe..2b2f77a5509d 100644 --- a/src/diffusers/models/autoencoders/autoencoder_kl_ltx.py +++ b/src/diffusers/models/autoencoders/autoencoder_kl_ltx.py @@ -196,6 +196,55 @@ def forward( return hidden_states +class LTXVideoDownsampler3d(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + stride: Union[int, Tuple[int, int, int]] = 1, + is_causal: bool = True, + padding_mode: str = "zeros", + ) -> None: + super().__init__() + + self.stride = stride if isinstance(stride, tuple) else (stride, stride, stride) + self.group_size = (in_channels * stride[0] * stride[1] * stride[2]) // out_channels + + out_channels = out_channels // (self.stride[0] * self.stride[1] * self.stride[2]) + + self.conv = LTXVideoCausalConv3d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=3, + stride=1, + is_causal=is_causal, + padding_mode=padding_mode, + ) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = torch.cat([hidden_states[:, :, : self.stride[0] - 1], hidden_states], dim=2) + + residual = ( + hidden_states.unflatten(4, (-1, self.stride[2])) + .unflatten(3, (-1, self.stride[1])) + .unflatten(2, (-1, self.stride[0])) + ) + residual = residual.permute(0, 1, 3, 5, 7, 2, 4, 6).flatten(1, 4) + residual = residual.unflatten(1, (-1, self.group_size)) + residual = residual.mean(dim=2) + + hidden_states = self.conv(hidden_states) + hidden_states = ( + hidden_states.unflatten(4, (-1, self.stride[2])) + .unflatten(3, (-1, self.stride[1])) + .unflatten(2, (-1, self.stride[0])) + ) + hidden_states = hidden_states.permute(0, 1, 3, 5, 7, 2, 4, 6).flatten(1, 4) + hidden_states = hidden_states + residual + + return hidden_states + + class LTXVideoUpsampler3d(nn.Module): def __init__( self, @@ -204,6 +253,7 @@ def __init__( is_causal: bool = True, residual: bool = False, upscale_factor: int = 1, + padding_mode: str = "zeros", ) -> None: super().__init__() @@ -219,6 +269,7 @@ def __init__( kernel_size=3, stride=1, is_causal=is_causal, + padding_mode=padding_mode, ) def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: @@ -352,6 +403,118 @@ def forward( return hidden_states +class LTXVideo095DownBlock3D(nn.Module): + r""" + Down block used in the LTXVideo model. + + Args: + in_channels (`int`): + Number of input channels. + out_channels (`int`, *optional*): + Number of output channels. If None, defaults to `in_channels`. + num_layers (`int`, defaults to `1`): + Number of resnet layers. + dropout (`float`, defaults to `0.0`): + Dropout rate. + resnet_eps (`float`, defaults to `1e-6`): + Epsilon value for normalization layers. + resnet_act_fn (`str`, defaults to `"swish"`): + Activation function to use. + spatio_temporal_scale (`bool`, defaults to `True`): + Whether or not to use a downsampling layer. If not used, output dimension would be same as input dimension. + Whether or not to downsample across temporal dimension. + is_causal (`bool`, defaults to `True`): + Whether this layer behaves causally (future frames depend only on past frames) or not. + """ + + _supports_gradient_checkpointing = True + + def __init__( + self, + in_channels: int, + out_channels: Optional[int] = None, + num_layers: int = 1, + dropout: float = 0.0, + resnet_eps: float = 1e-6, + resnet_act_fn: str = "swish", + spatio_temporal_scale: bool = True, + is_causal: bool = True, + downsample_type: str = "conv", + ): + super().__init__() + + out_channels = out_channels or in_channels + + resnets = [] + for _ in range(num_layers): + resnets.append( + LTXVideoResnetBlock3d( + in_channels=in_channels, + out_channels=in_channels, + dropout=dropout, + eps=resnet_eps, + non_linearity=resnet_act_fn, + is_causal=is_causal, + ) + ) + self.resnets = nn.ModuleList(resnets) + + self.downsamplers = None + if spatio_temporal_scale: + self.downsamplers = nn.ModuleList() + + if downsample_type == "conv": + self.downsamplers.append( + LTXVideoCausalConv3d( + in_channels=in_channels, + out_channels=in_channels, + kernel_size=3, + stride=(2, 2, 2), + is_causal=is_causal, + ) + ) + elif downsample_type == "spatial": + self.downsamplers.append( + LTXVideoDownsampler3d( + in_channels=in_channels, out_channels=out_channels, stride=(1, 2, 2), is_causal=is_causal + ) + ) + elif downsample_type == "temporal": + self.downsamplers.append( + LTXVideoDownsampler3d( + in_channels=in_channels, out_channels=out_channels, stride=(2, 1, 1), is_causal=is_causal + ) + ) + elif downsample_type == "spatiotemporal": + self.downsamplers.append( + LTXVideoDownsampler3d( + in_channels=in_channels, out_channels=out_channels, stride=(2, 2, 2), is_causal=is_causal + ) + ) + + self.gradient_checkpointing = False + + def forward( + self, + hidden_states: torch.Tensor, + temb: Optional[torch.Tensor] = None, + generator: Optional[torch.Generator] = None, + ) -> torch.Tensor: + r"""Forward method of the `LTXDownBlock3D` class.""" + + for i, resnet in enumerate(self.resnets): + if torch.is_grad_enabled() and self.gradient_checkpointing: + hidden_states = self._gradient_checkpointing_func(resnet, hidden_states, temb, generator) + else: + hidden_states = resnet(hidden_states, temb, generator) + + if self.downsamplers is not None: + for downsampler in self.downsamplers: + hidden_states = downsampler(hidden_states) + + return hidden_states + + # Adapted from diffusers.models.autoencoders.autoencoder_kl_cogvideox.CogVideoMidBlock3d class LTXVideoMidBlock3d(nn.Module): r""" @@ -593,8 +756,15 @@ def __init__( in_channels: int = 3, out_channels: int = 128, block_out_channels: Tuple[int, ...] = (128, 256, 512, 512), + down_block_types: Tuple[str, ...] = ( + "LTXVideoDownBlock3D", + "LTXVideoDownBlock3D", + "LTXVideoDownBlock3D", + "LTXVideoDownBlock3D", + ), spatio_temporal_scaling: Tuple[bool, ...] = (True, True, True, False), layers_per_block: Tuple[int, ...] = (4, 3, 3, 3, 4), + downsample_type: Tuple[str, ...] = ("conv", "conv", "conv", "conv"), patch_size: int = 4, patch_size_t: int = 1, resnet_norm_eps: float = 1e-6, @@ -617,20 +787,37 @@ def __init__( ) # down blocks - num_block_out_channels = len(block_out_channels) + is_ltx_095 = down_block_types[-1] == "LTXVideo095DownBlock3D" + num_block_out_channels = len(block_out_channels) - (1 if is_ltx_095 else 0) self.down_blocks = nn.ModuleList([]) for i in range(num_block_out_channels): input_channel = output_channel - output_channel = block_out_channels[i + 1] if i + 1 < num_block_out_channels else block_out_channels[i] - - down_block = LTXVideoDownBlock3D( - in_channels=input_channel, - out_channels=output_channel, - num_layers=layers_per_block[i], - resnet_eps=resnet_norm_eps, - spatio_temporal_scale=spatio_temporal_scaling[i], - is_causal=is_causal, - ) + if not is_ltx_095: + output_channel = block_out_channels[i + 1] if i + 1 < num_block_out_channels else block_out_channels[i] + else: + output_channel = block_out_channels[i + 1] + + if down_block_types[i] == "LTXVideoDownBlock3D": + down_block = LTXVideoDownBlock3D( + in_channels=input_channel, + out_channels=output_channel, + num_layers=layers_per_block[i], + resnet_eps=resnet_norm_eps, + spatio_temporal_scale=spatio_temporal_scaling[i], + is_causal=is_causal, + ) + elif down_block_types[i] == "LTXVideo095DownBlock3D": + down_block = LTXVideo095DownBlock3D( + in_channels=input_channel, + out_channels=output_channel, + num_layers=layers_per_block[i], + resnet_eps=resnet_norm_eps, + spatio_temporal_scale=spatio_temporal_scaling[i], + is_causal=is_causal, + downsample_type=downsample_type[i], + ) + else: + raise ValueError(f"Unknown down block type: {down_block_types[i]}") self.down_blocks.append(down_block) @@ -794,7 +981,9 @@ def __init__( # timestep embedding self.time_embedder = None self.scale_shift_table = None + self.timestep_scale_multiplier = None if timestep_conditioning: + self.timestep_scale_multiplier = nn.Parameter(torch.tensor(1000.0, dtype=torch.float32)) self.time_embedder = PixArtAlphaCombinedTimestepSizeEmbeddings(output_channel * 2, 0) self.scale_shift_table = nn.Parameter(torch.randn(2, output_channel) / output_channel**0.5) @@ -803,6 +992,9 @@ def __init__( def forward(self, hidden_states: torch.Tensor, temb: Optional[torch.Tensor] = None) -> torch.Tensor: hidden_states = self.conv_in(hidden_states) + if self.timestep_scale_multiplier is not None: + temb = temb * self.timestep_scale_multiplier + if torch.is_grad_enabled() and self.gradient_checkpointing: hidden_states = self._gradient_checkpointing_func(self.mid_block, hidden_states, temb) @@ -891,12 +1083,19 @@ def __init__( out_channels: int = 3, latent_channels: int = 128, block_out_channels: Tuple[int, ...] = (128, 256, 512, 512), + down_block_types: Tuple[str, ...] = ( + "LTXVideoDownBlock3D", + "LTXVideoDownBlock3D", + "LTXVideoDownBlock3D", + "LTXVideoDownBlock3D", + ), decoder_block_out_channels: Tuple[int, ...] = (128, 256, 512, 512), layers_per_block: Tuple[int, ...] = (4, 3, 3, 3, 4), decoder_layers_per_block: Tuple[int, ...] = (4, 3, 3, 3, 4), spatio_temporal_scaling: Tuple[bool, ...] = (True, True, True, False), decoder_spatio_temporal_scaling: Tuple[bool, ...] = (True, True, True, False), decoder_inject_noise: Tuple[bool, ...] = (False, False, False, False, False), + downsample_type: Tuple[str, ...] = ("conv", "conv", "conv", "conv"), upsample_residual: Tuple[bool, ...] = (False, False, False, False), upsample_factor: Tuple[int, ...] = (1, 1, 1, 1), timestep_conditioning: bool = False, @@ -906,6 +1105,8 @@ def __init__( scaling_factor: float = 1.0, encoder_causal: bool = True, decoder_causal: bool = False, + spatial_compression_ratio: int = None, + temporal_compression_ratio: int = None, ) -> None: super().__init__() @@ -913,8 +1114,10 @@ def __init__( in_channels=in_channels, out_channels=latent_channels, block_out_channels=block_out_channels, + down_block_types=down_block_types, spatio_temporal_scaling=spatio_temporal_scaling, layers_per_block=layers_per_block, + downsample_type=downsample_type, patch_size=patch_size, patch_size_t=patch_size_t, resnet_norm_eps=resnet_norm_eps, @@ -941,8 +1144,16 @@ def __init__( self.register_buffer("latents_mean", latents_mean, persistent=True) self.register_buffer("latents_std", latents_std, persistent=True) - self.spatial_compression_ratio = patch_size * 2 ** sum(spatio_temporal_scaling) - self.temporal_compression_ratio = patch_size_t * 2 ** sum(spatio_temporal_scaling) + self.spatial_compression_ratio = ( + patch_size * 2 ** sum(spatio_temporal_scaling) + if spatial_compression_ratio is None + else spatial_compression_ratio + ) + self.temporal_compression_ratio = ( + patch_size_t * 2 ** sum(spatio_temporal_scaling) + if temporal_compression_ratio is None + else temporal_compression_ratio + ) # When decoding a batch of video latents at a time, one can save memory by slicing across the batch dimension # to perform decoding of a single video latent at a time. diff --git a/src/diffusers/models/transformers/transformer_ltx.py b/src/diffusers/models/transformers/transformer_ltx.py index f5dc63f49562..c1f2df587927 100644 --- a/src/diffusers/models/transformers/transformer_ltx.py +++ b/src/diffusers/models/transformers/transformer_ltx.py @@ -14,7 +14,7 @@ # limitations under the License. import math -from typing import Any, Dict, Optional, Tuple +from typing import Any, Dict, Optional, Tuple, Union import torch import torch.nn as nn @@ -113,20 +113,19 @@ def __init__( self.patch_size_t = patch_size_t self.theta = theta - def forward( + def _prepare_video_coords( self, - hidden_states: torch.Tensor, + batch_size: int, num_frames: int, height: int, width: int, - rope_interpolation_scale: Optional[Tuple[torch.Tensor, float, float]] = None, - ) -> Tuple[torch.Tensor, torch.Tensor]: - batch_size = hidden_states.size(0) - + rope_interpolation_scale: Tuple[torch.Tensor, float, float], + device: torch.device, + ) -> torch.Tensor: # Always compute rope in fp32 - grid_h = torch.arange(height, dtype=torch.float32, device=hidden_states.device) - grid_w = torch.arange(width, dtype=torch.float32, device=hidden_states.device) - grid_f = torch.arange(num_frames, dtype=torch.float32, device=hidden_states.device) + grid_h = torch.arange(height, dtype=torch.float32, device=device) + grid_w = torch.arange(width, dtype=torch.float32, device=device) + grid_f = torch.arange(num_frames, dtype=torch.float32, device=device) grid = torch.meshgrid(grid_f, grid_h, grid_w, indexing="ij") grid = torch.stack(grid, dim=0) grid = grid.unsqueeze(0).repeat(batch_size, 1, 1, 1, 1) @@ -138,6 +137,38 @@ def forward( grid = grid.flatten(2, 4).transpose(1, 2) + return grid + + def forward( + self, + hidden_states: torch.Tensor, + num_frames: Optional[int] = None, + height: Optional[int] = None, + width: Optional[int] = None, + rope_interpolation_scale: Optional[Tuple[torch.Tensor, float, float]] = None, + video_coords: Optional[torch.Tensor] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + batch_size = hidden_states.size(0) + + if video_coords is None: + grid = self._prepare_video_coords( + batch_size, + num_frames, + height, + width, + rope_interpolation_scale=rope_interpolation_scale, + device=hidden_states.device, + ) + else: + grid = torch.stack( + [ + video_coords[:, 0] / self.base_num_frames, + video_coords[:, 1] / self.base_height, + video_coords[:, 2] / self.base_width, + ], + dim=-1, + ) + start = 1.0 end = self.theta freqs = self.theta ** torch.linspace( @@ -367,10 +398,11 @@ def forward( encoder_hidden_states: torch.Tensor, timestep: torch.LongTensor, encoder_attention_mask: torch.Tensor, - num_frames: int, - height: int, - width: int, - rope_interpolation_scale: Optional[Tuple[float, float, float]] = None, + num_frames: Optional[int] = None, + height: Optional[int] = None, + width: Optional[int] = None, + rope_interpolation_scale: Optional[Union[Tuple[float, float, float], torch.Tensor]] = None, + video_coords: Optional[torch.Tensor] = None, attention_kwargs: Optional[Dict[str, Any]] = None, return_dict: bool = True, ) -> torch.Tensor: @@ -389,7 +421,7 @@ def forward( "Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective." ) - image_rotary_emb = self.rope(hidden_states, num_frames, height, width, rope_interpolation_scale) + image_rotary_emb = self.rope(hidden_states, num_frames, height, width, rope_interpolation_scale, video_coords) # convert encoder_attention_mask to a bias the same way we do for attention_mask if encoder_attention_mask is not None and encoder_attention_mask.ndim == 2: diff --git a/src/diffusers/pipelines/__init__.py b/src/diffusers/pipelines/__init__.py index 541d1a743bcb..cbcefae8bcd2 100644 --- a/src/diffusers/pipelines/__init__.py +++ b/src/diffusers/pipelines/__init__.py @@ -264,7 +264,7 @@ ] ) _import_structure["latte"] = ["LattePipeline"] - _import_structure["ltx"] = ["LTXPipeline", "LTXImageToVideoPipeline"] + _import_structure["ltx"] = ["LTXPipeline", "LTXImageToVideoPipeline", "LTXConditionPipeline"] _import_structure["lumina"] = ["LuminaPipeline", "LuminaText2ImgPipeline"] _import_structure["lumina2"] = ["Lumina2Pipeline", "Lumina2Text2ImgPipeline"] _import_structure["marigold"].extend( @@ -618,7 +618,7 @@ LEditsPPPipelineStableDiffusion, LEditsPPPipelineStableDiffusionXL, ) - from .ltx import LTXImageToVideoPipeline, LTXPipeline + from .ltx import LTXConditionPipeline, LTXImageToVideoPipeline, LTXPipeline from .lumina import LuminaPipeline, LuminaText2ImgPipeline from .lumina2 import Lumina2Pipeline, Lumina2Text2ImgPipeline from .marigold import ( diff --git a/src/diffusers/pipelines/ltx/__init__.py b/src/diffusers/pipelines/ltx/__init__.py index 20cc1c216522..199e730d9b4d 100644 --- a/src/diffusers/pipelines/ltx/__init__.py +++ b/src/diffusers/pipelines/ltx/__init__.py @@ -23,6 +23,7 @@ _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) else: _import_structure["pipeline_ltx"] = ["LTXPipeline"] + _import_structure["pipeline_ltx_condition"] = ["LTXConditionPipeline"] _import_structure["pipeline_ltx_image2video"] = ["LTXImageToVideoPipeline"] if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: @@ -34,6 +35,7 @@ from ...utils.dummy_torch_and_transformers_objects import * else: from .pipeline_ltx import LTXPipeline + from .pipeline_ltx_condition import LTXConditionPipeline from .pipeline_ltx_image2video import LTXImageToVideoPipeline else: diff --git a/src/diffusers/pipelines/ltx/pipeline_ltx.py b/src/diffusers/pipelines/ltx/pipeline_ltx.py index 866be61810a9..f7b0811d1a22 100644 --- a/src/diffusers/pipelines/ltx/pipeline_ltx.py +++ b/src/diffusers/pipelines/ltx/pipeline_ltx.py @@ -694,9 +694,8 @@ def __call__( self._num_timesteps = len(timesteps) # 6. Prepare micro-conditions - latent_frame_rate = frame_rate / self.vae_temporal_compression_ratio rope_interpolation_scale = ( - 1 / latent_frame_rate, + self.vae_temporal_compression_ratio / frame_rate, self.vae_spatial_compression_ratio, self.vae_spatial_compression_ratio, ) diff --git a/src/diffusers/pipelines/ltx/pipeline_ltx_condition.py b/src/diffusers/pipelines/ltx/pipeline_ltx_condition.py new file mode 100644 index 000000000000..e7f3666cb2c7 --- /dev/null +++ b/src/diffusers/pipelines/ltx/pipeline_ltx_condition.py @@ -0,0 +1,1174 @@ +# Copyright 2024 Lightricks and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from dataclasses import dataclass +from typing import Any, Callable, Dict, List, Optional, Union + +import PIL.Image +import torch +from transformers import T5EncoderModel, T5TokenizerFast + +from ...callbacks import MultiPipelineCallbacks, PipelineCallback +from ...image_processor import PipelineImageInput +from ...loaders import FromSingleFileMixin, LTXVideoLoraLoaderMixin +from ...models.autoencoders import AutoencoderKLLTXVideo +from ...models.transformers import LTXVideoTransformer3DModel +from ...schedulers import FlowMatchEulerDiscreteScheduler +from ...utils import is_torch_xla_available, logging, replace_example_docstring +from ...utils.torch_utils import randn_tensor +from ...video_processor import VideoProcessor +from ..pipeline_utils import DiffusionPipeline +from .pipeline_output import LTXPipelineOutput + + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import torch + >>> from diffusers.pipelines.ltx.pipeline_ltx_condition import LTXConditionPipeline, LTXVideoCondition + >>> from diffusers.utils import export_to_video, load_video, load_image + + >>> pipe = LTXConditionPipeline.from_pretrained("Lightricks/LTX-Video-0.9.5", torch_dtype=torch.bfloat16) + >>> pipe.to("cuda") + + >>> # Load input image and video + >>> video = load_video( + ... "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/cosmos/cosmos-video2world-input-vid.mp4" + ... ) + >>> image = load_image( + ... "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/cosmos/cosmos-video2world-input.jpg" + ... ) + + >>> # Create conditioning objects + >>> condition1 = LTXVideoCondition( + ... image=image, + ... frame_index=0, + ... ) + >>> condition2 = LTXVideoCondition( + ... video=video, + ... frame_index=80, + ... ) + + >>> prompt = "The video depicts a long, straight highway stretching into the distance, flanked by metal guardrails. The road is divided into multiple lanes, with a few vehicles visible in the far distance. The surrounding landscape features dry, grassy fields on one side and rolling hills on the other. The sky is mostly clear with a few scattered clouds, suggesting a bright, sunny day. And then the camera switch to a winding mountain road covered in snow, with a single vehicle traveling along it. The road is flanked by steep, rocky cliffs and sparse vegetation. The landscape is characterized by rugged terrain and a river visible in the distance. The scene captures the solitude and beauty of a winter drive through a mountainous region." + >>> negative_prompt = "worst quality, inconsistent motion, blurry, jittery, distorted" + + >>> # Generate video + >>> generator = torch.Generator("cuda").manual_seed(0) + >>> video = pipe( + ... conditions=[condition1, condition2], + ... prompt=prompt, + ... negative_prompt=negative_prompt, + ... width=768, + ... height=512, + ... num_frames=161, + ... num_inference_steps=40, + ... generator=generator, + ... ).frames[0] + + >>> export_to_video(video, "output.mp4", fps=24) + ``` +""" + + +@dataclass +class LTXVideoCondition: + """ + Defines a single frame-conditioning item for LTX Video - a single frame or a sequence of frames. + + Attributes: + image (`PIL.Image.Image`): + The image to condition the video on. + video (`List[PIL.Image.Image]`): + The video to condition the video on. + frame_index (`int`): + The frame index at which the image or video will conditionally effect the video generation. + strength (`float`, defaults to `1.0`): + The strength of the conditioning effect. A value of `1.0` means the conditioning effect is fully applied. + """ + + image: Optional[PIL.Image.Image] = None + video: Optional[List[PIL.Image.Image]] = None + frame_index: int = 0 + strength: float = 1.0 + + +# from LTX-Video/ltx_video/schedulers/rf.py +def linear_quadratic_schedule(num_steps, threshold_noise=0.025, linear_steps=None): + if linear_steps is None: + linear_steps = num_steps // 2 + if num_steps < 2: + return torch.tensor([1.0]) + linear_sigma_schedule = [i * threshold_noise / linear_steps for i in range(linear_steps)] + threshold_noise_step_diff = linear_steps - threshold_noise * num_steps + quadratic_steps = num_steps - linear_steps + quadratic_coef = threshold_noise_step_diff / (linear_steps * quadratic_steps**2) + linear_coef = threshold_noise / linear_steps - 2 * threshold_noise_step_diff / (quadratic_steps**2) + const = quadratic_coef * (linear_steps**2) + quadratic_sigma_schedule = [ + quadratic_coef * (i**2) + linear_coef * i + const for i in range(linear_steps, num_steps) + ] + sigma_schedule = linear_sigma_schedule + quadratic_sigma_schedule + [1.0] + sigma_schedule = [1.0 - x for x in sigma_schedule] + return torch.tensor(sigma_schedule[:-1]) + + +# Copied from diffusers.pipelines.flux.pipeline_flux.calculate_shift +def calculate_shift( + image_seq_len, + base_seq_len: int = 256, + max_seq_len: int = 4096, + base_shift: float = 0.5, + max_shift: float = 1.15, +): + m = (max_shift - base_shift) / (max_seq_len - base_seq_len) + b = base_shift - m * base_seq_len + mu = image_seq_len * m + b + return mu + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: Optional[int] = None, + device: Optional[Union[str, torch.device]] = None, + timesteps: Optional[List[int]] = None, + sigmas: Optional[List[float]] = None, + **kwargs, +): + r""" + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`List[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`List[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents +def retrieve_latents( + encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample" +): + if hasattr(encoder_output, "latent_dist") and sample_mode == "sample": + return encoder_output.latent_dist.sample(generator) + elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax": + return encoder_output.latent_dist.mode() + elif hasattr(encoder_output, "latents"): + return encoder_output.latents + else: + raise AttributeError("Could not access latents of provided encoder_output") + + +class LTXConditionPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraLoaderMixin): + r""" + Pipeline for image-to-video generation. + + Reference: https://github.com/Lightricks/LTX-Video + + Args: + transformer ([`LTXVideoTransformer3DModel`]): + Conditional Transformer architecture to denoise the encoded video latents. + scheduler ([`FlowMatchEulerDiscreteScheduler`]): + A scheduler to be used in combination with `transformer` to denoise the encoded image latents. + vae ([`AutoencoderKLLTXVideo`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. + text_encoder ([`T5EncoderModel`]): + [T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5EncoderModel), specifically + the [google/t5-v1_1-xxl](https://huggingface.co/google/t5-v1_1-xxl) variant. + tokenizer (`CLIPTokenizer`): + Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/en/model_doc/clip#transformers.CLIPTokenizer). + tokenizer (`T5TokenizerFast`): + Second Tokenizer of class + [T5TokenizerFast](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5TokenizerFast). + """ + + model_cpu_offload_seq = "text_encoder->transformer->vae" + _optional_components = [] + _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"] + + def __init__( + self, + scheduler: FlowMatchEulerDiscreteScheduler, + vae: AutoencoderKLLTXVideo, + text_encoder: T5EncoderModel, + tokenizer: T5TokenizerFast, + transformer: LTXVideoTransformer3DModel, + ): + super().__init__() + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + transformer=transformer, + scheduler=scheduler, + ) + + self.vae_spatial_compression_ratio = ( + self.vae.spatial_compression_ratio if getattr(self, "vae", None) is not None else 32 + ) + self.vae_temporal_compression_ratio = ( + self.vae.temporal_compression_ratio if getattr(self, "vae", None) is not None else 8 + ) + self.transformer_spatial_patch_size = ( + self.transformer.config.patch_size if getattr(self, "transformer", None) is not None else 1 + ) + self.transformer_temporal_patch_size = ( + self.transformer.config.patch_size_t if getattr(self, "transformer") is not None else 1 + ) + + self.video_processor = VideoProcessor(vae_scale_factor=self.vae_spatial_compression_ratio) + self.tokenizer_max_length = ( + self.tokenizer.model_max_length if getattr(self, "tokenizer", None) is not None else 128 + ) + + self.default_height = 512 + self.default_width = 704 + self.default_frames = 121 + + def _get_t5_prompt_embeds( + self, + prompt: Union[str, List[str]] = None, + num_videos_per_prompt: int = 1, + max_sequence_length: int = 256, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + ): + device = device or self._execution_device + dtype = dtype or self.text_encoder.dtype + + prompt = [prompt] if isinstance(prompt, str) else prompt + batch_size = len(prompt) + + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=max_sequence_length, + truncation=True, + add_special_tokens=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + prompt_attention_mask = text_inputs.attention_mask + prompt_attention_mask = prompt_attention_mask.bool().to(device) + + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): + removed_text = self.tokenizer.batch_decode(untruncated_ids[:, max_sequence_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because `max_sequence_length` is set to " + f" {max_sequence_length} tokens: {removed_text}" + ) + + prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=prompt_attention_mask)[0] + prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) + + # duplicate text embeddings for each generation per prompt, using mps friendly method + _, seq_len, _ = prompt_embeds.shape + prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1) + + prompt_attention_mask = prompt_attention_mask.view(batch_size, -1) + prompt_attention_mask = prompt_attention_mask.repeat(num_videos_per_prompt, 1) + + return prompt_embeds, prompt_attention_mask + + # Copied from diffusers.pipelines.mochi.pipeline_mochi.MochiPipeline.encode_prompt + def encode_prompt( + self, + prompt: Union[str, List[str]], + negative_prompt: Optional[Union[str, List[str]]] = None, + do_classifier_free_guidance: bool = True, + num_videos_per_prompt: int = 1, + prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + prompt_attention_mask: Optional[torch.Tensor] = None, + negative_prompt_attention_mask: Optional[torch.Tensor] = None, + max_sequence_length: int = 256, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + do_classifier_free_guidance (`bool`, *optional*, defaults to `True`): + Whether to use classifier free guidance or not. + num_videos_per_prompt (`int`, *optional*, defaults to 1): + Number of videos that should be generated per prompt. torch device to place the resulting embeddings on + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + device: (`torch.device`, *optional*): + torch device + dtype: (`torch.dtype`, *optional*): + torch dtype + """ + device = device or self._execution_device + + prompt = [prompt] if isinstance(prompt, str) else prompt + if prompt is not None: + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if prompt_embeds is None: + prompt_embeds, prompt_attention_mask = self._get_t5_prompt_embeds( + prompt=prompt, + num_videos_per_prompt=num_videos_per_prompt, + max_sequence_length=max_sequence_length, + device=device, + dtype=dtype, + ) + + if do_classifier_free_guidance and negative_prompt_embeds is None: + negative_prompt = negative_prompt or "" + negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt + + if prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + + negative_prompt_embeds, negative_prompt_attention_mask = self._get_t5_prompt_embeds( + prompt=negative_prompt, + num_videos_per_prompt=num_videos_per_prompt, + max_sequence_length=max_sequence_length, + device=device, + dtype=dtype, + ) + + return prompt_embeds, prompt_attention_mask, negative_prompt_embeds, negative_prompt_attention_mask + + def check_inputs( + self, + prompt, + conditions, + image, + video, + frame_index, + strength, + height, + width, + callback_on_step_end_tensor_inputs=None, + prompt_embeds=None, + negative_prompt_embeds=None, + prompt_attention_mask=None, + negative_prompt_attention_mask=None, + ): + if height % 32 != 0 or width % 32 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 32 but are {height} and {width}.") + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if prompt_embeds is not None and prompt_attention_mask is None: + raise ValueError("Must provide `prompt_attention_mask` when specifying `prompt_embeds`.") + + if negative_prompt_embeds is not None and negative_prompt_attention_mask is None: + raise ValueError("Must provide `negative_prompt_attention_mask` when specifying `negative_prompt_embeds`.") + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + if prompt_attention_mask.shape != negative_prompt_attention_mask.shape: + raise ValueError( + "`prompt_attention_mask` and `negative_prompt_attention_mask` must have the same shape when passed directly, but" + f" got: `prompt_attention_mask` {prompt_attention_mask.shape} != `negative_prompt_attention_mask`" + f" {negative_prompt_attention_mask.shape}." + ) + + if conditions is not None and (image is not None or video is not None): + raise ValueError("If `conditions` is provided, `image` and `video` must not be provided.") + + if conditions is None and (image is None and video is None): + raise ValueError("If `conditions` is not provided, `image` or `video` must be provided.") + + if conditions is None: + if isinstance(image, list) and isinstance(frame_index, list) and len(image) != len(frame_index): + raise ValueError( + "If `conditions` is not provided, `image` and `frame_index` must be of the same length." + ) + elif isinstance(image, list) and isinstance(strength, list) and len(image) != len(strength): + raise ValueError("If `conditions` is not provided, `image` and `strength` must be of the same length.") + elif isinstance(video, list) and isinstance(frame_index, list) and len(video) != len(frame_index): + raise ValueError( + "If `conditions` is not provided, `video` and `frame_index` must be of the same length." + ) + elif isinstance(video, list) and isinstance(strength, list) and len(video) != len(strength): + raise ValueError("If `conditions` is not provided, `video` and `strength` must be of the same length.") + + @staticmethod + def _prepare_video_ids( + batch_size: int, + num_frames: int, + height: int, + width: int, + patch_size: int = 1, + patch_size_t: int = 1, + device: torch.device = None, + ) -> torch.Tensor: + latent_sample_coords = torch.meshgrid( + torch.arange(0, num_frames, patch_size_t, device=device), + torch.arange(0, height, patch_size, device=device), + torch.arange(0, width, patch_size, device=device), + indexing="ij", + ) + latent_sample_coords = torch.stack(latent_sample_coords, dim=0) + latent_coords = latent_sample_coords.unsqueeze(0).repeat(batch_size, 1, 1, 1, 1) + latent_coords = latent_coords.reshape(batch_size, -1, num_frames * height * width) + + return latent_coords + + @staticmethod + def _scale_video_ids( + video_ids: torch.Tensor, + scale_factor: int = 32, + scale_factor_t: int = 8, + frame_index: int = 0, + device: torch.device = None, + ) -> torch.Tensor: + scaled_latent_coords = ( + video_ids + * torch.tensor([scale_factor_t, scale_factor, scale_factor], device=video_ids.device)[None, :, None] + ) + scaled_latent_coords[:, 0] = (scaled_latent_coords[:, 0] + 1 - scale_factor_t).clamp(min=0) + scaled_latent_coords[:, 0] += frame_index + + return scaled_latent_coords + + @staticmethod + # Copied from diffusers.pipelines.ltx.pipeline_ltx.LTXPipeline._pack_latents + def _pack_latents(latents: torch.Tensor, patch_size: int = 1, patch_size_t: int = 1) -> torch.Tensor: + # Unpacked latents of shape are [B, C, F, H, W] are patched into tokens of shape [B, C, F // p_t, p_t, H // p, p, W // p, p]. + # The patch dimensions are then permuted and collapsed into the channel dimension of shape: + # [B, F // p_t * H // p * W // p, C * p_t * p * p] (an ndim=3 tensor). + # dim=0 is the batch size, dim=1 is the effective video sequence length, dim=2 is the effective number of input features + batch_size, num_channels, num_frames, height, width = latents.shape + post_patch_num_frames = num_frames // patch_size_t + post_patch_height = height // patch_size + post_patch_width = width // patch_size + latents = latents.reshape( + batch_size, + -1, + post_patch_num_frames, + patch_size_t, + post_patch_height, + patch_size, + post_patch_width, + patch_size, + ) + latents = latents.permute(0, 2, 4, 6, 1, 3, 5, 7).flatten(4, 7).flatten(1, 3) + return latents + + @staticmethod + # Copied from diffusers.pipelines.ltx.pipeline_ltx.LTXPipeline._unpack_latents + def _unpack_latents( + latents: torch.Tensor, num_frames: int, height: int, width: int, patch_size: int = 1, patch_size_t: int = 1 + ) -> torch.Tensor: + # Packed latents of shape [B, S, D] (S is the effective video sequence length, D is the effective feature dimensions) + # are unpacked and reshaped into a video tensor of shape [B, C, F, H, W]. This is the inverse operation of + # what happens in the `_pack_latents` method. + batch_size = latents.size(0) + latents = latents.reshape(batch_size, num_frames, height, width, -1, patch_size_t, patch_size, patch_size) + latents = latents.permute(0, 4, 1, 5, 2, 6, 3, 7).flatten(6, 7).flatten(4, 5).flatten(2, 3) + return latents + + @staticmethod + # Copied from diffusers.pipelines.ltx.pipeline_ltx.LTXPipeline._normalize_latents + def _normalize_latents( + latents: torch.Tensor, latents_mean: torch.Tensor, latents_std: torch.Tensor, scaling_factor: float = 1.0 + ) -> torch.Tensor: + # Normalize latents across the channel dimension [B, C, F, H, W] + latents_mean = latents_mean.view(1, -1, 1, 1, 1).to(latents.device, latents.dtype) + latents_std = latents_std.view(1, -1, 1, 1, 1).to(latents.device, latents.dtype) + latents = (latents - latents_mean) * scaling_factor / latents_std + return latents + + @staticmethod + # Copied from diffusers.pipelines.ltx.pipeline_ltx.LTXPipeline._denormalize_latents + def _denormalize_latents( + latents: torch.Tensor, latents_mean: torch.Tensor, latents_std: torch.Tensor, scaling_factor: float = 1.0 + ) -> torch.Tensor: + # Denormalize latents across the channel dimension [B, C, F, H, W] + latents_mean = latents_mean.view(1, -1, 1, 1, 1).to(latents.device, latents.dtype) + latents_std = latents_std.view(1, -1, 1, 1, 1).to(latents.device, latents.dtype) + latents = latents * latents_std / scaling_factor + latents_mean + return latents + + def trim_conditioning_sequence(self, start_frame: int, sequence_num_frames: int, target_num_frames: int): + """ + Trim a conditioning sequence to the allowed number of frames. + + Args: + start_frame (int): The target frame number of the first frame in the sequence. + sequence_num_frames (int): The number of frames in the sequence. + target_num_frames (int): The target number of frames in the generated video. + Returns: + int: updated sequence length + """ + scale_factor = self.vae_temporal_compression_ratio + num_frames = min(sequence_num_frames, target_num_frames - start_frame) + # Trim down to a multiple of temporal_scale_factor frames plus 1 + num_frames = (num_frames - 1) // scale_factor * scale_factor + 1 + return num_frames + + @staticmethod + def add_noise_to_image_conditioning_latents( + t: float, + init_latents: torch.Tensor, + latents: torch.Tensor, + noise_scale: float, + conditioning_mask: torch.Tensor, + generator, + eps=1e-6, + ): + """ + Add timestep-dependent noise to the hard-conditioning latents. This helps with motion continuity, especially + when conditioned on a single frame. + """ + noise = randn_tensor( + latents.shape, + generator=generator, + device=latents.device, + dtype=latents.dtype, + ) + # Add noise only to hard-conditioning latents (conditioning_mask = 1.0) + need_to_noise = (conditioning_mask > 1.0 - eps).unsqueeze(-1) + noised_latents = init_latents + noise_scale * noise * (t**2) + latents = torch.where(need_to_noise, noised_latents, latents) + return latents + + def prepare_latents( + self, + conditions: List[torch.Tensor], + condition_strength: List[float], + condition_frame_index: List[int], + batch_size: int = 1, + num_channels_latents: int = 128, + height: int = 512, + width: int = 704, + num_frames: int = 161, + num_prefix_latent_frames: int = 2, + generator: Optional[torch.Generator] = None, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + ) -> None: + num_latent_frames = (num_frames - 1) // self.vae_temporal_compression_ratio + 1 + latent_height = height // self.vae_spatial_compression_ratio + latent_width = width // self.vae_spatial_compression_ratio + + shape = (batch_size, num_channels_latents, num_latent_frames, latent_height, latent_width) + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + + condition_latent_frames_mask = torch.zeros((batch_size, num_latent_frames), device=device, dtype=torch.float32) + + extra_conditioning_latents = [] + extra_conditioning_video_ids = [] + extra_conditioning_mask = [] + extra_conditioning_num_latents = 0 + for data, strength, frame_index in zip(conditions, condition_strength, condition_frame_index): + condition_latents = retrieve_latents(self.vae.encode(data), generator=generator) + condition_latents = self._normalize_latents( + condition_latents, self.vae.latents_mean, self.vae.latents_std + ).to(device, dtype=dtype) + + num_data_frames = data.size(2) + num_cond_frames = condition_latents.size(2) + + if frame_index == 0: + latents[:, :, :num_cond_frames] = torch.lerp( + latents[:, :, :num_cond_frames], condition_latents, strength + ) + condition_latent_frames_mask[:, :num_cond_frames] = strength + + else: + if num_data_frames > 1: + if num_cond_frames < num_prefix_latent_frames: + raise ValueError( + f"Number of latent frames must be at least {num_prefix_latent_frames} but got {num_data_frames}." + ) + + if num_cond_frames > num_prefix_latent_frames: + start_frame = frame_index // self.vae_temporal_compression_ratio + num_prefix_latent_frames + end_frame = start_frame + num_cond_frames - num_prefix_latent_frames + latents[:, :, start_frame:end_frame] = torch.lerp( + latents[:, :, start_frame:end_frame], + condition_latents[:, :, num_prefix_latent_frames:], + strength, + ) + condition_latent_frames_mask[:, start_frame:end_frame] = strength + condition_latents = condition_latents[:, :, :num_prefix_latent_frames] + + noise = randn_tensor(condition_latents.shape, generator=generator, device=device, dtype=dtype) + condition_latents = torch.lerp(noise, condition_latents, strength) + + condition_video_ids = self._prepare_video_ids( + batch_size, + condition_latents.size(2), + latent_height, + latent_width, + patch_size=self.transformer_spatial_patch_size, + patch_size_t=self.transformer_temporal_patch_size, + device=device, + ) + condition_video_ids = self._scale_video_ids( + condition_video_ids, + scale_factor=self.vae_spatial_compression_ratio, + scale_factor_t=self.vae_temporal_compression_ratio, + frame_index=frame_index, + device=device, + ) + condition_latents = self._pack_latents( + condition_latents, + self.transformer_spatial_patch_size, + self.transformer_temporal_patch_size, + ) + condition_conditioning_mask = torch.full( + condition_latents.shape[:2], strength, device=device, dtype=dtype + ) + + extra_conditioning_latents.append(condition_latents) + extra_conditioning_video_ids.append(condition_video_ids) + extra_conditioning_mask.append(condition_conditioning_mask) + extra_conditioning_num_latents += condition_latents.size(1) + + video_ids = self._prepare_video_ids( + batch_size, + num_latent_frames, + latent_height, + latent_width, + patch_size_t=self.transformer_temporal_patch_size, + patch_size=self.transformer_spatial_patch_size, + device=device, + ) + conditioning_mask = condition_latent_frames_mask.gather(1, video_ids[:, 0]) + video_ids = self._scale_video_ids( + video_ids, + scale_factor=self.vae_spatial_compression_ratio, + scale_factor_t=self.vae_temporal_compression_ratio, + frame_index=0, + device=device, + ) + latents = self._pack_latents( + latents, self.transformer_spatial_patch_size, self.transformer_temporal_patch_size + ) + + if len(extra_conditioning_latents) > 0: + latents = torch.cat([*extra_conditioning_latents, latents], dim=1) + video_ids = torch.cat([*extra_conditioning_video_ids, video_ids], dim=2) + conditioning_mask = torch.cat([*extra_conditioning_mask, conditioning_mask], dim=1) + + return latents, conditioning_mask, video_ids, extra_conditioning_num_latents + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def do_classifier_free_guidance(self): + return self._guidance_scale > 1.0 + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def attention_kwargs(self): + return self._attention_kwargs + + @property + def interrupt(self): + return self._interrupt + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + conditions: Union[LTXVideoCondition, List[LTXVideoCondition]] = None, + image: Union[PipelineImageInput, List[PipelineImageInput]] = None, + video: List[PipelineImageInput] = None, + frame_index: Union[int, List[int]] = 0, + strength: Union[float, List[float]] = 1.0, + prompt: Union[str, List[str]] = None, + negative_prompt: Optional[Union[str, List[str]]] = None, + height: int = 512, + width: int = 704, + num_frames: int = 161, + frame_rate: int = 25, + num_inference_steps: int = 50, + timesteps: List[int] = None, + guidance_scale: float = 3, + image_cond_noise_scale: float = 0.15, + num_videos_per_prompt: Optional[int] = 1, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.Tensor] = None, + prompt_embeds: Optional[torch.Tensor] = None, + prompt_attention_mask: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_attention_mask: Optional[torch.Tensor] = None, + decode_timestep: Union[float, List[float]] = 0.0, + decode_noise_scale: Optional[Union[float, List[float]]] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + attention_kwargs: Optional[Dict[str, Any]] = None, + callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + max_sequence_length: int = 256, + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + conditions (`List[LTXVideoCondition], *optional*`): + The list of frame-conditioning items for the video generation.If not provided, conditions will be + created using `image`, `video`, `frame_index` and `strength`. + image (`PipelineImageInput` or `List[PipelineImageInput]`, *optional*): + The image or images to condition the video generation. If not provided, one has to pass `video` or + `conditions`. + video (`List[PipelineImageInput]`, *optional*): + The video to condition the video generation. If not provided, one has to pass `image` or `conditions`. + frame_index (`int` or `List[int]`, *optional*): + The frame index or frame indices at which the image or video will conditionally effect the video + generation. If not provided, one has to pass `conditions`. + strength (`float` or `List[float]`, *optional*): + The strength or strengths of the conditioning effect. If not provided, one has to pass `conditions`. + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + height (`int`, defaults to `512`): + The height in pixels of the generated image. This is set to 480 by default for the best results. + width (`int`, defaults to `704`): + The width in pixels of the generated image. This is set to 848 by default for the best results. + num_frames (`int`, defaults to `161`): + The number of video frames to generate + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + timesteps (`List[int]`, *optional*): + Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument + in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is + passed will be used. Must be in descending order. + guidance_scale (`float`, defaults to `3 `): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + num_videos_per_prompt (`int`, *optional*, defaults to 1): + The number of videos to generate per prompt. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will ge generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + prompt_attention_mask (`torch.Tensor`, *optional*): + Pre-generated attention mask for text embeddings. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. For PixArt-Sigma this negative prompt should be "". If not + provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. + negative_prompt_attention_mask (`torch.FloatTensor`, *optional*): + Pre-generated attention mask for negative text embeddings. + decode_timestep (`float`, defaults to `0.0`): + The timestep at which generated video is decoded. + decode_noise_scale (`float`, defaults to `None`): + The interpolation factor between random noise and denoised latents at the decode timestep. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.ltx.LTXPipelineOutput`] instead of a plain tuple. + attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + callback_on_step_end (`Callable`, *optional*): + A function that calls at the end of each denoising steps during the inference. The function is called + with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, + callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by + `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`List`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + max_sequence_length (`int` defaults to `128 `): + Maximum sequence length to use with the `prompt`. + + Examples: + + Returns: + [`~pipelines.ltx.LTXPipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`~pipelines.ltx.LTXPipelineOutput`] is returned, otherwise a `tuple` is + returned where the first element is a list with the generated images. + """ + + if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): + callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs + if latents is not None: + raise ValueError("Passing latents is not yet supported.") + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt=prompt, + conditions=conditions, + image=image, + video=video, + frame_index=frame_index, + strength=strength, + height=height, + width=width, + callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + prompt_attention_mask=prompt_attention_mask, + negative_prompt_attention_mask=negative_prompt_attention_mask, + ) + + self._guidance_scale = guidance_scale + self._attention_kwargs = attention_kwargs + self._interrupt = False + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if conditions is not None: + if not isinstance(conditions, list): + conditions = [conditions] + + strength = [condition.strength for condition in conditions] + frame_index = [condition.frame_index for condition in conditions] + image = [condition.image for condition in conditions] + video = [condition.video for condition in conditions] + else: + if not isinstance(image, list): + image = [image] + num_conditions = 1 + elif isinstance(image, list): + num_conditions = len(image) + if not isinstance(video, list): + video = [video] + num_conditions = 1 + elif isinstance(video, list): + num_conditions = len(video) + + if not isinstance(frame_index, list): + frame_index = [frame_index] * num_conditions + if not isinstance(strength, list): + strength = [strength] * num_conditions + + device = self._execution_device + + # 3. Prepare text embeddings + ( + prompt_embeds, + prompt_attention_mask, + negative_prompt_embeds, + negative_prompt_attention_mask, + ) = self.encode_prompt( + prompt=prompt, + negative_prompt=negative_prompt, + do_classifier_free_guidance=self.do_classifier_free_guidance, + num_videos_per_prompt=num_videos_per_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + prompt_attention_mask=prompt_attention_mask, + negative_prompt_attention_mask=negative_prompt_attention_mask, + max_sequence_length=max_sequence_length, + device=device, + ) + if self.do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) + prompt_attention_mask = torch.cat([negative_prompt_attention_mask, prompt_attention_mask], dim=0) + + vae_dtype = self.vae.dtype + + conditioning_tensors = [] + for condition_image, condition_video, condition_frame_index, condition_strength in zip( + image, video, frame_index, strength + ): + if condition_image is not None: + condition_tensor = ( + self.video_processor.preprocess(condition_image, height, width) + .unsqueeze(2) + .to(device, dtype=vae_dtype) + ) + elif condition_video is not None: + condition_tensor = self.video_processor.preprocess_video(condition_video, height, width) + num_frames_input = condition_tensor.size(2) + num_frames_output = self.trim_conditioning_sequence( + condition_frame_index, num_frames_input, num_frames + ) + condition_tensor = condition_tensor[:, :, :num_frames_output] + condition_tensor = condition_tensor.to(device, dtype=vae_dtype) + else: + raise ValueError("Either `image` or `video` must be provided in the `LTXVideoCondition`.") + + if condition_tensor.size(2) % self.vae_temporal_compression_ratio != 1: + raise ValueError( + f"Number of frames in the video must be of the form (k * {self.vae_temporal_compression_ratio} + 1) " + f"but got {condition_tensor.size(2)} frames." + ) + conditioning_tensors.append(condition_tensor) + + # 4. Prepare latent variables + num_channels_latents = self.transformer.config.in_channels + latents, conditioning_mask, video_coords, extra_conditioning_num_latents = self.prepare_latents( + conditioning_tensors, + strength, + frame_index, + batch_size=batch_size * num_videos_per_prompt, + num_channels_latents=num_channels_latents, + height=height, + width=width, + num_frames=num_frames, + generator=generator, + device=device, + dtype=torch.float32, + ) + + video_coords = video_coords.float() + video_coords[:, 0] = video_coords[:, 0] * (1.0 / frame_rate) + + init_latents = latents.clone() + + if self.do_classifier_free_guidance: + video_coords = torch.cat([video_coords, video_coords], dim=0) + + # 5. Prepare timesteps + latent_num_frames = (num_frames - 1) // self.vae_temporal_compression_ratio + 1 + latent_height = height // self.vae_spatial_compression_ratio + latent_width = width // self.vae_spatial_compression_ratio + sigmas = linear_quadratic_schedule(num_inference_steps) + timesteps = sigmas * 1000 + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, + num_inference_steps, + device, + timesteps=timesteps, + ) + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + self._num_timesteps = len(timesteps) + + # 7. Denoising loop + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + if image_cond_noise_scale > 0: + # Add timestep-dependent noise to the hard-conditioning latents + # This helps with motion continuity, especially when conditioned on a single frame + latents = self.add_noise_to_image_conditioning_latents( + t / 1000.0, + init_latents, + latents, + image_cond_noise_scale, + conditioning_mask, + generator, + ) + + latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents + conditioning_mask_model_input = ( + torch.cat([conditioning_mask, conditioning_mask]) + if self.do_classifier_free_guidance + else conditioning_mask + ) + latent_model_input = latent_model_input.to(prompt_embeds.dtype) + + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timestep = t.expand(latent_model_input.shape[0]).unsqueeze(-1).float() + timestep = torch.min(timestep, (1 - conditioning_mask_model_input) * 1000.0) + + noise_pred = self.transformer( + hidden_states=latent_model_input, + encoder_hidden_states=prompt_embeds, + timestep=timestep, + encoder_attention_mask=prompt_attention_mask, + video_coords=video_coords, + attention_kwargs=attention_kwargs, + return_dict=False, + )[0] + + if self.do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) + timestep, _ = timestep.chunk(2) + + denoised_latents = self.scheduler.step( + -noise_pred, t, latents, per_token_timesteps=timestep, return_dict=False + )[0] + tokens_to_denoise_mask = (t / 1000 - 1e-6 < (1.0 - conditioning_mask)).unsqueeze(-1) + latents = torch.where(tokens_to_denoise_mask, denoised_latents, latents) + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + if XLA_AVAILABLE: + xm.mark_step() + + latents = latents[:, extra_conditioning_num_latents:] + latents = self._unpack_latents( + latents, + latent_num_frames, + latent_height, + latent_width, + self.transformer_spatial_patch_size, + self.transformer_temporal_patch_size, + ) + + if output_type == "latent": + video = latents + else: + latents = self._denormalize_latents( + latents, self.vae.latents_mean, self.vae.latents_std, self.vae.config.scaling_factor + ) + latents = latents.to(prompt_embeds.dtype) + + if not self.vae.config.timestep_conditioning: + timestep = None + else: + noise = torch.randn(latents.shape, generator=generator, device=device, dtype=latents.dtype) + if not isinstance(decode_timestep, list): + decode_timestep = [decode_timestep] * batch_size + if decode_noise_scale is None: + decode_noise_scale = decode_timestep + elif not isinstance(decode_noise_scale, list): + decode_noise_scale = [decode_noise_scale] * batch_size + + timestep = torch.tensor(decode_timestep, device=device, dtype=latents.dtype) + decode_noise_scale = torch.tensor(decode_noise_scale, device=device, dtype=latents.dtype)[ + :, None, None, None, None + ] + latents = (1 - decode_noise_scale) * latents + decode_noise_scale * noise + + video = self.vae.decode(latents, timestep, return_dict=False)[0] + video = self.video_processor.postprocess_video(video, output_type=output_type) + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (video,) + + return LTXPipelineOutput(frames=video) diff --git a/src/diffusers/pipelines/ltx/pipeline_ltx_image2video.py b/src/diffusers/pipelines/ltx/pipeline_ltx_image2video.py index 0577a56ec13d..6c4214fe1b26 100644 --- a/src/diffusers/pipelines/ltx/pipeline_ltx_image2video.py +++ b/src/diffusers/pipelines/ltx/pipeline_ltx_image2video.py @@ -764,9 +764,8 @@ def __call__( self._num_timesteps = len(timesteps) # 6. Prepare micro-conditions - latent_frame_rate = frame_rate / self.vae_temporal_compression_ratio rope_interpolation_scale = ( - 1 / latent_frame_rate, + self.vae_temporal_compression_ratio / frame_rate, self.vae_spatial_compression_ratio, self.vae_spatial_compression_ratio, ) diff --git a/src/diffusers/schedulers/scheduling_flow_match_euler_discrete.py b/src/diffusers/schedulers/scheduling_flow_match_euler_discrete.py index e3bff7582cd9..cbb27e5fad63 100644 --- a/src/diffusers/schedulers/scheduling_flow_match_euler_discrete.py +++ b/src/diffusers/schedulers/scheduling_flow_match_euler_discrete.py @@ -377,6 +377,7 @@ def step( s_tmax: float = float("inf"), s_noise: float = 1.0, generator: Optional[torch.Generator] = None, + per_token_timesteps: Optional[torch.Tensor] = None, return_dict: bool = True, ) -> Union[FlowMatchEulerDiscreteSchedulerOutput, Tuple]: """ @@ -397,6 +398,8 @@ def step( Scaling factor for noise added to the sample. generator (`torch.Generator`, *optional*): A random number generator. + per_token_timesteps (`torch.Tensor`, *optional*): + The timesteps for each token in the sample. return_dict (`bool`): Whether or not to return a [`~schedulers.scheduling_flow_match_euler_discrete.FlowMatchEulerDiscreteSchedulerOutput`] or tuple. @@ -427,16 +430,26 @@ def step( # Upcast to avoid precision issues when computing prev_sample sample = sample.to(torch.float32) - sigma = self.sigmas[self.step_index] - sigma_next = self.sigmas[self.step_index + 1] + if per_token_timesteps is not None: + per_token_sigmas = per_token_timesteps / self.config.num_train_timesteps - prev_sample = sample + (sigma_next - sigma) * model_output + sigmas = self.sigmas[:, None, None] + lower_mask = sigmas < per_token_sigmas[None] - 1e-6 + lower_sigmas = lower_mask * sigmas + lower_sigmas, _ = lower_sigmas.max(dim=0) + dt = (per_token_sigmas - lower_sigmas)[..., None] + else: + sigma = self.sigmas[self.step_index] + sigma_next = self.sigmas[self.step_index + 1] + dt = sigma_next - sigma - # Cast sample back to model compatible dtype - prev_sample = prev_sample.to(model_output.dtype) + prev_sample = sample + dt * model_output # upon completion increase step index by one self._step_index += 1 + if per_token_timesteps is None: + # Cast sample back to model compatible dtype + prev_sample = prev_sample.to(model_output.dtype) if not return_dict: return (prev_sample,) diff --git a/src/diffusers/utils/dummy_torch_and_transformers_objects.py b/src/diffusers/utils/dummy_torch_and_transformers_objects.py index 841ffbdafa52..fafe3081d6c3 100644 --- a/src/diffusers/utils/dummy_torch_and_transformers_objects.py +++ b/src/diffusers/utils/dummy_torch_and_transformers_objects.py @@ -1202,6 +1202,21 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) +class LTXConditionPipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + class LTXImageToVideoPipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] diff --git a/tests/pipelines/ltx/test_ltx_condition.py b/tests/pipelines/ltx/test_ltx_condition.py new file mode 100644 index 000000000000..dbb9a740b433 --- /dev/null +++ b/tests/pipelines/ltx/test_ltx_condition.py @@ -0,0 +1,284 @@ +# Copyright 2024 The HuggingFace Team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +import unittest + +import numpy as np +import torch +from transformers import AutoTokenizer, T5EncoderModel + +from diffusers import ( + AutoencoderKLLTXVideo, + FlowMatchEulerDiscreteScheduler, + LTXConditionPipeline, + LTXVideoTransformer3DModel, +) +from diffusers.pipelines.ltx.pipeline_ltx_condition import LTXVideoCondition +from diffusers.utils.testing_utils import enable_full_determinism, torch_device + +from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS +from ..test_pipelines_common import PipelineTesterMixin, to_np + + +enable_full_determinism() + + +class LTXConditionPipelineFastTests(PipelineTesterMixin, unittest.TestCase): + pipeline_class = LTXConditionPipeline + params = TEXT_TO_IMAGE_PARAMS - {"cross_attention_kwargs"} + batch_params = TEXT_TO_IMAGE_BATCH_PARAMS.union({"image"}) + image_params = TEXT_TO_IMAGE_IMAGE_PARAMS + image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS + required_optional_params = frozenset( + [ + "num_inference_steps", + "generator", + "latents", + "return_dict", + "callback_on_step_end", + "callback_on_step_end_tensor_inputs", + ] + ) + test_xformers_attention = False + + def get_dummy_components(self): + torch.manual_seed(0) + transformer = LTXVideoTransformer3DModel( + in_channels=8, + out_channels=8, + patch_size=1, + patch_size_t=1, + num_attention_heads=4, + attention_head_dim=8, + cross_attention_dim=32, + num_layers=1, + caption_channels=32, + ) + + torch.manual_seed(0) + vae = AutoencoderKLLTXVideo( + in_channels=3, + out_channels=3, + latent_channels=8, + block_out_channels=(8, 8, 8, 8), + decoder_block_out_channels=(8, 8, 8, 8), + layers_per_block=(1, 1, 1, 1, 1), + decoder_layers_per_block=(1, 1, 1, 1, 1), + spatio_temporal_scaling=(True, True, False, False), + decoder_spatio_temporal_scaling=(True, True, False, False), + decoder_inject_noise=(False, False, False, False, False), + upsample_residual=(False, False, False, False), + upsample_factor=(1, 1, 1, 1), + timestep_conditioning=False, + patch_size=1, + patch_size_t=1, + encoder_causal=True, + decoder_causal=False, + ) + vae.use_framewise_encoding = False + vae.use_framewise_decoding = False + + torch.manual_seed(0) + scheduler = FlowMatchEulerDiscreteScheduler() + text_encoder = T5EncoderModel.from_pretrained("hf-internal-testing/tiny-random-t5") + tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-t5") + + components = { + "transformer": transformer, + "vae": vae, + "scheduler": scheduler, + "text_encoder": text_encoder, + "tokenizer": tokenizer, + } + return components + + def get_dummy_inputs(self, device, seed=0, use_conditions=False): + if str(device).startswith("mps"): + generator = torch.manual_seed(seed) + else: + generator = torch.Generator(device=device).manual_seed(seed) + + image = torch.randn((1, 3, 32, 32), generator=generator, device=device) + if use_conditions: + conditions = LTXVideoCondition( + image=image, + ) + else: + conditions = None + + inputs = { + "conditions": conditions, + "image": None if use_conditions else image, + "prompt": "dance monkey", + "negative_prompt": "", + "generator": generator, + "num_inference_steps": 2, + "guidance_scale": 3.0, + "height": 32, + "width": 32, + # 8 * k + 1 is the recommendation + "num_frames": 9, + "max_sequence_length": 16, + "output_type": "pt", + } + + return inputs + + def test_inference(self): + device = "cpu" + + components = self.get_dummy_components() + pipe = self.pipeline_class(**components) + pipe.to(device) + pipe.set_progress_bar_config(disable=None) + + inputs = self.get_dummy_inputs(device) + inputs2 = self.get_dummy_inputs(device, use_conditions=True) + video = pipe(**inputs).frames + generated_video = video[0] + video2 = pipe(**inputs2).frames + generated_video2 = video2[0] + + self.assertEqual(generated_video.shape, (9, 3, 32, 32)) + + max_diff = np.abs(generated_video - generated_video2).max() + self.assertLessEqual(max_diff, 1e-3) + + def test_callback_inputs(self): + sig = inspect.signature(self.pipeline_class.__call__) + has_callback_tensor_inputs = "callback_on_step_end_tensor_inputs" in sig.parameters + has_callback_step_end = "callback_on_step_end" in sig.parameters + + if not (has_callback_tensor_inputs and has_callback_step_end): + return + + components = self.get_dummy_components() + pipe = self.pipeline_class(**components) + pipe = pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + self.assertTrue( + hasattr(pipe, "_callback_tensor_inputs"), + f" {self.pipeline_class} should have `_callback_tensor_inputs` that defines a list of tensor variables its callback function can use as inputs", + ) + + def callback_inputs_subset(pipe, i, t, callback_kwargs): + # iterate over callback args + for tensor_name, tensor_value in callback_kwargs.items(): + # check that we're only passing in allowed tensor inputs + assert tensor_name in pipe._callback_tensor_inputs + + return callback_kwargs + + def callback_inputs_all(pipe, i, t, callback_kwargs): + for tensor_name in pipe._callback_tensor_inputs: + assert tensor_name in callback_kwargs + + # iterate over callback args + for tensor_name, tensor_value in callback_kwargs.items(): + # check that we're only passing in allowed tensor inputs + assert tensor_name in pipe._callback_tensor_inputs + + return callback_kwargs + + inputs = self.get_dummy_inputs(torch_device) + + # Test passing in a subset + inputs["callback_on_step_end"] = callback_inputs_subset + inputs["callback_on_step_end_tensor_inputs"] = ["latents"] + output = pipe(**inputs)[0] + + # Test passing in a everything + inputs["callback_on_step_end"] = callback_inputs_all + inputs["callback_on_step_end_tensor_inputs"] = pipe._callback_tensor_inputs + output = pipe(**inputs)[0] + + def callback_inputs_change_tensor(pipe, i, t, callback_kwargs): + is_last = i == (pipe.num_timesteps - 1) + if is_last: + callback_kwargs["latents"] = torch.zeros_like(callback_kwargs["latents"]) + return callback_kwargs + + inputs["callback_on_step_end"] = callback_inputs_change_tensor + inputs["callback_on_step_end_tensor_inputs"] = pipe._callback_tensor_inputs + output = pipe(**inputs)[0] + assert output.abs().sum() < 1e10 + + def test_inference_batch_single_identical(self): + self._test_inference_batch_single_identical(batch_size=3, expected_max_diff=1e-3) + + def test_attention_slicing_forward_pass( + self, test_max_difference=True, test_mean_pixel_difference=True, expected_max_diff=1e-3 + ): + if not self.test_attention_slicing: + return + + components = self.get_dummy_components() + pipe = self.pipeline_class(**components) + for component in pipe.components.values(): + if hasattr(component, "set_default_attn_processor"): + component.set_default_attn_processor() + pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + + generator_device = "cpu" + inputs = self.get_dummy_inputs(generator_device) + output_without_slicing = pipe(**inputs)[0] + + pipe.enable_attention_slicing(slice_size=1) + inputs = self.get_dummy_inputs(generator_device) + output_with_slicing1 = pipe(**inputs)[0] + + pipe.enable_attention_slicing(slice_size=2) + inputs = self.get_dummy_inputs(generator_device) + output_with_slicing2 = pipe(**inputs)[0] + + if test_max_difference: + max_diff1 = np.abs(to_np(output_with_slicing1) - to_np(output_without_slicing)).max() + max_diff2 = np.abs(to_np(output_with_slicing2) - to_np(output_without_slicing)).max() + self.assertLess( + max(max_diff1, max_diff2), + expected_max_diff, + "Attention slicing should not affect the inference results", + ) + + def test_vae_tiling(self, expected_diff_max: float = 0.2): + generator_device = "cpu" + components = self.get_dummy_components() + + pipe = self.pipeline_class(**components) + pipe.to("cpu") + pipe.set_progress_bar_config(disable=None) + + # Without tiling + inputs = self.get_dummy_inputs(generator_device) + inputs["height"] = inputs["width"] = 128 + output_without_tiling = pipe(**inputs)[0] + + # With tiling + pipe.vae.enable_tiling( + tile_sample_min_height=96, + tile_sample_min_width=96, + tile_sample_stride_height=64, + tile_sample_stride_width=64, + ) + inputs = self.get_dummy_inputs(generator_device) + inputs["height"] = inputs["width"] = 128 + output_with_tiling = pipe(**inputs)[0] + + self.assertLess( + (to_np(output_without_tiling) - to_np(output_with_tiling)).max(), + expected_diff_max, + "VAE tiling should not affect the inference results", + )