Skip to content

Commit 74c009f

Browse files
committed
apply changes from huggingface#8312
1 parent 256d77e commit 74c009f

File tree

1 file changed

+18
-7
lines changed

1 file changed

+18
-7
lines changed

examples/research_projects/resadapter/train_sd_resadapter.py

Lines changed: 18 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -760,17 +760,22 @@ def collate_fn(examples):
760760
)
761761

762762
# Scheduler and math around the number of training steps.
763-
overrode_max_train_steps = False
764-
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
763+
# Check the PR https://github.com/huggingface/diffusers/pull/8312 for detailed explanation.
764+
num_warmup_steps_for_scheduler = args.lr_warmup_steps * accelerator.num_processes
765765
if args.max_train_steps is None:
766-
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
767-
overrode_max_train_steps = True
766+
len_train_dataloader_after_sharding = math.ceil(len(train_dataloader) / accelerator.num_processes)
767+
num_update_steps_per_epoch = math.ceil(len_train_dataloader_after_sharding / args.gradient_accumulation_steps)
768+
num_training_steps_for_scheduler = (
769+
args.num_train_epochs * num_update_steps_per_epoch * accelerator.num_processes
770+
)
771+
else:
772+
num_training_steps_for_scheduler = args.max_train_steps * accelerator.num_processes
768773

769774
lr_scheduler = get_scheduler(
770775
args.lr_scheduler,
771776
optimizer=optimizer,
772-
num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes,
773-
num_training_steps=args.max_train_steps * accelerator.num_processes,
777+
num_warmup_steps=num_warmup_steps_for_scheduler,
778+
num_training_steps=num_training_steps_for_scheduler,
774779
)
775780

776781
# Prepare everything with our `accelerator`.
@@ -780,8 +785,14 @@ def collate_fn(examples):
780785

781786
# We need to recalculate our total training steps as the size of the training dataloader may have changed.
782787
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
783-
if overrode_max_train_steps:
788+
if args.max_train_steps is None:
784789
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
790+
if num_training_steps_for_scheduler != args.max_train_steps * accelerator.num_processes:
791+
logger.warning(
792+
f"The length of the 'train_dataloader' after 'accelerator.prepare' ({len(train_dataloader)}) does not match "
793+
f"the expected length ({len_train_dataloader_after_sharding}) when the learning rate scheduler was created. "
794+
f"This inconsistency may result in the learning rate scheduler not functioning properly."
795+
)
785796
# Afterwards we recalculate our number of training epochs
786797
args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
787798

0 commit comments

Comments
 (0)