Skip to content

Commit a2eda7b

Browse files
committed
fix example style
1 parent 21ab27e commit a2eda7b

File tree

1 file changed

+27
-8
lines changed

1 file changed

+27
-8
lines changed

examples/textual_inversion/textual_inversion_sdxl.py

Lines changed: 27 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -111,7 +111,9 @@ def save_model_card(repo_id: str, images=None, base_model=str, repo_folder=None)
111111
f.write(yaml + model_card)
112112

113113

114-
def log_validation(text_encoder_1, text_encoder_2, tokenizer_1, tokenizer_2, unet, vae, args, accelerator, weight_dtype, epoch):
114+
def log_validation(
115+
text_encoder_1, text_encoder_2, tokenizer_1, tokenizer_2, unet, vae, args, accelerator, weight_dtype, epoch
116+
):
115117
logger.info(
116118
f"Running validation... \n Generating {args.num_validation_images} images with prompt:"
117119
f" {args.validation_prompt}."
@@ -644,7 +646,6 @@ def main():
644646
args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision, variant=args.variant
645647
)
646648

647-
648649
# Add the placeholder token in tokenizer_1
649650
placeholder_tokens = [args.placeholder_token]
650651

@@ -875,17 +876,27 @@ def main():
875876
noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
876877

877878
# Get the text embedding for conditioning
878-
encoder_hidden_states_1 = text_encoder_1(batch["input_ids_1"], output_hidden_states=True).hidden_states[-2].to(dtype=weight_dtype)
879-
encoder_output_2 = text_encoder_2(batch["input_ids_2"].reshape(batch["input_ids_1"].shape[0], -1), output_hidden_states=True)
879+
encoder_hidden_states_1 = (
880+
text_encoder_1(batch["input_ids_1"], output_hidden_states=True)
881+
.hidden_states[-2]
882+
.to(dtype=weight_dtype)
883+
)
884+
encoder_output_2 = text_encoder_2(
885+
batch["input_ids_2"].reshape(batch["input_ids_1"].shape[0], -1), output_hidden_states=True
886+
)
880887
encoder_hidden_states_2 = encoder_output_2.hidden_states[-2].to(dtype=weight_dtype)
881888
sample_size = unet.config.sample_size * (2 ** (len(vae.config.block_out_channels) - 1))
882889
original_size = (sample_size, sample_size)
883-
add_time_ids = torch.tensor([list(original_size + (0, 0) + original_size)], dtype=weight_dtype, device=accelerator.device)
890+
add_time_ids = torch.tensor(
891+
[list(original_size + (0, 0) + original_size)], dtype=weight_dtype, device=accelerator.device
892+
)
884893
added_cond_kwargs = {"text_embeds": encoder_output_2[0], "time_ids": add_time_ids}
885894
encoder_hidden_states = torch.cat([encoder_hidden_states_1, encoder_hidden_states_2], dim=-1)
886895

887896
# Predict the noise residual
888-
model_pred = unet(noisy_latents, timesteps, encoder_hidden_states, added_cond_kwargs=added_cond_kwargs).sample
897+
model_pred = unet(
898+
noisy_latents, timesteps, encoder_hidden_states, added_cond_kwargs=added_cond_kwargs
899+
).sample
889900

890901
# Get the target for loss depending on the prediction type
891902
if noise_scheduler.config.prediction_type == "epsilon":
@@ -961,7 +972,16 @@ def main():
961972

962973
if args.validation_prompt is not None and global_step % args.validation_steps == 0:
963974
images = log_validation(
964-
text_encoder_1, text_encoder_2, tokenizer_1, tokenizer_2, unet, vae, args, accelerator, weight_dtype, epoch
975+
text_encoder_1,
976+
text_encoder_2,
977+
tokenizer_1,
978+
tokenizer_2,
979+
unet,
980+
vae,
981+
args,
982+
accelerator,
983+
weight_dtype,
984+
epoch,
965985
)
966986

967987
logs = {"loss": loss.detach().item(), "lr": lr_scheduler_1.get_last_lr()[0]}
@@ -1020,4 +1040,3 @@ def main():
10201040

10211041
if __name__ == "__main__":
10221042
main()
1023-

0 commit comments

Comments
 (0)