From 7d283727865e4cf842b301643032a70d4aefbb37 Mon Sep 17 00:00:00 2001 From: Benedikt Lorch Date: Sun, 18 Feb 2024 19:06:09 +0100 Subject: [PATCH 1/2] Convert channel order to BGR for the watermark encoder. Convert the watermarked BGR images back to RGB. Fixes #6292 --- src/diffusers/pipelines/stable_diffusion_xl/watermark.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/src/diffusers/pipelines/stable_diffusion_xl/watermark.py b/src/diffusers/pipelines/stable_diffusion_xl/watermark.py index 5b6e36d9f447..a3111defe445 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/watermark.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/watermark.py @@ -28,9 +28,15 @@ def apply_watermark(self, images: torch.FloatTensor): images = (255 * (images / 2 + 0.5)).cpu().permute(0, 2, 3, 1).float().numpy() + # Convert RGB to BGR, which is the channel order expected by the watermark encoder. + images = images[:, :, :, ::-1] + images = [self.encoder.encode(image, "dwtDct") for image in images] - images = torch.from_numpy(np.array(images)).permute(0, 3, 1, 2) + # Convert BGR back to RGB + images = np.array(images)[:, :, :, ::-1] + + images = torch.from_numpy(images).permute(0, 3, 1, 2) images = torch.clamp(2 * (images / 255 - 0.5), min=-1.0, max=1.0) return images From f8117865df821b09d8e585c271db05e5de47aa77 Mon Sep 17 00:00:00 2001 From: Benedikt Lorch Date: Tue, 27 Feb 2024 22:21:11 +0100 Subject: [PATCH 2/2] Revert channel order before stacking images to overcome limitations that negative strides are currently not supported --- src/diffusers/pipelines/stable_diffusion_xl/watermark.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/diffusers/pipelines/stable_diffusion_xl/watermark.py b/src/diffusers/pipelines/stable_diffusion_xl/watermark.py index a3111defe445..f457cdbdb1eb 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/watermark.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/watermark.py @@ -31,10 +31,10 @@ def apply_watermark(self, images: torch.FloatTensor): # Convert RGB to BGR, which is the channel order expected by the watermark encoder. images = images[:, :, :, ::-1] - images = [self.encoder.encode(image, "dwtDct") for image in images] + # Add watermark and convert BGR back to RGB + images = [self.encoder.encode(image, "dwtDct")[:, :, ::-1] for image in images] - # Convert BGR back to RGB - images = np.array(images)[:, :, :, ::-1] + images = np.array(images) images = torch.from_numpy(images).permute(0, 3, 1, 2)