From 534df629f775e13b4343a7723a53c92332d0c285 Mon Sep 17 00:00:00 2001 From: sang-k Date: Tue, 16 Jan 2024 03:45:14 +0900 Subject: [PATCH 1/2] Enable image resizing to adjust its height and width in StableDiffusionXLInstructPix2PixPipeline --- .../pipeline_stable_diffusion_xl_instruct_pix2pix.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py index b06363cffd69..b653d8e9f778 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py @@ -858,7 +858,7 @@ def __call__( ) # 4. Preprocess image - image = self.image_processor.preprocess(image).to(device) + image = self.image_processor.preprocess(image, height=height, width=width).to(device) # 5. Prepare timesteps self.scheduler.set_timesteps(num_inference_steps, device=device) From 6a83c48e527be14262fc9e3e6c8e096dd0e54fdf Mon Sep 17 00:00:00 2001 From: sang-k Date: Tue, 16 Jan 2024 03:45:29 +0900 Subject: [PATCH 2/2] Ensure that validation is performed at every 'validation_step', not at every step --- examples/instruct_pix2pix/train_instruct_pix2pix_sdxl.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/instruct_pix2pix/train_instruct_pix2pix_sdxl.py b/examples/instruct_pix2pix/train_instruct_pix2pix_sdxl.py index 7a158f5d0d2d..cab16a633369 100644 --- a/examples/instruct_pix2pix/train_instruct_pix2pix_sdxl.py +++ b/examples/instruct_pix2pix/train_instruct_pix2pix_sdxl.py @@ -1109,7 +1109,7 @@ def collate_fn(examples): progress_bar.set_postfix(**logs) ### BEGIN: Perform validation every `validation_epochs` steps - if global_step % args.validation_steps == 0 or global_step == 1: + if global_step % args.validation_steps == 0: if (args.val_image_url_or_path is not None) and (args.validation_prompt is not None): logger.info( f"Running validation... \n Generating {args.num_validation_images} images with prompt:"