Skip to content

Remove conversion to RGB #6479

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 11 commits into from
Jan 12, 2024
24 changes: 17 additions & 7 deletions src/diffusers/utils/loading_utils.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,24 @@
import os
from typing import Union
from typing import Callable, Union

import PIL.Image
import PIL.ImageOps
import requests


def load_image(image: Union[str, PIL.Image.Image]) -> PIL.Image.Image:
def load_image(
image: Union[str, PIL.Image.Image], convert_method: Callable[[PIL.Image.Image], PIL.Image.Image] = None
) -> PIL.Image.Image:
"""
Loads `image` to a PIL Image.

Args:
image (`str` or `PIL.Image.Image`):
The image to convert to the PIL Image format.
convert_method (Callable[[PIL.Image.Image], PIL.Image.Image], optional):
A conversion method to apply to the image after loading it.
When set to `None` the image will be converted "RGB".

Returns:
`PIL.Image.Image`:
A PIL Image.
Expand All @@ -24,14 +30,18 @@ def load_image(image: Union[str, PIL.Image.Image]) -> PIL.Image.Image:
image = PIL.Image.open(image)
else:
raise ValueError(
f"Incorrect path or url, URLs must start with `http://` or `https://`, and {image} is not a valid path"
f"Incorrect path or URL. URLs must start with `http://` or `https://`, and {image} is not a valid path."
)
elif isinstance(image, PIL.Image.Image):
image = image
else:
raise ValueError(
"Incorrect format used for image. Should be an url linking to an image, a local path, or a PIL image."
"Incorrect format used for the image. Should be a URL linking to an image, a local path, or a PIL image."
)

image = PIL.ImageOps.exif_transpose(image)
image = image.convert("RGB")

if convert_method is not None:
image = convert_method(image)
else:
image = image.convert("RGB")

return image