Skip to content

Commit f6742ea

Browse files
committed
fix(training): lr scheduler doesn't work properly in distributed scenario
1 parent 42cae93 commit f6742ea

File tree

47 files changed

+843
-324
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

47 files changed

+843
-324
lines changed

examples/advanced_diffusion_training/train_dreambooth_lora_sd15_advanced.py

Lines changed: 18 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1524,17 +1524,22 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers):
15241524
torch.cuda.empty_cache()
15251525

15261526
# Scheduler and math around the number of training steps.
1527-
overrode_max_train_steps = False
1528-
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
1527+
# Check the PR https://github.com/huggingface/diffusers/pull/8312 for detailed explanation.
1528+
num_warmup_steps_for_scheduler = args.lr_warmup_steps * accelerator.num_processes
15291529
if args.max_train_steps is None:
1530-
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
1531-
overrode_max_train_steps = True
1530+
len_train_dataloader_after_sharding = math.ceil(len(train_dataloader) / accelerator.num_processes)
1531+
num_update_steps_per_epoch = math.ceil(len_train_dataloader_after_sharding / args.gradient_accumulation_steps)
1532+
num_training_steps_for_scheduler = (
1533+
args.num_train_epochs * num_update_steps_per_epoch * accelerator.num_processes
1534+
)
1535+
else:
1536+
num_training_steps_for_scheduler = args.max_train_steps * accelerator.num_processes
15321537

15331538
lr_scheduler = get_scheduler(
15341539
args.lr_scheduler,
15351540
optimizer=optimizer,
1536-
num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes,
1537-
num_training_steps=args.max_train_steps * accelerator.num_processes,
1541+
num_warmup_steps=num_warmup_steps_for_scheduler,
1542+
num_training_steps=num_training_steps_for_scheduler,
15381543
num_cycles=args.lr_num_cycles,
15391544
power=args.lr_power,
15401545
)
@@ -1551,8 +1556,14 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers):
15511556

15521557
# We need to recalculate our total training steps as the size of the training dataloader may have changed.
15531558
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
1554-
if overrode_max_train_steps:
1559+
if args.max_train_steps is None:
15551560
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
1561+
if num_training_steps_for_scheduler != args.max_train_steps * accelerator.num_processes:
1562+
logger.warning(
1563+
f"The length of the 'train_dataloader' after 'accelerator.prepare' ({len(train_dataloader)}) does not match "
1564+
f"the expected length ({len_train_dataloader_after_sharding}) when the learning rate scheduler was created. "
1565+
f"This inconsistency may result in the learning rate scheduler not functioning properly."
1566+
)
15561567
# Afterwards we recalculate our number of training epochs
15571568
args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
15581569

examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py

Lines changed: 18 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1820,17 +1820,22 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers):
18201820
torch.cuda.empty_cache()
18211821

18221822
# Scheduler and math around the number of training steps.
1823-
overrode_max_train_steps = False
1824-
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
1823+
# Check the PR https://github.com/huggingface/diffusers/pull/8312 for detailed explanation.
1824+
num_warmup_steps_for_scheduler = args.lr_warmup_steps * accelerator.num_processes
18251825
if args.max_train_steps is None:
1826-
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
1827-
overrode_max_train_steps = True
1826+
len_train_dataloader_after_sharding = math.ceil(len(train_dataloader) / accelerator.num_processes)
1827+
num_update_steps_per_epoch = math.ceil(len_train_dataloader_after_sharding / args.gradient_accumulation_steps)
1828+
num_training_steps_for_scheduler = (
1829+
args.num_train_epochs * num_update_steps_per_epoch * accelerator.num_processes
1830+
)
1831+
else:
1832+
num_training_steps_for_scheduler = args.max_train_steps * accelerator.num_processes
18281833

18291834
lr_scheduler = get_scheduler(
18301835
args.lr_scheduler,
18311836
optimizer=optimizer,
1832-
num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes,
1833-
num_training_steps=args.max_train_steps * accelerator.num_processes,
1837+
num_warmup_steps=num_warmup_steps_for_scheduler,
1838+
num_training_steps=num_training_steps_for_scheduler,
18341839
num_cycles=args.lr_num_cycles,
18351840
power=args.lr_power,
18361841
)
@@ -1847,8 +1852,14 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers):
18471852

