@@ -121,6 +121,12 @@ def parse_args():
121
121
default = 5000 ,
122
122
help = "Total number of training steps to perform. If provided, overrides num_train_epochs." ,
123
123
)
124
+ parser .add_argument (
125
+ "--save_steps" ,
126
+ type = int ,
127
+ default = 500 ,
128
+ help = "Save learned_embeds.bin every X updates steps." ,
129
+ )
124
130
parser .add_argument (
125
131
"--learning_rate" ,
126
132
type = float ,
@@ -136,6 +142,13 @@ def parse_args():
136
142
parser .add_argument (
137
143
"--lr_warmup_steps" , type = int , default = 500 , help = "Number of steps for the warmup in the lr scheduler."
138
144
)
145
+ parser .add_argument (
146
+ "--revision" ,
147
+ type = str ,
148
+ default = None ,
149
+ required = False ,
150
+ help = "Revision of pretrained model identifier from huggingface.co/models." ,
151
+ )
139
152
parser .add_argument (
140
153
"--lr_scheduler" ,
141
154
type = str ,
@@ -420,9 +433,9 @@ def main():
420
433
placeholder_token_id = tokenizer .convert_tokens_to_ids (args .placeholder_token )
421
434
422
435
# Load models and create wrapper for stable diffusion
423
- text_encoder = FlaxCLIPTextModel .from_pretrained (args .pretrained_model_name_or_path , subfolder = "text_encoder" )
424
- vae , vae_params = FlaxAutoencoderKL .from_pretrained (args .pretrained_model_name_or_path , subfolder = "vae" )
425
- unet , unet_params = FlaxUNet2DConditionModel .from_pretrained (args .pretrained_model_name_or_path , subfolder = "unet" )
436
+ text_encoder = FlaxCLIPTextModel .from_pretrained (args .pretrained_model_name_or_path , subfolder = "text_encoder" , revision = args . revision )
437
+ vae , vae_params = FlaxAutoencoderKL .from_pretrained (args .pretrained_model_name_or_path , subfolder = "vae" , revision = args . revision )
438
+ unet , unet_params = FlaxUNet2DConditionModel .from_pretrained (args .pretrained_model_name_or_path , subfolder = "unet" , revision = args . revision )
426
439
427
440
# Create sampling rng
428
441
rng = jax .random .PRNGKey (args .seed )
@@ -619,6 +632,12 @@ def compute_loss(params):
619
632
620
633
if global_step >= args .max_train_steps :
621
634
break
635
+ if global_step % args .save_steps == 0 :
636
+ learned_embeds = get_params_to_save (state .params )["text_model" ]["embeddings" ]["token_embedding" ]["embedding" ][
637
+ placeholder_token_id
638
+ ]
639
+ learned_embeds_dict = {args .placeholder_token : learned_embeds }
640
+ jnp .save (os .path .join (args .output_dir , "learned_embeds-" + str (global_step )+ ".npy" ), learned_embeds_dict )
622
641
623
642
train_metric = jax_utils .unreplicate (train_metric )
624
643
0 commit comments