From e202e46fe439601d7b62fa1982a682baf62c9441 Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Thu, 6 Mar 2025 02:58:23 +0100 Subject: [PATCH 01/17] up --- .../models/autoencoders/autoencoder_kl_ltx.py | 6 +- .../pipelines/ltx/pipeline_ltx_condition.py | 77 +++++++++++++------ 2 files changed, 59 insertions(+), 24 deletions(-) diff --git a/src/diffusers/models/autoencoders/autoencoder_kl_ltx.py b/src/diffusers/models/autoencoders/autoencoder_kl_ltx.py index 5967a6e44f7d..974dc34d57e1 100644 --- a/src/diffusers/models/autoencoders/autoencoder_kl_ltx.py +++ b/src/diffusers/models/autoencoders/autoencoder_kl_ltx.py @@ -1105,6 +1105,8 @@ def __init__( scaling_factor: float = 1.0, encoder_causal: bool = True, decoder_causal: bool = False, + spatial_compression_ratio: int = None, + temporal_compression_ratio: int = None, ) -> None: super().__init__() @@ -1142,8 +1144,8 @@ def __init__( self.register_buffer("latents_mean", latents_mean, persistent=True) self.register_buffer("latents_std", latents_std, persistent=True) - self.spatial_compression_ratio = patch_size * 2 ** sum(spatio_temporal_scaling) - self.temporal_compression_ratio = patch_size_t * 2 ** sum(spatio_temporal_scaling) + self.spatial_compression_ratio = patch_size * 2 ** sum(spatio_temporal_scaling) if spatial_compression_ratio is None else spatial_compression_ratio + self.temporal_compression_ratio = patch_size_t * 2 ** sum(spatio_temporal_scaling) if temporal_compression_ratio is None else temporal_compression_ratio # When decoding a batch of video latents at a time, one can save memory by slicing across the batch dimension # to perform decoding of a single video latent at a time. diff --git a/src/diffusers/pipelines/ltx/pipeline_ltx_condition.py b/src/diffusers/pipelines/ltx/pipeline_ltx_condition.py index a37b9b5122f2..197127e90919 100644 --- a/src/diffusers/pipelines/ltx/pipeline_ltx_condition.py +++ b/src/diffusers/pipelines/ltx/pipeline_ltx_condition.py @@ -437,8 +437,8 @@ def check_inputs( ) @staticmethod - # Copied from diffusers.pipelines.ltx.pipeline_ltx.LTXPipeline._pack_latents - def _pack_latents(latents: torch.Tensor, patch_size: int = 1, patch_size_t: int = 1) -> torch.Tensor: + # adapted from diffusers.pipelines.ltx.pipeline_ltx.LTXPipeline._pack_latents + def _pack_latents(latents: torch.Tensor, patch_size: int = 1, patch_size_t: int = 1, device: torch.device = None) -> torch.Tensor: # Unpacked latents of shape are [B, C, F, H, W] are patched into tokens of shape [B, C, F // p_t, p_t, H // p, p, W // p, p]. # The patch dimensions are then permuted and collapsed into the channel dimension of shape: # [B, F // p_t * H // p * W // p, C * p_t * p * p] (an ndim=3 tensor). @@ -447,6 +447,16 @@ def _pack_latents(latents: torch.Tensor, patch_size: int = 1, patch_size_t: int post_patch_num_frames = num_frames // patch_size_t post_patch_height = height // patch_size post_patch_width = width // patch_size + + latent_sample_coords = torch.meshgrid( + torch.arange(0, num_frames, patch_size_t, device=device), + torch.arange(0, height, patch_size, device=device), + torch.arange(0, width, patch_size, device=device), + ) + latent_sample_coords = torch.stack(latent_sample_coords, dim=0) + latent_coords = latent_sample_coords.unsqueeze(0).repeat(batch_size, 1, 1, 1, 1) + latent_coords = latent_coords.reshape(batch_size, -1, num_frames * height * width) + latents = latents.reshape( batch_size, -1, @@ -458,7 +468,7 @@ def _pack_latents(latents: torch.Tensor, patch_size: int = 1, patch_size_t: int patch_size, ) latents = latents.permute(0, 2, 4, 6, 1, 3, 5, 7).flatten(4, 7).flatten(1, 3) - return latents + return latents, latent_coords @staticmethod # Copied from diffusers.pipelines.ltx.pipeline_ltx.LTXPipeline._unpack_latents @@ -544,6 +554,25 @@ def _prepare_non_first_frame_conditioning( return latents, condition_latents, condition_latent_frames_mask + def trim_conditioning_sequence( + self, start_frame: int, sequence_num_frames: int, target_num_frames: int + ): + """ + Trim a conditioning sequence to the allowed number of frames. + Args: + start_frame (int): The target frame number of the first frame in the sequence. + sequence_num_frames (int): The number of frames in the sequence. + target_num_frames (int): The target number of frames in the generated video. + Returns: + int: updated sequence length + """ + scale_factor = self.vae_temporal_compression_ratio + num_frames = min(sequence_num_frames, target_num_frames - start_frame) + # Trim down to a multiple of temporal_scale_factor frames plus 1 + num_frames = (num_frames - 1) // scale_factor * scale_factor + 1 + return num_frames + + def prepare_latents( self, conditions: Union[LTXVideoCondition, List[LTXVideoCondition]], @@ -579,7 +608,11 @@ def prepare_latents( if condition.image is not None: data = self.video_processor.preprocess(condition.image, height, width).unsqueeze(2) elif condition.video is not None: - data = self.video_processor.preprocess_video(condition.vide, height, width) + data = self.video_processor.preprocess_video(condition.video, height, width) + num_frames_input = data.size(2) + num_frames_output = self.trim_conditioning_sequence(condition.frame_index, num_frames_input, num_frames) + data = data[:, :, :num_frames_output] + data = data.to(device, dtype=dtype) else: raise ValueError("Either `image` or `video` must be provided in the `LTXVideoCondition`.") @@ -599,6 +632,7 @@ def prepare_latents( latents[:, :, :num_cond_frames], condition_latents, condition.strength ) condition_latent_frames_mask[:, :num_cond_frames] = condition.strength + # YiYi TODO: code path not tested else: if num_data_frames > 1: ( @@ -617,8 +651,8 @@ def prepare_latents( noise = randn_tensor(condition_latents.shape, generator=generator, device=device, dtype=dtype) condition_latents = torch.lerp(noise, condition_latents, condition.strength) c_nlf = condition_latents.shape[2] - condition_latents = self._pack_latents( - condition_latents, self.transformer_spatial_patch_size, self.transformer_temporal_patch_size + condition_latents, condition_latent_coords = self._pack_latents( + condition_latents, self.transformer_spatial_patch_size, self.transformer_temporal_patch_size, device ) conditioning_mask = torch.full( condition_latents.shape[:2], condition.strength, device=device, dtype=dtype @@ -642,23 +676,22 @@ def prepare_latents( extra_conditioning_rope_interpolation_scales.append(rope_interpolation_scale) extra_conditioning_mask.append(conditioning_mask) - latents = self._pack_latents( - latents, self.transformer_spatial_patch_size, self.transformer_temporal_patch_size + latents, latent_coords = self._pack_latents( + latents, self.transformer_spatial_patch_size, self.transformer_temporal_patch_size, device ) - rope_interpolation_scale = [ - self.vae_temporal_compression_ratio / frame_rate, - self.vae_spatial_compression_ratio, - self.vae_spatial_compression_ratio, - ] - rope_interpolation_scale = ( - torch.tensor(rope_interpolation_scale, device=device, dtype=dtype) - .view(-1, 1, 1, 1, 1) - .repeat(1, 1, num_latent_frames, latent_height, latent_width) + pixel_coords = ( + latent_coords + * torch.tensor([self.vae_temporal_compression_ratio, self.vae_spatial_compression_ratio, self.vae_spatial_compression_ratio], device=latent_coords.device)[None, :, None] ) - conditioning_mask = self._pack_latents( - conditioning_mask, self.transformer_spatial_patch_size, self.transformer_temporal_patch_size + pixel_coords[:, 0] = (pixel_coords[:, 0] + 1 - self.vae_temporal_compression_ratio).clamp(min=0) + + rope_interpolation_scale = pixel_coords + + conditioning_mask = condition_latent_frames_mask.gather( + 1, latent_coords[:, 0] ) + # YiYi TODO: code path not tested if len(extra_conditioning_latents) > 0: latents = torch.cat([*extra_conditioning_latents, latents], dim=1) rope_interpolation_scale = torch.cat( @@ -864,7 +897,7 @@ def __call__( frame_rate, generator, device, - torch.float32, + prompt_embeds.dtype, ) init_latents = latents.clone() @@ -955,8 +988,8 @@ def __call__( pred_latents = self.scheduler.step(noise_pred, t, noise_latents, return_dict=False)[0] latents = torch.cat([latents[:, :, :1], pred_latents], dim=2) - latents = self._pack_latents( - latents, self.transformer_spatial_patch_size, self.transformer_temporal_patch_size + latents, _ = self._pack_latents( + latents, self.transformer_spatial_patch_size, self.transformer_temporal_patch_size, device ) if callback_on_step_end is not None: From 267583af252fe6aae804348e44e8f14e6fa13212 Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Thu, 6 Mar 2025 08:02:02 +0100 Subject: [PATCH 02/17] up --- .../pipelines/ltx/pipeline_ltx_condition.py | 42 ++++++++----------- 1 file changed, 18 insertions(+), 24 deletions(-) diff --git a/src/diffusers/pipelines/ltx/pipeline_ltx_condition.py b/src/diffusers/pipelines/ltx/pipeline_ltx_condition.py index 197127e90919..a46ee26bbb32 100644 --- a/src/diffusers/pipelines/ltx/pipeline_ltx_condition.py +++ b/src/diffusers/pipelines/ltx/pipeline_ltx_condition.py @@ -513,10 +513,10 @@ def _prepare_non_first_frame_conditioning( frame_index: int, strength: float, num_prefix_latent_frames: int = 2, - prefix_latents_mode: str = "soft", + prefix_latents_mode: str = "concat", prefix_soft_conditioning_strength: float = 0.15, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - num_latent_frames = latents.size(2) + num_latent_frames = condition_latents.size(2) if num_latent_frames < num_prefix_latent_frames: raise ValueError( @@ -602,7 +602,7 @@ def prepare_latents( extra_conditioning_num_latents = ( 0 # Number of extra conditioning latents added (should be removed before decoding) ) - condition_latent_frames_mask = torch.zeros((batch_size, num_latent_frames), device=device, dtype=dtype) + condition_latent_frames_mask = torch.zeros((batch_size, num_latent_frames), device=device, dtype=torch.float32) for condition in conditions: if condition.image is not None: @@ -632,7 +632,7 @@ def prepare_latents( latents[:, :, :num_cond_frames], condition_latents, condition.strength ) condition_latent_frames_mask[:, :num_cond_frames] = condition.strength - # YiYi TODO: code path not tested + else: if num_data_frames > 1: ( @@ -651,47 +651,41 @@ def prepare_latents( noise = randn_tensor(condition_latents.shape, generator=generator, device=device, dtype=dtype) condition_latents = torch.lerp(noise, condition_latents, condition.strength) c_nlf = condition_latents.shape[2] - condition_latents, condition_latent_coords = self._pack_latents( + condition_latents, rope_interpolation_scale = self._pack_latents( condition_latents, self.transformer_spatial_patch_size, self.transformer_temporal_patch_size, device ) + + rope_interpolation_scale = ( + rope_interpolation_scale * + torch.tensor([self.vae_temporal_compression_ratio, self.vae_spatial_compression_ratio, self.vae_spatial_compression_ratio], device=latent_coords.device)[None, :, None] + ) + rope_interpolation_scale[:, 0] = (rope_interpolation_scale[:, 0] + 1 - self.vae_temporal_compression_ratio).clamp(min=0) + rope_interpolation_scale[:, 0] += condition.frame_index + conditioning_mask = torch.full( condition_latents.shape[:2], condition.strength, device=device, dtype=dtype ) - rope_interpolation_scale = [ - # TODO!!! This is incorrect: the frame index needs to added AFTER multiplying the interpolation - # scale with the grid. - (self.vae_temporal_compression_ratio + condition.frame_index) / frame_rate, - self.vae_spatial_compression_ratio, - self.vae_spatial_compression_ratio, - ] - rope_interpolation_scale = ( - torch.tensor(rope_interpolation_scale, device=device, dtype=dtype) - .view(-1, 1, 1, 1, 1) - .repeat(1, 1, c_nlf, latent_height, latent_width) - ) extra_conditioning_num_latents += condition_latents.size(1) extra_conditioning_latents.append(condition_latents) extra_conditioning_rope_interpolation_scales.append(rope_interpolation_scale) extra_conditioning_mask.append(conditioning_mask) - latents, latent_coords = self._pack_latents( + latents, rope_interpolation_scale = self._pack_latents( latents, self.transformer_spatial_patch_size, self.transformer_temporal_patch_size, device ) - pixel_coords = ( - latent_coords + + rope_interpolation_scale = ( + rope_interpolation_scale * torch.tensor([self.vae_temporal_compression_ratio, self.vae_spatial_compression_ratio, self.vae_spatial_compression_ratio], device=latent_coords.device)[None, :, None] ) - pixel_coords[:, 0] = (pixel_coords[:, 0] + 1 - self.vae_temporal_compression_ratio).clamp(min=0) - - rope_interpolation_scale = pixel_coords + rope_interpolation_scale[:, 0] = (rope_interpolation_scale[:, 0] + 1 - self.vae_temporal_compression_ratio).clamp(min=0) conditioning_mask = condition_latent_frames_mask.gather( 1, latent_coords[:, 0] ) - # YiYi TODO: code path not tested if len(extra_conditioning_latents) > 0: latents = torch.cat([*extra_conditioning_latents, latents], dim=1) rope_interpolation_scale = torch.cat( From 16c1467cd032459377c42a9d81a328c001ffee40 Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Thu, 6 Mar 2025 08:13:40 +0100 Subject: [PATCH 03/17] up --- src/diffusers/pipelines/ltx/pipeline_ltx_condition.py | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/src/diffusers/pipelines/ltx/pipeline_ltx_condition.py b/src/diffusers/pipelines/ltx/pipeline_ltx_condition.py index a46ee26bbb32..fd0b70e6d3a2 100644 --- a/src/diffusers/pipelines/ltx/pipeline_ltx_condition.py +++ b/src/diffusers/pipelines/ltx/pipeline_ltx_condition.py @@ -657,7 +657,7 @@ def prepare_latents( rope_interpolation_scale = ( rope_interpolation_scale * - torch.tensor([self.vae_temporal_compression_ratio, self.vae_spatial_compression_ratio, self.vae_spatial_compression_ratio], device=latent_coords.device)[None, :, None] + torch.tensor([self.vae_temporal_compression_ratio, self.vae_spatial_compression_ratio, self.vae_spatial_compression_ratio], device=rope_interpolation_scale.device)[None, :, None] ) rope_interpolation_scale[:, 0] = (rope_interpolation_scale[:, 0] + 1 - self.vae_temporal_compression_ratio).clamp(min=0) rope_interpolation_scale[:, 0] += condition.frame_index @@ -675,17 +675,16 @@ def prepare_latents( latents, rope_interpolation_scale = self._pack_latents( latents, self.transformer_spatial_patch_size, self.transformer_temporal_patch_size, device ) + conditioning_mask = condition_latent_frames_mask.gather( + 1, rope_interpolation_scale[:, 0] + ) rope_interpolation_scale = ( rope_interpolation_scale - * torch.tensor([self.vae_temporal_compression_ratio, self.vae_spatial_compression_ratio, self.vae_spatial_compression_ratio], device=latent_coords.device)[None, :, None] + * torch.tensor([self.vae_temporal_compression_ratio, self.vae_spatial_compression_ratio, self.vae_spatial_compression_ratio], device=rope_interpolation_scale.device)[None, :, None] ) rope_interpolation_scale[:, 0] = (rope_interpolation_scale[:, 0] + 1 - self.vae_temporal_compression_ratio).clamp(min=0) - conditioning_mask = condition_latent_frames_mask.gather( - 1, latent_coords[:, 0] - ) - if len(extra_conditioning_latents) > 0: latents = torch.cat([*extra_conditioning_latents, latents], dim=1) rope_interpolation_scale = torch.cat( From a098d94827afd4b71f0e60503f2cd97160b5f83d Mon Sep 17 00:00:00 2001 From: YiYi Xu Date: Wed, 5 Mar 2025 21:14:56 -1000 Subject: [PATCH 04/17] Update src/diffusers/pipelines/ltx/pipeline_ltx_condition.py Co-authored-by: hlky --- src/diffusers/pipelines/ltx/pipeline_ltx_condition.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/diffusers/pipelines/ltx/pipeline_ltx_condition.py b/src/diffusers/pipelines/ltx/pipeline_ltx_condition.py index fd0b70e6d3a2..c6682175dba5 100644 --- a/src/diffusers/pipelines/ltx/pipeline_ltx_condition.py +++ b/src/diffusers/pipelines/ltx/pipeline_ltx_condition.py @@ -452,6 +452,7 @@ def _pack_latents(latents: torch.Tensor, patch_size: int = 1, patch_size_t: int torch.arange(0, num_frames, patch_size_t, device=device), torch.arange(0, height, patch_size, device=device), torch.arange(0, width, patch_size, device=device), + indexing="ij", ) latent_sample_coords = torch.stack(latent_sample_coords, dim=0) latent_coords = latent_sample_coords.unsqueeze(0).repeat(batch_size, 1, 1, 1, 1) From d8bd10ed8b35066a0b9cdf0246f5ad2a1996f6b4 Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Mon, 10 Mar 2025 12:30:27 +0100 Subject: [PATCH 05/17] up --- src/diffusers/models/normalization.py | 3 +- .../models/transformers/transformer_ltx.py | 96 +++-- .../pipelines/ltx/pipeline_ltx_condition.py | 402 +++++++++--------- .../scheduling_flow_match_euler_discrete.py | 25 +- 4 files changed, 280 insertions(+), 246 deletions(-) diff --git a/src/diffusers/models/normalization.py b/src/diffusers/models/normalization.py index 383388ca543f..db054ee117dd 100644 --- a/src/diffusers/models/normalization.py +++ b/src/diffusers/models/normalization.py @@ -550,7 +550,8 @@ def forward(self, hidden_states): hidden_states = torch_npu.npu_rms_norm(hidden_states, self.weight, epsilon=self.eps)[0] if self.bias is not None: hidden_states = hidden_states + self.bias - elif is_torch_version(">=", "2.4"): + # YiYi TODO: testing only, remove this change before merging + elif is_torch_version(">=", "3.3"): if self.weight is not None: # convert into half-precision if necessary if self.weight.dtype in [torch.float16, torch.bfloat16]: diff --git a/src/diffusers/models/transformers/transformer_ltx.py b/src/diffusers/models/transformers/transformer_ltx.py index e7930b333ff6..937c20a3305c 100644 --- a/src/diffusers/models/transformers/transformer_ltx.py +++ b/src/diffusers/models/transformers/transformer_ltx.py @@ -115,46 +115,63 @@ def __init__( self.theta = theta self._causal_rope_fix = _causal_rope_fix - def forward( - self, - hidden_states: torch.Tensor, - num_frames: int, - height: int, - width: int, - frame_rate: Optional[int] = None, - rope_interpolation_scale: Optional[Tuple[torch.Tensor, float, float]] = None, - ) -> Tuple[torch.Tensor, torch.Tensor]: - batch_size = hidden_states.size(0) - + + def _prepare_video_coords(self, batch_size: int, num_frames: int, height: int, width: int, rope_interpolation_scale: Tuple[torch.Tensor, float, float], device: torch.device) -> torch.Tensor: # Always compute rope in fp32 - grid_h = torch.arange(height, dtype=torch.float32, device=hidden_states.device) - grid_w = torch.arange(width, dtype=torch.float32, device=hidden_states.device) - grid_f = torch.arange(num_frames, dtype=torch.float32, device=hidden_states.device) + grid_h = torch.arange(height, dtype=torch.float32, device=device) + grid_w = torch.arange(width, dtype=torch.float32, device=device) + grid_f = torch.arange(num_frames, dtype=torch.float32, device=device) grid = torch.meshgrid(grid_f, grid_h, grid_w, indexing="ij") grid = torch.stack(grid, dim=0) grid = grid.unsqueeze(0).repeat(batch_size, 1, 1, 1, 1) - if rope_interpolation_scale is not None: - if isinstance(rope_interpolation_scale, tuple): - # This will be deprecated in v0.34.0 - grid[:, 0:1] = grid[:, 0:1] * rope_interpolation_scale[0] * self.patch_size_t / self.base_num_frames - grid[:, 1:2] = grid[:, 1:2] * rope_interpolation_scale[1] * self.patch_size / self.base_height - grid[:, 2:3] = grid[:, 2:3] * rope_interpolation_scale[2] * self.patch_size / self.base_width + if isinstance(rope_interpolation_scale, tuple): + # This will be deprecated in v0.34.0 + grid[:, 0:1] = grid[:, 0:1] * rope_interpolation_scale[0] * self.patch_size_t / self.base_num_frames + grid[:, 1:2] = grid[:, 1:2] * rope_interpolation_scale[1] * self.patch_size / self.base_height + grid[:, 2:3] = grid[:, 2:3] * rope_interpolation_scale[2] * self.patch_size / self.base_width + else: + if not self._causal_rope_fix: + grid[:, 0:1] = ( + grid[:, 0:1] * rope_interpolation_scale[0:1] * self.patch_size_t / self.base_num_frames + ) else: - if not self._causal_rope_fix: - grid[:, 0:1] = ( - grid[:, 0:1] * rope_interpolation_scale[0:1] * self.patch_size_t / self.base_num_frames - ) - else: - grid[:, 0:1] = ( - ((grid[:, 0:1] - 1) * rope_interpolation_scale[0:1] + 1 / frame_rate).clamp(min=0) - * self.patch_size_t - / self.base_num_frames - ) - grid[:, 1:2] = grid[:, 1:2] * rope_interpolation_scale[1:2] * self.patch_size / self.base_height - grid[:, 2:3] = grid[:, 2:3] * rope_interpolation_scale[2:3] * self.patch_size / self.base_width + grid[:, 0:1] = ( + ((grid[:, 0:1] - 1) * rope_interpolation_scale[0:1] + 1 / frame_rate).clamp(min=0) + * self.patch_size_t + / self.base_num_frames + ) + grid[:, 1:2] = grid[:, 1:2] * rope_interpolation_scale[1:2] * self.patch_size / self.base_height + grid[:, 2:3] = grid[:, 2:3] * rope_interpolation_scale[2:3] * self.patch_size / self.base_width grid = grid.flatten(2, 4).transpose(1, 2) + + return grid + + + def forward( + self, + hidden_states: torch.Tensor, + num_frames: Optional[int] = None, + height: Optional[int] = None, + width: Optional[int] = None, + frame_rate: Optional[int] = None, + rope_interpolation_scale: Optional[Tuple[torch.Tensor, float, float]] = None, + video_coords: Optional[torch.Tensor] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + batch_size = hidden_states.size(0) + + if video_coords is None: + grid = self._prepare_video_coords(batch_size, num_frames, height, width, rope_interpolation_scale=rope_interpolation_scale, device=hidden_states.device) + else: + grid = torch.stack( + [ + video_coords[:, 0] / self.base_num_frames, + video_coords[:, 1] / self.base_height, + video_coords[:, 2] / self.base_width + ], + dim=-1, + ) start = 1.0 end = self.theta @@ -387,11 +404,12 @@ def forward( encoder_hidden_states: torch.Tensor, timestep: torch.LongTensor, encoder_attention_mask: torch.Tensor, - num_frames: int, - height: int, - width: int, - frame_rate: int, + num_frames: Optional[int] = None, + height: Optional[int] = None, + width: Optional[int] = None, + frame_rate: Optional[int] = None, rope_interpolation_scale: Optional[Union[Tuple[float, float, float], torch.Tensor]] = None, + video_coords: Optional[torch.Tensor] = None, attention_kwargs: Optional[Dict[str, Any]] = None, return_dict: bool = True, ) -> torch.Tensor: @@ -414,7 +432,8 @@ def forward( msg = "Passing a tuple for `rope_interpolation_scale` is deprecated and will be removed in v0.34.0." deprecate("rope_interpolation_scale", "0.34.0", msg) - image_rotary_emb = self.rope(hidden_states, num_frames, height, width, frame_rate, rope_interpolation_scale) + + image_rotary_emb = self.rope(hidden_states, num_frames, height, width, frame_rate, rope_interpolation_scale, video_coords) # convert encoder_attention_mask to a bias the same way we do for attention_mask if encoder_attention_mask is not None and encoder_attention_mask.ndim == 2: @@ -475,5 +494,6 @@ def apply_rotary_emb(x, freqs): cos, sin = freqs x_real, x_imag = x.unflatten(2, (-1, 2)).unbind(-1) # [B, S, H, D // 2] x_rotated = torch.stack([-x_imag, x_real], dim=-1).flatten(2) - out = (x.float() * cos + x_rotated.float() * sin).to(x.dtype) + # YiYi TODO: testing only, remove this change before merging + out = (x * cos.to(x.dtype) + x_rotated * sin.to(x.dtype)).to(x.dtype) return out diff --git a/src/diffusers/pipelines/ltx/pipeline_ltx_condition.py b/src/diffusers/pipelines/ltx/pipeline_ltx_condition.py index fd0b70e6d3a2..d2b560145b19 100644 --- a/src/diffusers/pipelines/ltx/pipeline_ltx_condition.py +++ b/src/diffusers/pipelines/ltx/pipeline_ltx_condition.py @@ -94,6 +94,31 @@ class LTXVideoCondition: strength: float = 1.0 +# from LTX-Video/ltx_video/schedulers/rf.py +def linear_quadratic_schedule(num_steps, threshold_noise=0.025, linear_steps=None): + if linear_steps is None: + linear_steps = num_steps // 2 + if num_steps < 2: + return torch.tensor([1.0]) + linear_sigma_schedule = [ + i * threshold_noise / linear_steps for i in range(linear_steps) + ] + threshold_noise_step_diff = linear_steps - threshold_noise * num_steps + quadratic_steps = num_steps - linear_steps + quadratic_coef = threshold_noise_step_diff / (linear_steps * quadratic_steps**2) + linear_coef = threshold_noise / linear_steps - 2 * threshold_noise_step_diff / ( + quadratic_steps**2 + ) + const = quadratic_coef * (linear_steps**2) + quadratic_sigma_schedule = [ + quadratic_coef * (i**2) + linear_coef * i + const + for i in range(linear_steps, num_steps) + ] + sigma_schedule = linear_sigma_schedule + quadratic_sigma_schedule + [1.0] + sigma_schedule = [1.0 - x for x in sigma_schedule] + return torch.tensor(sigma_schedule[:-1]) + + # Copied from diffusers.pipelines.flux.pipeline_flux.calculate_shift def calculate_shift( image_seq_len, @@ -285,7 +310,7 @@ def _get_t5_prompt_embeds( f" {max_sequence_length} tokens: {removed_text}" ) - prompt_embeds = self.text_encoder(text_input_ids.to(device))[0] + prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=prompt_attention_mask)[0] prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) # duplicate text embeddings for each generation per prompt, using mps friendly method @@ -436,17 +461,15 @@ def check_inputs( f" {negative_prompt_attention_mask.shape}." ) + @staticmethod # adapted from diffusers.pipelines.ltx.pipeline_ltx.LTXPipeline._pack_latents - def _pack_latents(latents: torch.Tensor, patch_size: int = 1, patch_size_t: int = 1, device: torch.device = None) -> torch.Tensor: + def _prepare_video_ids(latents: torch.Tensor, scale_factor: int = 32, scale_factor_t: int = 8, patch_size: int = 1, patch_size_t: int = 1, frame_index: int = 0, device: torch.device = None, return_unscaled_coords: bool = False) -> torch.Tensor: # Unpacked latents of shape are [B, C, F, H, W] are patched into tokens of shape [B, C, F // p_t, p_t, H // p, p, W // p, p]. # The patch dimensions are then permuted and collapsed into the channel dimension of shape: # [B, F // p_t * H // p * W // p, C * p_t * p * p] (an ndim=3 tensor). # dim=0 is the batch size, dim=1 is the effective video sequence length, dim=2 is the effective number of input features batch_size, num_channels, num_frames, height, width = latents.shape - post_patch_num_frames = num_frames // patch_size_t - post_patch_height = height // patch_size - post_patch_width = width // patch_size latent_sample_coords = torch.meshgrid( torch.arange(0, num_frames, patch_size_t, device=device), @@ -457,6 +480,30 @@ def _pack_latents(latents: torch.Tensor, patch_size: int = 1, patch_size_t: int latent_coords = latent_sample_coords.unsqueeze(0).repeat(batch_size, 1, 1, 1, 1) latent_coords = latent_coords.reshape(batch_size, -1, num_frames * height * width) + scaled_latent_coords = ( + latent_coords * + torch.tensor([scale_factor_t, scale_factor, scale_factor], device=latent_coords.device)[None, :, None] + ) + scaled_latent_coords[:, 0] = (scaled_latent_coords[:, 0] + 1 - scale_factor_t).clamp(min=0) + scaled_latent_coords[:, 0] += frame_index + + if return_unscaled_coords: + return latent_coords, scaled_latent_coords + else: + return scaled_latent_coords + + @staticmethod + # adapted from diffusers.pipelines.ltx.pipeline_ltx.LTXPipeline._pack_latents + def _pack_latents(latents: torch.Tensor, patch_size: int = 1, patch_size_t: int = 1, device: torch.device = None) -> torch.Tensor: + # Unpacked latents of shape are [B, C, F, H, W] are patched into tokens of shape [B, C, F // p_t, p_t, H // p, p, W // p, p]. + # The patch dimensions are then permuted and collapsed into the channel dimension of shape: + # [B, F // p_t * H // p * W // p, C * p_t * p * p] (an ndim=3 tensor). + # dim=0 is the batch size, dim=1 is the effective video sequence length, dim=2 is the effective number of input features + batch_size, num_channels, num_frames, height, width = latents.shape + post_patch_num_frames = num_frames // patch_size_t + post_patch_height = height // patch_size + post_patch_width = width // patch_size + latents = latents.reshape( batch_size, -1, @@ -468,7 +515,7 @@ def _pack_latents(latents: torch.Tensor, patch_size: int = 1, patch_size_t: int patch_size, ) latents = latents.permute(0, 2, 4, 6, 1, 3, 5, 7).flatten(4, 7).flatten(1, 3) - return latents, latent_coords + return latents @staticmethod # Copied from diffusers.pipelines.ltx.pipeline_ltx.LTXPipeline._unpack_latents @@ -505,54 +552,6 @@ def _denormalize_latents( latents = latents * latents_std / scaling_factor + latents_mean return latents - def _prepare_non_first_frame_conditioning( - self, - latents: torch.Tensor, - condition_latents: torch.Tensor, - condition_latent_frames_mask: torch.Tensor, - frame_index: int, - strength: float, - num_prefix_latent_frames: int = 2, - prefix_latents_mode: str = "concat", - prefix_soft_conditioning_strength: float = 0.15, - ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - num_latent_frames = condition_latents.size(2) - - if num_latent_frames < num_prefix_latent_frames: - raise ValueError( - f"Number of latent frames must be at least {num_prefix_latent_frames} but got {num_latent_frames}." - ) - if frame_index % self.vae_temporal_compression_ratio != 0: - raise ValueError( - f"Frame index must be divisible by {self.vae_temporal_compression_ratio} but got {frame_index}." - ) - - if num_latent_frames > num_prefix_latent_frames: - start_frame = frame_index // self.vae_temporal_compression_ratio + num_prefix_latent_frames - end_frame = start_frame + num_latent_frames - num_prefix_latent_frames - latents[:, :, start_frame:end_frame] = torch.lerp( - latents[:, :, start_frame:end_frame], condition_latents[:, :, num_prefix_latent_frames:], strength - ) - condition_latent_frames_mask[:, start_frame:end_frame] = strength - - if prefix_latents_mode == "soft": - if num_prefix_latent_frames > 1: - start_frame = frame_index // self.vae_temporal_compression_ratio + 1 - end_frame = start_frame + num_prefix_latent_frames - 1 - strength = min(prefix_soft_conditioning_strength, strength) - latents[:, :, start_frame:end_frame] = torch.lerp( - latents[:, :, start_frame:end_frame], condition_latents[:, :, 1:num_prefix_latent_frames], strength - ) - condition_latent_frames_mask[:, start_frame:end_frame] = strength - condition_latents = None - elif prefix_latents_mode == "drop": - condition_latents = None - elif prefix_latents_mode == "concat": - condition_latents = condition_latents[:, :, :num_prefix_latent_frames] - else: - raise ValueError(f"Invalid prefix_latents_mode: {prefix_latents_mode}") - - return latents, condition_latents, condition_latent_frames_mask def trim_conditioning_sequence( self, start_frame: int, sequence_num_frames: int, target_num_frames: int @@ -573,21 +572,48 @@ def trim_conditioning_sequence( return num_frames + @staticmethod + def add_noise_to_image_conditioning_latents( + t: float, + init_latents: torch.Tensor, + latents: torch.Tensor, + noise_scale: float, + conditioning_mask: torch.Tensor, + generator, + eps=1e-6, + ): + """ + Add timestep-dependent noise to the hard-conditioning latents. + This helps with motion continuity, especially when conditioned on a single frame. + """ + generator = torch.Generator(device="cpu").manual_seed(0) + noise = randn_tensor( + latents.shape, + generator=generator, + device=latents.device, + dtype=latents.dtype, + ) + # Add noise only to hard-conditioning latents (conditioning_mask = 1.0) + need_to_noise = (conditioning_mask > 1.0 - eps).unsqueeze(-1) + noised_latents = init_latents + noise_scale * noise * (t**2) + latents = torch.where(need_to_noise, noised_latents, latents) + return latents + def prepare_latents( self, - conditions: Union[LTXVideoCondition, List[LTXVideoCondition]], + conditions: List[torch.Tensor], + condition_strength: List[float], + condition_frame_index: List[int], batch_size: int = 1, num_channels_latents: int = 128, height: int = 512, width: int = 704, num_frames: int = 161, - frame_rate: int = 25, + num_prefix_latent_frames: int = 2, generator: Optional[torch.Generator] = None, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None, ) -> None: - if not isinstance(conditions, list): - conditions = [conditions] num_latent_frames = (num_frames - 1) // self.vae_temporal_compression_ratio + 1 latent_height = height // self.vae_spatial_compression_ratio @@ -595,104 +621,73 @@ def prepare_latents( shape = (batch_size, num_channels_latents, num_latent_frames, latent_height, latent_width) latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + # latents = torch.load("/raid/yiyi/LTX-Video/init_latents.pt").to(device, dtype=dtype) + condition_latent_frames_mask = torch.zeros((batch_size, num_latent_frames), device=device, dtype=torch.float32) + extra_conditioning_latents = [] - extra_conditioning_rope_interpolation_scales = [] + extra_conditioning_video_ids = [] extra_conditioning_mask = [] - extra_conditioning_num_latents = ( - 0 # Number of extra conditioning latents added (should be removed before decoding) - ) - condition_latent_frames_mask = torch.zeros((batch_size, num_latent_frames), device=device, dtype=torch.float32) - - for condition in conditions: - if condition.image is not None: - data = self.video_processor.preprocess(condition.image, height, width).unsqueeze(2) - elif condition.video is not None: - data = self.video_processor.preprocess_video(condition.video, height, width) - num_frames_input = data.size(2) - num_frames_output = self.trim_conditioning_sequence(condition.frame_index, num_frames_input, num_frames) - data = data[:, :, :num_frames_output] - data = data.to(device, dtype=dtype) - else: - raise ValueError("Either `image` or `video` must be provided in the `LTXVideoCondition`.") - - if data.size(2) % self.vae_temporal_compression_ratio != 1: - raise ValueError( - f"Number of frames in the video must be of the form (k * {self.vae_temporal_compression_ratio} + 1) " - f"but got {data.size(2)} frames." - ) - + extra_conditioning_num_latents = 0 + for data, strength, frame_index in zip(conditions, condition_strength, condition_frame_index): condition_latents = retrieve_latents(self.vae.encode(data), generator=generator) + # condition_latents = torch.load("/raid/yiyi/LTX-Video/latents_before_normalize.pt").to(device, dtype=dtype) condition_latents = self._normalize_latents(condition_latents, self.vae.latents_mean, self.vae.latents_std) + num_data_frames = data.size(2) num_cond_frames = condition_latents.size(2) - if condition.frame_index == 0: + if frame_index == 0: latents[:, :, :num_cond_frames] = torch.lerp( - latents[:, :, :num_cond_frames], condition_latents, condition.strength + latents[:, :, :num_cond_frames], condition_latents, strength ) - condition_latent_frames_mask[:, :num_cond_frames] = condition.strength + condition_latent_frames_mask[:, :num_cond_frames] = strength else: if num_data_frames > 1: - ( - latents, - condition_latents, - condition_latent_frames_mask, - ) = self._prepare_non_first_frame_conditioning( - latents, - condition_latents, - condition_latent_frames_mask, - condition.frame_index, - condition.strength, - ) - - if condition_latents is not None: - noise = randn_tensor(condition_latents.shape, generator=generator, device=device, dtype=dtype) - condition_latents = torch.lerp(noise, condition_latents, condition.strength) - c_nlf = condition_latents.shape[2] - condition_latents, rope_interpolation_scale = self._pack_latents( - condition_latents, self.transformer_spatial_patch_size, self.transformer_temporal_patch_size, device - ) - - rope_interpolation_scale = ( - rope_interpolation_scale * - torch.tensor([self.vae_temporal_compression_ratio, self.vae_spatial_compression_ratio, self.vae_spatial_compression_ratio], device=rope_interpolation_scale.device)[None, :, None] - ) - rope_interpolation_scale[:, 0] = (rope_interpolation_scale[:, 0] + 1 - self.vae_temporal_compression_ratio).clamp(min=0) - rope_interpolation_scale[:, 0] += condition.frame_index - - conditioning_mask = torch.full( - condition_latents.shape[:2], condition.strength, device=device, dtype=dtype - ) - - extra_conditioning_num_latents += condition_latents.size(1) - - extra_conditioning_latents.append(condition_latents) - extra_conditioning_rope_interpolation_scales.append(rope_interpolation_scale) - extra_conditioning_mask.append(conditioning_mask) - - latents, rope_interpolation_scale = self._pack_latents( - latents, self.transformer_spatial_patch_size, self.transformer_temporal_patch_size, device - ) + if num_cond_frames < num_prefix_latent_frames: + raise ValueError( + f"Number of latent frames must be at least {num_prefix_latent_frames} but got {num_data_frames}." + ) + + if num_cond_frames > num_prefix_latent_frames: + start_frame = frame_index // self.vae_temporal_compression_ratio + num_prefix_latent_frames + end_frame = start_frame + num_cond_frames - num_prefix_latent_frames + latents[:, :, start_frame:end_frame] = torch.lerp( + latents[:, :, start_frame:end_frame], condition_latents[:, :, num_prefix_latent_frames:], strength + ) + condition_latent_frames_mask[:, start_frame:end_frame] = strength + condition_latents = condition_latents[:, :, :num_prefix_latent_frames] + + + noise = randn_tensor(condition_latents.shape, generator=generator, device=device, dtype=dtype) + # noise = torch.load("/raid/yiyi/LTX-Video/noise.pt").to(device, dtype=dtype) + condition_latents = torch.lerp(noise, condition_latents, strength) + + condition_video_ids = self._prepare_video_ids(condition_latents, scale_factor=self.vae_spatial_compression_ratio, scale_factor_t=self.vae_temporal_compression_ratio, patch_size=self.transformer_spatial_patch_size, patch_size_t=self.transformer_temporal_patch_size, frame_index=frame_index, device=device) + condition_latents = self._pack_latents(condition_latents, self.transformer_spatial_patch_size, self.transformer_temporal_patch_size, device) + condition_conditioning_mask = torch.full(condition_latents.shape[:2], strength, device=device, dtype=dtype) + + + extra_conditioning_latents.append(condition_latents) + extra_conditioning_video_ids.append(condition_video_ids) + extra_conditioning_mask.append(condition_conditioning_mask) + extra_conditioning_num_latents += condition_latents.size(1) + + video_ids, video_ids_scaled = self._prepare_video_ids(latents, scale_factor_t = self.vae_temporal_compression_ratio, scale_factor = self.vae_spatial_compression_ratio, patch_size_t = self.transformer_temporal_patch_size, patch_size = self.transformer_spatial_patch_size, device=device, return_unscaled_coords=True) + latents = self._pack_latents(latents, self.transformer_spatial_patch_size, self.transformer_temporal_patch_size, device) conditioning_mask = condition_latent_frames_mask.gather( - 1, rope_interpolation_scale[:, 0] + 1, video_ids[:, 0] ) - - rope_interpolation_scale = ( - rope_interpolation_scale - * torch.tensor([self.vae_temporal_compression_ratio, self.vae_spatial_compression_ratio, self.vae_spatial_compression_ratio], device=rope_interpolation_scale.device)[None, :, None] - ) - rope_interpolation_scale[:, 0] = (rope_interpolation_scale[:, 0] + 1 - self.vae_temporal_compression_ratio).clamp(min=0) + if len(extra_conditioning_latents) > 0: latents = torch.cat([*extra_conditioning_latents, latents], dim=1) - rope_interpolation_scale = torch.cat( - [*extra_conditioning_rope_interpolation_scales, rope_interpolation_scale], dim=2 - ) + video_ids = torch.cat([*extra_conditioning_video_ids, video_ids_scaled], dim=2) conditioning_mask = torch.cat([*extra_conditioning_mask, conditioning_mask], dim=1) - return latents, conditioning_mask, rope_interpolation_scale, extra_conditioning_num_latents + return latents, conditioning_mask, video_ids, extra_conditioning_num_latents + @property def guidance_scale(self): @@ -743,7 +738,7 @@ def __call__( attention_kwargs: Optional[Dict[str, Any]] = None, callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, callback_on_step_end_tensor_inputs: List[str] = ["latents"], - max_sequence_length: int = 128, + max_sequence_length: int = 256, ): r""" Function invoked when calling the pipeline for generation. @@ -878,45 +873,65 @@ def __call__( prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) prompt_attention_mask = torch.cat([negative_prompt_attention_mask, prompt_attention_mask], dim=0) + vae_dtype = self.vae.dtype + + conditioning_tensors = [] + conditioning_strengths = [] + conditioning_start_frames = [] + + for condition in conditions: + if condition.image is not None: + condition_tensor = self.video_processor.preprocess(condition.image, height, width).unsqueeze(2) + elif condition.video is not None: + condition_tensor = self.video_processor.preprocess_video(condition.video, height, width) + num_frames_input = condition_tensor.size(2) + num_frames_output = self.trim_conditioning_sequence(condition.frame_index, num_frames_input, num_frames) + condition_tensor = condition_tensor[:, :, :num_frames_output] + condition_tensor = condition_tensor.to(device, dtype=vae_dtype) + else: + raise ValueError("Either `image` or `video` must be provided in the `LTXVideoCondition`.") + + if condition_tensor.size(2) % self.vae_temporal_compression_ratio != 1: + raise ValueError( + f"Number of frames in the video must be of the form (k * {self.vae_temporal_compression_ratio} + 1) " + f"but got {condition_tensor.size(2)} frames." + ) + conditioning_tensors.append(condition_tensor) + conditioning_strengths.append(condition.strength) + conditioning_start_frames.append(condition.frame_index) + # 4. Prepare latent variables num_channels_latents = self.transformer.config.in_channels - latents, conditioning_mask, rope_interpolation_scale, extra_conditioning_num_latents = self.prepare_latents( - conditions, - batch_size * num_videos_per_prompt, - num_channels_latents, - height, - width, - num_frames, - frame_rate, - generator, - device, - prompt_embeds.dtype, + latents, conditioning_mask, video_coords, extra_conditioning_num_latents = self.prepare_latents( + conditioning_tensors, + conditioning_strengths, + conditioning_start_frames, + batch_size=batch_size * num_videos_per_prompt, + num_channels_latents=num_channels_latents, + height=height, + width=width, + num_frames=num_frames, + generator=generator, + device=device, + dtype=prompt_embeds.dtype, ) init_latents = latents.clone() if self.do_classifier_free_guidance: - conditioning_mask = torch.cat([conditioning_mask, conditioning_mask]) + video_coords = torch.cat([video_coords, video_coords], dim=0) + # 5. Prepare timesteps latent_num_frames = (num_frames - 1) // self.vae_temporal_compression_ratio + 1 latent_height = height // self.vae_spatial_compression_ratio latent_width = width // self.vae_spatial_compression_ratio - video_sequence_length = latent_num_frames * latent_height * latent_width - sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) - mu = calculate_shift( - video_sequence_length, - self.scheduler.config.get("base_image_seq_len", 256), - self.scheduler.config.get("max_image_seq_len", 4096), - self.scheduler.config.get("base_shift", 0.5), - self.scheduler.config.get("max_shift", 1.15), - ) + sigmas = linear_quadratic_schedule(num_inference_steps) + timesteps = sigmas * 1000 timesteps, num_inference_steps = retrieve_timesteps( self.scheduler, num_inference_steps, device, - timesteps, - sigmas=sigmas, - mu=mu, + timesteps=timesteps, ) num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) self._num_timesteps = len(timesteps) @@ -928,62 +943,45 @@ def __call__( continue if image_cond_noise_scale > 0: - latents = latents - # TODO(aryan): implement this + # Add timestep-dependent noise to the hard-conditioning latents + # This helps with motion continuity, especially when conditioned on a single frame + latents = self.add_noise_to_image_conditioning_latents( + t/1000.0, + init_latents, + latents.float(), + image_cond_noise_scale, + conditioning_mask, + generator, + ) + latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents + conditioning_mask_model_input = torch.cat([conditioning_mask, conditioning_mask]) if self.do_classifier_free_guidance else conditioning_mask latent_model_input = latent_model_input.to(prompt_embeds.dtype) # broadcast to batch dimension in a way that's compatible with ONNX/Core ML - timestep = t.expand(latent_model_input.shape[0]) - timestep = timestep.unsqueeze(-1) * (1 - conditioning_mask) + timestep = t.expand(latent_model_input.shape[0]).unsqueeze(-1) + timestep = torch.min(timestep, (1 - conditioning_mask_model_input) * 1000.0) noise_pred = self.transformer( hidden_states=latent_model_input, encoder_hidden_states=prompt_embeds, timestep=timestep, encoder_attention_mask=prompt_attention_mask, - num_frames=latent_num_frames, - height=latent_height, - width=latent_width, - frame_rate=frame_rate, - rope_interpolation_scale=rope_interpolation_scale, + video_coords=video_coords, attention_kwargs=attention_kwargs, return_dict=False, )[0] - noise_pred = noise_pred.float() - if self.do_classifier_free_guidance: noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) timestep, _ = timestep.chunk(2) - # compute the previous noisy sample x_t -> x_t-1 - noise_pred = self._unpack_latents( - noise_pred, - latent_num_frames, - latent_height, - latent_width, - self.transformer_spatial_patch_size, - self.transformer_temporal_patch_size, - ) - latents = self._unpack_latents( - latents, - latent_num_frames, - latent_height, - latent_width, - self.transformer_spatial_patch_size, - self.transformer_temporal_patch_size, - ) - - noise_pred = noise_pred[:, :, 1:] - noise_latents = latents[:, :, 1:] - pred_latents = self.scheduler.step(noise_pred, t, noise_latents, return_dict=False)[0] + denoised_latents = self.scheduler.step(noise_pred, timestep, latents, return_dict=False)[0] + t_eps = 1e-6 + tokens_to_denoise_mask = (t/1000 - t_eps < (1.0 - conditioning_mask)).unsqueeze(-1) + latents = torch.where(tokens_to_denoise_mask, denoised_latents, latents) - latents = torch.cat([latents[:, :, :1], pred_latents], dim=2) - latents, _ = self._pack_latents( - latents, self.transformer_spatial_patch_size, self.transformer_temporal_patch_size, device - ) if callback_on_step_end is not None: callback_kwargs = {} @@ -1001,10 +999,8 @@ def __call__( if XLA_AVAILABLE: xm.mark_step() - if output_type == "latent": - video = latents - else: - latents = self._unpack_latents( + latents = latents[:, extra_conditioning_num_latents:] + latents = self._unpack_latents( latents, latent_num_frames, latent_height, @@ -1012,6 +1008,10 @@ def __call__( self.transformer_spatial_patch_size, self.transformer_temporal_patch_size, ) + + if output_type == "latent": + video = latents + else: latents = self._denormalize_latents( latents, self.vae.latents_mean, self.vae.latents_std, self.vae.config.scaling_factor ) diff --git a/src/diffusers/schedulers/scheduling_flow_match_euler_discrete.py b/src/diffusers/schedulers/scheduling_flow_match_euler_discrete.py index e3bff7582cd9..97081785b0cf 100644 --- a/src/diffusers/schedulers/scheduling_flow_match_euler_discrete.py +++ b/src/diffusers/schedulers/scheduling_flow_match_euler_discrete.py @@ -421,22 +421,35 @@ def step( ), ) + use_per_token_timesteps = isinstance(timestep, torch.Tensor) and timestep.ndim == 2 if self.step_index is None: - self._init_step_index(timestep) + if not use_per_token_timesteps: + self._init_step_index(timestep) # Upcast to avoid precision issues when computing prev_sample sample = sample.to(torch.float32) - sigma = self.sigmas[self.step_index] - sigma_next = self.sigmas[self.step_index + 1] - - prev_sample = sample + (sigma_next - sigma) * model_output + if use_per_token_timesteps: + t_eps = 1e-6 + per_token_sigmas = timestep/self.config.num_train_timesteps + sigmas = self.sigmas[:, None, None] + lower_mask = sigmas < per_token_sigmas[None] - t_eps + lower_sigmas = lower_mask * sigmas + lower_sigmas, _ = lower_sigmas.max(dim=0) + dt = (per_token_sigmas - lower_sigmas)[..., None] + else: + sigma = self.sigmas[self.step_index] + sigma_next = self.sigmas[self.step_index + 1] + dt = sigma_next - sigma + + prev_sample = sample + dt * model_output # Cast sample back to model compatible dtype prev_sample = prev_sample.to(model_output.dtype) # upon completion increase step index by one - self._step_index += 1 + if not use_per_token_timesteps: + self._step_index += 1 if not return_dict: return (prev_sample,) From cbc035d45186bb3814b9a246c2a526ecf8d0f563 Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Wed, 12 Mar 2025 01:40:09 +0100 Subject: [PATCH 06/17] make it work --- .../pipelines/ltx/pipeline_ltx_condition.py | 46 +++++++++++-------- 1 file changed, 26 insertions(+), 20 deletions(-) diff --git a/src/diffusers/pipelines/ltx/pipeline_ltx_condition.py b/src/diffusers/pipelines/ltx/pipeline_ltx_condition.py index b7969e5d7630..b82881bae81d 100644 --- a/src/diffusers/pipelines/ltx/pipeline_ltx_condition.py +++ b/src/diffusers/pipelines/ltx/pipeline_ltx_condition.py @@ -464,12 +464,7 @@ def check_inputs( @staticmethod # adapted from diffusers.pipelines.ltx.pipeline_ltx.LTXPipeline._pack_latents - def _prepare_video_ids(latents: torch.Tensor, scale_factor: int = 32, scale_factor_t: int = 8, patch_size: int = 1, patch_size_t: int = 1, frame_index: int = 0, device: torch.device = None, return_unscaled_coords: bool = False) -> torch.Tensor: - # Unpacked latents of shape are [B, C, F, H, W] are patched into tokens of shape [B, C, F // p_t, p_t, H // p, p, W // p, p]. - # The patch dimensions are then permuted and collapsed into the channel dimension of shape: - # [B, F // p_t * H // p * W // p, C * p_t * p * p] (an ndim=3 tensor). - # dim=0 is the batch size, dim=1 is the effective video sequence length, dim=2 is the effective number of input features - batch_size, num_channels, num_frames, height, width = latents.shape + def _prepare_video_ids(batch_size: int, num_frames: int, height: int, width: int, patch_size: int = 1, patch_size_t: int = 1, device: torch.device = None) -> torch.Tensor: latent_sample_coords = torch.meshgrid( torch.arange(0, num_frames, patch_size_t, device=device), @@ -481,17 +476,21 @@ def _prepare_video_ids(latents: torch.Tensor, scale_factor: int = 32, scale_fact latent_coords = latent_sample_coords.unsqueeze(0).repeat(batch_size, 1, 1, 1, 1) latent_coords = latent_coords.reshape(batch_size, -1, num_frames * height * width) + return latent_coords + + + @staticmethod + # adapted from diffusers.pipelines.ltx.pipeline_ltx.LTXPipeline._pack_latents + def _scale_video_ids(video_ids: torch.Tensor, scale_factor: int = 32, scale_factor_t: int = 8, frame_index: int = 0, device: torch.device = None) -> torch.Tensor: + scaled_latent_coords = ( - latent_coords * - torch.tensor([scale_factor_t, scale_factor, scale_factor], device=latent_coords.device)[None, :, None] + video_ids * + torch.tensor([scale_factor_t, scale_factor, scale_factor], device=video_ids.device)[None, :, None] ) scaled_latent_coords[:, 0] = (scaled_latent_coords[:, 0] + 1 - scale_factor_t).clamp(min=0) scaled_latent_coords[:, 0] += frame_index - if return_unscaled_coords: - return latent_coords, scaled_latent_coords - else: - return scaled_latent_coords + return scaled_latent_coords @staticmethod # adapted from diffusers.pipelines.ltx.pipeline_ltx.LTXPipeline._pack_latents @@ -622,7 +621,7 @@ def prepare_latents( shape = (batch_size, num_channels_latents, num_latent_frames, latent_height, latent_width) latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) - # latents = torch.load("/raid/yiyi/LTX-Video/init_latents.pt").to(device, dtype=dtype) + latents = torch.load("/raid/yiyi/LTX-Video/init_latents.pt").to(device, dtype=dtype) condition_latent_frames_mask = torch.zeros((batch_size, num_latent_frames), device=device, dtype=torch.float32) @@ -632,8 +631,8 @@ def prepare_latents( extra_conditioning_num_latents = 0 for data, strength, frame_index in zip(conditions, condition_strength, condition_frame_index): condition_latents = retrieve_latents(self.vae.encode(data), generator=generator) - # condition_latents = torch.load("/raid/yiyi/LTX-Video/latents_before_normalize.pt").to(device, dtype=dtype) condition_latents = self._normalize_latents(condition_latents, self.vae.latents_mean, self.vae.latents_std) + condition_latents = torch.load("/raid/yiyi/LTX-Video/conditioning_latents.pt").to(device, dtype=dtype) num_data_frames = data.size(2) num_cond_frames = condition_latents.size(2) @@ -662,10 +661,11 @@ def prepare_latents( noise = randn_tensor(condition_latents.shape, generator=generator, device=device, dtype=dtype) - # noise = torch.load("/raid/yiyi/LTX-Video/noise.pt").to(device, dtype=dtype) + noise = torch.load("/raid/yiyi/LTX-Video/noise.pt").to(device, dtype=dtype) condition_latents = torch.lerp(noise, condition_latents, strength) - condition_video_ids = self._prepare_video_ids(condition_latents, scale_factor=self.vae_spatial_compression_ratio, scale_factor_t=self.vae_temporal_compression_ratio, patch_size=self.transformer_spatial_patch_size, patch_size_t=self.transformer_temporal_patch_size, frame_index=frame_index, device=device) + condition_video_ids = self._prepare_video_ids(batch_size, condition_latents.size(2), latent_height, latent_width, patch_size=self.transformer_spatial_patch_size, patch_size_t=self.transformer_temporal_patch_size, device=device) + condition_video_ids = self._scale_video_ids(condition_video_ids, scale_factor=self.vae_spatial_compression_ratio, scale_factor_t=self.vae_temporal_compression_ratio, frame_index=frame_index, device=device) condition_latents = self._pack_latents(condition_latents, self.transformer_spatial_patch_size, self.transformer_temporal_patch_size, device) condition_conditioning_mask = torch.full(condition_latents.shape[:2], strength, device=device, dtype=dtype) @@ -675,7 +675,8 @@ def prepare_latents( extra_conditioning_mask.append(condition_conditioning_mask) extra_conditioning_num_latents += condition_latents.size(1) - video_ids, video_ids_scaled = self._prepare_video_ids(latents, scale_factor_t = self.vae_temporal_compression_ratio, scale_factor = self.vae_spatial_compression_ratio, patch_size_t = self.transformer_temporal_patch_size, patch_size = self.transformer_spatial_patch_size, device=device, return_unscaled_coords=True) + video_ids = self._prepare_video_ids(batch_size, num_latent_frames, latent_height, latent_width, patch_size_t = self.transformer_temporal_patch_size, patch_size = self.transformer_spatial_patch_size, device=device) + video_ids_scaled = self._scale_video_ids(video_ids, scale_factor=self.vae_spatial_compression_ratio, scale_factor_t=self.vae_temporal_compression_ratio, frame_index=0, device=device) latents = self._pack_latents(latents, self.transformer_spatial_patch_size, self.transformer_temporal_patch_size, device) conditioning_mask = condition_latent_frames_mask.gather( 1, video_ids[:, 0] @@ -916,6 +917,10 @@ def __call__( device=device, dtype=prompt_embeds.dtype, ) + + video_coords = video_coords.float() + video_coords[:, 0] = video_coords[:, 0] * (1.0 / frame_rate) + init_latents = latents.clone() if self.do_classifier_free_guidance: @@ -949,7 +954,7 @@ def __call__( latents = self.add_noise_to_image_conditioning_latents( t/1000.0, init_latents, - latents.float(), + latents, image_cond_noise_scale, conditioning_mask, generator, @@ -961,7 +966,7 @@ def __call__( latent_model_input = latent_model_input.to(prompt_embeds.dtype) # broadcast to batch dimension in a way that's compatible with ONNX/Core ML - timestep = t.expand(latent_model_input.shape[0]).unsqueeze(-1) + timestep = t.expand(latent_model_input.shape[0]).unsqueeze(-1).float() timestep = torch.min(timestep, (1 - conditioning_mask_model_input) * 1000.0) noise_pred = self.transformer( @@ -973,12 +978,13 @@ def __call__( attention_kwargs=attention_kwargs, return_dict=False, )[0] + if self.do_classifier_free_guidance: noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) timestep, _ = timestep.chunk(2) - denoised_latents = self.scheduler.step(noise_pred, timestep, latents, return_dict=False)[0] + denoised_latents = self.scheduler.step(-noise_pred, timestep, latents, return_dict=False)[0] t_eps = 1e-6 tokens_to_denoise_mask = (t/1000 - t_eps < (1.0 - conditioning_mask)).unsqueeze(-1) latents = torch.where(tokens_to_denoise_mask, denoised_latents, latents) From 1fdebeac5229396091a0388aa4e112d6d12fed04 Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Wed, 12 Mar 2025 03:34:00 +0100 Subject: [PATCH 07/17] up --- .../schedulers/scheduling_flow_match_euler_discrete.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/diffusers/schedulers/scheduling_flow_match_euler_discrete.py b/src/diffusers/schedulers/scheduling_flow_match_euler_discrete.py index 97081785b0cf..68cf59ef70f4 100644 --- a/src/diffusers/schedulers/scheduling_flow_match_euler_discrete.py +++ b/src/diffusers/schedulers/scheduling_flow_match_euler_discrete.py @@ -444,12 +444,12 @@ def step( dt = sigma_next - sigma prev_sample = sample + dt * model_output - # Cast sample back to model compatible dtype - prev_sample = prev_sample.to(model_output.dtype) # upon completion increase step index by one if not use_per_token_timesteps: self._step_index += 1 + # Cast sample back to model compatible dtype + prev_sample = prev_sample.to(model_output.dtype) if not return_dict: return (prev_sample,) From 0cc19058eee7915feb5ce9297cac0fcac439322e Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Wed, 12 Mar 2025 08:34:07 +0100 Subject: [PATCH 08/17] update conversion script --- scripts/convert_ltx_to_diffusers.py | 19 ++++++++++++------- 1 file changed, 12 insertions(+), 7 deletions(-) diff --git a/scripts/convert_ltx_to_diffusers.py b/scripts/convert_ltx_to_diffusers.py index ce980b304f1b..a55e1be6513c 100644 --- a/scripts/convert_ltx_to_diffusers.py +++ b/scripts/convert_ltx_to_diffusers.py @@ -268,6 +268,8 @@ def get_vae_config(version: str) -> Dict[str, Any]: "scaling_factor": 1.0, "encoder_causal": True, "decoder_causal": False, + "spatial_compression_ratio": 32, + "temporal_compression_ratio": 8, } VAE_KEYS_RENAME_DICT.update(VAE_095_RENAME_DICT) return config @@ -346,13 +348,16 @@ def get_args(): for param in text_encoder.parameters(): param.data = param.data.contiguous() - scheduler = FlowMatchEulerDiscreteScheduler( - use_dynamic_shifting=True, - base_shift=0.95, - max_shift=2.05, - base_image_seq_len=1024, - max_image_seq_len=4096, - shift_terminal=0.1, + if args.version == "0.9.5": + scheduler = FlowMatchEulerDiscreteScheduler(use_dynamic_shifting=False) + else: + scheduler = FlowMatchEulerDiscreteScheduler( + use_dynamic_shifting=True, + base_shift=0.95, + max_shift=2.05, + base_image_seq_len=1024, + max_image_seq_len=4096, + shift_terminal=0.1, ) pipe = LTXPipeline( From 7c2151fef80f736a05a9cce7af1a7aa71953178c Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Wed, 12 Mar 2025 08:34:37 +0100 Subject: [PATCH 09/17] up --- src/diffusers/pipelines/ltx/pipeline_ltx_condition.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/src/diffusers/pipelines/ltx/pipeline_ltx_condition.py b/src/diffusers/pipelines/ltx/pipeline_ltx_condition.py index b82881bae81d..2208168d04a2 100644 --- a/src/diffusers/pipelines/ltx/pipeline_ltx_condition.py +++ b/src/diffusers/pipelines/ltx/pipeline_ltx_condition.py @@ -463,7 +463,6 @@ def check_inputs( @staticmethod - # adapted from diffusers.pipelines.ltx.pipeline_ltx.LTXPipeline._pack_latents def _prepare_video_ids(batch_size: int, num_frames: int, height: int, width: int, patch_size: int = 1, patch_size_t: int = 1, device: torch.device = None) -> torch.Tensor: latent_sample_coords = torch.meshgrid( @@ -480,7 +479,6 @@ def _prepare_video_ids(batch_size: int, num_frames: int, height: int, width: int @staticmethod - # adapted from diffusers.pipelines.ltx.pipeline_ltx.LTXPipeline._pack_latents def _scale_video_ids(video_ids: torch.Tensor, scale_factor: int = 32, scale_factor_t: int = 8, frame_index: int = 0, device: torch.device = None) -> torch.Tensor: scaled_latent_coords = ( @@ -493,7 +491,7 @@ def _scale_video_ids(video_ids: torch.Tensor, scale_factor: int = 32, scale_fact return scaled_latent_coords @staticmethod - # adapted from diffusers.pipelines.ltx.pipeline_ltx.LTXPipeline._pack_latents + # Copied from diffusers.pipelines.ltx.pipeline_ltx.LTXPipeline._pack_latents def _pack_latents(latents: torch.Tensor, patch_size: int = 1, patch_size_t: int = 1, device: torch.device = None) -> torch.Tensor: # Unpacked latents of shape are [B, C, F, H, W] are patched into tokens of shape [B, C, F // p_t, p_t, H // p, p, W // p, p]. # The patch dimensions are then permuted and collapsed into the channel dimension of shape: From 353728adc7e111783901e3e9b1c7f3eaec48f97e Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Wed, 12 Mar 2025 08:58:13 +0100 Subject: [PATCH 10/17] up --- scripts/convert_ltx_to_diffusers.py | 4 +- .../models/autoencoders/autoencoder_kl_ltx.py | 12 +- .../models/transformers/transformer_ltx.py | 42 +++-- .../pipelines/ltx/pipeline_ltx_condition.py | 164 +++++++++++------- .../scheduling_flow_match_euler_discrete.py | 4 +- 5 files changed, 141 insertions(+), 85 deletions(-) diff --git a/scripts/convert_ltx_to_diffusers.py b/scripts/convert_ltx_to_diffusers.py index a55e1be6513c..2e966d5d110b 100644 --- a/scripts/convert_ltx_to_diffusers.py +++ b/scripts/convert_ltx_to_diffusers.py @@ -350,7 +350,7 @@ def get_args(): if args.version == "0.9.5": scheduler = FlowMatchEulerDiscreteScheduler(use_dynamic_shifting=False) - else: + else: scheduler = FlowMatchEulerDiscreteScheduler( use_dynamic_shifting=True, base_shift=0.95, @@ -358,7 +358,7 @@ def get_args(): base_image_seq_len=1024, max_image_seq_len=4096, shift_terminal=0.1, - ) + ) pipe = LTXPipeline( scheduler=scheduler, diff --git a/src/diffusers/models/autoencoders/autoencoder_kl_ltx.py b/src/diffusers/models/autoencoders/autoencoder_kl_ltx.py index 974dc34d57e1..2b2f77a5509d 100644 --- a/src/diffusers/models/autoencoders/autoencoder_kl_ltx.py +++ b/src/diffusers/models/autoencoders/autoencoder_kl_ltx.py @@ -1144,8 +1144,16 @@ def __init__( self.register_buffer("latents_mean", latents_mean, persistent=True) self.register_buffer("latents_std", latents_std, persistent=True) - self.spatial_compression_ratio = patch_size * 2 ** sum(spatio_temporal_scaling) if spatial_compression_ratio is None else spatial_compression_ratio - self.temporal_compression_ratio = patch_size_t * 2 ** sum(spatio_temporal_scaling) if temporal_compression_ratio is None else temporal_compression_ratio + self.spatial_compression_ratio = ( + patch_size * 2 ** sum(spatio_temporal_scaling) + if spatial_compression_ratio is None + else spatial_compression_ratio + ) + self.temporal_compression_ratio = ( + patch_size_t * 2 ** sum(spatio_temporal_scaling) + if temporal_compression_ratio is None + else temporal_compression_ratio + ) # When decoding a batch of video latents at a time, one can save memory by slicing across the batch dimension # to perform decoding of a single video latent at a time. diff --git a/src/diffusers/models/transformers/transformer_ltx.py b/src/diffusers/models/transformers/transformer_ltx.py index 937c20a3305c..79f006f7eb23 100644 --- a/src/diffusers/models/transformers/transformer_ltx.py +++ b/src/diffusers/models/transformers/transformer_ltx.py @@ -115,8 +115,16 @@ def __init__( self.theta = theta self._causal_rope_fix = _causal_rope_fix - - def _prepare_video_coords(self, batch_size: int, num_frames: int, height: int, width: int, rope_interpolation_scale: Tuple[torch.Tensor, float, float], device: torch.device) -> torch.Tensor: + def _prepare_video_coords( + self, + batch_size: int, + num_frames: int, + height: int, + width: int, + rope_interpolation_scale: Tuple[torch.Tensor, float, float], + frame_rate: float, + device: torch.device, + ) -> torch.Tensor: # Always compute rope in fp32 grid_h = torch.arange(height, dtype=torch.float32, device=device) grid_w = torch.arange(width, dtype=torch.float32, device=device) @@ -132,9 +140,7 @@ def _prepare_video_coords(self, batch_size: int, num_frames: int, height: int, w grid[:, 2:3] = grid[:, 2:3] * rope_interpolation_scale[2] * self.patch_size / self.base_width else: if not self._causal_rope_fix: - grid[:, 0:1] = ( - grid[:, 0:1] * rope_interpolation_scale[0:1] * self.patch_size_t / self.base_num_frames - ) + grid[:, 0:1] = grid[:, 0:1] * rope_interpolation_scale[0:1] * self.patch_size_t / self.base_num_frames else: grid[:, 0:1] = ( ((grid[:, 0:1] - 1) * rope_interpolation_scale[0:1] + 1 / frame_rate).clamp(min=0) @@ -145,9 +151,8 @@ def _prepare_video_coords(self, batch_size: int, num_frames: int, height: int, w grid[:, 2:3] = grid[:, 2:3] * rope_interpolation_scale[2:3] * self.patch_size / self.base_width grid = grid.flatten(2, 4).transpose(1, 2) - + return grid - def forward( self, @@ -162,14 +167,22 @@ def forward( batch_size = hidden_states.size(0) if video_coords is None: - grid = self._prepare_video_coords(batch_size, num_frames, height, width, rope_interpolation_scale=rope_interpolation_scale, device=hidden_states.device) + grid = self._prepare_video_coords( + batch_size, + num_frames, + height, + width, + rope_interpolation_scale=rope_interpolation_scale, + frame_rate=frame_rate, + device=hidden_states.device, + ) else: grid = torch.stack( [ - video_coords[:, 0] / self.base_num_frames, - video_coords[:, 1] / self.base_height, - video_coords[:, 2] / self.base_width - ], + video_coords[:, 0] / self.base_num_frames, + video_coords[:, 1] / self.base_height, + video_coords[:, 2] / self.base_width, + ], dim=-1, ) @@ -432,8 +445,9 @@ def forward( msg = "Passing a tuple for `rope_interpolation_scale` is deprecated and will be removed in v0.34.0." deprecate("rope_interpolation_scale", "0.34.0", msg) - - image_rotary_emb = self.rope(hidden_states, num_frames, height, width, frame_rate, rope_interpolation_scale, video_coords) + image_rotary_emb = self.rope( + hidden_states, num_frames, height, width, frame_rate, rope_interpolation_scale, video_coords + ) # convert encoder_attention_mask to a bias the same way we do for attention_mask if encoder_attention_mask is not None and encoder_attention_mask.ndim == 2: diff --git a/src/diffusers/pipelines/ltx/pipeline_ltx_condition.py b/src/diffusers/pipelines/ltx/pipeline_ltx_condition.py index 2208168d04a2..64af9010eab2 100644 --- a/src/diffusers/pipelines/ltx/pipeline_ltx_condition.py +++ b/src/diffusers/pipelines/ltx/pipeline_ltx_condition.py @@ -14,9 +14,8 @@ import inspect from dataclasses import dataclass -from typing import Any, Callable, Dict, List, Optional, Tuple, Union +from typing import Any, Callable, Dict, List, Optional, Union -import numpy as np import PIL.Image import torch from transformers import T5EncoderModel, T5TokenizerFast @@ -100,19 +99,14 @@ def linear_quadratic_schedule(num_steps, threshold_noise=0.025, linear_steps=Non linear_steps = num_steps // 2 if num_steps < 2: return torch.tensor([1.0]) - linear_sigma_schedule = [ - i * threshold_noise / linear_steps for i in range(linear_steps) - ] + linear_sigma_schedule = [i * threshold_noise / linear_steps for i in range(linear_steps)] threshold_noise_step_diff = linear_steps - threshold_noise * num_steps quadratic_steps = num_steps - linear_steps quadratic_coef = threshold_noise_step_diff / (linear_steps * quadratic_steps**2) - linear_coef = threshold_noise / linear_steps - 2 * threshold_noise_step_diff / ( - quadratic_steps**2 - ) + linear_coef = threshold_noise / linear_steps - 2 * threshold_noise_step_diff / (quadratic_steps**2) const = quadratic_coef * (linear_steps**2) quadratic_sigma_schedule = [ - quadratic_coef * (i**2) + linear_coef * i + const - for i in range(linear_steps, num_steps) + quadratic_coef * (i**2) + linear_coef * i + const for i in range(linear_steps, num_steps) ] sigma_schedule = linear_sigma_schedule + quadratic_sigma_schedule + [1.0] sigma_schedule = [1.0 - x for x in sigma_schedule] @@ -461,10 +455,16 @@ def check_inputs( f" {negative_prompt_attention_mask.shape}." ) - @staticmethod - def _prepare_video_ids(batch_size: int, num_frames: int, height: int, width: int, patch_size: int = 1, patch_size_t: int = 1, device: torch.device = None) -> torch.Tensor: - + def _prepare_video_ids( + batch_size: int, + num_frames: int, + height: int, + width: int, + patch_size: int = 1, + patch_size_t: int = 1, + device: torch.device = None, + ) -> torch.Tensor: latent_sample_coords = torch.meshgrid( torch.arange(0, num_frames, patch_size_t, device=device), torch.arange(0, height, patch_size, device=device), @@ -476,23 +476,27 @@ def _prepare_video_ids(batch_size: int, num_frames: int, height: int, width: int latent_coords = latent_coords.reshape(batch_size, -1, num_frames * height * width) return latent_coords - @staticmethod - def _scale_video_ids(video_ids: torch.Tensor, scale_factor: int = 32, scale_factor_t: int = 8, frame_index: int = 0, device: torch.device = None) -> torch.Tensor: - + def _scale_video_ids( + video_ids: torch.Tensor, + scale_factor: int = 32, + scale_factor_t: int = 8, + frame_index: int = 0, + device: torch.device = None, + ) -> torch.Tensor: scaled_latent_coords = ( - video_ids * - torch.tensor([scale_factor_t, scale_factor, scale_factor], device=video_ids.device)[None, :, None] + video_ids + * torch.tensor([scale_factor_t, scale_factor, scale_factor], device=video_ids.device)[None, :, None] ) - scaled_latent_coords[:, 0] = (scaled_latent_coords[:, 0] + 1 - scale_factor_t).clamp(min=0) + scaled_latent_coords[:, 0] = (scaled_latent_coords[:, 0] + 1 - scale_factor_t).clamp(min=0) scaled_latent_coords[:, 0] += frame_index return scaled_latent_coords - + @staticmethod # Copied from diffusers.pipelines.ltx.pipeline_ltx.LTXPipeline._pack_latents - def _pack_latents(latents: torch.Tensor, patch_size: int = 1, patch_size_t: int = 1, device: torch.device = None) -> torch.Tensor: + def _pack_latents(latents: torch.Tensor, patch_size: int = 1, patch_size_t: int = 1) -> torch.Tensor: # Unpacked latents of shape are [B, C, F, H, W] are patched into tokens of shape [B, C, F // p_t, p_t, H // p, p, W // p, p]. # The patch dimensions are then permuted and collapsed into the channel dimension of shape: # [B, F // p_t * H // p * W // p, C * p_t * p * p] (an ndim=3 tensor). @@ -501,7 +505,6 @@ def _pack_latents(latents: torch.Tensor, patch_size: int = 1, patch_size_t: int post_patch_num_frames = num_frames // patch_size_t post_patch_height = height // patch_size post_patch_width = width // patch_size - latents = latents.reshape( batch_size, -1, @@ -550,12 +553,10 @@ def _denormalize_latents( latents = latents * latents_std / scaling_factor + latents_mean return latents - - def trim_conditioning_sequence( - self, start_frame: int, sequence_num_frames: int, target_num_frames: int - ): + def trim_conditioning_sequence(self, start_frame: int, sequence_num_frames: int, target_num_frames: int): """ Trim a conditioning sequence to the allowed number of frames. + Args: start_frame (int): The target frame number of the first frame in the sequence. sequence_num_frames (int): The number of frames in the sequence. @@ -569,7 +570,6 @@ def trim_conditioning_sequence( num_frames = (num_frames - 1) // scale_factor * scale_factor + 1 return num_frames - @staticmethod def add_noise_to_image_conditioning_latents( t: float, @@ -581,8 +581,8 @@ def add_noise_to_image_conditioning_latents( eps=1e-6, ): """ - Add timestep-dependent noise to the hard-conditioning latents. - This helps with motion continuity, especially when conditioned on a single frame. + Add timestep-dependent noise to the hard-conditioning latents. This helps with motion continuity, especially + when conditioned on a single frame. """ generator = torch.Generator(device="cpu").manual_seed(0) noise = randn_tensor( @@ -612,7 +612,6 @@ def prepare_latents( device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None, ) -> None: - num_latent_frames = (num_frames - 1) // self.vae_temporal_compression_ratio + 1 latent_height = height // self.vae_spatial_compression_ratio latent_width = width // self.vae_spatial_compression_ratio @@ -622,7 +621,7 @@ def prepare_latents( latents = torch.load("/raid/yiyi/LTX-Video/init_latents.pt").to(device, dtype=dtype) condition_latent_frames_mask = torch.zeros((batch_size, num_latent_frames), device=device, dtype=torch.float32) - + extra_conditioning_latents = [] extra_conditioning_video_ids = [] extra_conditioning_mask = [] @@ -631,7 +630,7 @@ def prepare_latents( condition_latents = retrieve_latents(self.vae.encode(data), generator=generator) condition_latents = self._normalize_latents(condition_latents, self.vae.latents_mean, self.vae.latents_std) condition_latents = torch.load("/raid/yiyi/LTX-Video/conditioning_latents.pt").to(device, dtype=dtype) - + num_data_frames = data.size(2) num_cond_frames = condition_latents.size(2) @@ -652,34 +651,67 @@ def prepare_latents( start_frame = frame_index // self.vae_temporal_compression_ratio + num_prefix_latent_frames end_frame = start_frame + num_cond_frames - num_prefix_latent_frames latents[:, :, start_frame:end_frame] = torch.lerp( - latents[:, :, start_frame:end_frame], condition_latents[:, :, num_prefix_latent_frames:], strength + latents[:, :, start_frame:end_frame], + condition_latents[:, :, num_prefix_latent_frames:], + strength, ) condition_latent_frames_mask[:, start_frame:end_frame] = strength condition_latents = condition_latents[:, :, :num_prefix_latent_frames] - noise = randn_tensor(condition_latents.shape, generator=generator, device=device, dtype=dtype) noise = torch.load("/raid/yiyi/LTX-Video/noise.pt").to(device, dtype=dtype) condition_latents = torch.lerp(noise, condition_latents, strength) - condition_video_ids = self._prepare_video_ids(batch_size, condition_latents.size(2), latent_height, latent_width, patch_size=self.transformer_spatial_patch_size, patch_size_t=self.transformer_temporal_patch_size, device=device) - condition_video_ids = self._scale_video_ids(condition_video_ids, scale_factor=self.vae_spatial_compression_ratio, scale_factor_t=self.vae_temporal_compression_ratio, frame_index=frame_index, device=device) - condition_latents = self._pack_latents(condition_latents, self.transformer_spatial_patch_size, self.transformer_temporal_patch_size, device) - condition_conditioning_mask = torch.full(condition_latents.shape[:2], strength, device=device, dtype=dtype) - - + condition_video_ids = self._prepare_video_ids( + batch_size, + condition_latents.size(2), + latent_height, + latent_width, + patch_size=self.transformer_spatial_patch_size, + patch_size_t=self.transformer_temporal_patch_size, + device=device, + ) + condition_video_ids = self._scale_video_ids( + condition_video_ids, + scale_factor=self.vae_spatial_compression_ratio, + scale_factor_t=self.vae_temporal_compression_ratio, + frame_index=frame_index, + device=device, + ) + condition_latents = self._pack_latents( + condition_latents, + self.transformer_spatial_patch_size, + self.transformer_temporal_patch_size, + ) + condition_conditioning_mask = torch.full( + condition_latents.shape[:2], strength, device=device, dtype=dtype + ) + extra_conditioning_latents.append(condition_latents) extra_conditioning_video_ids.append(condition_video_ids) extra_conditioning_mask.append(condition_conditioning_mask) extra_conditioning_num_latents += condition_latents.size(1) - - video_ids = self._prepare_video_ids(batch_size, num_latent_frames, latent_height, latent_width, patch_size_t = self.transformer_temporal_patch_size, patch_size = self.transformer_spatial_patch_size, device=device) - video_ids_scaled = self._scale_video_ids(video_ids, scale_factor=self.vae_spatial_compression_ratio, scale_factor_t=self.vae_temporal_compression_ratio, frame_index=0, device=device) - latents = self._pack_latents(latents, self.transformer_spatial_patch_size, self.transformer_temporal_patch_size, device) - conditioning_mask = condition_latent_frames_mask.gather( - 1, video_ids[:, 0] + + video_ids = self._prepare_video_ids( + batch_size, + num_latent_frames, + latent_height, + latent_width, + patch_size_t=self.transformer_temporal_patch_size, + patch_size=self.transformer_spatial_patch_size, + device=device, ) - + video_ids_scaled = self._scale_video_ids( + video_ids, + scale_factor=self.vae_spatial_compression_ratio, + scale_factor_t=self.vae_temporal_compression_ratio, + frame_index=0, + device=device, + ) + latents = self._pack_latents( + latents, self.transformer_spatial_patch_size, self.transformer_temporal_patch_size + ) + conditioning_mask = condition_latent_frames_mask.gather(1, video_ids[:, 0]) if len(extra_conditioning_latents) > 0: latents = torch.cat([*extra_conditioning_latents, latents], dim=1) @@ -688,7 +720,6 @@ def prepare_latents( return latents, conditioning_mask, video_ids, extra_conditioning_num_latents - @property def guidance_scale(self): return self._guidance_scale @@ -874,7 +905,7 @@ def __call__( prompt_attention_mask = torch.cat([negative_prompt_attention_mask, prompt_attention_mask], dim=0) vae_dtype = self.vae.dtype - + conditioning_tensors = [] conditioning_strengths = [] conditioning_start_frames = [] @@ -885,7 +916,9 @@ def __call__( elif condition.video is not None: condition_tensor = self.video_processor.preprocess_video(condition.video, height, width) num_frames_input = condition_tensor.size(2) - num_frames_output = self.trim_conditioning_sequence(condition.frame_index, num_frames_input, num_frames) + num_frames_output = self.trim_conditioning_sequence( + condition.frame_index, num_frames_input, num_frames + ) condition_tensor = condition_tensor[:, :, :num_frames_output] condition_tensor = condition_tensor.to(device, dtype=vae_dtype) else: @@ -918,13 +951,12 @@ def __call__( video_coords = video_coords.float() video_coords[:, 0] = video_coords[:, 0] * (1.0 / frame_rate) - + init_latents = latents.clone() if self.do_classifier_free_guidance: video_coords = torch.cat([video_coords, video_coords], dim=0) - # 5. Prepare timesteps latent_num_frames = (num_frames - 1) // self.vae_temporal_compression_ratio + 1 latent_height = height // self.vae_spatial_compression_ratio @@ -950,7 +982,7 @@ def __call__( # Add timestep-dependent noise to the hard-conditioning latents # This helps with motion continuity, especially when conditioned on a single frame latents = self.add_noise_to_image_conditioning_latents( - t/1000.0, + t / 1000.0, init_latents, latents, image_cond_noise_scale, @@ -958,9 +990,12 @@ def __call__( generator, ) - latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents - conditioning_mask_model_input = torch.cat([conditioning_mask, conditioning_mask]) if self.do_classifier_free_guidance else conditioning_mask + conditioning_mask_model_input = ( + torch.cat([conditioning_mask, conditioning_mask]) + if self.do_classifier_free_guidance + else conditioning_mask + ) latent_model_input = latent_model_input.to(prompt_embeds.dtype) # broadcast to batch dimension in a way that's compatible with ONNX/Core ML @@ -984,10 +1019,9 @@ def __call__( denoised_latents = self.scheduler.step(-noise_pred, timestep, latents, return_dict=False)[0] t_eps = 1e-6 - tokens_to_denoise_mask = (t/1000 - t_eps < (1.0 - conditioning_mask)).unsqueeze(-1) + tokens_to_denoise_mask = (t / 1000 - t_eps < (1.0 - conditioning_mask)).unsqueeze(-1) latents = torch.where(tokens_to_denoise_mask, denoised_latents, latents) - if callback_on_step_end is not None: callback_kwargs = {} for k in callback_on_step_end_tensor_inputs: @@ -1006,14 +1040,14 @@ def __call__( latents = latents[:, extra_conditioning_num_latents:] latents = self._unpack_latents( - latents, - latent_num_frames, - latent_height, - latent_width, - self.transformer_spatial_patch_size, - self.transformer_temporal_patch_size, - ) - + latents, + latent_num_frames, + latent_height, + latent_width, + self.transformer_spatial_patch_size, + self.transformer_temporal_patch_size, + ) + if output_type == "latent": video = latents else: diff --git a/src/diffusers/schedulers/scheduling_flow_match_euler_discrete.py b/src/diffusers/schedulers/scheduling_flow_match_euler_discrete.py index 68cf59ef70f4..b391e487188a 100644 --- a/src/diffusers/schedulers/scheduling_flow_match_euler_discrete.py +++ b/src/diffusers/schedulers/scheduling_flow_match_euler_discrete.py @@ -431,7 +431,7 @@ def step( if use_per_token_timesteps: t_eps = 1e-6 - per_token_sigmas = timestep/self.config.num_train_timesteps + per_token_sigmas = timestep / self.config.num_train_timesteps sigmas = self.sigmas[:, None, None] lower_mask = sigmas < per_token_sigmas[None] - t_eps @@ -442,7 +442,7 @@ def step( sigma = self.sigmas[self.step_index] sigma_next = self.sigmas[self.step_index + 1] dt = sigma_next - sigma - + prev_sample = sample + dt * model_output # upon completion increase step index by one From d85d21c87cf16d4f39fca809cc9c9fd50db2ac76 Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Wed, 12 Mar 2025 09:00:54 +0100 Subject: [PATCH 11/17] up --- src/diffusers/pipelines/ltx/pipeline_ltx_condition.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/pipelines/ltx/pipeline_ltx_condition.py b/src/diffusers/pipelines/ltx/pipeline_ltx_condition.py index 64af9010eab2..2c4cc716b7e3 100644 --- a/src/diffusers/pipelines/ltx/pipeline_ltx_condition.py +++ b/src/diffusers/pipelines/ltx/pipeline_ltx_condition.py @@ -946,7 +946,7 @@ def __call__( num_frames=num_frames, generator=generator, device=device, - dtype=prompt_embeds.dtype, + dtype=torch.float32, ) video_coords = video_coords.float() From 445cf58e58a500bf7f5673740d00a57bb69d8f85 Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Wed, 12 Mar 2025 09:03:47 +0100 Subject: [PATCH 12/17] up --- src/diffusers/models/normalization.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/diffusers/models/normalization.py b/src/diffusers/models/normalization.py index db054ee117dd..383388ca543f 100644 --- a/src/diffusers/models/normalization.py +++ b/src/diffusers/models/normalization.py @@ -550,8 +550,7 @@ def forward(self, hidden_states): hidden_states = torch_npu.npu_rms_norm(hidden_states, self.weight, epsilon=self.eps)[0] if self.bias is not None: hidden_states = hidden_states + self.bias - # YiYi TODO: testing only, remove this change before merging - elif is_torch_version(">=", "3.3"): + elif is_torch_version(">=", "2.4"): if self.weight is not None: # convert into half-precision if necessary if self.weight.dtype in [torch.float16, torch.bfloat16]: From fb46d217a0288e078fd2a5d106b1e6190c0815fa Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Wed, 12 Mar 2025 09:08:27 +0100 Subject: [PATCH 13/17] up more --- src/diffusers/models/transformers/transformer_ltx.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/diffusers/models/transformers/transformer_ltx.py b/src/diffusers/models/transformers/transformer_ltx.py index 79f006f7eb23..3c79921d85c5 100644 --- a/src/diffusers/models/transformers/transformer_ltx.py +++ b/src/diffusers/models/transformers/transformer_ltx.py @@ -508,6 +508,5 @@ def apply_rotary_emb(x, freqs): cos, sin = freqs x_real, x_imag = x.unflatten(2, (-1, 2)).unbind(-1) # [B, S, H, D // 2] x_rotated = torch.stack([-x_imag, x_real], dim=-1).flatten(2) - # YiYi TODO: testing only, remove this change before merging - out = (x * cos.to(x.dtype) + x_rotated * sin.to(x.dtype)).to(x.dtype) + out = (x.float() * cos + x_rotated.float() * sin).to(x.dtype) return out From 64df9afdea29310ebf7dfdfc15f20320dd4bcfc6 Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Wed, 12 Mar 2025 09:16:55 +0100 Subject: [PATCH 14/17] up --- .../pipelines/ltx/pipeline_ltx_condition.py | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/src/diffusers/pipelines/ltx/pipeline_ltx_condition.py b/src/diffusers/pipelines/ltx/pipeline_ltx_condition.py index 2c4cc716b7e3..59b31a3a8652 100644 --- a/src/diffusers/pipelines/ltx/pipeline_ltx_condition.py +++ b/src/diffusers/pipelines/ltx/pipeline_ltx_condition.py @@ -584,7 +584,8 @@ def add_noise_to_image_conditioning_latents( Add timestep-dependent noise to the hard-conditioning latents. This helps with motion continuity, especially when conditioned on a single frame. """ - generator = torch.Generator(device="cpu").manual_seed(0) + # YiYi TODO: testing only, remove this change before merging + # generator = torch.Generator(device="cpu").manual_seed(0) noise = randn_tensor( latents.shape, generator=generator, @@ -618,7 +619,8 @@ def prepare_latents( shape = (batch_size, num_channels_latents, num_latent_frames, latent_height, latent_width) latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) - latents = torch.load("/raid/yiyi/LTX-Video/init_latents.pt").to(device, dtype=dtype) + # YiYi TODO: testing only, remove this change before merging + # latents = torch.load("/raid/yiyi/LTX-Video/init_latents.pt").to(device, dtype=dtype) condition_latent_frames_mask = torch.zeros((batch_size, num_latent_frames), device=device, dtype=torch.float32) @@ -628,8 +630,9 @@ def prepare_latents( extra_conditioning_num_latents = 0 for data, strength, frame_index in zip(conditions, condition_strength, condition_frame_index): condition_latents = retrieve_latents(self.vae.encode(data), generator=generator) - condition_latents = self._normalize_latents(condition_latents, self.vae.latents_mean, self.vae.latents_std) - condition_latents = torch.load("/raid/yiyi/LTX-Video/conditioning_latents.pt").to(device, dtype=dtype) + condition_latents = self._normalize_latents( + condition_latents, self.vae.latents_mean, self.vae.latents_std + ).to(device, dtype=dtype) num_data_frames = data.size(2) num_cond_frames = condition_latents.size(2) @@ -659,7 +662,8 @@ def prepare_latents( condition_latents = condition_latents[:, :, :num_prefix_latent_frames] noise = randn_tensor(condition_latents.shape, generator=generator, device=device, dtype=dtype) - noise = torch.load("/raid/yiyi/LTX-Video/noise.pt").to(device, dtype=dtype) + # YiYi TODO: testing only, remove this change before merging + # noise = torch.load("/raid/yiyi/LTX-Video/noise.pt").to(device, dtype=dtype) condition_latents = torch.lerp(noise, condition_latents, strength) condition_video_ids = self._prepare_video_ids( From 00e9670d736fc99294e9008166e42ffdd69056da Mon Sep 17 00:00:00 2001 From: YiYi Xu Date: Thu, 13 Mar 2025 09:22:08 -1000 Subject: [PATCH 15/17] Apply suggestions from code review Co-authored-by: Aryan --- src/diffusers/pipelines/ltx/pipeline_ltx_condition.py | 6 ------ 1 file changed, 6 deletions(-) diff --git a/src/diffusers/pipelines/ltx/pipeline_ltx_condition.py b/src/diffusers/pipelines/ltx/pipeline_ltx_condition.py index 59b31a3a8652..3436beb83f27 100644 --- a/src/diffusers/pipelines/ltx/pipeline_ltx_condition.py +++ b/src/diffusers/pipelines/ltx/pipeline_ltx_condition.py @@ -584,8 +584,6 @@ def add_noise_to_image_conditioning_latents( Add timestep-dependent noise to the hard-conditioning latents. This helps with motion continuity, especially when conditioned on a single frame. """ - # YiYi TODO: testing only, remove this change before merging - # generator = torch.Generator(device="cpu").manual_seed(0) noise = randn_tensor( latents.shape, generator=generator, @@ -619,8 +617,6 @@ def prepare_latents( shape = (batch_size, num_channels_latents, num_latent_frames, latent_height, latent_width) latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) - # YiYi TODO: testing only, remove this change before merging - # latents = torch.load("/raid/yiyi/LTX-Video/init_latents.pt").to(device, dtype=dtype) condition_latent_frames_mask = torch.zeros((batch_size, num_latent_frames), device=device, dtype=torch.float32) @@ -662,8 +658,6 @@ def prepare_latents( condition_latents = condition_latents[:, :, :num_prefix_latent_frames] noise = randn_tensor(condition_latents.shape, generator=generator, device=device, dtype=dtype) - # YiYi TODO: testing only, remove this change before merging - # noise = torch.load("/raid/yiyi/LTX-Video/noise.pt").to(device, dtype=dtype) condition_latents = torch.lerp(noise, condition_latents, strength) condition_video_ids = self._prepare_video_ids( From ed2f7e3cd021d6a61e37a145e1109b658761216d Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Fri, 14 Mar 2025 10:03:14 +0100 Subject: [PATCH 16/17] add docs tests + more refactor --- docs/source/en/api/pipelines/ltx_video.md | 6 + scripts/convert_ltx_to_diffusers.py | 2 + src/diffusers/__init__.py | 2 + .../models/autoencoders/autoencoder_kl_ltx.py | 8 +- src/diffusers/pipelines/__init__.py | 4 +- src/diffusers/pipelines/ltx/__init__.py | 2 + .../pipelines/ltx/pipeline_ltx_condition.py | 122 ++++++-- .../scheduling_flow_match_euler_discrete.py | 18 +- tests/pipelines/ltx/test_ltx_condition.py | 284 ++++++++++++++++++ 9 files changed, 409 insertions(+), 39 deletions(-) create mode 100644 tests/pipelines/ltx/test_ltx_condition.py diff --git a/docs/source/en/api/pipelines/ltx_video.md b/docs/source/en/api/pipelines/ltx_video.md index f31c621293fc..4bc22c0f9f6c 100644 --- a/docs/source/en/api/pipelines/ltx_video.md +++ b/docs/source/en/api/pipelines/ltx_video.md @@ -196,6 +196,12 @@ export_to_video(video, "ship.mp4", fps=24) - all - __call__ +## LTXConditionPipeline + +[[autodoc]] LTXConditionPipeline + - all + - __call__ + ## LTXPipelineOutput [[autodoc]] pipelines.ltx.pipeline_output.LTXPipelineOutput diff --git a/scripts/convert_ltx_to_diffusers.py b/scripts/convert_ltx_to_diffusers.py index 2e966d5d110b..52a7791bb2fc 100644 --- a/scripts/convert_ltx_to_diffusers.py +++ b/scripts/convert_ltx_to_diffusers.py @@ -105,6 +105,7 @@ def remove_keys_(key: str, state_dict: Dict[str, Any]): "per_channel_statistics.mean-of-means": remove_keys_, "per_channel_statistics.mean-of-stds": remove_keys_, "model.diffusion_model": remove_keys_, + "decoder.timestep_scale_multiplier": remove_keys_, } @@ -270,6 +271,7 @@ def get_vae_config(version: str) -> Dict[str, Any]: "decoder_causal": False, "spatial_compression_ratio": 32, "temporal_compression_ratio": 8, + "timestep_scale_multiplier": 1000.0, } VAE_KEYS_RENAME_DICT.update(VAE_095_RENAME_DICT) return config diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index cfb0bd08f818..704dee331f7e 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -347,6 +347,7 @@ "LDMTextToImagePipeline", "LEditsPPPipelineStableDiffusion", "LEditsPPPipelineStableDiffusionXL", + "LTXConditionPipeline", "LTXImageToVideoPipeline", "LTXPipeline", "Lumina2Text2ImgPipeline", @@ -857,6 +858,7 @@ LDMTextToImagePipeline, LEditsPPPipelineStableDiffusion, LEditsPPPipelineStableDiffusionXL, + LTXConditionPipeline, LTXImageToVideoPipeline, LTXPipeline, Lumina2Text2ImgPipeline, diff --git a/src/diffusers/models/autoencoders/autoencoder_kl_ltx.py b/src/diffusers/models/autoencoders/autoencoder_kl_ltx.py index 2b2f77a5509d..9384f8863b16 100644 --- a/src/diffusers/models/autoencoders/autoencoder_kl_ltx.py +++ b/src/diffusers/models/autoencoders/autoencoder_kl_ltx.py @@ -921,12 +921,14 @@ def __init__( timestep_conditioning: bool = False, upsample_residual: Tuple[bool, ...] = (False, False, False, False), upsample_factor: Tuple[bool, ...] = (1, 1, 1, 1), + timestep_scale_multiplier: float = 1.0, ) -> None: super().__init__() self.patch_size = patch_size self.patch_size_t = patch_size_t self.out_channels = out_channels * patch_size**2 + self.timestep_scale_multiplier = timestep_scale_multiplier block_out_channels = tuple(reversed(block_out_channels)) spatio_temporal_scaling = tuple(reversed(spatio_temporal_scaling)) @@ -981,9 +983,7 @@ def __init__( # timestep embedding self.time_embedder = None self.scale_shift_table = None - self.timestep_scale_multiplier = None if timestep_conditioning: - self.timestep_scale_multiplier = nn.Parameter(torch.tensor(1000.0, dtype=torch.float32)) self.time_embedder = PixArtAlphaCombinedTimestepSizeEmbeddings(output_channel * 2, 0) self.scale_shift_table = nn.Parameter(torch.randn(2, output_channel) / output_channel**0.5) @@ -992,7 +992,7 @@ def __init__( def forward(self, hidden_states: torch.Tensor, temb: Optional[torch.Tensor] = None) -> torch.Tensor: hidden_states = self.conv_in(hidden_states) - if self.timestep_scale_multiplier is not None: + if temb is not None: temb = temb * self.timestep_scale_multiplier if torch.is_grad_enabled() and self.gradient_checkpointing: @@ -1107,6 +1107,7 @@ def __init__( decoder_causal: bool = False, spatial_compression_ratio: int = None, temporal_compression_ratio: int = None, + timestep_scale_multiplier: float = 1.0, ) -> None: super().__init__() @@ -1137,6 +1138,7 @@ def __init__( inject_noise=decoder_inject_noise, upsample_residual=upsample_residual, upsample_factor=upsample_factor, + timestep_scale_multiplier=timestep_scale_multiplier, ) latents_mean = torch.zeros((latent_channels,), requires_grad=False) diff --git a/src/diffusers/pipelines/__init__.py b/src/diffusers/pipelines/__init__.py index e99162e7a7fe..af5ffdca2152 100644 --- a/src/diffusers/pipelines/__init__.py +++ b/src/diffusers/pipelines/__init__.py @@ -260,7 +260,7 @@ ] ) _import_structure["latte"] = ["LattePipeline"] - _import_structure["ltx"] = ["LTXPipeline", "LTXImageToVideoPipeline"] + _import_structure["ltx"] = ["LTXPipeline", "LTXImageToVideoPipeline", "LTXConditionPipeline"] _import_structure["lumina"] = ["LuminaText2ImgPipeline"] _import_structure["lumina2"] = ["Lumina2Text2ImgPipeline"] _import_structure["marigold"].extend( @@ -610,7 +610,7 @@ LEditsPPPipelineStableDiffusion, LEditsPPPipelineStableDiffusionXL, ) - from .ltx import LTXImageToVideoPipeline, LTXPipeline + from .ltx import LTXConditionPipeline, LTXImageToVideoPipeline, LTXPipeline from .lumina import LuminaText2ImgPipeline from .lumina2 import Lumina2Text2ImgPipeline from .marigold import ( diff --git a/src/diffusers/pipelines/ltx/__init__.py b/src/diffusers/pipelines/ltx/__init__.py index 20cc1c216522..199e730d9b4d 100644 --- a/src/diffusers/pipelines/ltx/__init__.py +++ b/src/diffusers/pipelines/ltx/__init__.py @@ -23,6 +23,7 @@ _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) else: _import_structure["pipeline_ltx"] = ["LTXPipeline"] + _import_structure["pipeline_ltx_condition"] = ["LTXConditionPipeline"] _import_structure["pipeline_ltx_image2video"] = ["LTXImageToVideoPipeline"] if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: @@ -34,6 +35,7 @@ from ...utils.dummy_torch_and_transformers_objects import * else: from .pipeline_ltx import LTXPipeline + from .pipeline_ltx_condition import LTXConditionPipeline from .pipeline_ltx_image2video import LTXImageToVideoPipeline else: diff --git a/src/diffusers/pipelines/ltx/pipeline_ltx_condition.py b/src/diffusers/pipelines/ltx/pipeline_ltx_condition.py index 3436beb83f27..ff504dc7582b 100644 --- a/src/diffusers/pipelines/ltx/pipeline_ltx_condition.py +++ b/src/diffusers/pipelines/ltx/pipeline_ltx_condition.py @@ -21,6 +21,7 @@ from transformers import T5EncoderModel, T5TokenizerFast from ...callbacks import MultiPipelineCallbacks, PipelineCallback +from ...image_processor import PipelineImageInput from ...loaders import FromSingleFileMixin, LTXVideoLoraLoaderMixin from ...models.autoencoders import AutoencoderKLLTXVideo from ...models.transformers import LTXVideoTransformer3DModel @@ -45,12 +46,11 @@ Examples: ```py >>> import torch - >>> from diffusers import LTXImageToVideoPipeline + >>> from diffusers import LTXConditionPipeline >>> from diffusers.utils import export_to_video, load_image - >>> pipe = LTXImageToVideoPipeline.from_pretrained("Lightricks/LTX-Video", torch_dtype=torch.bfloat16) + >>> pipe = LTXConditionPipeline.from_pretrained("YiYiXu/ltx-95", torch_dtype=torch.bfloat16) >>> pipe.to("cuda") - >>> image = load_image( ... "https://huggingface.co/datasets/a-r-r-o-w/tiny-meme-dataset-captioned/resolve/main/images/8.png" ... ) @@ -405,6 +405,11 @@ def encode_prompt( def check_inputs( self, prompt, + conditions, + image, + video, + frame_index, + strength, height, width, callback_on_step_end_tensor_inputs=None, @@ -455,6 +460,26 @@ def check_inputs( f" {negative_prompt_attention_mask.shape}." ) + if conditions is not None and (image is not None or video is not None): + raise ValueError("If `conditions` is provided, `image` and `video` must not be provided.") + + if conditions is None and (image is None and video is None): + raise ValueError("If `conditions` is not provided, `image` or `video` must be provided.") + + if conditions is None: + if isinstance(image, list) and isinstance(frame_index, list) and len(image) != len(frame_index): + raise ValueError( + "If `conditions` is not provided, `image` and `frame_index` must be of the same length." + ) + elif isinstance(image, list) and isinstance(strength, list) and len(image) != len(strength): + raise ValueError("If `conditions` is not provided, `image` and `strength` must be of the same length.") + elif isinstance(video, list) and isinstance(frame_index, list) and len(video) != len(frame_index): + raise ValueError( + "If `conditions` is not provided, `video` and `frame_index` must be of the same length." + ) + elif isinstance(video, list) and isinstance(strength, list) and len(video) != len(strength): + raise ValueError("If `conditions` is not provided, `video` and `strength` must be of the same length.") + @staticmethod def _prepare_video_ids( batch_size: int, @@ -699,7 +724,8 @@ def prepare_latents( patch_size=self.transformer_spatial_patch_size, device=device, ) - video_ids_scaled = self._scale_video_ids( + conditioning_mask = condition_latent_frames_mask.gather(1, video_ids[:, 0]) + video_ids = self._scale_video_ids( video_ids, scale_factor=self.vae_spatial_compression_ratio, scale_factor_t=self.vae_temporal_compression_ratio, @@ -709,11 +735,10 @@ def prepare_latents( latents = self._pack_latents( latents, self.transformer_spatial_patch_size, self.transformer_temporal_patch_size ) - conditioning_mask = condition_latent_frames_mask.gather(1, video_ids[:, 0]) if len(extra_conditioning_latents) > 0: latents = torch.cat([*extra_conditioning_latents, latents], dim=1) - video_ids = torch.cat([*extra_conditioning_video_ids, video_ids_scaled], dim=2) + video_ids = torch.cat([*extra_conditioning_video_ids, video_ids], dim=2) conditioning_mask = torch.cat([*extra_conditioning_mask, conditioning_mask], dim=1) return latents, conditioning_mask, video_ids, extra_conditioning_num_latents @@ -742,7 +767,11 @@ def interrupt(self): @replace_example_docstring(EXAMPLE_DOC_STRING) def __call__( self, - conditions: Union[LTXVideoCondition, List[LTXVideoCondition]], + conditions: Union[LTXVideoCondition, List[LTXVideoCondition]] = None, + image: Union[PipelineImageInput, List[PipelineImageInput]] = None, + video: List[PipelineImageInput] = None, + frame_index: Union[int, List[int]] = 0, + strength: Union[float, List[float]] = 1.0, prompt: Union[str, List[str]] = None, negative_prompt: Optional[Union[str, List[str]]] = None, height: int = 512, @@ -773,8 +802,19 @@ def __call__( Function invoked when calling the pipeline for generation. Args: - conditions (`List[LTXVideoCondition]`): - The list of frame-conditioning items for the video generation. + conditions (`List[LTXVideoCondition], *optional*`): + The list of frame-conditioning items for the video generation.If not provided, conditions will be + created using `image`, `video`, `frame_index` and `strength`. + image (`PipelineImageInput` or `List[PipelineImageInput]`, *optional*): + The image or images to condition the video generation. If not provided, one has to pass `video` or + `conditions`. + video (`List[PipelineImageInput]`, *optional*): + The video to condition the video generation. If not provided, one has to pass `image` or `conditions`. + frame_index (`int` or `List[int]`, *optional*): + The frame index or frame indices at which the image or video will conditionally effect the video + generation. If not provided, one has to pass `conditions`. + strength (`float` or `List[float]`, *optional*): + The strength or strengths of the conditioning effect. If not provided, one has to pass `conditions`. prompt (`str` or `List[str]`, *optional*): The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. instead. @@ -857,6 +897,11 @@ def __call__( # 1. Check inputs. Raise error if not correct self.check_inputs( prompt=prompt, + conditions=conditions, + image=image, + video=video, + frame_index=frame_index, + strength=strength, height=height, width=width, callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, @@ -878,6 +923,31 @@ def __call__( else: batch_size = prompt_embeds.shape[0] + if conditions is not None: + if not isinstance(conditions, list): + conditions = [conditions] + + strength = [condition.strength for condition in conditions] + frame_index = [condition.frame_index for condition in conditions] + image = [condition.image for condition in conditions] + video = [condition.video for condition in conditions] + else: + if not isinstance(image, list): + image = [image] + num_conditions = 1 + elif isinstance(image, list): + num_conditions = len(image) + if not isinstance(video, list): + video = [video] + num_conditions = 1 + elif isinstance(video, list): + num_conditions = len(video) + + if not isinstance(frame_index, list): + frame_index = [frame_index] * num_conditions + if not isinstance(strength, list): + strength = [strength] * num_conditions + device = self._execution_device # 3. Prepare text embeddings @@ -905,17 +975,20 @@ def __call__( vae_dtype = self.vae.dtype conditioning_tensors = [] - conditioning_strengths = [] - conditioning_start_frames = [] - - for condition in conditions: - if condition.image is not None: - condition_tensor = self.video_processor.preprocess(condition.image, height, width).unsqueeze(2) - elif condition.video is not None: - condition_tensor = self.video_processor.preprocess_video(condition.video, height, width) + for condition_image, condition_video, condition_frame_index, condition_strength in zip( + image, video, frame_index, strength + ): + if condition_image is not None: + condition_tensor = ( + self.video_processor.preprocess(condition_image, height, width) + .unsqueeze(2) + .to(device, dtype=vae_dtype) + ) + elif condition_video is not None: + condition_tensor = self.video_processor.preprocess_video(condition_video, height, width) num_frames_input = condition_tensor.size(2) num_frames_output = self.trim_conditioning_sequence( - condition.frame_index, num_frames_input, num_frames + condition_frame_index, num_frames_input, num_frames ) condition_tensor = condition_tensor[:, :, :num_frames_output] condition_tensor = condition_tensor.to(device, dtype=vae_dtype) @@ -928,15 +1001,13 @@ def __call__( f"but got {condition_tensor.size(2)} frames." ) conditioning_tensors.append(condition_tensor) - conditioning_strengths.append(condition.strength) - conditioning_start_frames.append(condition.frame_index) # 4. Prepare latent variables num_channels_latents = self.transformer.config.in_channels latents, conditioning_mask, video_coords, extra_conditioning_num_latents = self.prepare_latents( conditioning_tensors, - conditioning_strengths, - conditioning_start_frames, + strength, + frame_index, batch_size=batch_size * num_videos_per_prompt, num_channels_latents=num_channels_latents, height=height, @@ -1015,9 +1086,10 @@ def __call__( noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) timestep, _ = timestep.chunk(2) - denoised_latents = self.scheduler.step(-noise_pred, timestep, latents, return_dict=False)[0] - t_eps = 1e-6 - tokens_to_denoise_mask = (t / 1000 - t_eps < (1.0 - conditioning_mask)).unsqueeze(-1) + denoised_latents = self.scheduler.step( + -noise_pred, t, latents, per_token_timesteps=timestep, return_dict=False + )[0] + tokens_to_denoise_mask = (t / 1000 - 1e-6 < (1.0 - conditioning_mask)).unsqueeze(-1) latents = torch.where(tokens_to_denoise_mask, denoised_latents, latents) if callback_on_step_end is not None: diff --git a/src/diffusers/schedulers/scheduling_flow_match_euler_discrete.py b/src/diffusers/schedulers/scheduling_flow_match_euler_discrete.py index b391e487188a..cbb27e5fad63 100644 --- a/src/diffusers/schedulers/scheduling_flow_match_euler_discrete.py +++ b/src/diffusers/schedulers/scheduling_flow_match_euler_discrete.py @@ -377,6 +377,7 @@ def step( s_tmax: float = float("inf"), s_noise: float = 1.0, generator: Optional[torch.Generator] = None, + per_token_timesteps: Optional[torch.Tensor] = None, return_dict: bool = True, ) -> Union[FlowMatchEulerDiscreteSchedulerOutput, Tuple]: """ @@ -397,6 +398,8 @@ def step( Scaling factor for noise added to the sample. generator (`torch.Generator`, *optional*): A random number generator. + per_token_timesteps (`torch.Tensor`, *optional*): + The timesteps for each token in the sample. return_dict (`bool`): Whether or not to return a [`~schedulers.scheduling_flow_match_euler_discrete.FlowMatchEulerDiscreteSchedulerOutput`] or tuple. @@ -421,20 +424,17 @@ def step( ), ) - use_per_token_timesteps = isinstance(timestep, torch.Tensor) and timestep.ndim == 2 if self.step_index is None: - if not use_per_token_timesteps: - self._init_step_index(timestep) + self._init_step_index(timestep) # Upcast to avoid precision issues when computing prev_sample sample = sample.to(torch.float32) - if use_per_token_timesteps: - t_eps = 1e-6 - per_token_sigmas = timestep / self.config.num_train_timesteps + if per_token_timesteps is not None: + per_token_sigmas = per_token_timesteps / self.config.num_train_timesteps sigmas = self.sigmas[:, None, None] - lower_mask = sigmas < per_token_sigmas[None] - t_eps + lower_mask = sigmas < per_token_sigmas[None] - 1e-6 lower_sigmas = lower_mask * sigmas lower_sigmas, _ = lower_sigmas.max(dim=0) dt = (per_token_sigmas - lower_sigmas)[..., None] @@ -446,8 +446,8 @@ def step( prev_sample = sample + dt * model_output # upon completion increase step index by one - if not use_per_token_timesteps: - self._step_index += 1 + self._step_index += 1 + if per_token_timesteps is None: # Cast sample back to model compatible dtype prev_sample = prev_sample.to(model_output.dtype) diff --git a/tests/pipelines/ltx/test_ltx_condition.py b/tests/pipelines/ltx/test_ltx_condition.py new file mode 100644 index 000000000000..dbb9a740b433 --- /dev/null +++ b/tests/pipelines/ltx/test_ltx_condition.py @@ -0,0 +1,284 @@ +# Copyright 2024 The HuggingFace Team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +import unittest + +import numpy as np +import torch +from transformers import AutoTokenizer, T5EncoderModel + +from diffusers import ( + AutoencoderKLLTXVideo, + FlowMatchEulerDiscreteScheduler, + LTXConditionPipeline, + LTXVideoTransformer3DModel, +) +from diffusers.pipelines.ltx.pipeline_ltx_condition import LTXVideoCondition +from diffusers.utils.testing_utils import enable_full_determinism, torch_device + +from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS +from ..test_pipelines_common import PipelineTesterMixin, to_np + + +enable_full_determinism() + + +class LTXConditionPipelineFastTests(PipelineTesterMixin, unittest.TestCase): + pipeline_class = LTXConditionPipeline + params = TEXT_TO_IMAGE_PARAMS - {"cross_attention_kwargs"} + batch_params = TEXT_TO_IMAGE_BATCH_PARAMS.union({"image"}) + image_params = TEXT_TO_IMAGE_IMAGE_PARAMS + image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS + required_optional_params = frozenset( + [ + "num_inference_steps", + "generator", + "latents", + "return_dict", + "callback_on_step_end", + "callback_on_step_end_tensor_inputs", + ] + ) + test_xformers_attention = False + + def get_dummy_components(self): + torch.manual_seed(0) + transformer = LTXVideoTransformer3DModel( + in_channels=8, + out_channels=8, + patch_size=1, + patch_size_t=1, + num_attention_heads=4, + attention_head_dim=8, + cross_attention_dim=32, + num_layers=1, + caption_channels=32, + ) + + torch.manual_seed(0) + vae = AutoencoderKLLTXVideo( + in_channels=3, + out_channels=3, + latent_channels=8, + block_out_channels=(8, 8, 8, 8), + decoder_block_out_channels=(8, 8, 8, 8), + layers_per_block=(1, 1, 1, 1, 1), + decoder_layers_per_block=(1, 1, 1, 1, 1), + spatio_temporal_scaling=(True, True, False, False), + decoder_spatio_temporal_scaling=(True, True, False, False), + decoder_inject_noise=(False, False, False, False, False), + upsample_residual=(False, False, False, False), + upsample_factor=(1, 1, 1, 1), + timestep_conditioning=False, + patch_size=1, + patch_size_t=1, + encoder_causal=True, + decoder_causal=False, + ) + vae.use_framewise_encoding = False + vae.use_framewise_decoding = False + + torch.manual_seed(0) + scheduler = FlowMatchEulerDiscreteScheduler() + text_encoder = T5EncoderModel.from_pretrained("hf-internal-testing/tiny-random-t5") + tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-t5") + + components = { + "transformer": transformer, + "vae": vae, + "scheduler": scheduler, + "text_encoder": text_encoder, + "tokenizer": tokenizer, + } + return components + + def get_dummy_inputs(self, device, seed=0, use_conditions=False): + if str(device).startswith("mps"): + generator = torch.manual_seed(seed) + else: + generator = torch.Generator(device=device).manual_seed(seed) + + image = torch.randn((1, 3, 32, 32), generator=generator, device=device) + if use_conditions: + conditions = LTXVideoCondition( + image=image, + ) + else: + conditions = None + + inputs = { + "conditions": conditions, + "image": None if use_conditions else image, + "prompt": "dance monkey", + "negative_prompt": "", + "generator": generator, + "num_inference_steps": 2, + "guidance_scale": 3.0, + "height": 32, + "width": 32, + # 8 * k + 1 is the recommendation + "num_frames": 9, + "max_sequence_length": 16, + "output_type": "pt", + } + + return inputs + + def test_inference(self): + device = "cpu" + + components = self.get_dummy_components() + pipe = self.pipeline_class(**components) + pipe.to(device) + pipe.set_progress_bar_config(disable=None) + + inputs = self.get_dummy_inputs(device) + inputs2 = self.get_dummy_inputs(device, use_conditions=True) + video = pipe(**inputs).frames + generated_video = video[0] + video2 = pipe(**inputs2).frames + generated_video2 = video2[0] + + self.assertEqual(generated_video.shape, (9, 3, 32, 32)) + + max_diff = np.abs(generated_video - generated_video2).max() + self.assertLessEqual(max_diff, 1e-3) + + def test_callback_inputs(self): + sig = inspect.signature(self.pipeline_class.__call__) + has_callback_tensor_inputs = "callback_on_step_end_tensor_inputs" in sig.parameters + has_callback_step_end = "callback_on_step_end" in sig.parameters + + if not (has_callback_tensor_inputs and has_callback_step_end): + return + + components = self.get_dummy_components() + pipe = self.pipeline_class(**components) + pipe = pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + self.assertTrue( + hasattr(pipe, "_callback_tensor_inputs"), + f" {self.pipeline_class} should have `_callback_tensor_inputs` that defines a list of tensor variables its callback function can use as inputs", + ) + + def callback_inputs_subset(pipe, i, t, callback_kwargs): + # iterate over callback args + for tensor_name, tensor_value in callback_kwargs.items(): + # check that we're only passing in allowed tensor inputs + assert tensor_name in pipe._callback_tensor_inputs + + return callback_kwargs + + def callback_inputs_all(pipe, i, t, callback_kwargs): + for tensor_name in pipe._callback_tensor_inputs: + assert tensor_name in callback_kwargs + + # iterate over callback args + for tensor_name, tensor_value in callback_kwargs.items(): + # check that we're only passing in allowed tensor inputs + assert tensor_name in pipe._callback_tensor_inputs + + return callback_kwargs + + inputs = self.get_dummy_inputs(torch_device) + + # Test passing in a subset + inputs["callback_on_step_end"] = callback_inputs_subset + inputs["callback_on_step_end_tensor_inputs"] = ["latents"] + output = pipe(**inputs)[0] + + # Test passing in a everything + inputs["callback_on_step_end"] = callback_inputs_all + inputs["callback_on_step_end_tensor_inputs"] = pipe._callback_tensor_inputs + output = pipe(**inputs)[0] + + def callback_inputs_change_tensor(pipe, i, t, callback_kwargs): + is_last = i == (pipe.num_timesteps - 1) + if is_last: + callback_kwargs["latents"] = torch.zeros_like(callback_kwargs["latents"]) + return callback_kwargs + + inputs["callback_on_step_end"] = callback_inputs_change_tensor + inputs["callback_on_step_end_tensor_inputs"] = pipe._callback_tensor_inputs + output = pipe(**inputs)[0] + assert output.abs().sum() < 1e10 + + def test_inference_batch_single_identical(self): + self._test_inference_batch_single_identical(batch_size=3, expected_max_diff=1e-3) + + def test_attention_slicing_forward_pass( + self, test_max_difference=True, test_mean_pixel_difference=True, expected_max_diff=1e-3 + ): + if not self.test_attention_slicing: + return + + components = self.get_dummy_components() + pipe = self.pipeline_class(**components) + for component in pipe.components.values(): + if hasattr(component, "set_default_attn_processor"): + component.set_default_attn_processor() + pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + + generator_device = "cpu" + inputs = self.get_dummy_inputs(generator_device) + output_without_slicing = pipe(**inputs)[0] + + pipe.enable_attention_slicing(slice_size=1) + inputs = self.get_dummy_inputs(generator_device) + output_with_slicing1 = pipe(**inputs)[0] + + pipe.enable_attention_slicing(slice_size=2) + inputs = self.get_dummy_inputs(generator_device) + output_with_slicing2 = pipe(**inputs)[0] + + if test_max_difference: + max_diff1 = np.abs(to_np(output_with_slicing1) - to_np(output_without_slicing)).max() + max_diff2 = np.abs(to_np(output_with_slicing2) - to_np(output_without_slicing)).max() + self.assertLess( + max(max_diff1, max_diff2), + expected_max_diff, + "Attention slicing should not affect the inference results", + ) + + def test_vae_tiling(self, expected_diff_max: float = 0.2): + generator_device = "cpu" + components = self.get_dummy_components() + + pipe = self.pipeline_class(**components) + pipe.to("cpu") + pipe.set_progress_bar_config(disable=None) + + # Without tiling + inputs = self.get_dummy_inputs(generator_device) + inputs["height"] = inputs["width"] = 128 + output_without_tiling = pipe(**inputs)[0] + + # With tiling + pipe.vae.enable_tiling( + tile_sample_min_height=96, + tile_sample_min_width=96, + tile_sample_stride_height=64, + tile_sample_stride_width=64, + ) + inputs = self.get_dummy_inputs(generator_device) + inputs["height"] = inputs["width"] = 128 + output_with_tiling = pipe(**inputs)[0] + + self.assertLess( + (to_np(output_without_tiling) - to_np(output_with_tiling)).max(), + expected_diff_max, + "VAE tiling should not affect the inference results", + ) From b98d69cd47edb6283d67d238165ee583bea3b450 Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Fri, 14 Mar 2025 10:04:01 +0100 Subject: [PATCH 17/17] up --- .../pipelines/ltx/pipeline_ltx_condition.py | 1 - .../utils/dummy_torch_and_transformers_objects.py | 15 +++++++++++++++ 2 files changed, 15 insertions(+), 1 deletion(-) diff --git a/src/diffusers/pipelines/ltx/pipeline_ltx_condition.py b/src/diffusers/pipelines/ltx/pipeline_ltx_condition.py index ff504dc7582b..515950de507e 100644 --- a/src/diffusers/pipelines/ltx/pipeline_ltx_condition.py +++ b/src/diffusers/pipelines/ltx/pipeline_ltx_condition.py @@ -401,7 +401,6 @@ def encode_prompt( return prompt_embeds, prompt_attention_mask, negative_prompt_embeds, negative_prompt_attention_mask - # Copied from diffusers.pipelines.ltx.pipeline_ltx.LTXPipeline.check_inputs def check_inputs( self, prompt, diff --git a/src/diffusers/utils/dummy_torch_and_transformers_objects.py b/src/diffusers/utils/dummy_torch_and_transformers_objects.py index 5a2818c2e245..9159c30eefb1 100644 --- a/src/diffusers/utils/dummy_torch_and_transformers_objects.py +++ b/src/diffusers/utils/dummy_torch_and_transformers_objects.py @@ -1187,6 +1187,21 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) +class LTXConditionPipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + class LTXImageToVideoPipeline(metaclass=DummyObject): _backends = ["torch", "transformers"]