18481853
# We need to recalculate our total training steps as the size of the training dataloader may have changed.
18491854
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
1850-
if overrode_max_train_steps:
1855+
if args.max_train_steps is None:
18511856
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
1857+
if num_training_steps_for_scheduler != args.max_train_steps * accelerator.num_processes:
1858+
logger.warning(
1859+
f"The length of the 'train_dataloader' after 'accelerator.prepare' ({len(train_dataloader)}) does not match "
1860+
f"the expected length ({len_train_dataloader_after_sharding}) when the learning rate scheduler was created. "
1861+
f"This inconsistency may result in the learning rate scheduler not functioning properly."
1862+
)
18521863
# Afterwards we recalculate our number of training epochs
18531864
args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
18541865

examples/consistency_distillation/train_lcm_distill_lora_sdxl.py

Lines changed: 18 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1111,11 +1111,16 @@ def compute_time_ids(original_size, crops_coords_top_left):
11111111

11121112
# 15. LR Scheduler creation
11131113
# Scheduler and math around the number of training steps.
1114-
overrode_max_train_steps = False
1115-
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
1114+
# Check the PR https://github.com/huggingface/diffusers/pull/8312 for detailed explanation.
1115+
num_warmup_steps_for_scheduler = args.lr_warmup_steps * accelerator.num_processes
11161116
if args.max_train_steps is None:
1117-
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
1118-
overrode_max_train_steps = True
1117+
len_train_dataloader_after_sharding = math.ceil(len(train_dataloader) / accelerator.num_processes)
1118+
num_update_steps_per_epoch = math.ceil(len_train_dataloader_after_sharding / args.gradient_accumulation_steps)
1119+
num_training_steps_for_scheduler = (
1120+
args.num_train_epochs * num_update_steps_per_epoch * accelerator.num_processes
1121+
)
1122+
else:
1123+
num_training_steps_for_scheduler = args.max_train_steps * accelerator.num_processes
11191124

11201125
if args.scale_lr:
11211126
args.learning_rate = (
@@ -1130,8 +1135,8 @@ def compute_time_ids(original_size, crops_coords_top_left):
11301135
lr_scheduler = get_scheduler(
11311136
args.lr_scheduler,
11321137
optimizer=optimizer,
1133-
num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes,
1134-
num_training_steps=args.max_train_steps * accelerator.num_processes,
1138+
num_warmup_steps=num_warmup_steps_for_scheduler,
1139+
num_training_steps=num_training_steps_for_scheduler,
11351140
)
11361141

11371142
# 16. Prepare for training
@@ -1142,8 +1147,14 @@ def compute_time_ids(original_size, crops_coords_top_left):
11421147

11431148
# We need to recalculate our total training steps as the size of the training dataloader may have changed.
11441149
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
1145-
if overrode_max_train_steps:
1150+
if args.max_train_steps is None:
11461151
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
1152+
if num_training_steps_for_scheduler != args.max_train_steps * accelerator.num_processes:
1153+
logger.warning(
1154+
f"The length of the 'train_dataloader' after 'accelerator.prepare' ({len(train_dataloader)}) does not match "
1155+
f"the expected length ({len_train_dataloader_after_sharding}) when the learning rate scheduler was created. "
1156+
f"This inconsistency may result in the learning rate scheduler not functioning properly."
1157+
)
11471158
# Afterwards we recalculate our number of training epochs
11481159
args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
11491160

examples/controlnet/train_controlnet.py

Lines changed: 18 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -931,17 +931,22 @@ def load_model_hook(models, input_dir):
931931
)
932932

933933
# Scheduler and math around the number of training steps.
934-
overrode_max_train_steps = False
935-
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
934+
# Check the PR https://github.com/huggingface/diffusers/pull/8312 for detailed explanation.
935+
num_warmup_steps_for_scheduler = args.lr_warmup_steps * accelerator.num_processes
936936
if args.max_train_steps is None:
937-
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
938-
overrode_max_train_steps = True
937+
len_train_dataloader_after_sharding = math.ceil(len(train_dataloader) / accelerator.num_processes)
938+
num_update_steps_per_epoch = math.ceil(len_train_dataloader_after_sharding / args.gradient_accumulation_steps)
939+
num_training_steps_for_scheduler = (
940+
args.num_train_epochs * num_update_steps_per_epoch * accelerator.num_processes
941+
)
942+
else:
943+
num_training_steps_for_scheduler = args.max_train_steps * accelerator.num_processes
939944

