Skip to content

Commit 9c29e93

Browse files
authored
Set LANCZOS as the default interpolation method for image resizing. (#11492)
* Set LANCZOS as the default interpolation method for image resizing. * style: run make style and quality checks
1 parent 071807c commit 9c29e93

File tree

1 file changed

+14
-1
lines changed

1 file changed

+14
-1
lines changed

examples/advanced_diffusion_training/train_dreambooth_lora_sd15_advanced.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -673,6 +673,15 @@ def parse_args(input_args=None):
673673
default=False,
674674
help="Cache the VAE latents",
675675
)
676+
parser.add_argument(
677+
"--image_interpolation_mode",
678+
type=str,
679+
default="lanczos",
680+
choices=[
681+
f.lower() for f in dir(transforms.InterpolationMode) if not f.startswith("__") and not f.endswith("__")
682+
],
683+
help="The image interpolation method to use for resizing images.",
684+
)
676685

677686
if input_args is not None:
678687
args = parser.parse_args(input_args)
@@ -907,6 +916,10 @@ def __init__(
907916
self.num_instance_images = len(self.instance_images)
908917
self._length = self.num_instance_images
909918

919+
interpolation = getattr(transforms.InterpolationMode, args.image_interpolation_mode.upper(), None)
920+
if interpolation is None:
921+
raise ValueError(f"Unsupported interpolation mode {interpolation=}.")
922+
910923
if class_data_root is not None:
911924
self.class_data_root = Path(class_data_root)
912925
self.class_data_root.mkdir(parents=True, exist_ok=True)
@@ -921,7 +934,7 @@ def __init__(
921934

922935
self.image_transforms = transforms.Compose(
923936
[
924-
transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR),
937+
transforms.Resize(size, interpolation=interpolation),
925938
transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size),
926939
transforms.ToTensor(),
927940
transforms.Normalize([0.5], [0.5]),

0 commit comments

Comments
 (0)