From 25eab22b617c3139d137076d4ad6ce99ae0a0cf6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=C3=81lvaro=20Somoza?= Date: Fri, 4 Apr 2025 06:53:54 -0300 Subject: [PATCH 1/3] initial --- examples/dreambooth/train_dreambooth_lora_sdxl.py | 14 +++++++++++++- 1 file changed, 13 insertions(+), 1 deletion(-) diff --git a/examples/dreambooth/train_dreambooth_lora_sdxl.py b/examples/dreambooth/train_dreambooth_lora_sdxl.py index f0d993ad9bbc..aa63c96f1764 100644 --- a/examples/dreambooth/train_dreambooth_lora_sdxl.py +++ b/examples/dreambooth/train_dreambooth_lora_sdxl.py @@ -669,6 +669,13 @@ def parse_args(input_args=None): ), ) + parser.add_argument( + "--image_interpolation_mode", + type=str, + default="lanczos", + help="The image interpolation method to use for resizing images.", + ) + if input_args is not None: args = parser.parse_args(input_args) else: @@ -790,7 +797,12 @@ def __init__( self.original_sizes = [] self.crop_top_lefts = [] self.pixel_values = [] - train_resize = transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR) + + if args.image_interpolation_mode == "bilinear": + train_resize = transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR) + else: + train_resize = transforms.Resize(size, interpolation=transforms.InterpolationMode.LANCZOS) + train_crop = transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size) train_flip = transforms.RandomHorizontalFlip(p=1.0) train_transforms = transforms.Compose( From 1b9f29e41b8a1012fa739475ec513f17f2f14a7e Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Tue, 8 Apr 2025 11:46:57 +0530 Subject: [PATCH 2/3] Update examples/dreambooth/train_dreambooth_lora_sdxl.py Co-authored-by: hlky --- examples/dreambooth/train_dreambooth_lora_sdxl.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/examples/dreambooth/train_dreambooth_lora_sdxl.py b/examples/dreambooth/train_dreambooth_lora_sdxl.py index aa63c96f1764..f273b3c3c82e 100644 --- a/examples/dreambooth/train_dreambooth_lora_sdxl.py +++ b/examples/dreambooth/train_dreambooth_lora_sdxl.py @@ -798,10 +798,10 @@ def __init__( self.crop_top_lefts = [] self.pixel_values = [] - if args.image_interpolation_mode == "bilinear": - train_resize = transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR) - else: - train_resize = transforms.Resize(size, interpolation=transforms.InterpolationMode.LANCZOS) + interpolation = getattr(transforms.InterpolationMode, args.image_interpolation_mode.upper(), None) + if interpolation is None: + raise ValueError(f"Unsupported interpolation mode.") + train_resize = transforms.Resize(size, interpolation=interpolation) train_crop = transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size) train_flip = transforms.RandomHorizontalFlip(p=1.0) From 0a3226989c78d2e1274895278c96ba75ae7eaf3f Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Tue, 8 Apr 2025 11:54:22 +0530 Subject: [PATCH 3/3] update --- examples/dreambooth/train_dreambooth_lora_sdxl.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/examples/dreambooth/train_dreambooth_lora_sdxl.py b/examples/dreambooth/train_dreambooth_lora_sdxl.py index f273b3c3c82e..735d48b83400 100644 --- a/examples/dreambooth/train_dreambooth_lora_sdxl.py +++ b/examples/dreambooth/train_dreambooth_lora_sdxl.py @@ -673,6 +673,9 @@ def parse_args(input_args=None): "--image_interpolation_mode", type=str, default="lanczos", + choices=[ + f.lower() for f in dir(transforms.InterpolationMode) if not f.startswith("__") and not f.endswith("__") + ], help="The image interpolation method to use for resizing images.", ) @@ -800,7 +803,7 @@ def __init__( interpolation = getattr(transforms.InterpolationMode, args.image_interpolation_mode.upper(), None) if interpolation is None: - raise ValueError(f"Unsupported interpolation mode.") + raise ValueError(f"Unsupported interpolation mode {interpolation=}.") train_resize = transforms.Resize(size, interpolation=interpolation) train_crop = transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size)