940945
lr_scheduler = get_scheduler(
941946
args.lr_scheduler,
942947
optimizer=optimizer,
943-
num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes,
944-
num_training_steps=args.max_train_steps * accelerator.num_processes,
948+
num_warmup_steps=num_warmup_steps_for_scheduler,
949+
num_training_steps=num_training_steps_for_scheduler,
945950
num_cycles=args.lr_num_cycles,
946951
power=args.lr_power,
947952
)
@@ -966,8 +971,14 @@ def load_model_hook(models, input_dir):
966971

967972
# We need to recalculate our total training steps as the size of the training dataloader may have changed.
968973
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
969-
if overrode_max_train_steps:
974+
if args.max_train_steps is None:
970975
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
976+
if num_training_steps_for_scheduler != args.max_train_steps * accelerator.num_processes:
977+
logger.warning(
978+
f"The length of the 'train_dataloader' after 'accelerator.prepare' ({len(train_dataloader)}) does not match "
979+
f"the expected length ({len_train_dataloader_after_sharding}) when the learning rate scheduler was created. "
980+
f"This inconsistency may result in the learning rate scheduler not functioning properly."
981+
)
971982
# Afterwards we recalculate our number of training epochs
972983
args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
973984

examples/controlnet/train_controlnet_sdxl.py

Lines changed: 18 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1088,17 +1088,22 @@ def compute_embeddings(batch, proportion_empty_prompts, text_encoders, tokenizer
10881088
)
10891089

10901090
# Scheduler and math around the number of training steps.
1091-
overrode_max_train_steps = False
1092-
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
1091+
# Check the PR https://github.com/huggingface/diffusers/pull/8312 for detailed explanation.
1092+
num_warmup_steps_for_scheduler = args.lr_warmup_steps * accelerator.num_processes
10931093
if args.max_train_steps is None:
1094-
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
1095-
overrode_max_train_steps = True
1094+
len_train_dataloader_after_sharding = math.ceil(len(train_dataloader) / accelerator.num_processes)
1095+
num_update_steps_per_epoch = math.ceil(len_train_dataloader_after_sharding / args.gradient_accumulation_steps)
1096+
num_training_steps_for_scheduler = (
1097+
args.num_train_epochs * num_update_steps_per_epoch * accelerator.num_processes
1098+
)
1099+
else:
1100+
num_training_steps_for_scheduler = args.max_train_steps * accelerator.num_processes
10961101

10971102
lr_scheduler = get_scheduler(
10981103
args.lr_scheduler,
10991104
optimizer=optimizer,
1100-
num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes,
1101-
num_training_steps=args.max_train_steps * accelerator.num_processes,
1105+
num_warmup_steps=num_warmup_steps_for_scheduler,
1106+
num_training_steps=num_training_steps_for_scheduler,
11021107
num_cycles=args.lr_num_cycles,
11031108
power=args.lr_power,
11041109
)
@@ -1110,8 +1115,14 @@ def compute_embeddings(batch, proportion_empty_prompts, text_encoders, tokenizer
11101115

11111116
# We need to recalculate our total training steps as the size of the training dataloader may have changed.
11121117
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
1113-
if overrode_max_train_steps:
1118+
if args.max_train_steps is None:
11141119
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
1120+
if num_training_steps_for_scheduler != args.max_train_steps * accelerator.num_processes:
1121+
logger.warning(
1122+
f"The length of the 'train_dataloader' after 'accelerator.prepare' ({len(train_dataloader)}) does not match "
1123+
f"the expected length ({len_train_dataloader_after_sharding}) when the learning rate scheduler was created. "
1124+
f"This inconsistency may result in the learning rate scheduler not functioning properly."
1125+
)
11151126
# Afterwards we recalculate our number of training epochs
11161127
args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
11171128

examples/custom_diffusion/train_custom_diffusion.py

Lines changed: 18 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1040,17 +1040,22 @@ def main(args):
10401040
)
10411041

