Skip to content

Commit 7780eea

Browse files
committed
fix(training): lr scheduler doesn't work properly in distributed scenarios
1 parent 42cae93 commit 7780eea

File tree

1 file changed

+18
-7
lines changed

1 file changed

+18
-7
lines changed

examples/text_to_image/train_text_to_image_lora.py

Lines changed: 18 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -697,17 +697,22 @@ def collate_fn(examples):
697697
)
698698

699699
# Scheduler and math around the number of training steps.
700-
overrode_max_train_steps = False
701-
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
700+
# Check the PR https://github.com/huggingface/diffusers/pull/8312 for detailed explanation.
701+
num_warmup_steps_for_scheduler = args.lr_warmup_steps * accelerator.num_processes
702702
if args.max_train_steps is None:
703-
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
704-
overrode_max_train_steps = True
703+
len_train_dataloader_after_sharding = math.ceil(len(train_dataloader) / accelerator.num_processes)
704+
num_update_steps_per_epoch = math.ceil(len_train_dataloader_after_sharding / args.gradient_accumulation_steps)
705+
num_training_steps_for_scheduler = (
706+
args.num_train_epochs * num_update_steps_per_epoch * accelerator.num_processes
707+
)
708+
else:
709+
num_training_steps_for_scheduler = args.max_train_steps * accelerator.num_processes
705710

706711
lr_scheduler = get_scheduler(
707712
args.lr_scheduler,
708713
optimizer=optimizer,
709-
num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes,
710-
num_training_steps=args.max_train_steps * accelerator.num_processes,
714+
num_warmup_steps=num_warmup_steps_for_scheduler,
715+
num_training_steps=num_training_steps_for_scheduler,
711716
)
712717

713718
# Prepare everything with our `accelerator`.
@@ -717,8 +722,14 @@ def collate_fn(examples):
717722

718723
# We need to recalculate our total training steps as the size of the training dataloader may have changed.
719724
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
720-
if overrode_max_train_steps:
725+
if args.max_train_steps is None:
721726
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
727+
if num_training_steps_for_scheduler != args.max_train_steps * accelerator.num_processes:
728+
logger.warning(
729+
f"The length of the 'train_dataloader' after 'accelerator.prepare' ({len(train_dataloader)}) does not match "
730+
f"the expected length ({len_train_dataloader_after_sharding}) when the learning rate scheduler was created. "
731+
f"This inconsistency may result in the learning rate scheduler not functioning properly."
732+
)
722733
# Afterwards we recalculate our number of training epochs
723734
args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
724735

0 commit comments

Comments
 (0)