Skip to content

Convert RGB to BGR for the SDXL watermark encoder #7013

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Apr 26, 2024

Conversation

btlorch
Copy link
Contributor

@btlorch btlorch commented Feb 18, 2024

What does this PR do?

The watermark encoder used by SDXL expects input images in BGR format. Hence, we reorder to channels from RGB to BGR before the encoding. The watermarked BGR images are converted back to RGB.

Fixes #6292.

Who can review?

@yiyixuxu @sayakpaul @patrickvonplaten

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@yiyixuxu
Copy link
Collaborator

thanks!
can you have a script and results that show a difference?

@btlorch
Copy link
Contributor Author

btlorch commented Feb 19, 2024

Here is a working example. As already mentioned in the issue, the modification does not eliminate the watermarking artifacts, but makes them appear less pronounced.

from imwatermark import WatermarkEncoder
import numpy as np
import torch


# Copied from https://github.com/Stability-AI/generative-models/blob/613af104c6b85184091d42d374fef420eddb356d/scripts/demo/streamlit_helpers.py#L66
WATERMARK_MESSAGE = 0b101100111110110010010000011110111011000110011110
# bin(x)[2:] gives bits of x as str, use int to convert them to 0/1
WATERMARK_BITS = [int(bit) for bit in bin(WATERMARK_MESSAGE)[2:]]


class StableDiffusionXLWatermarker:
    def __init__(self):
        self.watermark = WATERMARK_BITS
        self.encoder = WatermarkEncoder()

        self.encoder.set_watermark("bits", self.watermark)

    def apply_watermark(self, images: torch.FloatTensor, convert_channels: bool):
        """
        :param images: tensor of shape [num_images, num_channels, height, width]
        :param convert_channels: if True, convert RGB to BGR before the watermark encoding
        :return: watermarked images as tensor
        """
        # can't encode images that are smaller than 256
        if images.shape[-1] < 256:
            return images

        images = (255 * (images / 2 + 0.5)).cpu().permute(0, 2, 3, 1).float().numpy()

        # Convert RGB to BGR, which is the channel order expected by the watermark encoder.
        if convert_channels:
            images = images[:, :, :, ::-1]

        images = [self.encoder.encode(image, "dwtDct") for image in images]
        images = np.array(images)

        # Convert BGR back to RGB
        if convert_channels:
            # Copy array because tensors with negative strides are currently not supported
            images = images[:, :, :, ::-1].copy()

        images = torch.from_numpy(images).permute(0, 3, 1, 2)

        images = torch.clamp(2 * (images / 255 - 0.5), min=-1.0, max=1.0)
        return images


def pil_to_tensor(img):
    img = np.array(img)

    # Convert from [0, 255] to range [-1, 1]
    img = (img / 255.) * 2 - 1.

    # Convert to tensor, move channels to the front, prepend singleton batch dimension
    img_batch = torch.from_numpy(img).permute(2, 0, 1).unsqueeze(dim=0)

    return img_batch


def tensor_to_pil(img_batch):
    """
    :param img_batch: tensor of shape [batch_size, num_channels, height width]
    :return: PIL image
    """
    batch_size = img_batch.shape[0]
    assert batch_size == 1, "Expected only a single image"

    img_batch = (255 * (img_batch / 2 + 0.5)).permute(0, 2, 3, 1).numpy().round().astype(np.uint8)

    return Image.fromarray(img_batch[0])


if __name__ == "__main__":
    import matplotlib.pyplot as plt
    from PIL import Image

    watermark_encoder = StableDiffusionXLWatermarker()

    original_img = np.array(Image.open("cow.jpg"))

    original_img_batch = pil_to_tensor(original_img)

    # Apply watermark to image in RGB order
    watermarked_wrong_channel_order_batch = watermark_encoder.apply_watermark(original_img_batch, convert_channels=False)
    watermarked_wrong_channel_order_img = tensor_to_pil(watermarked_wrong_channel_order_batch)

    # Apply watermark to image in BGR order
    watermarked_correct_channel_order_batch = watermark_encoder.apply_watermark(original_img_batch, convert_channels=True)
    watermarked_correct_channel_order_img = tensor_to_pil(watermarked_correct_channel_order_batch)

    fig, axes = plt.subplots(1, 3, figsize=(20, 8))

    axes[0].imshow(original_img)
    axes[0].set_title("Original")

    axes[1].imshow(watermarked_wrong_channel_order_img)
    axes[1].set_title("Wrong channel order")

    axes[2].imshow(watermarked_correct_channel_order_img)
    axes[2].set_title("Correct channel order")

    fig.tight_layout()
    plt.show()

