@@ -760,17 +760,22 @@ def collate_fn(examples):
760
760
)
761
761
762
762
# Scheduler and math around the number of training steps.
763
- overrode_max_train_steps = False
764
- num_update_steps_per_epoch = math . ceil ( len ( train_dataloader ) / args . gradient_accumulation_steps )
763
+ # Check the PR https://github.com/huggingface/diffusers/pull/8312 for detailed explanation.
764
+ num_warmup_steps_for_scheduler = args . lr_warmup_steps * accelerator . num_processes
765
765
if args .max_train_steps is None :
766
- args .max_train_steps = args .num_train_epochs * num_update_steps_per_epoch
767
- overrode_max_train_steps = True
766
+ len_train_dataloader_after_sharding = math .ceil (len (train_dataloader ) / accelerator .num_processes )
767
+ num_update_steps_per_epoch = math .ceil (len_train_dataloader_after_sharding / args .gradient_accumulation_steps )
768
+ num_training_steps_for_scheduler = (
769
+ args .num_train_epochs * num_update_steps_per_epoch * accelerator .num_processes
770
+ )
771
+ else :
772
+ num_training_steps_for_scheduler = args .max_train_steps * accelerator .num_processes
768
773
769
774
lr_scheduler = get_scheduler (
770
775
args .lr_scheduler ,
771
776
optimizer = optimizer ,
772
- num_warmup_steps = args . lr_warmup_steps * accelerator . num_processes ,
773
- num_training_steps = args . max_train_steps * accelerator . num_processes ,
777
+ num_warmup_steps = num_warmup_steps_for_scheduler ,
778
+ num_training_steps = num_training_steps_for_scheduler ,
774
779
)
775
780
776
781
# Prepare everything with our `accelerator`.
@@ -780,8 +785,14 @@ def collate_fn(examples):
780
785
781
786
# We need to recalculate our total training steps as the size of the training dataloader may have changed.
782
787
num_update_steps_per_epoch = math .ceil (len (train_dataloader ) / args .gradient_accumulation_steps )
783
- if overrode_max_train_steps :
788
+ if args . max_train_steps is None :
784
789
args .max_train_steps = args .num_train_epochs * num_update_steps_per_epoch
790
+ if num_training_steps_for_scheduler != args .max_train_steps * accelerator .num_processes :
791
+ logger .warning (
792
+ f"The length of the 'train_dataloader' after 'accelerator.prepare' ({ len (train_dataloader )} ) does not match "
793
+ f"the expected length ({ len_train_dataloader_after_sharding } ) when the learning rate scheduler was created. "
794
+ f"This inconsistency may result in the learning rate scheduler not functioning properly."
795
+ )
785
796
# Afterwards we recalculate our number of training epochs
786
797
args .num_train_epochs = math .ceil (args .max_train_steps / num_update_steps_per_epoch )
787
798
0 commit comments