diff --git a/src/diffusers/models/autoencoders/autoencoder_kl_wan.py b/src/diffusers/models/autoencoders/autoencoder_kl_wan.py index 513afa3dfaee..b8d6ed6bce05 100644 --- a/src/diffusers/models/autoencoders/autoencoder_kl_wan.py +++ b/src/diffusers/models/autoencoders/autoencoder_kl_wan.py @@ -715,11 +715,6 @@ def __init__( ) -> None: super().__init__() - # Store normalization parameters as tensors - self.mean = torch.tensor(latents_mean) - self.std = torch.tensor(latents_std) - self.scale = torch.stack([self.mean, 1.0 / self.std]) # Shape: [2, C] - self.z_dim = z_dim self.temperal_downsample = temperal_downsample self.temperal_upsample = temperal_downsample[::-1] @@ -751,7 +746,6 @@ def _count_conv3d(model): self._enc_feat_map = [None] * self._enc_conv_num def _encode(self, x: torch.Tensor) -> torch.Tensor: - scale = self.scale.type_as(x) self.clear_cache() ## cache t = x.shape[2] @@ -770,8 +764,6 @@ def _encode(self, x: torch.Tensor) -> torch.Tensor: enc = self.quant_conv(out) mu, logvar = enc[:, : self.z_dim, :, :, :], enc[:, self.z_dim :, :, :, :] - mu = (mu - scale[0].view(1, self.z_dim, 1, 1, 1)) * scale[1].view(1, self.z_dim, 1, 1, 1) - logvar = (logvar - scale[0].view(1, self.z_dim, 1, 1, 1)) * scale[1].view(1, self.z_dim, 1, 1, 1) enc = torch.cat([mu, logvar], dim=1) self.clear_cache() return enc @@ -798,10 +790,8 @@ def encode( return (posterior,) return AutoencoderKLOutput(latent_dist=posterior) - def _decode(self, z: torch.Tensor, scale, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]: + def _decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]: self.clear_cache() - # z: [b,c,t,h,w] - z = z / scale[1].view(1, self.z_dim, 1, 1, 1) + scale[0].view(1, self.z_dim, 1, 1, 1) iter_ = z.shape[2] x = self.post_quant_conv(z) @@ -835,8 +825,7 @@ def decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutp If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is returned. """ - scale = self.scale.type_as(z) - decoded = self._decode(z, scale).sample + decoded = self._decode(z).sample if not return_dict: return (decoded,) diff --git a/src/diffusers/pipelines/wan/pipeline_wan.py b/src/diffusers/pipelines/wan/pipeline_wan.py index b1ac912969aa..6fab997e6660 100644 --- a/src/diffusers/pipelines/wan/pipeline_wan.py +++ b/src/diffusers/pipelines/wan/pipeline_wan.py @@ -563,6 +563,15 @@ def __call__( if not output_type == "latent": latents = latents.to(self.vae.dtype) + latents_mean = ( + torch.tensor(self.vae.config.latents_mean) + .view(1, self.vae.config.z_dim, 1, 1, 1) + .to(latents.device, latents.dtype) + ) + latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to( + latents.device, latents.dtype + ) + latents = latents / latents_std + latents_mean video = self.vae.decode(latents, return_dict=False)[0] video = self.video_processor.postprocess_video(video, output_type=output_type) else: diff --git a/src/diffusers/pipelines/wan/pipeline_wan_i2v.py b/src/diffusers/pipelines/wan/pipeline_wan_i2v.py index 24eb5586c34b..863178e7c434 100644 --- a/src/diffusers/pipelines/wan/pipeline_wan_i2v.py +++ b/src/diffusers/pipelines/wan/pipeline_wan_i2v.py @@ -392,6 +392,17 @@ def prepare_latents( latent_condition = retrieve_latents(self.vae.encode(video_condition), generator) latent_condition = latent_condition.repeat(batch_size, 1, 1, 1, 1) + latents_mean = ( + torch.tensor(self.vae.config.latents_mean) + .view(1, self.vae.config.z_dim, 1, 1, 1) + .to(latents.device, latents.dtype) + ) + latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to( + latents.device, latents.dtype + ) + + latent_condition = (latent_condition - latents_mean) * latents_std + mask_lat_size = torch.ones(batch_size, 1, num_frames, latent_height, latent_width) mask_lat_size[:, :, list(range(1, num_frames))] = 0 first_frame_mask = mask_lat_size[:, :, 0:1] @@ -654,6 +665,15 @@ def __call__( if not output_type == "latent": latents = latents.to(self.vae.dtype) + latents_mean = ( + torch.tensor(self.vae.config.latents_mean) + .view(1, self.vae.config.z_dim, 1, 1, 1) + .to(latents.device, latents.dtype) + ) + latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to( + latents.device, latents.dtype + ) + latents = latents / latents_std + latents_mean video = self.vae.decode(latents, return_dict=False)[0] video = self.video_processor.postprocess_video(video, output_type=output_type) else: