Skip to content

Commit acf479b

Browse files
[advanced flux training] bug fix + reduce memory cost as in #9829 (#9838)
* memory improvement as done here: #9829 * fix bug * fix bug * style --------- Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
1 parent 03bf77c commit acf479b

File tree

2 files changed

+17
-3
lines changed

2 files changed

+17
-3
lines changed

examples/advanced_diffusion_training/train_dreambooth_lora_flux_advanced.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2154,6 +2154,7 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
21542154

21552155
# encode batch prompts when custom prompts are provided for each image -
21562156
if train_dataset.custom_instance_prompts:
2157+
elems_to_repeat = 1
21572158
if freeze_text_encoder:
21582159
prompt_embeds, pooled_prompt_embeds, text_ids = compute_text_embeddings(
21592160
prompts, text_encoders, tokenizers
@@ -2168,17 +2169,21 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
21682169
max_sequence_length=args.max_sequence_length,
21692170
add_special_tokens=add_special_tokens_t5,
21702171
)
2172+
else:
2173+
elems_to_repeat = len(prompts)
21712174

21722175
if not freeze_text_encoder:
21732176
prompt_embeds, pooled_prompt_embeds, text_ids = encode_prompt(
21742177
text_encoders=[text_encoder_one, text_encoder_two],
21752178
tokenizers=[None, None],
2176-
text_input_ids_list=[tokens_one, tokens_two],
2179+
text_input_ids_list=[
2180+
tokens_one.repeat(elems_to_repeat, 1),
2181+
tokens_two.repeat(elems_to_repeat, 1),
2182+
],
21772183
max_sequence_length=args.max_sequence_length,
21782184
device=accelerator.device,
21792185
prompt=prompts,
21802186
)
2181-
21822187
# Convert images to latent space
21832188
if args.cache_latents:
21842189
model_input = latents_cache[step].sample()
@@ -2371,6 +2376,9 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
23712376
epoch=epoch,
23722377
torch_dtype=weight_dtype,
23732378
)
2379+
images = None
2380+
del pipeline
2381+
23742382
if freeze_text_encoder:
23752383
del text_encoder_one, text_encoder_two
23762384
free_memory()
@@ -2448,6 +2456,8 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
24482456
commit_message="End of training",
24492457
ignore_patterns=["step_*", "epoch_*"],
24502458
)
2459+
images = None
2460+
del pipeline
24512461

24522462
accelerator.end_training()
24532463

examples/dreambooth/train_dreambooth_lora_flux.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1648,11 +1648,15 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
16481648
prompt=prompts,
16491649
)
16501650
else:
1651+
elems_to_repeat = len(prompts)
16511652
if args.train_text_encoder:
16521653
prompt_embeds, pooled_prompt_embeds, text_ids = encode_prompt(
16531654
text_encoders=[text_encoder_one, text_encoder_two],
16541655
tokenizers=[None, None],
1655-
text_input_ids_list=[tokens_one, tokens_two],
1656+
text_input_ids_list=[
1657+
tokens_one.repeat(elems_to_repeat, 1),
1658+
tokens_two.repeat(elems_to_repeat, 1),
1659+
],
16561660
max_sequence_length=args.max_sequence_length,
16571661
device=accelerator.device,
16581662
prompt=args.instance_prompt,

0 commit comments

Comments
 (0)