10421042
# Scheduler and math around the number of training steps.
1043-
overrode_max_train_steps = False
1044-
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
1043+
# Check the PR https://github.com/huggingface/diffusers/pull/8312 for detailed explanation.
1044+
num_warmup_steps_for_scheduler = args.lr_warmup_steps * accelerator.num_processes
10451045
if args.max_train_steps is None:
1046-
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
1047-
overrode_max_train_steps = True
1046+
len_train_dataloader_after_sharding = math.ceil(len(train_dataloader) / accelerator.num_processes)
1047+
num_update_steps_per_epoch = math.ceil(len_train_dataloader_after_sharding / args.gradient_accumulation_steps)
1048+
num_training_steps_for_scheduler = (
1049+
args.num_train_epochs * num_update_steps_per_epoch * accelerator.num_processes
1050+
)
1051+
else:
1052+
num_training_steps_for_scheduler = args.max_train_steps * accelerator.num_processes
10481053

10491054
lr_scheduler = get_scheduler(
10501055
args.lr_scheduler,
10511056
optimizer=optimizer,
1052-
num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes,
1053-
num_training_steps=args.max_train_steps * accelerator.num_processes,
1057+
num_warmup_steps=num_warmup_steps_for_scheduler,
1058+
num_training_steps=num_training_steps_for_scheduler,
10541059
)
10551060

10561061
# Prepare everything with our `accelerator`.
@@ -1065,8 +1070,14 @@ def main(args):
10651070

10661071
# We need to recalculate our total training steps as the size of the training dataloader may have changed.
10671072
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
1068-
if overrode_max_train_steps:
1073+
if args.max_train_steps is None:
10691074
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
1075+
if num_training_steps_for_scheduler != args.max_train_steps * accelerator.num_processes:
1076+
logger.warning(
1077+
f"The length of the 'train_dataloader' after 'accelerator.prepare' ({len(train_dataloader)}) does not match "
1078+
f"the expected length ({len_train_dataloader_after_sharding}) when the learning rate scheduler was created. "
1079+
f"This inconsistency may result in the learning rate scheduler not functioning properly."
1080+
)
10701081
# Afterwards we recalculate our number of training epochs
10711082
args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
10721083

examples/dreambooth/train_dreambooth.py

Lines changed: 18 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1114,17 +1114,22 @@ def compute_text_embeddings(prompt):
11141114
)
11151115

11161116
# Scheduler and math around the number of training steps.
1117-
overrode_max_train_steps = False
1118-
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
1117+
# Check the PR https://github.com/huggingface/diffusers/pull/8312 for detailed explanation.
1118+
num_warmup_steps_for_scheduler = args.lr_warmup_steps * accelerator.num_processes
11191119
if args.max_train_steps is None:
1120-
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
1121-
overrode_max_train_steps = True
1120+
len_train_dataloader_after_sharding = math.ceil(len(train_dataloader) / accelerator.num_processes)
1121+
num_update_steps_per_epoch = math.ceil(len_train_dataloader_after_sharding / args.gradient_accumulation_steps)
1122+
num_training_steps_for_scheduler = (
1123+
args.num_train_epochs * num_update_steps_per_epoch * accelerator.num_processes
1124+
)
1125+
else:
1126+
num_training_steps_for_scheduler = args.max_train_steps * accelerator.num_processes
11221127

11231128
lr_scheduler = get_scheduler(
11241129
args.lr_scheduler,
11251130
optimizer=optimizer,
1126-
num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes,
1127-
num_training_steps=args.max_train_steps * accelerator.num_processes,
1131+
num_warmup_steps=num_warmup_steps_for_scheduler,
1132+
num_training_steps=num_training_steps_for_scheduler,
11281133
num_cycles=args.lr_num_cycles,
11291134
power=args.lr_power,
11301135
)
@@ -1156,8 +1161,14 @@ def compute_text_embeddings(prompt):
11561161

11571162
# We need to recalculate our total training steps as the size of the training dataloader may have changed.
11581163
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
1159-
if overrode_max_train_steps:
1164+
if args.max_train_steps is None:
11601165
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
1166+
if num_training_steps_for_scheduler != args.max_train_steps * accelerator.num_processes:
1167+
logger.warning(
1168+
f"The length of the 'train_dataloader' after 'accelerator.prepare' ({len(train_dataloader)}) does not match "
1169+
f"the expected length ({len_train_dataloader_after_sharding}) when the learning rate scheduler was created. "
1170+
f"This inconsistency may result in the learning rate scheduler not functioning properly."
1171+
)
11611172
# Afterwards we recalculate our number of training epochs
11621173
args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
11631174

0 commit comments

Comments
 (0)