Skip to content

Commit c7da8fd

Browse files
yiyixuxuyiyixuxu
and
yiyixuxu
authored
add intermediate logging for dreambooth training script (#2557)
* add intermediate logging --------- Co-authored-by: yiyixuxu <yixu310@gmail,com>
1 parent b8bfef2 commit c7da8fd

File tree

1 file changed

+84
-2
lines changed

1 file changed

+84
-2
lines changed

examples/dreambooth/train_dreambooth.py

Lines changed: 84 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
from typing import Optional
2525

2626
import accelerate
27+
import numpy as np
2728
import torch
2829
import torch.nn.functional as F
2930
import torch.utils.checkpoint
@@ -40,18 +41,71 @@
4041
from transformers import AutoTokenizer, PretrainedConfig
4142

4243
import diffusers
43-
from diffusers import AutoencoderKL, DDPMScheduler, DiffusionPipeline, UNet2DConditionModel
44+
from diffusers import (
45+
AutoencoderKL,
46+
DDPMScheduler,
47+
DiffusionPipeline,
48+
DPMSolverMultistepScheduler,
49+
UNet2DConditionModel,
50+
)
4451
from diffusers.optimization import get_scheduler
45-
from diffusers.utils import check_min_version
52+
from diffusers.utils import check_min_version, is_wandb_available
4653
from diffusers.utils.import_utils import is_xformers_available
4754

4855

56+
if is_wandb_available():
57+
import wandb
58+
4959
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
5060
check_min_version("0.15.0.dev0")
5161

5262
logger = get_logger(__name__)
5363

5464

65+
def log_validation(text_encoder, tokenizer, unet, vae, args, accelerator, weight_dtype, epoch):
66+
logger.info(
67+
f"Running validation... \n Generating {args.num_validation_images} images with prompt:"
68+
f" {args.validation_prompt}."
69+
)
70+
# create pipeline (note: unet and vae are loaded again in float32)
71+
pipeline = DiffusionPipeline.from_pretrained(
72+
args.pretrained_model_name_or_path,
73+
text_encoder=accelerator.unwrap_model(text_encoder),
74+
tokenizer=tokenizer,
75+
unet=accelerator.unwrap_model(unet),
76+
vae=vae,
77+
revision=args.revision,
78+
torch_dtype=weight_dtype,
79+
)
80+
pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config)
81+
pipeline = pipeline.to(accelerator.device)
82+
pipeline.set_progress_bar_config(disable=True)
83+
84+
# run inference
85+
generator = None if args.seed is None else torch.Generator(device=accelerator.device).manual_seed(args.seed)
86+
images = []
87+
for _ in range(args.num_validation_images):
88+
with torch.autocast("cuda"):
89+
image = pipeline(args.validation_prompt, num_inference_steps=25, generator=generator).images[0]
90+
images.append(image)
91+
92+
for tracker in accelerator.trackers:
93+
if tracker.name == "tensorboard":
94+
np_images = np.stack([np.asarray(img) for img in images])
95+
tracker.writer.add_images("validation", np_images, epoch, dataformats="NHWC")
96+
if tracker.name == "wandb":
97+
tracker.log(
98+
{
99+
"validation": [
100+
wandb.Image(image, caption=f"{i}: {args.validation_prompt}") for i, image in enumerate(images)
101+
]
102+
}
103+
)
104+
105+
del pipeline
106+
torch.cuda.empty_cache()
107+
108+
55109
def import_model_class_from_model_name_or_path(pretrained_model_name_or_path: str, revision: str):
56110
text_encoder_config = PretrainedConfig.from_pretrained(
57111
pretrained_model_name_or_path,
@@ -306,6 +360,28 @@ def parse_args(input_args=None):
306360
' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.'
307361
),
308362
)
363+
parser.add_argument(
364+
"--validation_prompt",
365+
type=str,
366+
default=None,
367+
help="A prompt that is used during validation to verify that the model is learning.",
368+
)
369+
parser.add_argument(
370+
"--num_validation_images",
371+
type=int,
372+
default=4,
373+
help="Number of images that should be generated during validation with `validation_prompt`.",
374+
)
375+
parser.add_argument(
376+
"--validation_steps",
377+
type=int,
378+
default=100,
379+
help=(
380+
"Run validation every X steps. Validation consists of running the prompt"
381+
" `args.validation_prompt` multiple times: `args.num_validation_images`"
382+
" and logging the images."
383+
),
384+
)
309385
parser.add_argument(
310386
"--mixed_precision",
311387
type=str,
@@ -508,6 +584,10 @@ def main(args):
508584
project_config=accelerator_project_config,
509585
)
510586

587+
if args.report_to == "wandb":
588+
if not is_wandb_available():
589+
raise ImportError("Make sure to install wandb if you want to use it for logging during training.")
590+
511591
# Currently, it's not possible to do gradient accumulation when training two models with accelerate.accumulate
512592
# This will be enabled soon in accelerate. For now, we don't allow gradient accumulation when training two models.
513593
# TODO (patil-suraj): Remove this check when gradient accumulation with two models is enabled in accelerate.
@@ -920,6 +1000,8 @@ def load_model_hook(models, input_dir):
9201000
save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
9211001
accelerator.save_state(save_path)
9221002
logger.info(f"Saved state to {save_path}")
1003+
if args.validation_prompt is not None and global_step % args.validation_steps == 0:
1004+
log_validation(text_encoder, tokenizer, unet, vae, args, accelerator, weight_dtype, epoch)
9231005

9241006
logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
9251007
progress_bar.set_postfix(**logs)

0 commit comments

Comments
 (0)