Skip to content

Commit 8a366b8

Browse files
2510sayakpaul
andauthored
Fix gradient-checkpointing option is ignored in SDXL+LoRA training. (#6388) (#6402)
* Fix gradient-checkpointing option is ignored in SDXL+LoRA training. (#6388) * Fix gradient-checkpointing option is ignored in SD+LoRA training. * Fix gradient checkpoint is not applied to text encoders. (SDXL+LoRA) --------- Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
1 parent 61d223c commit 8a366b8

File tree

2 files changed

+9
-0
lines changed

2 files changed

+9
-0
lines changed

examples/text_to_image/train_text_to_image_lora.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -486,6 +486,9 @@ def main():
486486

487487
lora_layers = filter(lambda p: p.requires_grad, unet.parameters())
488488

489+
if args.gradient_checkpointing:
490+
unet.enable_gradient_checkpointing()
491+
489492
# Enable TF32 for faster training on Ampere GPUs,
490493
# cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
491494
if args.allow_tf32:

examples/text_to_image/train_text_to_image_lora_sdxl.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -706,6 +706,12 @@ def load_model_hook(models, input_dir):
706706
accelerator.register_save_state_pre_hook(save_model_hook)
707707
accelerator.register_load_state_pre_hook(load_model_hook)
708708

709+
if args.gradient_checkpointing:
710+
unet.enable_gradient_checkpointing()
711+
if args.train_text_encoder:
712+
text_encoder_one.gradient_checkpointing_enable()
713+
text_encoder_two.gradient_checkpointing_enable()
714+
709715
# Enable TF32 for faster training on Ampere GPUs,
710716
# cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
711717
if args.allow_tf32:

0 commit comments

Comments
 (0)