From daa4fda22de6db63a168a3552a6ca4c3bee13cdd Mon Sep 17 00:00:00 2001 From: Vaibhav Date: Wed, 30 Apr 2025 10:48:45 +0530 Subject: [PATCH 1/4] Add LANCZOS as default interplotation mode. --- examples/controlnet/train_controlnet_sdxl.py | 32 +++++++++++++++++--- 1 file changed, 28 insertions(+), 4 deletions(-) diff --git a/examples/controlnet/train_controlnet_sdxl.py b/examples/controlnet/train_controlnet_sdxl.py index 8b87db958d58..59a6b46ecb71 100644 --- a/examples/controlnet/train_controlnet_sdxl.py +++ b/examples/controlnet/train_controlnet_sdxl.py @@ -134,8 +134,19 @@ def log_validation(vae, unet, controlnet, args, accelerator, weight_dtype, step, for validation_prompt, validation_image in zip(validation_prompts, validation_images): validation_image = Image.open(validation_image).convert("RGB") - validation_image = validation_image.resize((args.resolution, args.resolution)) - + + # Use the same interpolation mode as in training + if args.interpolation_type.lower() == "lanczos": + interpolation_mode = transforms.InterpolationMode.LANCZOS + else: + interpolation_mode = transforms.InterpolationMode.BILINEAR + + transform = transforms.Compose([ + transforms.Resize(args.resolution, interpolation=interpolation_mode), + transforms.CenterCrop(args.resolution), + ]) + validation_image = transform(validation_image) + images = [] for _ in range(args.num_validation_images): @@ -587,6 +598,13 @@ def parse_args(input_args=None): " more information see https://huggingface.co/docs/accelerate/v0.17.0/en/package_reference/accelerator#accelerate.Accelerator" ), ) + parser.add_argument( + "--interpolation_type", + type=str, + default="lanczos", + choices=["lanczos", "bilinear"], + help="The interpolation method to use for resizing images. Choose between 'lanczos' (default) and 'bilinear'.", + ) if input_args is not None: args = parser.parse_args(input_args) @@ -732,9 +750,15 @@ def encode_prompt(prompt_batch, text_encoders, tokenizers, proportion_empty_prom def prepare_train_dataset(dataset, accelerator): + # Set the interpolation mode based on user preference + if args.interpolation_type.lower() == "lanczos": + interpolation_mode = transforms.InterpolationMode.LANCZOS + else: + interpolation_mode = transforms.InterpolationMode.BILINEAR + image_transforms = transforms.Compose( [ - transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR), + transforms.Resize(args.resolution, interpolation=interpolation_mode), transforms.CenterCrop(args.resolution), transforms.ToTensor(), transforms.Normalize([0.5], [0.5]), @@ -743,7 +767,7 @@ def prepare_train_dataset(dataset, accelerator): conditioning_image_transforms = transforms.Compose( [ - transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR), + transforms.Resize(args.resolution, interpolation=interpolation_mode), transforms.CenterCrop(args.resolution), transforms.ToTensor(), ] From a35acc9cd0a8acc254563aeec84cd3dc2cf1f992 Mon Sep 17 00:00:00 2001 From: Vaibhav Kumawat Date: Wed, 30 Apr 2025 18:33:23 +0530 Subject: [PATCH 2/4] LANCZOS as default interplotation --- examples/controlnet/train_controlnet_sdxl.py | 37 +++++++++++--------- 1 file changed, 21 insertions(+), 16 deletions(-) diff --git a/examples/controlnet/train_controlnet_sdxl.py b/examples/controlnet/train_controlnet_sdxl.py index 59a6b46ecb71..4a508d3a074a 100644 --- a/examples/controlnet/train_controlnet_sdxl.py +++ b/examples/controlnet/train_controlnet_sdxl.py @@ -135,18 +135,17 @@ def log_validation(vae, unet, controlnet, args, accelerator, weight_dtype, step, for validation_prompt, validation_image in zip(validation_prompts, validation_images): validation_image = Image.open(validation_image).convert("RGB") - # Use the same interpolation mode as in training - if args.interpolation_type.lower() == "lanczos": + # Get the interpolation mode from string + try: + interpolation_mode = getattr(transforms.InterpolationMode, args.interpolation_type.upper()) + except (AttributeError, KeyError): interpolation_mode = transforms.InterpolationMode.LANCZOS - else: - interpolation_mode = transforms.InterpolationMode.BILINEAR - - transform = transforms.Compose([ - transforms.Resize(args.resolution, interpolation=interpolation_mode), - transforms.CenterCrop(args.resolution), - ]) - validation_image = transform(validation_image) + validation_image = validation_image.resize( + (args.resolution, args.resolution), + resample=Image.Resampling.LANCZOS if interpolation_mode == transforms.InterpolationMode.LANCZOS else Image.Resampling.BILINEAR + ) + images = [] for _ in range(args.num_validation_images): @@ -602,8 +601,11 @@ def parse_args(input_args=None): "--interpolation_type", type=str, default="lanczos", - choices=["lanczos", "bilinear"], - help="The interpolation method to use for resizing images. Choose between 'lanczos' (default) and 'bilinear'.", + help=( + "The interpolation method to use for resizing images. Choose between 'bilinear', 'bicubic', 'lanczos', " + "'nearest', 'nearest-exact', 'area', etc. See https://pytorch.org/vision/stable/transforms.html for all " + "options. Default is 'lanczos'." + ), ) if input_args is not None: @@ -750,11 +752,14 @@ def encode_prompt(prompt_batch, text_encoders, tokenizers, proportion_empty_prom def prepare_train_dataset(dataset, accelerator): - # Set the interpolation mode based on user preference - if args.interpolation_type.lower() == "lanczos": + # Get the interpolation mode from string + try: + interpolation_mode = getattr(transforms.InterpolationMode, args.interpolation_type.upper()) + except (AttributeError, KeyError): + logger.warning( + f"Interpolation mode {args.interpolation_type} not found. Falling back to LANCZOS." + ) interpolation_mode = transforms.InterpolationMode.LANCZOS - else: - interpolation_mode = transforms.InterpolationMode.BILINEAR image_transforms = transforms.Compose( [ From e94b91d0331c3007d2dfd8a81df870d98151a31f Mon Sep 17 00:00:00 2001 From: Vaibhav Kumawat Date: Wed, 30 Apr 2025 19:37:33 +0530 Subject: [PATCH 3/4] LANCZOS as default interplotation mode --- examples/controlnet/train_controlnet_sdxl.py | 45 +++++++++++--------- 1 file changed, 26 insertions(+), 19 deletions(-) diff --git a/examples/controlnet/train_controlnet_sdxl.py b/examples/controlnet/train_controlnet_sdxl.py index 4a508d3a074a..feeceee2e6e4 100644 --- a/examples/controlnet/train_controlnet_sdxl.py +++ b/examples/controlnet/train_controlnet_sdxl.py @@ -135,16 +135,22 @@ def log_validation(vae, unet, controlnet, args, accelerator, weight_dtype, step, for validation_prompt, validation_image in zip(validation_prompts, validation_images): validation_image = Image.open(validation_image).convert("RGB") - # Get the interpolation mode from string try: - interpolation_mode = getattr(transforms.InterpolationMode, args.interpolation_type.upper()) + interpolation = getattr(transforms.InterpolationMode, args.image_interpolation_mode.upper()) except (AttributeError, KeyError): - interpolation_mode = transforms.InterpolationMode.LANCZOS - - validation_image = validation_image.resize( - (args.resolution, args.resolution), - resample=Image.Resampling.LANCZOS if interpolation_mode == transforms.InterpolationMode.LANCZOS else Image.Resampling.BILINEAR - ) + supported_interpolation_modes = [ + f.lower() for f in dir(transforms.InterpolationMode) if not f.startswith("__") and not f.endswith("__") + ] + raise ValueError( + f"Interpolation mode {args.image_interpolation_mode} is not supported. " + f"Please select one of the following: {', '.join(supported_interpolation_modes)}" + ) + + transform = transforms.Compose([ + transforms.Resize(args.resolution, interpolation=interpolation), + transforms.CenterCrop(args.resolution), + ]) + validation_image = transform(validation_image) images = [] @@ -598,14 +604,13 @@ def parse_args(input_args=None): ), ) parser.add_argument( - "--interpolation_type", + "--image_interpolation_mode", type=str, default="lanczos", - help=( - "The interpolation method to use for resizing images. Choose between 'bilinear', 'bicubic', 'lanczos', " - "'nearest', 'nearest-exact', 'area', etc. See https://pytorch.org/vision/stable/transforms.html for all " - "options. Default is 'lanczos'." - ), + choices=[ + f.lower() for f in dir(transforms.InterpolationMode) if not f.startswith("__") and not f.endswith("__") + ], + help="The image interpolation method to use for resizing images.", ) if input_args is not None: @@ -752,14 +757,16 @@ def encode_prompt(prompt_batch, text_encoders, tokenizers, proportion_empty_prom def prepare_train_dataset(dataset, accelerator): - # Get the interpolation mode from string try: - interpolation_mode = getattr(transforms.InterpolationMode, args.interpolation_type.upper()) + interpolation_mode = getattr(transforms.InterpolationMode, args.image_interpolation_mode.upper()) except (AttributeError, KeyError): - logger.warning( - f"Interpolation mode {args.interpolation_type} not found. Falling back to LANCZOS." + supported_interpolation_modes = [ + f.lower() for f in dir(transforms.InterpolationMode) if not f.startswith("__") and not f.endswith("__") + ] + raise ValueError( + f"Interpolation mode {args.image_interpolation_mode} is not supported. " + f"Please select one of the following: {', '.join(supported_interpolation_modes)}" ) - interpolation_mode = transforms.InterpolationMode.LANCZOS image_transforms = transforms.Compose( [ From 0236a91d3ebe2f081eab9ec9cc51df2ef34c7128 Mon Sep 17 00:00:00 2001 From: Vaibhav Date: Wed, 30 Apr 2025 23:07:47 +0530 Subject: [PATCH 4/4] Added LANCZOS as default interplotation mode --- examples/controlnet/train_controlnet_sdxl.py | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/examples/controlnet/train_controlnet_sdxl.py b/examples/controlnet/train_controlnet_sdxl.py index feeceee2e6e4..3368db1ec096 100644 --- a/examples/controlnet/train_controlnet_sdxl.py +++ b/examples/controlnet/train_controlnet_sdxl.py @@ -134,7 +134,7 @@ def log_validation(vae, unet, controlnet, args, accelerator, weight_dtype, step, for validation_prompt, validation_image in zip(validation_prompts, validation_images): validation_image = Image.open(validation_image).convert("RGB") - + try: interpolation = getattr(transforms.InterpolationMode, args.image_interpolation_mode.upper()) except (AttributeError, KeyError): @@ -145,11 +145,13 @@ def log_validation(vae, unet, controlnet, args, accelerator, weight_dtype, step, f"Interpolation mode {args.image_interpolation_mode} is not supported. " f"Please select one of the following: {', '.join(supported_interpolation_modes)}" ) - - transform = transforms.Compose([ - transforms.Resize(args.resolution, interpolation=interpolation), - transforms.CenterCrop(args.resolution), - ]) + + transform = transforms.Compose( + [ + transforms.Resize(args.resolution, interpolation=interpolation), + transforms.CenterCrop(args.resolution), + ] + ) validation_image = transform(validation_image) images = [] @@ -767,7 +769,7 @@ def prepare_train_dataset(dataset, accelerator): f"Interpolation mode {args.image_interpolation_mode} is not supported. " f"Please select one of the following: {', '.join(supported_interpolation_modes)}" ) - + image_transforms = transforms.Compose( [ transforms.Resize(args.resolution, interpolation=interpolation_mode),