diff --git a/examples/text_to_image/train_text_to_image_lora_sdxl.py b/examples/text_to_image/train_text_to_image_lora_sdxl.py index 606a88f55b32..16820221b0d2 100644 --- a/examples/text_to_image/train_text_to_image_lora_sdxl.py +++ b/examples/text_to_image/train_text_to_image_lora_sdxl.py @@ -834,6 +834,9 @@ def preprocess_train(examples): for image in images: original_sizes.append((image.height, image.width)) image = train_resize(image) + if args.random_flip and random.random() < 0.5: + # flip + image = train_flip(image) if args.center_crop: y1 = max(0, int(round((image.height - args.resolution) / 2.0))) x1 = max(0, int(round((image.width - args.resolution) / 2.0))) @@ -841,10 +844,6 @@ def preprocess_train(examples): else: y1, x1, h, w = train_crop.get_params(image, (args.resolution, args.resolution)) image = crop(image, y1, x1, h, w) - if args.random_flip and random.random() < 0.5: - # flip - x1 = image.width - x1 - image = train_flip(image) crop_top_left = (y1, x1) crop_top_lefts.append(crop_top_left) image = train_transforms(image) diff --git a/examples/text_to_image/train_text_to_image_sdxl.py b/examples/text_to_image/train_text_to_image_sdxl.py index 0bb57b1f3126..58b096beee8b 100644 --- a/examples/text_to_image/train_text_to_image_sdxl.py +++ b/examples/text_to_image/train_text_to_image_sdxl.py @@ -842,6 +842,9 @@ def preprocess_train(examples): for image in images: original_sizes.append((image.height, image.width)) image = train_resize(image) + if args.random_flip and random.random() < 0.5: + # flip + image = train_flip(image) if args.center_crop: y1 = max(0, int(round((image.height - args.resolution) / 2.0))) x1 = max(0, int(round((image.width - args.resolution) / 2.0))) @@ -849,10 +852,6 @@ def preprocess_train(examples): else: y1, x1, h, w = train_crop.get_params(image, (args.resolution, args.resolution)) image = crop(image, y1, x1, h, w) - if args.random_flip and random.random() < 0.5: - # flip - x1 = image.width - x1 - image = train_flip(image) crop_top_left = (y1, x1) crop_top_lefts.append(crop_top_left) image = train_transforms(image)