@@ -673,6 +673,15 @@ def parse_args(input_args=None):
673
673
default = False ,
674
674
help = "Cache the VAE latents" ,
675
675
)
676
+ parser .add_argument (
677
+ "--image_interpolation_mode" ,
678
+ type = str ,
679
+ default = "lanczos" ,
680
+ choices = [
681
+ f .lower () for f in dir (transforms .InterpolationMode ) if not f .startswith ("__" ) and not f .endswith ("__" )
682
+ ],
683
+ help = "The image interpolation method to use for resizing images." ,
684
+ )
676
685
677
686
if input_args is not None :
678
687
args = parser .parse_args (input_args )
@@ -907,6 +916,10 @@ def __init__(
907
916
self .num_instance_images = len (self .instance_images )
908
917
self ._length = self .num_instance_images
909
918
919
+ interpolation = getattr (transforms .InterpolationMode , args .image_interpolation_mode .upper (), None )
920
+ if interpolation is None :
921
+ raise ValueError (f"Unsupported interpolation mode { interpolation = } ." )
922
+
910
923
if class_data_root is not None :
911
924
self .class_data_root = Path (class_data_root )
912
925
self .class_data_root .mkdir (parents = True , exist_ok = True )
@@ -921,7 +934,7 @@ def __init__(
921
934
922
935
self .image_transforms = transforms .Compose (
923
936
[
924
- transforms .Resize (size , interpolation = transforms . InterpolationMode . BILINEAR ),
937
+ transforms .Resize (size , interpolation = interpolation ),
925
938
transforms .CenterCrop (size ) if center_crop else transforms .RandomCrop (size ),
926
939
transforms .ToTensor (),
927
940
transforms .Normalize ([0.5 ], [0.5 ]),
0 commit comments