Skip to content

Commit f3f626d

Browse files
author
haixinxu
authored
Allow textual_inversion_flax script to use save_steps and revision flag (#2075)
* Update textual_inversion_flax.py * Update textual_inversion_flax.py * Typo sorry. * Format source
1 parent b7b4683 commit f3f626d

File tree

1 file changed

+22
-3
lines changed

1 file changed

+22
-3
lines changed

examples/textual_inversion/textual_inversion_flax.py

Lines changed: 22 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -121,6 +121,12 @@ def parse_args():
121121
default=5000,
122122
help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
123123
)
124+
parser.add_argument(
125+
"--save_steps",
126+
type=int,
127+
default=500,
128+
help="Save learned_embeds.bin every X updates steps.",
129+
)
124130
parser.add_argument(
125131
"--learning_rate",
126132
type=float,
@@ -136,6 +142,13 @@ def parse_args():
136142
parser.add_argument(
137143
"--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler."
138144
)
145+
parser.add_argument(
146+
"--revision",
147+
type=str,
148+
default=None,
149+
required=False,
150+
help="Revision of pretrained model identifier from huggingface.co/models.",
151+
)
139152
parser.add_argument(
140153
"--lr_scheduler",
141154
type=str,
@@ -420,9 +433,9 @@ def main():
420433
placeholder_token_id = tokenizer.convert_tokens_to_ids(args.placeholder_token)
421434

422435
# Load models and create wrapper for stable diffusion
423-
text_encoder = FlaxCLIPTextModel.from_pretrained(args.pretrained_model_name_or_path, subfolder="text_encoder")
424-
vae, vae_params = FlaxAutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder="vae")
425-
unet, unet_params = FlaxUNet2DConditionModel.from_pretrained(args.pretrained_model_name_or_path, subfolder="unet")
436+
text_encoder = FlaxCLIPTextModel.from_pretrained(args.pretrained_model_name_or_path, subfolder="text_encoder",revision=args.revision)
437+
vae, vae_params = FlaxAutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder="vae",revision=args.revision)
438+
unet, unet_params = FlaxUNet2DConditionModel.from_pretrained(args.pretrained_model_name_or_path, subfolder="unet",revision=args.revision)
426439

427440
# Create sampling rng
428441
rng = jax.random.PRNGKey(args.seed)
@@ -619,6 +632,12 @@ def compute_loss(params):
619632

620633
if global_step >= args.max_train_steps:
621634
break
635+
if global_step % args.save_steps == 0:
636+
learned_embeds = get_params_to_save(state.params)["text_model"]["embeddings"]["token_embedding"]["embedding"][
637+
placeholder_token_id
638+
]
639+
learned_embeds_dict = {args.placeholder_token: learned_embeds}
640+
jnp.save(os.path.join(args.output_dir, "learned_embeds-"+str(global_step)+".npy"), learned_embeds_dict)
622641

623642
train_metric = jax_utils.unreplicate(train_metric)
624643

0 commit comments

Comments
 (0)