Skip to content

Commit b8bfef2

Browse files
make style
1 parent f3f626d commit b8bfef2

File tree

1 file changed

+15
-7
lines changed

1 file changed

+15
-7
lines changed

examples/textual_inversion/textual_inversion_flax.py

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -433,9 +433,15 @@ def main():
433433
placeholder_token_id = tokenizer.convert_tokens_to_ids(args.placeholder_token)
434434

435435
# Load models and create wrapper for stable diffusion
436-
text_encoder = FlaxCLIPTextModel.from_pretrained(args.pretrained_model_name_or_path, subfolder="text_encoder",revision=args.revision)
437-
vae, vae_params = FlaxAutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder="vae",revision=args.revision)
438-
unet, unet_params = FlaxUNet2DConditionModel.from_pretrained(args.pretrained_model_name_or_path, subfolder="unet",revision=args.revision)
436+
text_encoder = FlaxCLIPTextModel.from_pretrained(
437+
args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision
438+
)
439+
vae, vae_params = FlaxAutoencoderKL.from_pretrained(
440+
args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision
441+
)
442+
unet, unet_params = FlaxUNet2DConditionModel.from_pretrained(
443+
args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision
444+
)
439445

440446
# Create sampling rng
441447
rng = jax.random.PRNGKey(args.seed)
@@ -633,11 +639,13 @@ def compute_loss(params):
633639
if global_step >= args.max_train_steps:
634640
break
635641
if global_step % args.save_steps == 0:
636-
learned_embeds = get_params_to_save(state.params)["text_model"]["embeddings"]["token_embedding"]["embedding"][
637-
placeholder_token_id
638-
]
642+
learned_embeds = get_params_to_save(state.params)["text_model"]["embeddings"]["token_embedding"][
643+
"embedding"
644+
][placeholder_token_id]
639645
learned_embeds_dict = {args.placeholder_token: learned_embeds}
640-
jnp.save(os.path.join(args.output_dir, "learned_embeds-"+str(global_step)+".npy"), learned_embeds_dict)
646+
jnp.save(
647+
os.path.join(args.output_dir, "learned_embeds-" + str(global_step) + ".npy"), learned_embeds_dict
648+
)
641649

642650
train_metric = jax_utils.unreplicate(train_metric)
643651

0 commit comments

Comments
 (0)