1
1
import argparse
2
- import itertools
3
2
import math
4
3
import os
5
4
import random
@@ -147,6 +146,11 @@ def parse_args():
147
146
default = 1 ,
148
147
help = "Number of updates steps to accumulate before performing a backward/update pass." ,
149
148
)
149
+ parser .add_argument (
150
+ "--gradient_checkpointing" ,
151
+ action = "store_true" ,
152
+ help = "Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass." ,
153
+ )
150
154
parser .add_argument (
151
155
"--learning_rate" ,
152
156
type = float ,
@@ -383,11 +387,6 @@ def get_full_repo_name(model_id: str, organization: Optional[str] = None, token:
383
387
return f"{ organization } /{ model_id } "
384
388
385
389
386
- def freeze_params (params ):
387
- for param in params :
388
- param .requires_grad = False
389
-
390
-
391
390
def main ():
392
391
args = parse_args ()
393
392
logging_dir = os .path .join (args .output_dir , args .logging_dir )
@@ -460,6 +459,10 @@ def main():
460
459
revision = args .revision ,
461
460
)
462
461
462
+ if args .gradient_checkpointing :
463
+ text_encoder .gradient_checkpointing_enable ()
464
+ unet .enable_gradient_checkpointing ()
465
+
463
466
if args .enable_xformers_memory_efficient_attention :
464
467
if is_xformers_available ():
465
468
unet .enable_xformers_memory_efficient_attention ()
@@ -474,15 +477,12 @@ def main():
474
477
token_embeds [placeholder_token_id ] = token_embeds [initializer_token_id ]
475
478
476
479
# Freeze vae and unet
477
- freeze_params ( vae .parameters () )
478
- freeze_params ( unet .parameters () )
480
+ vae .requires_grad_ ( False )
481
+ unet .requires_grad_ ( False )
479
482
# Freeze all parameters except for the token embeddings in text encoder
480
- params_to_freeze = itertools .chain (
481
- text_encoder .text_model .encoder .parameters (),
482
- text_encoder .text_model .final_layer_norm .parameters (),
483
- text_encoder .text_model .embeddings .position_embedding .parameters (),
484
- )
485
- freeze_params (params_to_freeze )
483
+ text_encoder .text_model .encoder .requires_grad_ (False )
484
+ text_encoder .text_model .final_layer_norm .requires_grad_ (False )
485
+ text_encoder .text_model .embeddings .position_embedding .requires_grad_ (False )
486
486
487
487
if args .scale_lr :
488
488
args .learning_rate = (
@@ -541,9 +541,10 @@ def main():
541
541
unet .to (accelerator .device , dtype = weight_dtype )
542
542
vae .to (accelerator .device , dtype = weight_dtype )
543
543
544
- # Keep vae and unet in eval model as we don't train these
545
- vae .eval ()
546
- unet .eval ()
544
+ # Keep unet in train mode if we are using gradient checkpointing to save memory.
545
+ # The dropout is 0 so it doesn't matter if we are in eval or train mode.
546
+ if args .gradient_checkpointing :
547
+ unet .train ()
547
548
548
549
# We need to recalculate our total training steps as the size of the training dataloader may have changed.
549
550
num_update_steps_per_epoch = math .ceil (len (train_dataloader ) / args .gradient_accumulation_steps )
@@ -609,12 +610,11 @@ def main():
609
610
latents = latents * 0.18215
610
611
611
612
# Sample noise that we'll add to the latents
612
- noise = torch .randn (latents . shape ). to ( latents . device ). to ( dtype = weight_dtype )
613
+ noise = torch .randn_like (latents )
613
614
bsz = latents .shape [0 ]
614
615
# Sample a random timestep for each image
615
- timesteps = torch .randint (
616
- 0 , noise_scheduler .config .num_train_timesteps , (bsz ,), device = latents .device
617
- ).long ()
616
+ timesteps = torch .randint (0 , noise_scheduler .config .num_train_timesteps , (bsz ,), device = latents .device )
617
+ timesteps = timesteps .long ()
618
618
619
619
# Add noise to the latents according to the noise magnitude at each timestep
620
620
# (this is the forward diffusion process)
@@ -669,8 +669,7 @@ def main():
669
669
if global_step >= args .max_train_steps :
670
670
break
671
671
672
- accelerator .wait_for_everyone ()
673
-
672
+ accelerator .wait_for_everyone ()
674
673
# Create the pipeline using using the trained modules and save it.
675
674
if accelerator .is_main_process :
676
675
if args .push_to_hub and args .only_save_embeds :
0 commit comments