Skip to content

Commit 322ef19

Browse files
committed
reset for error changes
1 parent a2eda7b commit 322ef19

File tree

1 file changed

+41
-81
lines changed

1 file changed

+41
-81
lines changed

examples/textual_inversion/textual_inversion_sdxl.py

Lines changed: 41 additions & 81 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@
4848
DDPMScheduler,
4949
DiffusionPipeline,
5050
DPMSolverMultistepScheduler,
51+
StableDiffusionPipeline,
5152
UNet2DConditionModel,
5253
)
5354
from diffusers.optimization import get_scheduler
@@ -111,18 +112,15 @@ def save_model_card(repo_id: str, images=None, base_model=str, repo_folder=None)
111112
f.write(yaml + model_card)
112113

113114

114-
def log_validation(
115-
text_encoder_1, text_encoder_2, tokenizer_1, tokenizer_2, unet, vae, args, accelerator, weight_dtype, epoch
116-
):
115+
def log_validation(text_encoder_1, text_encoder_2, tokenizer_1, tokenizer_2, unet, vae, args, accelerator, weight_dtype, epoch):
117116
logger.info(
118117
f"Running validation... \n Generating {args.num_validation_images} images with prompt:"
119118
f" {args.validation_prompt}."
120119
)
121-
# create pipeline (note: unet and vae are loaded again in float32)
122120
pipeline = DiffusionPipeline.from_pretrained(
123121
args.pretrained_model_name_or_path,
124122
text_encoder=accelerator.unwrap_model(text_encoder_1),
125-
text_encoder_2=accelerator.unwrap_model(text_encoder_2),
123+
text_encoder_2=text_encoder_2,
126124
tokenizer=tokenizer_1,
127125
tokenizer_2=tokenizer_2,
128126
unet=unet,
@@ -361,7 +359,7 @@ def parse_args():
361359
parser.add_argument(
362360
"--validation_prompt",
363361
type=str,
364-
default=None,
362+
default="A <cat-toy> backpack",
365363
help="A prompt that is used during validation to verify that the model is learning.",
366364
)
367365
parser.add_argument(
@@ -380,16 +378,6 @@ def parse_args():
380378
" and logging the images."
381379
),
382380
)
383-
parser.add_argument(
384-
"--validation_epochs",
385-
type=int,
386-
default=None,
387-
help=(
388-
"Deprecated in favor of validation_steps. Run validation every X epochs. Validation consists of running the prompt"
389-
" `args.validation_prompt` multiple times: `args.num_validation_images`"
390-
" and logging the images."
391-
),
392-
)
393381
parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
394382
parser.add_argument(
395383
"--checkpointing_steps",
@@ -418,11 +406,6 @@ def parse_args():
418406
parser.add_argument(
419407
"--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers."
420408
)
421-
parser.add_argument(
422-
"--no_safe_serialization",
423-
action="store_true",
424-
help="If specified save the checkpoint not in `safetensors` format, but in original PyTorch format instead.",
425-
)
426409

427410
args = parser.parse_args()
428411
env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
@@ -529,6 +512,7 @@ def __init__(
529512

530513
self.templates = imagenet_style_templates_small if learnable_property == "style" else imagenet_templates_small
531514
self.flip_transform = transforms.RandomHorizontalFlip(p=self.flip_p)
515+
self.crop = transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size)
532516

533517
def __len__(self):
534518
return self._length
@@ -543,6 +527,18 @@ def __getitem__(self, i):
543527
placeholder_string = self.placeholder_token
544528
text = random.choice(self.templates).format(placeholder_string)
545529

530+
example["original_size"] = (image.height, image.width)
531+
532+
if self.center_crop:
533+
y1 = max(0, int(round((image.height - self.size) / 2.0)))
534+
x1 = max(0, int(round((image.width - self.size) / 2.0)))
535+
image = self.crop(image)
536+
else:
537+
y1, x1, h, w = self.crop.get_params(image, (self.size, self.size))
538+
image = transforms.functional.crop(image, y1, x1, h, w)
539+
540+
example["crop_top_left"] = (y1, x1)
541+
546542
example["input_ids_1"] = self.tokenizer_1(
547543
text,
548544
padding="max_length",
@@ -564,13 +560,7 @@ def __getitem__(self, i):
564560

565561
if self.center_crop:
566562
crop = min(img.shape[0], img.shape[1])
567-
(
568-
h,
569-
w,
570-
) = (
571-
img.shape[0],
572-
img.shape[1],
573-
)
563+
(h, w,) = (img.shape[0], img.shape[1],)
574564
img = img[(h - crop) // 2 : (h + crop) // 2, (w - crop) // 2 : (w + crop) // 2]
575565

576566
image = Image.fromarray(img)
@@ -646,6 +636,7 @@ def main():
646636
args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision, variant=args.variant
647637
)
648638

639+
649640
# Add the placeholder token in tokenizer_1
650641
placeholder_tokens = [args.placeholder_token]
651642

@@ -686,21 +677,14 @@ def main():
686677
# Freeze vae and unet
687678
vae.requires_grad_(False)
688679
unet.requires_grad_(False)
680+
text_encoder_2.requires_grad_(False)
689681
# Freeze all parameters except for the token embeddings in text encoder
690682
text_encoder_1.text_model.encoder.requires_grad_(False)
691683
text_encoder_1.text_model.final_layer_norm.requires_grad_(False)
692684
text_encoder_1.text_model.embeddings.position_embedding.requires_grad_(False)
693-
text_encoder_2.text_model.encoder.requires_grad_(False)
694-
text_encoder_2.text_model.final_layer_norm.requires_grad_(False)
695-
text_encoder_2.text_model.embeddings.position_embedding.requires_grad_(False)
696685

697686
if args.gradient_checkpointing:
698-
# Keep unet in train mode if we are using gradient checkpointing to save memory.
699-
# The dropout cannot be != 0 so it doesn't matter if we are in eval or train mode.
700-
unet.train()
701687
text_encoder_1.gradient_checkpointing_enable()
702-
text_encoder_2.gradient_checkpointing_enable()
703-
unet.enable_gradient_checkpointing()
704688

705689
if args.enable_xformers_memory_efficient_attention:
706690
if is_xformers_available():
@@ -749,15 +733,6 @@ def main():
749733
train_dataloader = torch.utils.data.DataLoader(
750734
train_dataset, batch_size=args.train_batch_size, shuffle=True, num_workers=args.dataloader_num_workers
751735
)
752-
if args.validation_epochs is not None:
753-
warnings.warn(
754-
f"FutureWarning: You are doing logging with validation_epochs={args.validation_epochs}."
755-
" Deprecated validation_epochs in favor of `validation_steps`"
756-
f"Setting `args.validation_steps` to {args.validation_epochs * len(train_dataset)}",
757-
FutureWarning,
758-
stacklevel=2,
759-
)
760-
args.validation_steps = args.validation_epochs * len(train_dataset)
761736

762737
# Scheduler and math around the number of training steps.
763738
overrode_max_train_steps = False
@@ -791,7 +766,7 @@ def main():
791766
# Move vae and unet and text_encoder_2 to device and cast to weight_dtype
792767
unet.to(accelerator.device, dtype=weight_dtype)
793768
vae.to(accelerator.device, dtype=weight_dtype)
794-
text_encoder_2 = text_encoder_2.to(accelerator.device, dtype=weight_dtype)
769+
text_encoder_2.to(accelerator.device, dtype=weight_dtype)
795770

796771
# We need to recalculate our total training steps as the size of the training dataloader may have changed.
797772
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
@@ -876,27 +851,18 @@ def main():
876851
noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
877852

878853
# Get the text embedding for conditioning
879-
encoder_hidden_states_1 = (
880-
text_encoder_1(batch["input_ids_1"], output_hidden_states=True)
881-
.hidden_states[-2]
882-
.to(dtype=weight_dtype)
883-
)
884-
encoder_output_2 = text_encoder_2(
885-
batch["input_ids_2"].reshape(batch["input_ids_1"].shape[0], -1), output_hidden_states=True
886-
)
854+
encoder_hidden_states_1 = text_encoder_1(batch["input_ids_1"], output_hidden_states=True).hidden_states[-2].to(dtype=weight_dtype)
855+
encoder_output_2 = text_encoder_2(batch["input_ids_2"].reshape(batch["input_ids_1"].shape[0], -1), output_hidden_states=True)
887856
encoder_hidden_states_2 = encoder_output_2.hidden_states[-2].to(dtype=weight_dtype)
888-
sample_size = unet.config.sample_size * (2 ** (len(vae.config.block_out_channels) - 1))
889-
original_size = (sample_size, sample_size)
890-
add_time_ids = torch.tensor(
891-
[list(original_size + (0, 0) + original_size)], dtype=weight_dtype, device=accelerator.device
892-
)
857+
original_size = [(batch["original_size"][0][i].item(), batch["original_size"][1][i].item()) for i in range(args.train_batch_size)]
858+
crop_top_left = [(batch["crop_top_left"][0][i].item(), batch["crop_top_left"][1][i].item()) for i in range(args.train_batch_size)]
859+
target_size = (args.resolution, args.resolution)
860+
add_time_ids = torch.cat([torch.tensor(original_size[i] + crop_top_left[i] + target_size) for i in range(args.train_batch_size)]).to(accelerator.device, dtype=weight_dtype)
893861
added_cond_kwargs = {"text_embeds": encoder_output_2[0], "time_ids": add_time_ids}
894862
encoder_hidden_states = torch.cat([encoder_hidden_states_1, encoder_hidden_states_2], dim=-1)
895863

896864
# Predict the noise residual
897-
model_pred = unet(
898-
noisy_latents, timesteps, encoder_hidden_states, added_cond_kwargs=added_cond_kwargs
899-
).sample
865+
model_pred = unet(noisy_latents, timesteps, encoder_hidden_states, added_cond_kwargs=added_cond_kwargs).sample
900866

901867
# Get the target for loss depending on the prediction type
902868
if noise_scheduler.config.prediction_type == "epsilon":
@@ -929,19 +895,15 @@ def main():
929895
progress_bar.update(1)
930896
global_step += 1
931897
if global_step % args.save_steps == 0:
932-
weight_name = (
933-
f"learned_embeds-steps-{global_step}.bin"
934-
if args.no_safe_serialization
935-
else f"learned_embeds-steps-{global_step}.safetensors"
936-
)
898+
weight_name = (f"learned_embeds-steps-{global_step}.safetensors")
937899
save_path = os.path.join(args.output_dir, weight_name)
938900
save_progress(
939901
text_encoder_1,
940902
placeholder_token_ids,
941903
accelerator,
942904
args,
943905
save_path,
944-
safe_serialization=not args.no_safe_serialization,
906+
safe_serialization=True,
945907
)
946908

947909
if accelerator.is_main_process:
@@ -972,16 +934,7 @@ def main():
972934

973935
if args.validation_prompt is not None and global_step % args.validation_steps == 0:
974936
images = log_validation(
975-
text_encoder_1,
976-
text_encoder_2,
977-
tokenizer_1,
978-
tokenizer_2,
979-
unet,
980-
vae,
981-
args,
982-
accelerator,
983-
weight_dtype,
984-
epoch,
937+
text_encoder_1, text_encoder_2, tokenizer_1, tokenizer_2, unet, vae, args, accelerator, weight_dtype, epoch
985938
)
986939

987940
logs = {"loss": loss.detach().item(), "lr": lr_scheduler_1.get_last_lr()[0]}
@@ -993,6 +946,10 @@ def main():
993946
# Create the pipeline using the trained modules and save it.
994947
accelerator.wait_for_everyone()
995948
if accelerator.is_main_process:
949+
images = log_validation(
950+
text_encoder_1, text_encoder_2, tokenizer_1, tokenizer_2, unet, vae, args, accelerator, weight_dtype, epoch
951+
)
952+
996953
if args.push_to_hub and not args.save_as_full_pipeline:
997954
logger.warn("Enabling full model saving because --push_to_hub=True was specified.")
998955
save_full_model = True
@@ -1002,23 +959,23 @@ def main():
1002959
pipeline = DiffusionPipeline.from_pretrained(
1003960
args.pretrained_model_name_or_path,
1004961
text_encoder=accelerator.unwrap_model(text_encoder_1),
1005-
text_encoder_2=accelerator.unwrap_model(text_encoder_2),
962+
text_encoder_2=text_encoder_2,
1006963
vae=vae,
1007964
unet=unet,
1008965
tokenizer=tokenizer_1,
1009966
tokenizer_2=tokenizer_2,
1010967
)
1011968
pipeline.save_pretrained(args.output_dir)
1012969
# Save the newly trained embeddings
1013-
weight_name = "learned_embeds.bin" if args.no_safe_serialization else "learned_embeds.safetensors"
970+
weight_name = "learned_embeds.safetensors"
1014971
save_path = os.path.join(args.output_dir, weight_name)
1015972
save_progress(
1016973
text_encoder_1,
1017974
placeholder_token_ids,
1018975
accelerator,
1019976
args,
1020977
save_path,
1021-
safe_serialization=not args.no_safe_serialization,
978+
safe_serialization=True,
1022979
)
1023980

1024981
if args.push_to_hub:
@@ -1035,6 +992,9 @@ def main():
1035992
ignore_patterns=["step_*", "epoch_*"],
1036993
)
1037994

995+
for i in range(len(images)):
996+
images[i].save(f"cat-backpack_sdxl_test_{i}.png")
997+
1038998
accelerator.end_training()
1039999

10401000

0 commit comments

Comments
 (0)