diff --git a/examples/controlnet/train_controlnet_sdxl.py b/examples/controlnet/train_controlnet_sdxl.py index 8b87db958d58..3368db1ec096 100644 --- a/examples/controlnet/train_controlnet_sdxl.py +++ b/examples/controlnet/train_controlnet_sdxl.py @@ -134,7 +134,25 @@ def log_validation(vae, unet, controlnet, args, accelerator, weight_dtype, step, for validation_prompt, validation_image in zip(validation_prompts, validation_images): validation_image = Image.open(validation_image).convert("RGB") - validation_image = validation_image.resize((args.resolution, args.resolution)) + + try: + interpolation = getattr(transforms.InterpolationMode, args.image_interpolation_mode.upper()) + except (AttributeError, KeyError): + supported_interpolation_modes = [ + f.lower() for f in dir(transforms.InterpolationMode) if not f.startswith("__") and not f.endswith("__") + ] + raise ValueError( + f"Interpolation mode {args.image_interpolation_mode} is not supported. " + f"Please select one of the following: {', '.join(supported_interpolation_modes)}" + ) + + transform = transforms.Compose( + [ + transforms.Resize(args.resolution, interpolation=interpolation), + transforms.CenterCrop(args.resolution), + ] + ) + validation_image = transform(validation_image) images = [] @@ -587,6 +605,15 @@ def parse_args(input_args=None): " more information see https://huggingface.co/docs/accelerate/v0.17.0/en/package_reference/accelerator#accelerate.Accelerator" ), ) + parser.add_argument( + "--image_interpolation_mode", + type=str, + default="lanczos", + choices=[ + f.lower() for f in dir(transforms.InterpolationMode) if not f.startswith("__") and not f.endswith("__") + ], + help="The image interpolation method to use for resizing images.", + ) if input_args is not None: args = parser.parse_args(input_args) @@ -732,9 +759,20 @@ def encode_prompt(prompt_batch, text_encoders, tokenizers, proportion_empty_prom def prepare_train_dataset(dataset, accelerator): + try: + interpolation_mode = getattr(transforms.InterpolationMode, args.image_interpolation_mode.upper()) + except (AttributeError, KeyError): + supported_interpolation_modes = [ + f.lower() for f in dir(transforms.InterpolationMode) if not f.startswith("__") and not f.endswith("__") + ] + raise ValueError( + f"Interpolation mode {args.image_interpolation_mode} is not supported. " + f"Please select one of the following: {', '.join(supported_interpolation_modes)}" + ) + image_transforms = transforms.Compose( [ - transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR), + transforms.Resize(args.resolution, interpolation=interpolation_mode), transforms.CenterCrop(args.resolution), transforms.ToTensor(), transforms.Normalize([0.5], [0.5]), @@ -743,7 +781,7 @@ def prepare_train_dataset(dataset, accelerator): conditioning_image_transforms = transforms.Compose( [ - transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR), + transforms.Resize(args.resolution, interpolation=interpolation_mode), transforms.CenterCrop(args.resolution), transforms.ToTensor(), ]