-
Notifications
You must be signed in to change notification settings - Fork 6k
Convert RGB to BGR for the SDXL watermark encoder #7013
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
…atermarked BGR images back to RGB. Fixes huggingface#6292
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
thanks! |
Here is a working example. As already mentioned in the issue, the modification does not eliminate the watermarking artifacts, but makes them appear less pronounced. from imwatermark import WatermarkEncoder
import numpy as np
import torch
# Copied from https://github.com/Stability-AI/generative-models/blob/613af104c6b85184091d42d374fef420eddb356d/scripts/demo/streamlit_helpers.py#L66
WATERMARK_MESSAGE = 0b101100111110110010010000011110111011000110011110
# bin(x)[2:] gives bits of x as str, use int to convert them to 0/1
WATERMARK_BITS = [int(bit) for bit in bin(WATERMARK_MESSAGE)[2:]]
class StableDiffusionXLWatermarker:
def __init__(self):
self.watermark = WATERMARK_BITS
self.encoder = WatermarkEncoder()
self.encoder.set_watermark("bits", self.watermark)
def apply_watermark(self, images: torch.FloatTensor, convert_channels: bool):
"""
:param images: tensor of shape [num_images, num_channels, height, width]
:param convert_channels: if True, convert RGB to BGR before the watermark encoding
:return: watermarked images as tensor
"""
# can't encode images that are smaller than 256
if images.shape[-1] < 256:
return images
images = (255 * (images / 2 + 0.5)).cpu().permute(0, 2, 3, 1).float().numpy()
# Convert RGB to BGR, which is the channel order expected by the watermark encoder.
if convert_channels:
images = images[:, :, :, ::-1]
images = [self.encoder.encode(image, "dwtDct") for image in images]
images = np.array(images)
# Convert BGR back to RGB
if convert_channels:
# Copy array because tensors with negative strides are currently not supported
images = images[:, :, :, ::-1].copy()
images = torch.from_numpy(images).permute(0, 3, 1, 2)
images = torch.clamp(2 * (images / 255 - 0.5), min=-1.0, max=1.0)
return images
def pil_to_tensor(img):
img = np.array(img)
# Convert from [0, 255] to range [-1, 1]
img = (img / 255.) * 2 - 1.
# Convert to tensor, move channels to the front, prepend singleton batch dimension
img_batch = torch.from_numpy(img).permute(2, 0, 1).unsqueeze(dim=0)
return img_batch
def tensor_to_pil(img_batch):
"""
:param img_batch: tensor of shape [batch_size, num_channels, height width]
:return: PIL image
"""
batch_size = img_batch.shape[0]
assert batch_size == 1, "Expected only a single image"
img_batch = (255 * (img_batch / 2 + 0.5)).permute(0, 2, 3, 1).numpy().round().astype(np.uint8)
return Image.fromarray(img_batch[0])
if __name__ == "__main__":
import matplotlib.pyplot as plt
from PIL import Image
watermark_encoder = StableDiffusionXLWatermarker()
original_img = np.array(Image.open("cow.jpg"))
original_img_batch = pil_to_tensor(original_img)
# Apply watermark to image in RGB order
watermarked_wrong_channel_order_batch = watermark_encoder.apply_watermark(original_img_batch, convert_channels=False)
watermarked_wrong_channel_order_img = tensor_to_pil(watermarked_wrong_channel_order_batch)
# Apply watermark to image in BGR order
watermarked_correct_channel_order_batch = watermark_encoder.apply_watermark(original_img_batch, convert_channels=True)
watermarked_correct_channel_order_img = tensor_to_pil(watermarked_correct_channel_order_batch)
fig, axes = plt.subplots(1, 3, figsize=(20, 8))
axes[0].imshow(original_img)
axes[0].set_title("Original")
axes[1].imshow(watermarked_wrong_channel_order_img)
axes[1].set_title("Wrong channel order")
axes[2].imshow(watermarked_correct_channel_order_img)
axes[2].set_title("Correct channel order")
fig.tight_layout()
plt.show() Zoom in to see the watermarking artifacts. |
Can we see an example using the sdxl pipeline with watermark and show the output from |
…hat negative strides are currently not supported
You can make experiments by adapting the code from the issue and fixing the random seed. For your convenience, I copy the snippet here with some additional seeding: from diffusers import AutoPipelineForText2Image
import torch
device = "cuda" if torch.cuda.is_available() else "cpu"
pipeline_text2image = AutoPipelineForText2Image.from_pretrained(
"stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16, variant="fp16", use_safetensors=True
).to(device)
generator = torch.Generator(device=device)
generator = generator.manual_seed(6020)
image_latents = torch.randn(
(1, pipeline_text2image.unet.config.in_channels, 1024 // 8, 1024 // 8),
generator=generator,
device=device,
dtype=torch.float16,
)
prompt = "A highland cow in the Scottish highlands"
image = pipeline_text2image(prompt=prompt, latents=image_latents).images[0] |
I'm just hoping to see some results side by side before we can merge this we don't need to run extensive experiments for this PR; maybe just one or two examples where we can see a slight improvement. :) |
This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread. Please note that issues that do not follow the contributing guidelines are likely to be ignored. |
@btlorch hi there! Could you please address #7013 (comment) for us so that we can ship this? |
This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread. Please note that issues that do not follow the contributing guidelines are likely to be ignored. |
@bghira |
i'm certain this is more correct but the outcome is still miserable, i feel we should probably merge it to at least match the expected results of the upstream library. it's, to me, much worse results though. the colours are thrown far off in the green and blue range this way :( |
* Convert channel order to BGR for the watermark encoder. Convert the watermarked BGR images back to RGB. Fixes #6292 * Revert channel order before stacking images to overcome limitations that negative strides are currently not supported --------- Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
What does this PR do?
The watermark encoder used by SDXL expects input images in BGR format. Hence, we reorder to channels from RGB to BGR before the encoding. The watermarked BGR images are converted back to RGB.
Fixes #6292.
Who can review?
@yiyixuxu @sayakpaul @patrickvonplaten