@@ -433,9 +433,15 @@ def main():
433
433
placeholder_token_id = tokenizer .convert_tokens_to_ids (args .placeholder_token )
434
434
435
435
# Load models and create wrapper for stable diffusion
436
- text_encoder = FlaxCLIPTextModel .from_pretrained (args .pretrained_model_name_or_path , subfolder = "text_encoder" ,revision = args .revision )
437
- vae , vae_params = FlaxAutoencoderKL .from_pretrained (args .pretrained_model_name_or_path , subfolder = "vae" ,revision = args .revision )
438
- unet , unet_params = FlaxUNet2DConditionModel .from_pretrained (args .pretrained_model_name_or_path , subfolder = "unet" ,revision = args .revision )
436
+ text_encoder = FlaxCLIPTextModel .from_pretrained (
437
+ args .pretrained_model_name_or_path , subfolder = "text_encoder" , revision = args .revision
438
+ )
439
+ vae , vae_params = FlaxAutoencoderKL .from_pretrained (
440
+ args .pretrained_model_name_or_path , subfolder = "vae" , revision = args .revision
441
+ )
442
+ unet , unet_params = FlaxUNet2DConditionModel .from_pretrained (
443
+ args .pretrained_model_name_or_path , subfolder = "unet" , revision = args .revision
444
+ )
439
445
440
446
# Create sampling rng
441
447
rng = jax .random .PRNGKey (args .seed )
@@ -633,11 +639,13 @@ def compute_loss(params):
633
639
if global_step >= args .max_train_steps :
634
640
break
635
641
if global_step % args .save_steps == 0 :
636
- learned_embeds = get_params_to_save (state .params )["text_model" ]["embeddings" ]["token_embedding" ]["embedding" ][
637
- placeholder_token_id
638
- ]
642
+ learned_embeds = get_params_to_save (state .params )["text_model" ]["embeddings" ]["token_embedding" ][
643
+ "embedding"
644
+ ][ placeholder_token_id ]
639
645
learned_embeds_dict = {args .placeholder_token : learned_embeds }
640
- jnp .save (os .path .join (args .output_dir , "learned_embeds-" + str (global_step )+ ".npy" ), learned_embeds_dict )
646
+ jnp .save (
647
+ os .path .join (args .output_dir , "learned_embeds-" + str (global_step ) + ".npy" ), learned_embeds_dict
648
+ )
641
649
642
650
train_metric = jax_utils .unreplicate (train_metric )
643
651
0 commit comments