diff --git a/src/diffusers/utils/loading_utils.py b/src/diffusers/utils/loading_utils.py index 279aa6fe737b..e129d5f3e366 100644 --- a/src/diffusers/utils/loading_utils.py +++ b/src/diffusers/utils/loading_utils.py @@ -1,18 +1,24 @@ import os -from typing import Union +from typing import Callable, Union import PIL.Image import PIL.ImageOps import requests -def load_image(image: Union[str, PIL.Image.Image]) -> PIL.Image.Image: +def load_image( + image: Union[str, PIL.Image.Image], convert_method: Callable[[PIL.Image.Image], PIL.Image.Image] = None +) -> PIL.Image.Image: """ Loads `image` to a PIL Image. Args: image (`str` or `PIL.Image.Image`): The image to convert to the PIL Image format. + convert_method (Callable[[PIL.Image.Image], PIL.Image.Image], optional): + A conversion method to apply to the image after loading it. + When set to `None` the image will be converted "RGB". + Returns: `PIL.Image.Image`: A PIL Image. @@ -24,14 +30,18 @@ def load_image(image: Union[str, PIL.Image.Image]) -> PIL.Image.Image: image = PIL.Image.open(image) else: raise ValueError( - f"Incorrect path or url, URLs must start with `http://` or `https://`, and {image} is not a valid path" + f"Incorrect path or URL. URLs must start with `http://` or `https://`, and {image} is not a valid path." ) - elif isinstance(image, PIL.Image.Image): - image = image else: raise ValueError( - "Incorrect format used for image. Should be an url linking to an image, a local path, or a PIL image." + "Incorrect format used for the image. Should be a URL linking to an image, a local path, or a PIL image." ) + image = PIL.ImageOps.exif_transpose(image) - image = image.convert("RGB") + + if convert_method is not None: + image = convert_method(image) + else: + image = image.convert("RGB") + return image