invisible_watermark

Zoom in to see the watermarking artifacts.

@yiyixuxu
Copy link
Collaborator

Can we see an example using the sdxl pipeline with watermark and show the output from main and thisbranch?

…hat negative strides are currently not supported
@btlorch
Copy link
Contributor Author

btlorch commented Feb 27, 2024

You can make experiments by adapting the code from the issue and fixing the random seed. For your convenience, I copy the snippet here with some additional seeding:

from diffusers import AutoPipelineForText2Image
import torch


device = "cuda" if torch.cuda.is_available() else "cpu"
pipeline_text2image = AutoPipelineForText2Image.from_pretrained(
    "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16, variant="fp16", use_safetensors=True
).to(device)

generator = torch.Generator(device=device)
generator = generator.manual_seed(6020)

image_latents = torch.randn(
    (1, pipeline_text2image.unet.config.in_channels, 1024 // 8, 1024 // 8),
    generator=generator,
    device=device,
    dtype=torch.float16,
)

prompt = "A highland cow in the Scottish highlands"
image = pipeline_text2image(prompt=prompt, latents=image_latents).images[0]

@yiyixuxu
Copy link
Collaborator

I'm just hoping to see some results side by side before we can merge this

we don't need to run extensive experiments for this PR; maybe just one or two examples where we can see a slight improvement. :)

Copy link
Contributor

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

@github-actions github-actions bot added the stale Issues that haven't received updates label Mar 23, 2024
@sayakpaul
Copy link
Member

@btlorch hi there!

Could you please address #7013 (comment) for us so that we can ship this?

@github-actions github-actions bot removed the stale Issues that haven't received updates label Mar 24, 2024
Copy link
Contributor

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

@github-actions github-actions bot added the stale Issues that haven't received updates label Apr 17, 2024
@yiyixuxu yiyixuxu removed the stale Issues that haven't received updates label Apr 19, 2024
@bghira
Copy link
Contributor

bghira commented Apr 25, 2024

image
this actually makes the artifacts worse. there is noticeable colour aberration on the fixed sample. do you observe this too?

@bghira
Copy link
Contributor

bghira commented Apr 25, 2024

before, we have red pixels scattered in the samples.

now there are green and blue pixels scattered

image

@bghira
Copy link
Contributor

bghira commented Apr 25, 2024

the earlier example actually totally has its hue changed in the output

image

@yiyixuxu
Copy link
Collaborator

@bghira
I can't really tell 😬 but I trust your expertise on this
should we close this then?

@bghira
Copy link
Contributor

bghira commented Apr 26, 2024

i'm certain this is more correct but the outcome is still miserable, i feel we should probably merge it to at least match the expected results of the upstream library.

it's, to me, much worse results though. the colours are thrown far off in the green and blue range this way :(

@yiyixuxu yiyixuxu merged commit ebc99a7 into huggingface:main Apr 26, 2024
sayakpaul added a commit that referenced this pull request May 9, 2024
* debugging

* save the resulting image

* check if order reversing works.

* checking values.

* up

* okay

* checking

* fix

* remove print
sayakpaul added a commit that referenced this pull request Dec 23, 2024
* Convert channel order to BGR for the watermark encoder. Convert the watermarked BGR images back to RGB. Fixes #6292

* Revert channel order before stacking images to overcome limitations that negative strides are currently not supported

---------

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
sayakpaul added a commit that referenced this pull request Dec 23, 2024
* debugging

* save the resulting image

* check if order reversing works.

* checking values.

* up

* okay

* checking

* fix

* remove print
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Color channel order for watermark embedding
5 participants