@@ -2154,6 +2154,7 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
2154
2154
2155
2155
# encode batch prompts when custom prompts are provided for each image -
2156
2156
if train_dataset .custom_instance_prompts :
2157
+ elems_to_repeat = 1
2157
2158
if freeze_text_encoder :
2158
2159
prompt_embeds , pooled_prompt_embeds , text_ids = compute_text_embeddings (
2159
2160
prompts , text_encoders , tokenizers
@@ -2168,17 +2169,21 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
2168
2169
max_sequence_length = args .max_sequence_length ,
2169
2170
add_special_tokens = add_special_tokens_t5 ,
2170
2171
)
2172
+ else :
2173
+ elems_to_repeat = len (prompts )
2171
2174
2172
2175
if not freeze_text_encoder :
2173
2176
prompt_embeds , pooled_prompt_embeds , text_ids = encode_prompt (
2174
2177
text_encoders = [text_encoder_one , text_encoder_two ],
2175
2178
tokenizers = [None , None ],
2176
- text_input_ids_list = [tokens_one , tokens_two ],
2179
+ text_input_ids_list = [
2180
+ tokens_one .repeat (elems_to_repeat , 1 ),
2181
+ tokens_two .repeat (elems_to_repeat , 1 ),
2182
+ ],
2177
2183
max_sequence_length = args .max_sequence_length ,
2178
2184
device = accelerator .device ,
2179
2185
prompt = prompts ,
2180
2186
)
2181
-
2182
2187
# Convert images to latent space
2183
2188
if args .cache_latents :
2184
2189
model_input = latents_cache [step ].sample ()
@@ -2371,6 +2376,9 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
2371
2376
epoch = epoch ,
2372
2377
torch_dtype = weight_dtype ,
2373
2378
)
2379
+ images = None
2380
+ del pipeline
2381
+
2374
2382
if freeze_text_encoder :
2375
2383
del text_encoder_one , text_encoder_two
2376
2384
free_memory ()
@@ -2448,6 +2456,8 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
2448
2456
commit_message = "End of training" ,
2449
2457
ignore_patterns = ["step_*" , "epoch_*" ],
2450
2458
)
2459
+ images = None
2460
+ del pipeline
2451
2461
2452
2462
accelerator .end_training ()
2453
2463
0 commit comments