|
18 | 18 | import math
|
19 | 19 | import os
|
20 | 20 | import random
|
| 21 | +import warnings |
21 | 22 | from pathlib import Path
|
22 | 23 | from typing import Optional
|
23 | 24 |
|
|
54 | 55 | from diffusers.utils.import_utils import is_xformers_available
|
55 | 56 |
|
56 | 57 |
|
| 58 | +if is_wandb_available(): |
| 59 | + import wandb |
| 60 | + |
57 | 61 | if version.parse(version.parse(PIL.__version__).base_version) >= version.parse("9.1.0"):
|
58 | 62 | PIL_INTERPOLATION = {
|
59 | 63 | "linear": PIL.Image.Resampling.BILINEAR,
|
|
79 | 83 | logger = get_logger(__name__)
|
80 | 84 |
|
81 | 85 |
|
| 86 | +def log_validation(text_encoder, tokenizer, unet, vae, args, accelerator, weight_dtype, epoch): |
| 87 | + logger.info( |
| 88 | + f"Running validation... \n Generating {args.num_validation_images} images with prompt:" |
| 89 | + f" {args.validation_prompt}." |
| 90 | + ) |
| 91 | + # create pipeline (note: unet and vae are loaded again in float32) |
| 92 | + pipeline = DiffusionPipeline.from_pretrained( |
| 93 | + args.pretrained_model_name_or_path, |
| 94 | + text_encoder=accelerator.unwrap_model(text_encoder), |
| 95 | + tokenizer=tokenizer, |
| 96 | + unet=unet, |
| 97 | + vae=vae, |
| 98 | + revision=args.revision, |
| 99 | + torch_dtype=weight_dtype, |
| 100 | + ) |
| 101 | + pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config) |
| 102 | + pipeline = pipeline.to(accelerator.device) |
| 103 | + pipeline.set_progress_bar_config(disable=True) |
| 104 | + |
| 105 | + # run inference |
| 106 | + generator = None if args.seed is None else torch.Generator(device=accelerator.device).manual_seed(args.seed) |
| 107 | + images = [] |
| 108 | + for _ in range(args.num_validation_images): |
| 109 | + with torch.autocast("cuda"): |
| 110 | + image = pipeline(args.validation_prompt, num_inference_steps=25, generator=generator).images[0] |
| 111 | + images.append(image) |
| 112 | + |
| 113 | + for tracker in accelerator.trackers: |
| 114 | + if tracker.name == "tensorboard": |
| 115 | + np_images = np.stack([np.asarray(img) for img in images]) |
| 116 | + tracker.writer.add_images("validation", np_images, epoch, dataformats="NHWC") |
| 117 | + if tracker.name == "wandb": |
| 118 | + tracker.log( |
| 119 | + { |
| 120 | + "validation": [ |
| 121 | + wandb.Image(image, caption=f"{i}: {args.validation_prompt}") for i, image in enumerate(images) |
| 122 | + ] |
| 123 | + } |
| 124 | + ) |
| 125 | + |
| 126 | + del pipeline |
| 127 | + torch.cuda.empty_cache() |
| 128 | + |
| 129 | + |
82 | 130 | def save_progress(text_encoder, placeholder_token_id, accelerator, args, save_path):
|
83 | 131 | logger.info("Saving embeddings")
|
84 | 132 | learned_embeds = accelerator.unwrap_model(text_encoder).get_input_embeddings().weight[placeholder_token_id]
|
@@ -268,12 +316,22 @@ def parse_args():
|
268 | 316 | default=4,
|
269 | 317 | help="Number of images that should be generated during validation with `validation_prompt`.",
|
270 | 318 | )
|
| 319 | + parser.add_argument( |
| 320 | + "--validation_steps", |
| 321 | + type=int, |
| 322 | + default=100, |
| 323 | + help=( |
| 324 | + "Run validation every X steps. Validation consists of running the prompt" |
| 325 | + " `args.validation_prompt` multiple times: `args.num_validation_images`" |
| 326 | + " and logging the images." |
| 327 | + ), |
| 328 | + ) |
271 | 329 | parser.add_argument(
|
272 | 330 | "--validation_epochs",
|
273 | 331 | type=int,
|
274 |
| - default=50, |
| 332 | + default=None, |
275 | 333 | help=(
|
276 |
| - "Run validation every X epochs. Validation consists of running the prompt" |
| 334 | + "Deprecated in favor of validation_steps. Run validation every X epochs. Validation consists of running the prompt" |
277 | 335 | " `args.validation_prompt` multiple times: `args.num_validation_images`"
|
278 | 336 | " and logging the images."
|
279 | 337 | ),
|
@@ -488,7 +546,6 @@ def main():
|
488 | 546 | if args.report_to == "wandb":
|
489 | 547 | if not is_wandb_available():
|
490 | 548 | raise ImportError("Make sure to install wandb if you want to use it for logging during training.")
|
491 |
| - import wandb |
492 | 549 |
|
493 | 550 | # Make one log on every process with the configuration for debugging.
|
494 | 551 | logging.basicConfig(
|
@@ -627,6 +684,15 @@ def main():
|
627 | 684 | train_dataloader = torch.utils.data.DataLoader(
|
628 | 685 | train_dataset, batch_size=args.train_batch_size, shuffle=True, num_workers=args.dataloader_num_workers
|
629 | 686 | )
|
| 687 | + if args.validation_epochs is not None: |
| 688 | + warnings.warn( |
| 689 | + f"FutureWarning: You are doing logging with validation_epochs={args.validation_epochs}." |
| 690 | + " Deprecated validation_epochs in favor of `validation_steps`" |
| 691 | + f"Setting `args.validation_steps` to {args.validation_epochs * len(train_dataset)}", |
| 692 | + FutureWarning, |
| 693 | + stacklevel=2, |
| 694 | + ) |
| 695 | + args.validation_steps = args.validation_epochs * len(train_dataset) |
630 | 696 |
|
631 | 697 | # Scheduler and math around the number of training steps.
|
632 | 698 | overrode_max_train_steps = False
|
@@ -683,7 +749,6 @@ def main():
|
683 | 749 | logger.info(f" Total optimization steps = {args.max_train_steps}")
|
684 | 750 | global_step = 0
|
685 | 751 | first_epoch = 0
|
686 |
| - |
687 | 752 | # Potentially load in the weights and states from a previous save
|
688 | 753 | if args.resume_from_checkpoint:
|
689 | 754 | if args.resume_from_checkpoint != "latest":
|
@@ -783,60 +848,15 @@ def main():
|
783 | 848 | save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
|
784 | 849 | accelerator.save_state(save_path)
|
785 | 850 | logger.info(f"Saved state to {save_path}")
|
| 851 | + if args.validation_prompt is not None and global_step % args.validation_steps == 0: |
| 852 | + log_validation(text_encoder, tokenizer, unet, vae, args, accelerator, weight_dtype, epoch) |
786 | 853 |
|
787 | 854 | logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
|
788 | 855 | progress_bar.set_postfix(**logs)
|
789 | 856 | accelerator.log(logs, step=global_step)
|
790 | 857 |
|
791 | 858 | if global_step >= args.max_train_steps:
|
792 | 859 | break
|
793 |
| - |
794 |
| - if args.validation_prompt is not None and epoch % args.validation_epochs == 0: |
795 |
| - logger.info( |
796 |
| - f"Running validation... \n Generating {args.num_validation_images} images with prompt:" |
797 |
| - f" {args.validation_prompt}." |
798 |
| - ) |
799 |
| - # create pipeline (note: unet and vae are loaded again in float32) |
800 |
| - pipeline = DiffusionPipeline.from_pretrained( |
801 |
| - args.pretrained_model_name_or_path, |
802 |
| - text_encoder=accelerator.unwrap_model(text_encoder), |
803 |
| - tokenizer=tokenizer, |
804 |
| - unet=unet, |
805 |
| - vae=vae, |
806 |
| - revision=args.revision, |
807 |
| - torch_dtype=weight_dtype, |
808 |
| - ) |
809 |
| - pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config) |
810 |
| - pipeline = pipeline.to(accelerator.device) |
811 |
| - pipeline.set_progress_bar_config(disable=True) |
812 |
| - |
813 |
| - # run inference |
814 |
| - generator = ( |
815 |
| - None if args.seed is None else torch.Generator(device=accelerator.device).manual_seed(args.seed) |
816 |
| - ) |
817 |
| - images = [] |
818 |
| - for _ in range(args.num_validation_images): |
819 |
| - with torch.autocast("cuda"): |
820 |
| - image = pipeline(args.validation_prompt, num_inference_steps=25, generator=generator).images[0] |
821 |
| - images.append(image) |
822 |
| - |
823 |
| - for tracker in accelerator.trackers: |
824 |
| - if tracker.name == "tensorboard": |
825 |
| - np_images = np.stack([np.asarray(img) for img in images]) |
826 |
| - tracker.writer.add_images("validation", np_images, epoch, dataformats="NHWC") |
827 |
| - if tracker.name == "wandb": |
828 |
| - tracker.log( |
829 |
| - { |
830 |
| - "validation": [ |
831 |
| - wandb.Image(image, caption=f"{i}: {args.validation_prompt}") |
832 |
| - for i, image in enumerate(images) |
833 |
| - ] |
834 |
| - } |
835 |
| - ) |
836 |
| - |
837 |
| - del pipeline |
838 |
| - torch.cuda.empty_cache() |
839 |
| - |
840 | 860 | # Create the pipeline using using the trained modules and save it.
|
841 | 861 | accelerator.wait_for_everyone()
|
842 | 862 | if accelerator.is_main_process:
|
|
0 commit comments