diff --git a/examples/community/lpw_stable_diffusion_xl.py b/examples/community/lpw_stable_diffusion_xl.py index 4bcef10f97c2..4d9683b73fc4 100644 --- a/examples/community/lpw_stable_diffusion_xl.py +++ b/examples/community/lpw_stable_diffusion_xl.py @@ -1773,7 +1773,7 @@ def denoising_value_valid(dnv): f"Incorrect configuration settings! The config of `pipeline.unet`: {self.unet.config} expects" f" {self.unet.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +" f" `num_channels_mask`: {num_channels_mask} + `num_channels_masked_image`: {num_channels_masked_image}" - f" = {num_channels_latents+num_channels_masked_image+num_channels_mask}. Please verify the config of" + f" = {num_channels_latents + num_channels_masked_image + num_channels_mask}. Please verify the config of" " `pipeline.unet` or your `mask_image` or `image` input." ) elif num_channels_unet != 4: @@ -1924,7 +1924,22 @@ def denoising_value_valid(dnv): self.upcast_vae() latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype) - image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] + # unscale/denormalize the latents + # denormalize with the mean and std if available and not None + has_latents_mean = hasattr(self.vae.config, "latents_mean") and self.vae.config.latents_mean is not None + has_latents_std = hasattr(self.vae.config, "latents_std") and self.vae.config.latents_std is not None + if has_latents_mean and has_latents_std: + latents_mean = ( + torch.tensor(self.vae.config.latents_mean).view(1, 4, 1, 1).to(latents.device, latents.dtype) + ) + latents_std = ( + torch.tensor(self.vae.config.latents_std).view(1, 4, 1, 1).to(latents.device, latents.dtype) + ) + latents = latents * latents_std / self.vae.config.scaling_factor + latents_mean + else: + latents = latents / self.vae.config.scaling_factor + + image = self.vae.decode(latents, return_dict=False)[0] # cast back to fp16 if needed if needs_upcasting: