@@ -111,7 +111,9 @@ def save_model_card(repo_id: str, images=None, base_model=str, repo_folder=None)
111
111
f .write (yaml + model_card )
112
112
113
113
114
- def log_validation (text_encoder_1 , text_encoder_2 , tokenizer_1 , tokenizer_2 , unet , vae , args , accelerator , weight_dtype , epoch ):
114
+ def log_validation (
115
+ text_encoder_1 , text_encoder_2 , tokenizer_1 , tokenizer_2 , unet , vae , args , accelerator , weight_dtype , epoch
116
+ ):
115
117
logger .info (
116
118
f"Running validation... \n Generating { args .num_validation_images } images with prompt:"
117
119
f" { args .validation_prompt } ."
@@ -644,7 +646,6 @@ def main():
644
646
args .pretrained_model_name_or_path , subfolder = "unet" , revision = args .revision , variant = args .variant
645
647
)
646
648
647
-
648
649
# Add the placeholder token in tokenizer_1
649
650
placeholder_tokens = [args .placeholder_token ]
650
651
@@ -875,17 +876,27 @@ def main():
875
876
noisy_latents = noise_scheduler .add_noise (latents , noise , timesteps )
876
877
877
878
# Get the text embedding for conditioning
878
- encoder_hidden_states_1 = text_encoder_1 (batch ["input_ids_1" ], output_hidden_states = True ).hidden_states [- 2 ].to (dtype = weight_dtype )
879
- encoder_output_2 = text_encoder_2 (batch ["input_ids_2" ].reshape (batch ["input_ids_1" ].shape [0 ], - 1 ), output_hidden_states = True )
879
+ encoder_hidden_states_1 = (
880
+ text_encoder_1 (batch ["input_ids_1" ], output_hidden_states = True )
881
+ .hidden_states [- 2 ]
882
+ .to (dtype = weight_dtype )
883
+ )
884
+ encoder_output_2 = text_encoder_2 (
885
+ batch ["input_ids_2" ].reshape (batch ["input_ids_1" ].shape [0 ], - 1 ), output_hidden_states = True
886
+ )
880
887
encoder_hidden_states_2 = encoder_output_2 .hidden_states [- 2 ].to (dtype = weight_dtype )
881
888
sample_size = unet .config .sample_size * (2 ** (len (vae .config .block_out_channels ) - 1 ))
882
889
original_size = (sample_size , sample_size )
883
- add_time_ids = torch .tensor ([list (original_size + (0 , 0 ) + original_size )], dtype = weight_dtype , device = accelerator .device )
890
+ add_time_ids = torch .tensor (
891
+ [list (original_size + (0 , 0 ) + original_size )], dtype = weight_dtype , device = accelerator .device
892
+ )
884
893
added_cond_kwargs = {"text_embeds" : encoder_output_2 [0 ], "time_ids" : add_time_ids }
885
894
encoder_hidden_states = torch .cat ([encoder_hidden_states_1 , encoder_hidden_states_2 ], dim = - 1 )
886
895
887
896
# Predict the noise residual
888
- model_pred = unet (noisy_latents , timesteps , encoder_hidden_states , added_cond_kwargs = added_cond_kwargs ).sample
897
+ model_pred = unet (
898
+ noisy_latents , timesteps , encoder_hidden_states , added_cond_kwargs = added_cond_kwargs
899
+ ).sample
889
900
890
901
# Get the target for loss depending on the prediction type
891
902
if noise_scheduler .config .prediction_type == "epsilon" :
@@ -961,7 +972,16 @@ def main():
961
972
962
973
if args .validation_prompt is not None and global_step % args .validation_steps == 0 :
963
974
images = log_validation (
964
- text_encoder_1 , text_encoder_2 , tokenizer_1 , tokenizer_2 , unet , vae , args , accelerator , weight_dtype , epoch
975
+ text_encoder_1 ,
976
+ text_encoder_2 ,
977
+ tokenizer_1 ,
978
+ tokenizer_2 ,
979
+ unet ,
980
+ vae ,
981
+ args ,
982
+ accelerator ,
983
+ weight_dtype ,
984
+ epoch ,
965
985
)
966
986
967
987
logs = {"loss" : loss .detach ().item (), "lr" : lr_scheduler_1 .get_last_lr ()[0 ]}
@@ -1020,4 +1040,3 @@ def main():
1020
1040
1021
1041
if __name__ == "__main__" :
1022
1042
main ()
1023
-
0 commit comments