diff --git a/src/diffusers/pipelines/ltx/pipeline_ltx_image2video.py b/src/diffusers/pipelines/ltx/pipeline_ltx_image2video.py index 6c4214fe1b26..0f640dc33546 100644 --- a/src/diffusers/pipelines/ltx/pipeline_ltx_image2video.py +++ b/src/diffusers/pipelines/ltx/pipeline_ltx_image2video.py @@ -487,19 +487,21 @@ def prepare_latents( ) -> torch.Tensor: height = height // self.vae_spatial_compression_ratio width = width // self.vae_spatial_compression_ratio - num_frames = ( - (num_frames - 1) // self.vae_temporal_compression_ratio + 1 if latents is None else latents.size(2) - ) + num_frames = (num_frames - 1) // self.vae_temporal_compression_ratio + 1 shape = (batch_size, num_channels_latents, num_frames, height, width) mask_shape = (batch_size, 1, num_frames, height, width) if latents is not None: - conditioning_mask = latents.new_zeros(shape) + conditioning_mask = latents.new_zeros(mask_shape) conditioning_mask[:, :, 0] = 1.0 conditioning_mask = self._pack_latents( conditioning_mask, self.transformer_spatial_patch_size, self.transformer_temporal_patch_size - ) + ).squeeze(-1) + if latents.ndim != 3 or latents.shape[:2] != conditioning_mask.shape: + raise ValueError( + f"Provided `latents` tensor has shape {latents.shape}, but the expected shape is {conditioning_mask.shape + (num_channels_latents,)}." + ) return latents.to(device=device, dtype=dtype), conditioning_mask if isinstance(generator, list):