diff --git a/examples/instruct_pix2pix/train_instruct_pix2pix_sdxl.py b/examples/instruct_pix2pix/train_instruct_pix2pix_sdxl.py index cab16a633369..e20f0404c587 100644 --- a/examples/instruct_pix2pix/train_instruct_pix2pix_sdxl.py +++ b/examples/instruct_pix2pix/train_instruct_pix2pix_sdxl.py @@ -55,6 +55,9 @@ from diffusers.utils.torch_utils import is_compiled_module +if is_wandb_available(): + import wandb + # Will error if the minimal version of diffusers is not installed. Remove at your own risks. check_min_version("0.26.0.dev0") @@ -67,6 +70,57 @@ TORCH_DTYPE_MAPPING = {"fp32": torch.float32, "fp16": torch.float16, "bf16": torch.bfloat16} +def log_validation( + pipeline, + args, + accelerator, + generator, + global_step, + is_final_validation=False, +): + logger.info( + f"Running validation... \n Generating {args.num_validation_images} images with prompt:" + f" {args.validation_prompt}." + ) + + pipeline = pipeline.to(accelerator.device) + pipeline.set_progress_bar_config(disable=True) + + val_save_dir = os.path.join(args.output_dir, "validation_images") + if not os.path.exists(val_save_dir): + os.makedirs(val_save_dir) + + original_image = ( + lambda image_url_or_path: load_image(image_url_or_path) + if urlparse(image_url_or_path).scheme + else Image.open(image_url_or_path).convert("RGB") + )(args.val_image_url_or_path) + + with torch.autocast(str(accelerator.device).replace(":0", ""), enabled=accelerator.mixed_precision == "fp16"): + edited_images = [] + # Run inference + for val_img_idx in range(args.num_validation_images): + a_val_img = pipeline( + args.validation_prompt, + image=original_image, + num_inference_steps=20, + image_guidance_scale=1.5, + guidance_scale=7, + generator=generator, + ).images[0] + edited_images.append(a_val_img) + # Save validation images + a_val_img.save(os.path.join(val_save_dir, f"step_{global_step}_val_img_{val_img_idx}.png")) + + for tracker in accelerator.trackers: + if tracker.name == "wandb": + wandb_table = wandb.Table(columns=WANDB_TABLE_COL_NAMES) + for edited_image in edited_images: + wandb_table.add_data(wandb.Image(original_image), wandb.Image(edited_image), args.validation_prompt) + logger_name = "test" if is_final_validation else "validation" + tracker.log({logger_name: wandb_table}) + + def import_model_class_from_model_name_or_path( pretrained_model_name_or_path: str, revision: str, subfolder: str = "text_encoder" ): @@ -447,11 +501,6 @@ def main(): generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) - if args.report_to == "wandb": - if not is_wandb_available(): - raise ImportError("Make sure to install wandb if you want to use it for logging during training.") - import wandb - # Make one log on every process with the configuration for debugging. logging.basicConfig( format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", @@ -1111,11 +1160,6 @@ def collate_fn(examples): ### BEGIN: Perform validation every `validation_epochs` steps if global_step % args.validation_steps == 0: if (args.val_image_url_or_path is not None) and (args.validation_prompt is not None): - logger.info( - f"Running validation... \n Generating {args.num_validation_images} images with prompt:" - f" {args.validation_prompt}." - ) - # create pipeline if args.use_ema: # Store the UNet parameters temporarily and load the EMA parameters to perform inference. @@ -1135,44 +1179,16 @@ def collate_fn(examples): variant=args.variant, torch_dtype=weight_dtype, ) - pipeline = pipeline.to(accelerator.device) - pipeline.set_progress_bar_config(disable=True) - - # run inference - # Save validation images - val_save_dir = os.path.join(args.output_dir, "validation_images") - if not os.path.exists(val_save_dir): - os.makedirs(val_save_dir) - - original_image = ( - lambda image_url_or_path: load_image(image_url_or_path) - if urlparse(image_url_or_path).scheme - else Image.open(image_url_or_path).convert("RGB") - )(args.val_image_url_or_path) - with torch.autocast( - str(accelerator.device).replace(":0", ""), enabled=accelerator.mixed_precision == "fp16" - ): - edited_images = [] - for val_img_idx in range(args.num_validation_images): - a_val_img = pipeline( - args.validation_prompt, - image=original_image, - num_inference_steps=20, - image_guidance_scale=1.5, - guidance_scale=7, - generator=generator, - ).images[0] - edited_images.append(a_val_img) - a_val_img.save(os.path.join(val_save_dir, f"step_{global_step}_val_img_{val_img_idx}.png")) - - for tracker in accelerator.trackers: - if tracker.name == "wandb": - wandb_table = wandb.Table(columns=WANDB_TABLE_COL_NAMES) - for edited_image in edited_images: - wandb_table.add_data( - wandb.Image(original_image), wandb.Image(edited_image), args.validation_prompt - ) - tracker.log({"validation": wandb_table}) + + log_validation( + pipeline, + args, + accelerator, + generator, + global_step, + is_final_validation=False, + ) + if args.use_ema: # Switch back to the original UNet parameters. ema_unet.restore(unet.parameters()) @@ -1187,7 +1203,6 @@ def collate_fn(examples): # Create the pipeline using the trained modules and save it. accelerator.wait_for_everyone() if accelerator.is_main_process: - unet = unwrap_model(unet) if args.use_ema: ema_unet.copy_to(unet.parameters()) @@ -1198,10 +1213,11 @@ def collate_fn(examples): tokenizer=tokenizer_1, tokenizer_2=tokenizer_2, vae=vae, - unet=unet, + unet=unwrap_model(unet), revision=args.revision, variant=args.variant, ) + pipeline.save_pretrained(args.output_dir) if args.push_to_hub: @@ -1212,30 +1228,15 @@ def collate_fn(examples): ignore_patterns=["step_*", "epoch_*"], ) - if args.validation_prompt is not None: - edited_images = [] - pipeline = pipeline.to(accelerator.device) - with torch.autocast(str(accelerator.device).replace(":0", "")): - for _ in range(args.num_validation_images): - edited_images.append( - pipeline( - args.validation_prompt, - image=original_image, - num_inference_steps=20, - image_guidance_scale=1.5, - guidance_scale=7, - generator=generator, - ).images[0] - ) - - for tracker in accelerator.trackers: - if tracker.name == "wandb": - wandb_table = wandb.Table(columns=WANDB_TABLE_COL_NAMES) - for edited_image in edited_images: - wandb_table.add_data( - wandb.Image(original_image), wandb.Image(edited_image), args.validation_prompt - ) - tracker.log({"test": wandb_table}) + if (args.val_image_url_or_path is not None) and (args.validation_prompt is not None): + log_validation( + pipeline, + args, + accelerator, + generator, + global_step, + is_final_validation=True, + ) accelerator.end_training()