|
24 | 24 | from typing import Optional
|
25 | 25 |
|
26 | 26 | import accelerate
|
| 27 | +import numpy as np |
27 | 28 | import torch
|
28 | 29 | import torch.nn.functional as F
|
29 | 30 | import torch.utils.checkpoint
|
|
40 | 41 | from transformers import AutoTokenizer, PretrainedConfig
|
41 | 42 |
|
42 | 43 | import diffusers
|
43 |
| -from diffusers import AutoencoderKL, DDPMScheduler, DiffusionPipeline, UNet2DConditionModel |
| 44 | +from diffusers import ( |
| 45 | + AutoencoderKL, |
| 46 | + DDPMScheduler, |
| 47 | + DiffusionPipeline, |
| 48 | + DPMSolverMultistepScheduler, |
| 49 | + UNet2DConditionModel, |
| 50 | +) |
44 | 51 | from diffusers.optimization import get_scheduler
|
45 |
| -from diffusers.utils import check_min_version |
| 52 | +from diffusers.utils import check_min_version, is_wandb_available |
46 | 53 | from diffusers.utils.import_utils import is_xformers_available
|
47 | 54 |
|
48 | 55 |
|
| 56 | +if is_wandb_available(): |
| 57 | + import wandb |
| 58 | + |
49 | 59 | # Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
50 | 60 | check_min_version("0.15.0.dev0")
|
51 | 61 |
|
52 | 62 | logger = get_logger(__name__)
|
53 | 63 |
|
54 | 64 |
|
| 65 | +def log_validation(text_encoder, tokenizer, unet, vae, args, accelerator, weight_dtype, epoch): |
| 66 | + logger.info( |
| 67 | + f"Running validation... \n Generating {args.num_validation_images} images with prompt:" |
| 68 | + f" {args.validation_prompt}." |
| 69 | + ) |
| 70 | + # create pipeline (note: unet and vae are loaded again in float32) |
| 71 | + pipeline = DiffusionPipeline.from_pretrained( |
| 72 | + args.pretrained_model_name_or_path, |
| 73 | + text_encoder=accelerator.unwrap_model(text_encoder), |
| 74 | + tokenizer=tokenizer, |
| 75 | + unet=accelerator.unwrap_model(unet), |
| 76 | + vae=vae, |
| 77 | + revision=args.revision, |
| 78 | + torch_dtype=weight_dtype, |
| 79 | + ) |
| 80 | + pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config) |
| 81 | + pipeline = pipeline.to(accelerator.device) |
| 82 | + pipeline.set_progress_bar_config(disable=True) |
| 83 | + |
| 84 | + # run inference |
| 85 | + generator = None if args.seed is None else torch.Generator(device=accelerator.device).manual_seed(args.seed) |
| 86 | + images = [] |
| 87 | + for _ in range(args.num_validation_images): |
| 88 | + with torch.autocast("cuda"): |
| 89 | + image = pipeline(args.validation_prompt, num_inference_steps=25, generator=generator).images[0] |
| 90 | + images.append(image) |
| 91 | + |
| 92 | + for tracker in accelerator.trackers: |
| 93 | + if tracker.name == "tensorboard": |
| 94 | + np_images = np.stack([np.asarray(img) for img in images]) |
| 95 | + tracker.writer.add_images("validation", np_images, epoch, dataformats="NHWC") |
| 96 | + if tracker.name == "wandb": |
| 97 | + tracker.log( |
| 98 | + { |
| 99 | + "validation": [ |
| 100 | + wandb.Image(image, caption=f"{i}: {args.validation_prompt}") for i, image in enumerate(images) |
| 101 | + ] |
| 102 | + } |
| 103 | + ) |
| 104 | + |
| 105 | + del pipeline |
| 106 | + torch.cuda.empty_cache() |
| 107 | + |
| 108 | + |
55 | 109 | def import_model_class_from_model_name_or_path(pretrained_model_name_or_path: str, revision: str):
|
56 | 110 | text_encoder_config = PretrainedConfig.from_pretrained(
|
57 | 111 | pretrained_model_name_or_path,
|
@@ -306,6 +360,28 @@ def parse_args(input_args=None):
|
306 | 360 | ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.'
|
307 | 361 | ),
|
308 | 362 | )
|
| 363 | + parser.add_argument( |
| 364 | + "--validation_prompt", |
| 365 | + type=str, |
| 366 | + default=None, |
| 367 | + help="A prompt that is used during validation to verify that the model is learning.", |
| 368 | + ) |
| 369 | + parser.add_argument( |
| 370 | + "--num_validation_images", |
| 371 | + type=int, |
| 372 | + default=4, |
| 373 | + help="Number of images that should be generated during validation with `validation_prompt`.", |
| 374 | + ) |
| 375 | + parser.add_argument( |
| 376 | + "--validation_steps", |
| 377 | + type=int, |
| 378 | + default=100, |
| 379 | + help=( |
| 380 | + "Run validation every X steps. Validation consists of running the prompt" |
| 381 | + " `args.validation_prompt` multiple times: `args.num_validation_images`" |
| 382 | + " and logging the images." |
| 383 | + ), |
| 384 | + ) |
309 | 385 | parser.add_argument(
|
310 | 386 | "--mixed_precision",
|
311 | 387 | type=str,
|
@@ -508,6 +584,10 @@ def main(args):
|
508 | 584 | project_config=accelerator_project_config,
|
509 | 585 | )
|
510 | 586 |
|
| 587 | + if args.report_to == "wandb": |
| 588 | + if not is_wandb_available(): |
| 589 | + raise ImportError("Make sure to install wandb if you want to use it for logging during training.") |
| 590 | + |
511 | 591 | # Currently, it's not possible to do gradient accumulation when training two models with accelerate.accumulate
|
512 | 592 | # This will be enabled soon in accelerate. For now, we don't allow gradient accumulation when training two models.
|
513 | 593 | # TODO (patil-suraj): Remove this check when gradient accumulation with two models is enabled in accelerate.
|
@@ -920,6 +1000,8 @@ def load_model_hook(models, input_dir):
|
920 | 1000 | save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
|
921 | 1001 | accelerator.save_state(save_path)
|
922 | 1002 | logger.info(f"Saved state to {save_path}")
|
| 1003 | + if args.validation_prompt is not None and global_step % args.validation_steps == 0: |
| 1004 | + log_validation(text_encoder, tokenizer, unet, vae, args, accelerator, weight_dtype, epoch) |
923 | 1005 |
|
924 | 1006 | logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
|
925 | 1007 | progress_bar.set_postfix(**logs)
|
|
0 commit comments