Skip to content

Commit d9b9533

Browse files
Textual inv make save log both steps (#2178)
* Initial commit * removed images * Made logging the same as save * Removed logging function * Quality fixes * Quality fixes * Tested * Added support back for validation_epochs * Fixing styles * Did changes * Change to log_validation * Add extra space after wandb import * Add extra space after wandb Co-authored-by: Will Berman <wlbberman@gmail.com> * Fixed spacing --------- Co-authored-by: Will Berman <wlbberman@gmail.com>
1 parent 8014848 commit d9b9533

File tree

1 file changed

+71
-51
lines changed

1 file changed

+71
-51
lines changed

examples/textual_inversion/textual_inversion.py

Lines changed: 71 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
import math
1919
import os
2020
import random
21+
import warnings
2122
from pathlib import Path
2223
from typing import Optional
2324

@@ -54,6 +55,9 @@
5455
from diffusers.utils.import_utils import is_xformers_available
5556

5657

58+
if is_wandb_available():
59+
import wandb
60+
5761
if version.parse(version.parse(PIL.__version__).base_version) >= version.parse("9.1.0"):
5862
PIL_INTERPOLATION = {
5963
"linear": PIL.Image.Resampling.BILINEAR,
@@ -79,6 +83,50 @@
7983
logger = get_logger(__name__)
8084

8185

86+
def log_validation(text_encoder, tokenizer, unet, vae, args, accelerator, weight_dtype, epoch):
87+
logger.info(
88+
f"Running validation... \n Generating {args.num_validation_images} images with prompt:"
89+
f" {args.validation_prompt}."
90+
)
91+
# create pipeline (note: unet and vae are loaded again in float32)
92+
pipeline = DiffusionPipeline.from_pretrained(
93+
args.pretrained_model_name_or_path,
94+
text_encoder=accelerator.unwrap_model(text_encoder),
95+
tokenizer=tokenizer,
96+
unet=unet,
97+
vae=vae,
98+
revision=args.revision,
99+
torch_dtype=weight_dtype,
100+
)
101+
pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config)
102+
pipeline = pipeline.to(accelerator.device)
103+
pipeline.set_progress_bar_config(disable=True)
104+
105+
# run inference
106+
generator = None if args.seed is None else torch.Generator(device=accelerator.device).manual_seed(args.seed)
107+
images = []
108+
for _ in range(args.num_validation_images):
109+
with torch.autocast("cuda"):
110+
image = pipeline(args.validation_prompt, num_inference_steps=25, generator=generator).images[0]
111+
images.append(image)
112+
113+
for tracker in accelerator.trackers:
114+
if tracker.name == "tensorboard":
115+
np_images = np.stack([np.asarray(img) for img in images])
116+
tracker.writer.add_images("validation", np_images, epoch, dataformats="NHWC")
117+
if tracker.name == "wandb":
118+
tracker.log(
119+
{
120+
"validation": [
121+
wandb.Image(image, caption=f"{i}: {args.validation_prompt}") for i, image in enumerate(images)
122+
]
123+
}
124+
)
125+
126+
del pipeline
127+
torch.cuda.empty_cache()
128+
129+
82130
def save_progress(text_encoder, placeholder_token_id, accelerator, args, save_path):
83131
logger.info("Saving embeddings")
84132
learned_embeds = accelerator.unwrap_model(text_encoder).get_input_embeddings().weight[placeholder_token_id]
@@ -268,12 +316,22 @@ def parse_args():
268316
default=4,
269317
help="Number of images that should be generated during validation with `validation_prompt`.",
270318
)
319+
parser.add_argument(
320+
"--validation_steps",
321+
type=int,
322+
default=100,
323+
help=(
324+
"Run validation every X steps. Validation consists of running the prompt"
325+
" `args.validation_prompt` multiple times: `args.num_validation_images`"
326+
" and logging the images."
327+
),
328+
)
271329
parser.add_argument(
272330
"--validation_epochs",
273331
type=int,
274-
default=50,
332+
default=None,
275333
help=(
276-
"Run validation every X epochs. Validation consists of running the prompt"
334+
"Deprecated in favor of validation_steps. Run validation every X epochs. Validation consists of running the prompt"
277335
" `args.validation_prompt` multiple times: `args.num_validation_images`"
278336
" and logging the images."
279337
),
@@ -488,7 +546,6 @@ def main():
488546
if args.report_to == "wandb":
489547
if not is_wandb_available():
490548
raise ImportError("Make sure to install wandb if you want to use it for logging during training.")
491-
import wandb
492549

493550
# Make one log on every process with the configuration for debugging.
494551
logging.basicConfig(
@@ -627,6 +684,15 @@ def main():
627684
train_dataloader = torch.utils.data.DataLoader(
628685
train_dataset, batch_size=args.train_batch_size, shuffle=True, num_workers=args.dataloader_num_workers
629686
)
687+
if args.validation_epochs is not None:
688+
warnings.warn(
689+
f"FutureWarning: You are doing logging with validation_epochs={args.validation_epochs}."
690+
" Deprecated validation_epochs in favor of `validation_steps`"
691+
f"Setting `args.validation_steps` to {args.validation_epochs * len(train_dataset)}",
692+
FutureWarning,
693+
stacklevel=2,
694+
)
695+
args.validation_steps = args.validation_epochs * len(train_dataset)
630696

631697
# Scheduler and math around the number of training steps.
632698
overrode_max_train_steps = False
@@ -683,7 +749,6 @@ def main():
683749
logger.info(f" Total optimization steps = {args.max_train_steps}")
684750
global_step = 0
685751
first_epoch = 0
686-
687752
# Potentially load in the weights and states from a previous save
688753
if args.resume_from_checkpoint:
689754
if args.resume_from_checkpoint != "latest":
@@ -783,60 +848,15 @@ def main():
783848
save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
784849
accelerator.save_state(save_path)
785850
logger.info(f"Saved state to {save_path}")
851+
if args.validation_prompt is not None and global_step % args.validation_steps == 0:
852+
log_validation(text_encoder, tokenizer, unet, vae, args, accelerator, weight_dtype, epoch)
786853

787854
logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
788855
progress_bar.set_postfix(**logs)
789856
accelerator.log(logs, step=global_step)
790857

791858
if global_step >= args.max_train_steps:
792859
break
793-
794-
if args.validation_prompt is not None and epoch % args.validation_epochs == 0:
795-
logger.info(
796-
f"Running validation... \n Generating {args.num_validation_images} images with prompt:"
797-
f" {args.validation_prompt}."
798-
)
799-
# create pipeline (note: unet and vae are loaded again in float32)
800-
pipeline = DiffusionPipeline.from_pretrained(
801-
args.pretrained_model_name_or_path,
802-
text_encoder=accelerator.unwrap_model(text_encoder),
803-
tokenizer=tokenizer,
804-
unet=unet,
805-
vae=vae,
806-
revision=args.revision,
807-
torch_dtype=weight_dtype,
808-
)
809-
pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config)
810-
pipeline = pipeline.to(accelerator.device)
811-
pipeline.set_progress_bar_config(disable=True)
812-
813-
# run inference
814-
generator = (
815-
None if args.seed is None else torch.Generator(device=accelerator.device).manual_seed(args.seed)
816-
)
817-
images = []
818-
for _ in range(args.num_validation_images):
819-
with torch.autocast("cuda"):
820-
image = pipeline(args.validation_prompt, num_inference_steps=25, generator=generator).images[0]
821-
images.append(image)
822-
823-
for tracker in accelerator.trackers:
824-
if tracker.name == "tensorboard":
825-
np_images = np.stack([np.asarray(img) for img in images])
826-
tracker.writer.add_images("validation", np_images, epoch, dataformats="NHWC")
827-
if tracker.name == "wandb":
828-
tracker.log(
829-
{
830-
"validation": [
831-
wandb.Image(image, caption=f"{i}: {args.validation_prompt}")
832-
for i, image in enumerate(images)
833-
]
834-
}
835-
)
836-
837-
del pipeline
838-
torch.cuda.empty_cache()
839-
840860
# Create the pipeline using using the trained modules and save it.
841861
accelerator.wait_for_everyone()
842862
if accelerator.is_main_process:

0 commit comments

Comments
 (0)