Skip to content

Commit 9ea7052

Browse files
authored
[textual inversion] add gradient checkpointing and small fixes. (#1848)
Co-authored-by: Henrik Forstén <henrik.forsten@gmail.com> * update TI script * make flake happy * fix typo
1 parent 03bf877 commit 9ea7052

File tree

1 file changed

+22
-23
lines changed

1 file changed

+22
-23
lines changed

examples/textual_inversion/textual_inversion.py

Lines changed: 22 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
import argparse
2-
import itertools
32
import math
43
import os
54
import random
@@ -147,6 +146,11 @@ def parse_args():
147146
default=1,
148147
help="Number of updates steps to accumulate before performing a backward/update pass.",
149148
)
149+
parser.add_argument(
150+
"--gradient_checkpointing",
151+
action="store_true",
152+
help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.",
153+
)
150154
parser.add_argument(
151155
"--learning_rate",
152156
type=float,
@@ -383,11 +387,6 @@ def get_full_repo_name(model_id: str, organization: Optional[str] = None, token:
383387
return f"{organization}/{model_id}"
384388

385389

386-
def freeze_params(params):
387-
for param in params:
388-
param.requires_grad = False
389-
390-
391390
def main():
392391
args = parse_args()
393392
logging_dir = os.path.join(args.output_dir, args.logging_dir)
@@ -460,6 +459,10 @@ def main():
460459
revision=args.revision,
461460
)
462461

462+
if args.gradient_checkpointing:
463+
text_encoder.gradient_checkpointing_enable()
464+
unet.enable_gradient_checkpointing()
465+
463466
if args.enable_xformers_memory_efficient_attention:
464467
if is_xformers_available():
465468
unet.enable_xformers_memory_efficient_attention()
@@ -474,15 +477,12 @@ def main():
474477
token_embeds[placeholder_token_id] = token_embeds[initializer_token_id]
475478

476479
# Freeze vae and unet
477-
freeze_params(vae.parameters())
478-
freeze_params(unet.parameters())
480+
vae.requires_grad_(False)
481+
unet.requires_grad_(False)
479482
# Freeze all parameters except for the token embeddings in text encoder
480-
params_to_freeze = itertools.chain(
481-
text_encoder.text_model.encoder.parameters(),
482-
text_encoder.text_model.final_layer_norm.parameters(),
483-
text_encoder.text_model.embeddings.position_embedding.parameters(),
484-
)
485-
freeze_params(params_to_freeze)
483+
text_encoder.text_model.encoder.requires_grad_(False)
484+
text_encoder.text_model.final_layer_norm.requires_grad_(False)
485+
text_encoder.text_model.embeddings.position_embedding.requires_grad_(False)
486486

487487
if args.scale_lr:
488488
args.learning_rate = (
@@ -541,9 +541,10 @@ def main():
541541
unet.to(accelerator.device, dtype=weight_dtype)
542542
vae.to(accelerator.device, dtype=weight_dtype)
543543

544-
# Keep vae and unet in eval model as we don't train these
545-
vae.eval()
546-
unet.eval()
544+
# Keep unet in train mode if we are using gradient checkpointing to save memory.
545+
# The dropout is 0 so it doesn't matter if we are in eval or train mode.
546+
if args.gradient_checkpointing:
547+
unet.train()
547548

548549
# We need to recalculate our total training steps as the size of the training dataloader may have changed.
549550
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
@@ -609,12 +610,11 @@ def main():
609610
latents = latents * 0.18215
610611

611612
# Sample noise that we'll add to the latents
612-
noise = torch.randn(latents.shape).to(latents.device).to(dtype=weight_dtype)
613+
noise = torch.randn_like(latents)
613614
bsz = latents.shape[0]
614615
# Sample a random timestep for each image
615-
timesteps = torch.randint(
616-
0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device
617-
).long()
616+
timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device)
617+
timesteps = timesteps.long()
618618

619619
# Add noise to the latents according to the noise magnitude at each timestep
620620
# (this is the forward diffusion process)
@@ -669,8 +669,7 @@ def main():
669669
if global_step >= args.max_train_steps:
670670
break
671671

672-
accelerator.wait_for_everyone()
673-
672+
accelerator.wait_for_everyone()
674673
# Create the pipeline using using the trained modules and save it.
675674
if accelerator.is_main_process:
676675
if args.push_to_hub and args.only_save_embeds:

0 commit comments

Comments
 (0)