From 23d490995dd3effa023b5612fc3ea0afd38eca65 Mon Sep 17 00:00:00 2001 From: yijun-lee Date: Mon, 5 May 2025 20:07:02 +0900 Subject: [PATCH 1/2] Set LANCZOS as the default interpolation method for image resizing. --- .../train_dreambooth_lora_sd15_advanced.py | 15 ++++++++++++++- 1 file changed, 14 insertions(+), 1 deletion(-) diff --git a/examples/advanced_diffusion_training/train_dreambooth_lora_sd15_advanced.py b/examples/advanced_diffusion_training/train_dreambooth_lora_sd15_advanced.py index 90d5f86522c3..b6016ad78660 100644 --- a/examples/advanced_diffusion_training/train_dreambooth_lora_sd15_advanced.py +++ b/examples/advanced_diffusion_training/train_dreambooth_lora_sd15_advanced.py @@ -673,6 +673,15 @@ def parse_args(input_args=None): default=False, help="Cache the VAE latents", ) + parser.add_argument( + "--image_interpolation_mode", + type=str, + default="lanczos", + choices=[ + f.lower() for f in dir(transforms.InterpolationMode) if not f.startswith("__") and not f.endswith("__") + ], + help="The image interpolation method to use for resizing images.", + ) if input_args is not None: args = parser.parse_args(input_args) @@ -906,6 +915,10 @@ def __init__( self.instance_images.extend(itertools.repeat(img, repeats)) self.num_instance_images = len(self.instance_images) self._length = self.num_instance_images + + interpolation = getattr(transforms.InterpolationMode, args.image_interpolation_mode.upper(), None) + if interpolation is None: + raise ValueError(f"Unsupported interpolation mode {interpolation=}.") if class_data_root is not None: self.class_data_root = Path(class_data_root) @@ -921,7 +934,7 @@ def __init__( self.image_transforms = transforms.Compose( [ - transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR), + transforms.Resize(size, interpolation=interpolation), transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size), transforms.ToTensor(), transforms.Normalize([0.5], [0.5]), From c600c526c05329010fa0fd4941a5e4e267810839 Mon Sep 17 00:00:00 2001 From: yijun-lee Date: Tue, 6 May 2025 00:00:30 +0900 Subject: [PATCH 2/2] style: run make style and quality checks --- .../train_dreambooth_lora_sd15_advanced.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/advanced_diffusion_training/train_dreambooth_lora_sd15_advanced.py b/examples/advanced_diffusion_training/train_dreambooth_lora_sd15_advanced.py index b6016ad78660..58b1aa0e5618 100644 --- a/examples/advanced_diffusion_training/train_dreambooth_lora_sd15_advanced.py +++ b/examples/advanced_diffusion_training/train_dreambooth_lora_sd15_advanced.py @@ -915,7 +915,7 @@ def __init__( self.instance_images.extend(itertools.repeat(img, repeats)) self.num_instance_images = len(self.instance_images) self._length = self.num_instance_images - + interpolation = getattr(transforms.InterpolationMode, args.image_interpolation_mode.upper(), None) if interpolation is None: raise ValueError(f"Unsupported interpolation mode {interpolation=}.")