From 1a44ebbacac5bfc26ca4a1237b6b1da7006c96cd Mon Sep 17 00:00:00 2001 From: jiqing-feng Date: Tue, 2 Jan 2024 03:28:20 -0500 Subject: [PATCH 01/16] enable stable-xl textual inversion --- .../textual_inversion/textual_inversion.py | 92 +++++++++++++++++-- 1 file changed, 83 insertions(+), 9 deletions(-) diff --git a/examples/textual_inversion/textual_inversion.py b/examples/textual_inversion/textual_inversion.py index 50bcc992064d..af1232c0c1fd 100644 --- a/examples/textual_inversion/textual_inversion.py +++ b/examples/textual_inversion/textual_inversion.py @@ -40,7 +40,7 @@ from torch.utils.data import Dataset from torchvision import transforms from tqdm.auto import tqdm -from transformers import CLIPTextModel, CLIPTokenizer +from transformers import CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer import diffusers from diffusers import ( @@ -767,6 +767,53 @@ def main(): text_encoder, optimizer, train_dataloader, lr_scheduler ) + has_added_cond_kwargs = True if "stable-diffusion-xl" in args.pretrained_model_name_or_path else False + if has_added_cond_kwargs: + tokenizer_2 = CLIPTokenizer.from_pretrained(args.pretrained_model_name_or_path, subfolder="tokenizer_2") + text_encoder_2 = CLIPTextModelWithProjection.from_pretrained( + args.pretrained_model_name_or_path, subfolder="text_encoder_2", revision=args.revision + ) + # Resize the token embeddings as we are adding new special tokens to the tokenizer + text_encoder_2.resize_token_embeddings(len(tokenizer_2)) + # Freeze all parameters except for the token embeddings in text encoder + text_encoder_2.text_model.encoder.requires_grad_(False) + text_encoder_2.text_model.final_layer_norm.requires_grad_(False) + text_encoder_2.text_model.embeddings.position_embedding.requires_grad_(False) + if args.gradient_checkpointing: + text_encoder.gradient_checkpointing_enable() + + train_dataset_2 = TextualInversionDataset( + data_root=args.train_data_dir, + tokenizer=tokenizer_2, + size=args.resolution, + placeholder_token=(" ".join(tokenizer.convert_ids_to_tokens(placeholder_token_ids))), + repeats=args.repeats, + learnable_property=args.learnable_property, + center_crop=args.center_crop, + set="train", + ) + train_dataloader_2 = torch.utils.data.DataLoader( + train_dataset_2, batch_size=args.train_batch_size, shuffle=True, num_workers=args.dataloader_num_workers + ) + optimizer_2 = torch.optim.AdamW( + text_encoder_2.get_input_embeddings().parameters(), # only optimize the embeddings + lr=args.learning_rate, + betas=(args.adam_beta1, args.adam_beta2), + weight_decay=args.adam_weight_decay, + eps=args.adam_epsilon, + ) + lr_scheduler_2 = get_scheduler( + args.lr_scheduler, + optimizer=optimizer_2, + num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes, + num_training_steps=args.max_train_steps * accelerator.num_processes, + num_cycles=args.lr_num_cycles, + ) + text_encoder_2.train() + text_encoder_2, optimizer_2, train_dataloader_2, lr_scheduler_2 = accelerator.prepare( + text_encoder_2, optimizer_2, train_dataloader_2, lr_scheduler_2 + ) + # For mixed precision training we cast all non-trainable weigths (vae, non-lora text_encoder and non-lora unet) to half-precision # as these weights are only used for inference, keeping weights in full precision is not required. weight_dtype = torch.float32 @@ -844,6 +891,7 @@ def main(): for epoch in range(first_epoch, args.num_train_epochs): text_encoder.train() + text_encoder_2.train() for step, batch in enumerate(train_dataloader): with accelerator.accumulate(text_encoder): # Convert images to latent space @@ -863,9 +911,21 @@ def main(): # Get the text embedding for conditioning encoder_hidden_states = text_encoder(batch["input_ids"])[0].to(dtype=weight_dtype) + added_cond_kwargs = None + if has_added_cond_kwargs: + encoder_hidden_states = text_encoder(batch["input_ids"], output_hidden_states=True).hidden_states[-2].to(dtype=weight_dtype) + encoder_output_2 = text_encoder_2(train_dataset_2[step]["input_ids"].reshape(batch["input_ids"].shape[0], -1), output_hidden_states=True) + encoder_hidden_states_2 = encoder_output_2.hidden_states[-2].to(dtype=weight_dtype) + sample_size = unet.config.sample_size * (2 ** (len(vae.config.block_out_channels) - 1)) + original_size = (sample_size, sample_size) + add_time_ids = torch.tensor([list(original_size + (0, 0) + original_size)], dtype=weight_dtype) + added_cond_kwargs = {"text_embeds": encoder_output_2[0], "time_ids": add_time_ids} + encoder_hidden_states = torch.cat([encoder_hidden_states, encoder_hidden_states_2], dim=-1) + else: + encoder_hidden_states = text_encoder(batch["input_ids"])[0].to(dtype=weight_dtype) # Predict the noise residual - model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample + model_pred = unet(noisy_latents, timesteps, encoder_hidden_states, added_cond_kwargs=added_cond_kwargs).sample # Get the target for loss depending on the prediction type if noise_scheduler.config.prediction_type == "epsilon": @@ -883,6 +943,10 @@ def main(): lr_scheduler.step() optimizer.zero_grad() + optimizer_2.step() + lr_scheduler_2.step() + optimizer_2.zero_grad() + # Let's make sure we don't update any embedding weights besides the newly added token index_no_updates = torch.ones((len(tokenizer),), dtype=torch.bool) index_no_updates[min(placeholder_token_ids) : max(placeholder_token_ids) + 1] = False @@ -959,13 +1023,23 @@ def main(): else: save_full_model = args.save_as_full_pipeline if save_full_model: - pipeline = StableDiffusionPipeline.from_pretrained( - args.pretrained_model_name_or_path, - text_encoder=accelerator.unwrap_model(text_encoder), - vae=vae, - unet=unet, - tokenizer=tokenizer, - ) + if "xl" in args.pretrained_model_name_or_path: + pipeline = DiffusionPipeline.from_pretrained( + args.pretrained_model_name_or_path, + text_encoder=accelerator.unwrap_model(text_encoder), + text_encoder_2=accelerator.unwrap_model(text_encoder_2), + vae=vae, + unet=unet, + tokenizer=tokenizer, + ) + else: + pipeline = StableDiffusionPipeline.from_pretrained( + args.pretrained_model_name_or_path, + text_encoder=accelerator.unwrap_model(text_encoder), + vae=vae, + unet=unet, + tokenizer=tokenizer, + ) pipeline.save_pretrained(args.output_dir) # Save the newly trained embeddings weight_name = "learned_embeds.bin" if args.no_safe_serialization else "learned_embeds.safetensors" From b9518fc6d59c82a2a69c584378eabbe4049f75d8 Mon Sep 17 00:00:00 2001 From: jiqing-feng Date: Tue, 2 Jan 2024 03:51:06 -0500 Subject: [PATCH 02/16] check if optimizer_2 exists --- examples/textual_inversion/textual_inversion.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/examples/textual_inversion/textual_inversion.py b/examples/textual_inversion/textual_inversion.py index af1232c0c1fd..70262d48c7b3 100644 --- a/examples/textual_inversion/textual_inversion.py +++ b/examples/textual_inversion/textual_inversion.py @@ -943,9 +943,10 @@ def main(): lr_scheduler.step() optimizer.zero_grad() - optimizer_2.step() - lr_scheduler_2.step() - optimizer_2.zero_grad() + if has_added_cond_kwargs: + optimizer_2.step() + lr_scheduler_2.step() + optimizer_2.zero_grad() # Let's make sure we don't update any embedding weights besides the newly added token index_no_updates = torch.ones((len(tokenizer),), dtype=torch.bool) @@ -1023,7 +1024,7 @@ def main(): else: save_full_model = args.save_as_full_pipeline if save_full_model: - if "xl" in args.pretrained_model_name_or_path: + if has_added_cond_kwargs: pipeline = DiffusionPipeline.from_pretrained( args.pretrained_model_name_or_path, text_encoder=accelerator.unwrap_model(text_encoder), From d74f76524a9c2da4d97dd1afa25673f61fafc8c7 Mon Sep 17 00:00:00 2001 From: jiqing-feng Date: Wed, 3 Jan 2024 04:46:24 -0500 Subject: [PATCH 03/16] check text_encoder_2 before using --- examples/textual_inversion/textual_inversion.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/examples/textual_inversion/textual_inversion.py b/examples/textual_inversion/textual_inversion.py index 70262d48c7b3..e09b2fa6a6dc 100644 --- a/examples/textual_inversion/textual_inversion.py +++ b/examples/textual_inversion/textual_inversion.py @@ -891,7 +891,8 @@ def main(): for epoch in range(first_epoch, args.num_train_epochs): text_encoder.train() - text_encoder_2.train() + if has_added_cond_kwargs: + text_encoder_2.train() for step, batch in enumerate(train_dataloader): with accelerator.accumulate(text_encoder): # Convert images to latent space From 3c6cafda506a6c13c88cd7be4b9ecf5716bab857 Mon Sep 17 00:00:00 2001 From: jiqing-feng Date: Fri, 5 Jan 2024 06:14:31 -0500 Subject: [PATCH 04/16] add textual inversion for sdxl in a single file --- .../test_textual_inversion_sdxl.py | 152 +++ .../textual_inversion/textual_inversion.py | 96 +- .../textual_inversion_sdxl.py | 1024 +++++++++++++++++ 3 files changed, 1186 insertions(+), 86 deletions(-) create mode 100644 examples/textual_inversion/test_textual_inversion_sdxl.py create mode 100644 examples/textual_inversion/textual_inversion_sdxl.py diff --git a/examples/textual_inversion/test_textual_inversion_sdxl.py b/examples/textual_inversion/test_textual_inversion_sdxl.py new file mode 100644 index 000000000000..5c8f949feed3 --- /dev/null +++ b/examples/textual_inversion/test_textual_inversion_sdxl.py @@ -0,0 +1,152 @@ +# coding=utf-8 +# Copyright 2023 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +import os +import sys +import tempfile + + +sys.path.append("..") +from test_examples_utils import ExamplesTestsAccelerate, run_command # noqa: E402 + + +logging.basicConfig(level=logging.DEBUG) + +logger = logging.getLogger() +stream_handler = logging.StreamHandler(sys.stdout) +logger.addHandler(stream_handler) + + +class TextualInversionSdxl(ExamplesTestsAccelerate): + def test_textual_inversion_sdxl(self): + with tempfile.TemporaryDirectory() as tmpdir: + test_args = f""" + examples/textual_inversion/textual_inversion_sdxl.py + --pretrained_model_name_or_path hf-internal-testing/tiny-sdxl-pipe + --train_data_dir docs/source/en/imgs + --learnable_property object + --placeholder_token + --initializer_token a + --save_steps 1 + --num_vectors 2 + --resolution 64 + --train_batch_size 1 + --gradient_accumulation_steps 1 + --max_train_steps 2 + --learning_rate 5.0e-04 + --scale_lr + --lr_scheduler constant + --lr_warmup_steps 0 + --output_dir {tmpdir} + """.split() + + run_command(self._launch_args + test_args) + # save_pretrained smoke test + self.assertTrue(os.path.isfile(os.path.join(tmpdir, "learned_embeds.safetensors"))) + + def test_textual_inversion_sdxl_checkpointing(self): + with tempfile.TemporaryDirectory() as tmpdir: + test_args = f""" + examples/textual_inversion/textual_inversion_sdxl.py + --pretrained_model_name_or_path hf-internal-testing/tiny-sdxl-pipe + --train_data_dir docs/source/en/imgs + --learnable_property object + --placeholder_token + --initializer_token a + --save_steps 1 + --num_vectors 2 + --resolution 64 + --train_batch_size 1 + --gradient_accumulation_steps 1 + --max_train_steps 3 + --learning_rate 5.0e-04 + --scale_lr + --lr_scheduler constant + --lr_warmup_steps 0 + --output_dir {tmpdir} + --checkpointing_steps=1 + --checkpoints_total_limit=2 + """.split() + + run_command(self._launch_args + test_args) + + # check checkpoint directories exist + self.assertEqual( + {x for x in os.listdir(tmpdir) if "checkpoint" in x}, + {"checkpoint-2", "checkpoint-3"}, + ) + + def test_textual_inversion_sdxl_checkpointing_checkpoints_total_limit_removes_multiple_checkpoints(self): + with tempfile.TemporaryDirectory() as tmpdir: + test_args = f""" + examples/textual_inversion/textual_inversion_sdxl.py + --pretrained_model_name_or_path hf-internal-testing/tiny-sdxl-pipe + --train_data_dir docs/source/en/imgs + --learnable_property object + --placeholder_token + --initializer_token a + --save_steps 1 + --num_vectors 2 + --resolution 64 + --train_batch_size 1 + --gradient_accumulation_steps 1 + --max_train_steps 2 + --learning_rate 5.0e-04 + --scale_lr + --lr_scheduler constant + --lr_warmup_steps 0 + --output_dir {tmpdir} + --checkpointing_steps=1 + """.split() + + run_command(self._launch_args + test_args) + + # check checkpoint directories exist + self.assertEqual( + {x for x in os.listdir(tmpdir) if "checkpoint" in x}, + {"checkpoint-1", "checkpoint-2"}, + ) + + resume_run_args = f""" + examples/textual_inversion/textual_inversion_sdxl.py + --pretrained_model_name_or_path hf-internal-testing/tiny-sdxl-pipe + --train_data_dir docs/source/en/imgs + --learnable_property object + --placeholder_token + --initializer_token a + --save_steps 1 + --num_vectors 2 + --resolution 64 + --train_batch_size 1 + --gradient_accumulation_steps 1 + --max_train_steps 2 + --learning_rate 5.0e-04 + --scale_lr + --lr_scheduler constant + --lr_warmup_steps 0 + --output_dir {tmpdir} + --checkpointing_steps=1 + --resume_from_checkpoint=checkpoint-2 + --checkpoints_total_limit=2 + """.split() + + run_command(self._launch_args + resume_run_args) + + # check checkpoint directories exist + self.assertEqual( + {x for x in os.listdir(tmpdir) if "checkpoint" in x}, + {"checkpoint-2", "checkpoint-3"}, + ) diff --git a/examples/textual_inversion/textual_inversion.py b/examples/textual_inversion/textual_inversion.py index e09b2fa6a6dc..319e94895aae 100644 --- a/examples/textual_inversion/textual_inversion.py +++ b/examples/textual_inversion/textual_inversion.py @@ -40,7 +40,7 @@ from torch.utils.data import Dataset from torchvision import transforms from tqdm.auto import tqdm -from transformers import CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer +from transformers import CLIPTextModel, CLIPTokenizer import diffusers from diffusers import ( @@ -767,53 +767,6 @@ def main(): text_encoder, optimizer, train_dataloader, lr_scheduler ) - has_added_cond_kwargs = True if "stable-diffusion-xl" in args.pretrained_model_name_or_path else False - if has_added_cond_kwargs: - tokenizer_2 = CLIPTokenizer.from_pretrained(args.pretrained_model_name_or_path, subfolder="tokenizer_2") - text_encoder_2 = CLIPTextModelWithProjection.from_pretrained( - args.pretrained_model_name_or_path, subfolder="text_encoder_2", revision=args.revision - ) - # Resize the token embeddings as we are adding new special tokens to the tokenizer - text_encoder_2.resize_token_embeddings(len(tokenizer_2)) - # Freeze all parameters except for the token embeddings in text encoder - text_encoder_2.text_model.encoder.requires_grad_(False) - text_encoder_2.text_model.final_layer_norm.requires_grad_(False) - text_encoder_2.text_model.embeddings.position_embedding.requires_grad_(False) - if args.gradient_checkpointing: - text_encoder.gradient_checkpointing_enable() - - train_dataset_2 = TextualInversionDataset( - data_root=args.train_data_dir, - tokenizer=tokenizer_2, - size=args.resolution, - placeholder_token=(" ".join(tokenizer.convert_ids_to_tokens(placeholder_token_ids))), - repeats=args.repeats, - learnable_property=args.learnable_property, - center_crop=args.center_crop, - set="train", - ) - train_dataloader_2 = torch.utils.data.DataLoader( - train_dataset_2, batch_size=args.train_batch_size, shuffle=True, num_workers=args.dataloader_num_workers - ) - optimizer_2 = torch.optim.AdamW( - text_encoder_2.get_input_embeddings().parameters(), # only optimize the embeddings - lr=args.learning_rate, - betas=(args.adam_beta1, args.adam_beta2), - weight_decay=args.adam_weight_decay, - eps=args.adam_epsilon, - ) - lr_scheduler_2 = get_scheduler( - args.lr_scheduler, - optimizer=optimizer_2, - num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes, - num_training_steps=args.max_train_steps * accelerator.num_processes, - num_cycles=args.lr_num_cycles, - ) - text_encoder_2.train() - text_encoder_2, optimizer_2, train_dataloader_2, lr_scheduler_2 = accelerator.prepare( - text_encoder_2, optimizer_2, train_dataloader_2, lr_scheduler_2 - ) - # For mixed precision training we cast all non-trainable weigths (vae, non-lora text_encoder and non-lora unet) to half-precision # as these weights are only used for inference, keeping weights in full precision is not required. weight_dtype = torch.float32 @@ -891,8 +844,6 @@ def main(): for epoch in range(first_epoch, args.num_train_epochs): text_encoder.train() - if has_added_cond_kwargs: - text_encoder_2.train() for step, batch in enumerate(train_dataloader): with accelerator.accumulate(text_encoder): # Convert images to latent space @@ -912,21 +863,9 @@ def main(): # Get the text embedding for conditioning encoder_hidden_states = text_encoder(batch["input_ids"])[0].to(dtype=weight_dtype) - added_cond_kwargs = None - if has_added_cond_kwargs: - encoder_hidden_states = text_encoder(batch["input_ids"], output_hidden_states=True).hidden_states[-2].to(dtype=weight_dtype) - encoder_output_2 = text_encoder_2(train_dataset_2[step]["input_ids"].reshape(batch["input_ids"].shape[0], -1), output_hidden_states=True) - encoder_hidden_states_2 = encoder_output_2.hidden_states[-2].to(dtype=weight_dtype) - sample_size = unet.config.sample_size * (2 ** (len(vae.config.block_out_channels) - 1)) - original_size = (sample_size, sample_size) - add_time_ids = torch.tensor([list(original_size + (0, 0) + original_size)], dtype=weight_dtype) - added_cond_kwargs = {"text_embeds": encoder_output_2[0], "time_ids": add_time_ids} - encoder_hidden_states = torch.cat([encoder_hidden_states, encoder_hidden_states_2], dim=-1) - else: - encoder_hidden_states = text_encoder(batch["input_ids"])[0].to(dtype=weight_dtype) # Predict the noise residual - model_pred = unet(noisy_latents, timesteps, encoder_hidden_states, added_cond_kwargs=added_cond_kwargs).sample + model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample # Get the target for loss depending on the prediction type if noise_scheduler.config.prediction_type == "epsilon": @@ -944,11 +883,6 @@ def main(): lr_scheduler.step() optimizer.zero_grad() - if has_added_cond_kwargs: - optimizer_2.step() - lr_scheduler_2.step() - optimizer_2.zero_grad() - # Let's make sure we don't update any embedding weights besides the newly added token index_no_updates = torch.ones((len(tokenizer),), dtype=torch.bool) index_no_updates[min(placeholder_token_ids) : max(placeholder_token_ids) + 1] = False @@ -1025,23 +959,13 @@ def main(): else: save_full_model = args.save_as_full_pipeline if save_full_model: - if has_added_cond_kwargs: - pipeline = DiffusionPipeline.from_pretrained( - args.pretrained_model_name_or_path, - text_encoder=accelerator.unwrap_model(text_encoder), - text_encoder_2=accelerator.unwrap_model(text_encoder_2), - vae=vae, - unet=unet, - tokenizer=tokenizer, - ) - else: - pipeline = StableDiffusionPipeline.from_pretrained( - args.pretrained_model_name_or_path, - text_encoder=accelerator.unwrap_model(text_encoder), - vae=vae, - unet=unet, - tokenizer=tokenizer, - ) + pipeline = StableDiffusionPipeline.from_pretrained( + args.pretrained_model_name_or_path, + text_encoder=accelerator.unwrap_model(text_encoder), + vae=vae, + unet=unet, + tokenizer=tokenizer, + ) pipeline.save_pretrained(args.output_dir) # Save the newly trained embeddings weight_name = "learned_embeds.bin" if args.no_safe_serialization else "learned_embeds.safetensors" @@ -1073,4 +997,4 @@ def main(): if __name__ == "__main__": - main() + main() \ No newline at end of file diff --git a/examples/textual_inversion/textual_inversion_sdxl.py b/examples/textual_inversion/textual_inversion_sdxl.py new file mode 100644 index 000000000000..f725a59848e2 --- /dev/null +++ b/examples/textual_inversion/textual_inversion_sdxl.py @@ -0,0 +1,1024 @@ +#!/usr/bin/env python +# coding=utf-8 +# Copyright 2023 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and + +import argparse +import logging +import math +import os +import random +import shutil +import warnings +from pathlib import Path + +import numpy as np +import PIL +import safetensors +import torch +import torch.nn.functional as F +import torch.utils.checkpoint +import transformers +from accelerate import Accelerator +from accelerate.logging import get_logger +from accelerate.utils import ProjectConfiguration, set_seed +from huggingface_hub import create_repo, upload_folder + +# TODO: remove and import from diffusers.utils when the new version of diffusers is released +from packaging import version +from PIL import Image +from torch.utils.data import Dataset +from torchvision import transforms +from tqdm.auto import tqdm +from transformers import CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer + +import diffusers +from diffusers import ( + AutoencoderKL, + DDPMScheduler, + DiffusionPipeline, + DPMSolverMultistepScheduler, + StableDiffusionPipeline, + UNet2DConditionModel, +) +from diffusers.optimization import get_scheduler +from diffusers.utils import check_min_version, is_wandb_available +from diffusers.utils.import_utils import is_xformers_available + + +if is_wandb_available(): + import wandb + +if version.parse(version.parse(PIL.__version__).base_version) >= version.parse("9.1.0"): + PIL_INTERPOLATION = { + "linear": PIL.Image.Resampling.BILINEAR, + "bilinear": PIL.Image.Resampling.BILINEAR, + "bicubic": PIL.Image.Resampling.BICUBIC, + "lanczos": PIL.Image.Resampling.LANCZOS, + "nearest": PIL.Image.Resampling.NEAREST, + } +else: + PIL_INTERPOLATION = { + "linear": PIL.Image.LINEAR, + "bilinear": PIL.Image.BILINEAR, + "bicubic": PIL.Image.BICUBIC, + "lanczos": PIL.Image.LANCZOS, + "nearest": PIL.Image.NEAREST, + } +# ------------------------------------------------------------------------------ + + +# Will error if the minimal version of diffusers is not installed. Remove at your own risks. +check_min_version("0.25.0.dev0") + +logger = get_logger(__name__) + + +def save_model_card(repo_id: str, images=None, base_model=str, repo_folder=None): + img_str = "" + for i, image in enumerate(images): + image.save(os.path.join(repo_folder, f"image_{i}.png")) + img_str += f"![img_{i}](./image_{i}.png)\n" + + yaml = f""" +--- +license: creativeml-openrail-m +base_model: {base_model} +tags: +- stable-diffusion +- stable-diffusion-diffusers +- text-to-image +- diffusers +- textual_inversion +inference: true +--- + """ + model_card = f""" +# Textual inversion text2image fine-tuning - {repo_id} +These are textual inversion adaption weights for {base_model}. You can find some example images in the following. \n +{img_str} +""" + with open(os.path.join(repo_folder, "README.md"), "w") as f: + f.write(yaml + model_card) + + +def log_validation(text_encoder_1, text_encoder_2, tokenizer_1, tokenizer_2, unet, vae, args, accelerator, weight_dtype, epoch): + logger.info( + f"Running validation... \n Generating {args.num_validation_images} images with prompt:" + f" {args.validation_prompt}." + ) + # create pipeline (note: unet and vae are loaded again in float32) + pipeline = DiffusionPipeline.from_pretrained( + args.pretrained_model_name_or_path, + text_encoder=accelerator.unwrap_model(text_encoder_1), + text_encoder_2=accelerator.unwrap_model(text_encoder_2), + tokenizer=tokenizer_1, + tokenizer_2=tokenizer_2, + unet=unet, + vae=vae, + safety_checker=None, + revision=args.revision, + variant=args.variant, + torch_dtype=weight_dtype, + ) + pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config) + pipeline = pipeline.to(accelerator.device) + pipeline.set_progress_bar_config(disable=True) + + # run inference + generator = None if args.seed is None else torch.Generator(device=accelerator.device).manual_seed(args.seed) + images = [] + for _ in range(args.num_validation_images): + with torch.autocast("cuda"): + image = pipeline(args.validation_prompt, num_inference_steps=25, generator=generator).images[0] + images.append(image) + + for tracker in accelerator.trackers: + if tracker.name == "tensorboard": + np_images = np.stack([np.asarray(img) for img in images]) + tracker.writer.add_images("validation", np_images, epoch, dataformats="NHWC") + if tracker.name == "wandb": + tracker.log( + { + "validation": [ + wandb.Image(image, caption=f"{i}: {args.validation_prompt}") for i, image in enumerate(images) + ] + } + ) + + del pipeline + torch.cuda.empty_cache() + return images + + +def save_progress(text_encoder, placeholder_token_ids, accelerator, args, save_path, safe_serialization=True): + logger.info("Saving embeddings") + learned_embeds = ( + accelerator.unwrap_model(text_encoder) + .get_input_embeddings() + .weight[min(placeholder_token_ids) : max(placeholder_token_ids) + 1] + ) + learned_embeds_dict = {args.placeholder_token: learned_embeds.detach().cpu()} + + if safe_serialization: + safetensors.torch.save_file(learned_embeds_dict, save_path, metadata={"format": "pt"}) + else: + torch.save(learned_embeds_dict, save_path) + + +def parse_args(): + parser = argparse.ArgumentParser(description="Simple example of a training script.") + parser.add_argument( + "--save_steps", + type=int, + default=500, + help="Save learned_embeds.bin every X updates steps.", + ) + parser.add_argument( + "--save_as_full_pipeline", + action="store_true", + help="Save the complete stable diffusion pipeline.", + ) + parser.add_argument( + "--num_vectors", + type=int, + default=1, + help="How many textual inversion vectors shall be used to learn the concept.", + ) + parser.add_argument( + "--pretrained_model_name_or_path", + type=str, + default=None, + required=True, + help="Path to pretrained model or model identifier from huggingface.co/models.", + ) + parser.add_argument( + "--revision", + type=str, + default=None, + required=False, + help="Revision of pretrained model identifier from huggingface.co/models.", + ) + parser.add_argument( + "--variant", + type=str, + default=None, + help="Variant of the model files of the pretrained model identifier from huggingface.co/models, 'e.g.' fp16", + ) + parser.add_argument( + "--train_data_dir", type=str, default=None, required=True, help="A folder containing the training data." + ) + parser.add_argument( + "--placeholder_token", + type=str, + default=None, + required=True, + help="A token to use as a placeholder for the concept.", + ) + parser.add_argument( + "--initializer_token", type=str, default=None, required=True, help="A token to use as initializer word." + ) + parser.add_argument("--learnable_property", type=str, default="object", help="Choose between 'object' and 'style'") + parser.add_argument("--repeats", type=int, default=100, help="How many times to repeat the training data.") + parser.add_argument( + "--output_dir", + type=str, + default="text-inversion-model", + help="The output directory where the model predictions and checkpoints will be written.", + ) + parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.") + parser.add_argument( + "--resolution", + type=int, + default=512, + help=( + "The resolution for input images, all the images in the train/validation dataset will be resized to this" + " resolution" + ), + ) + parser.add_argument( + "--center_crop", action="store_true", help="Whether to center crop images before resizing to resolution." + ) + parser.add_argument( + "--train_batch_size", type=int, default=16, help="Batch size (per device) for the training dataloader." + ) + parser.add_argument("--num_train_epochs", type=int, default=100) + parser.add_argument( + "--max_train_steps", + type=int, + default=5000, + help="Total number of training steps to perform. If provided, overrides num_train_epochs.", + ) + parser.add_argument( + "--gradient_accumulation_steps", + type=int, + default=1, + help="Number of updates steps to accumulate before performing a backward/update pass.", + ) + parser.add_argument( + "--gradient_checkpointing", + action="store_true", + help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.", + ) + parser.add_argument( + "--learning_rate", + type=float, + default=1e-4, + help="Initial learning rate (after the potential warmup period) to use.", + ) + parser.add_argument( + "--scale_lr", + action="store_true", + default=False, + help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.", + ) + parser.add_argument( + "--lr_scheduler", + type=str, + default="constant", + help=( + 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",' + ' "constant", "constant_with_warmup"]' + ), + ) + parser.add_argument( + "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler." + ) + parser.add_argument( + "--lr_num_cycles", + type=int, + default=1, + help="Number of hard resets of the lr in cosine_with_restarts scheduler.", + ) + parser.add_argument( + "--dataloader_num_workers", + type=int, + default=0, + help=( + "Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process." + ), + ) + parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.") + parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.") + parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.") + parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer") + parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.") + parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.") + parser.add_argument( + "--hub_model_id", + type=str, + default=None, + help="The name of the repository to keep in sync with the local `output_dir`.", + ) + parser.add_argument( + "--logging_dir", + type=str, + default="logs", + help=( + "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to" + " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***." + ), + ) + parser.add_argument( + "--mixed_precision", + type=str, + default="no", + choices=["no", "fp16", "bf16"], + help=( + "Whether to use mixed precision. Choose" + "between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >= 1.10." + "and an Nvidia Ampere GPU." + ), + ) + parser.add_argument( + "--allow_tf32", + action="store_true", + help=( + "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see" + " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices" + ), + ) + parser.add_argument( + "--report_to", + type=str, + default="tensorboard", + help=( + 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`' + ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.' + ), + ) + parser.add_argument( + "--validation_prompt", + type=str, + default=None, + help="A prompt that is used during validation to verify that the model is learning.", + ) + parser.add_argument( + "--num_validation_images", + type=int, + default=4, + help="Number of images that should be generated during validation with `validation_prompt`.", + ) + parser.add_argument( + "--validation_steps", + type=int, + default=100, + help=( + "Run validation every X steps. Validation consists of running the prompt" + " `args.validation_prompt` multiple times: `args.num_validation_images`" + " and logging the images." + ), + ) + parser.add_argument( + "--validation_epochs", + type=int, + default=None, + help=( + "Deprecated in favor of validation_steps. Run validation every X epochs. Validation consists of running the prompt" + " `args.validation_prompt` multiple times: `args.num_validation_images`" + " and logging the images." + ), + ) + parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank") + parser.add_argument( + "--checkpointing_steps", + type=int, + default=500, + help=( + "Save a checkpoint of the training state every X updates. These checkpoints are only suitable for resuming" + " training using `--resume_from_checkpoint`." + ), + ) + parser.add_argument( + "--checkpoints_total_limit", + type=int, + default=None, + help=("Max number of checkpoints to store."), + ) + parser.add_argument( + "--resume_from_checkpoint", + type=str, + default=None, + help=( + "Whether training should be resumed from a previous checkpoint. Use a path saved by" + ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.' + ), + ) + parser.add_argument( + "--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers." + ) + parser.add_argument( + "--no_safe_serialization", + action="store_true", + help="If specified save the checkpoint not in `safetensors` format, but in original PyTorch format instead.", + ) + + args = parser.parse_args() + env_local_rank = int(os.environ.get("LOCAL_RANK", -1)) + if env_local_rank != -1 and env_local_rank != args.local_rank: + args.local_rank = env_local_rank + + if args.train_data_dir is None: + raise ValueError("You must specify a train data directory.") + + return args + + +imagenet_templates_small = [ + "a photo of a {}", + "a rendering of a {}", + "a cropped photo of the {}", + "the photo of a {}", + "a photo of a clean {}", + "a photo of a dirty {}", + "a dark photo of the {}", + "a photo of my {}", + "a photo of the cool {}", + "a close-up photo of a {}", + "a bright photo of the {}", + "a cropped photo of a {}", + "a photo of the {}", + "a good photo of the {}", + "a photo of one {}", + "a close-up photo of the {}", + "a rendition of the {}", + "a photo of the clean {}", + "a rendition of a {}", + "a photo of a nice {}", + "a good photo of a {}", + "a photo of the nice {}", + "a photo of the small {}", + "a photo of the weird {}", + "a photo of the large {}", + "a photo of a cool {}", + "a photo of a small {}", +] + +imagenet_style_templates_small = [ + "a painting in the style of {}", + "a rendering in the style of {}", + "a cropped painting in the style of {}", + "the painting in the style of {}", + "a clean painting in the style of {}", + "a dirty painting in the style of {}", + "a dark painting in the style of {}", + "a picture in the style of {}", + "a cool painting in the style of {}", + "a close-up painting in the style of {}", + "a bright painting in the style of {}", + "a cropped painting in the style of {}", + "a good painting in the style of {}", + "a close-up painting in the style of {}", + "a rendition in the style of {}", + "a nice painting in the style of {}", + "a small painting in the style of {}", + "a weird painting in the style of {}", + "a large painting in the style of {}", +] + + +class TextualInversionDataset(Dataset): + def __init__( + self, + data_root, + tokenizer_1, + tokenizer_2, + learnable_property="object", # [object, style] + size=512, + repeats=100, + interpolation="bicubic", + flip_p=0.5, + set="train", + placeholder_token="*", + center_crop=False, + ): + self.data_root = data_root + self.tokenizer_1 = tokenizer_1 + self.tokenizer_2 = tokenizer_2 + self.learnable_property = learnable_property + self.size = size + self.placeholder_token = placeholder_token + self.center_crop = center_crop + self.flip_p = flip_p + + self.image_paths = [os.path.join(self.data_root, file_path) for file_path in os.listdir(self.data_root)] + + self.num_images = len(self.image_paths) + self._length = self.num_images + + if set == "train": + self._length = self.num_images * repeats + + self.interpolation = { + "linear": PIL_INTERPOLATION["linear"], + "bilinear": PIL_INTERPOLATION["bilinear"], + "bicubic": PIL_INTERPOLATION["bicubic"], + "lanczos": PIL_INTERPOLATION["lanczos"], + }[interpolation] + + self.templates = imagenet_style_templates_small if learnable_property == "style" else imagenet_templates_small + self.flip_transform = transforms.RandomHorizontalFlip(p=self.flip_p) + + def __len__(self): + return self._length + + def __getitem__(self, i): + example = {} + image = Image.open(self.image_paths[i % self.num_images]) + + if not image.mode == "RGB": + image = image.convert("RGB") + + placeholder_string = self.placeholder_token + text = random.choice(self.templates).format(placeholder_string) + + example["input_ids_1"] = self.tokenizer_1( + text, + padding="max_length", + truncation=True, + max_length=self.tokenizer_1.model_max_length, + return_tensors="pt", + ).input_ids[0] + + example["input_ids_2"] = self.tokenizer_2( + text, + padding="max_length", + truncation=True, + max_length=self.tokenizer_2.model_max_length, + return_tensors="pt", + ).input_ids[0] + + # default to score-sde preprocessing + img = np.array(image).astype(np.uint8) + + if self.center_crop: + crop = min(img.shape[0], img.shape[1]) + ( + h, + w, + ) = ( + img.shape[0], + img.shape[1], + ) + img = img[(h - crop) // 2 : (h + crop) // 2, (w - crop) // 2 : (w + crop) // 2] + + image = Image.fromarray(img) + image = image.resize((self.size, self.size), resample=self.interpolation) + + image = self.flip_transform(image) + image = np.array(image).astype(np.uint8) + image = (image / 127.5 - 1.0).astype(np.float32) + + example["pixel_values"] = torch.from_numpy(image).permute(2, 0, 1) + return example + + +def main(): + args = parse_args() + logging_dir = os.path.join(args.output_dir, args.logging_dir) + accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir) + accelerator = Accelerator( + gradient_accumulation_steps=args.gradient_accumulation_steps, + mixed_precision=args.mixed_precision, + log_with=args.report_to, + project_config=accelerator_project_config, + ) + + if args.report_to == "wandb": + if not is_wandb_available(): + raise ImportError("Make sure to install wandb if you want to use it for logging during training.") + + # Make one log on every process with the configuration for debugging. + logging.basicConfig( + format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", + datefmt="%m/%d/%Y %H:%M:%S", + level=logging.INFO, + ) + logger.info(accelerator.state, main_process_only=False) + if accelerator.is_local_main_process: + transformers.utils.logging.set_verbosity_warning() + diffusers.utils.logging.set_verbosity_info() + else: + transformers.utils.logging.set_verbosity_error() + diffusers.utils.logging.set_verbosity_error() + + # If passed along, set the training seed now. + if args.seed is not None: + set_seed(args.seed) + + # Handle the repository creation + if accelerator.is_main_process: + if args.output_dir is not None: + os.makedirs(args.output_dir, exist_ok=True) + + if args.push_to_hub: + repo_id = create_repo( + repo_id=args.hub_model_id or Path(args.output_dir).name, exist_ok=True, token=args.hub_token + ).repo_id + + # Load tokenizer + tokenizer_1 = CLIPTokenizer.from_pretrained(args.pretrained_model_name_or_path, subfolder="tokenizer") + tokenizer_2 = CLIPTokenizer.from_pretrained(args.pretrained_model_name_or_path, subfolder="tokenizer_2") + + # Load scheduler and models + noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler") + text_encoder_1 = CLIPTextModel.from_pretrained( + args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision + ) + text_encoder_2 = CLIPTextModelWithProjection.from_pretrained( + args.pretrained_model_name_or_path, subfolder="text_encoder_2", revision=args.revision + ) + vae = AutoencoderKL.from_pretrained( + args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision, variant=args.variant + ) + unet = UNet2DConditionModel.from_pretrained( + args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision, variant=args.variant + ) + + + # Add the placeholder token in tokenizer_1 + placeholder_tokens = [args.placeholder_token] + + if args.num_vectors < 1: + raise ValueError(f"--num_vectors has to be larger or equal to 1, but is {args.num_vectors}") + + # add dummy tokens for multi-vector + additional_tokens = [] + for i in range(1, args.num_vectors): + additional_tokens.append(f"{args.placeholder_token}_{i}") + placeholder_tokens += additional_tokens + + num_added_tokens = tokenizer_1.add_tokens(placeholder_tokens) + if num_added_tokens != args.num_vectors: + raise ValueError( + f"The tokenizer already contains the token {args.placeholder_token}. Please pass a different" + " `placeholder_token` that is not already in the tokenizer." + ) + + # Convert the initializer_token, placeholder_token to ids + token_ids = tokenizer_1.encode(args.initializer_token, add_special_tokens=False) + # Check if initializer_token is a single token or a sequence of tokens + if len(token_ids) > 1: + raise ValueError("The initializer token must be a single token.") + + initializer_token_id = token_ids[0] + placeholder_token_ids = tokenizer_1.convert_tokens_to_ids(placeholder_tokens) + + # Resize the token embeddings as we are adding new special tokens to the tokenizer + text_encoder_1.resize_token_embeddings(len(tokenizer_1)) + + # Initialise the newly added placeholder token with the embeddings of the initializer token + token_embeds = text_encoder_1.get_input_embeddings().weight.data + with torch.no_grad(): + for token_id in placeholder_token_ids: + token_embeds[token_id] = token_embeds[initializer_token_id].clone() + + # Freeze vae and unet + vae.requires_grad_(False) + unet.requires_grad_(False) + # Freeze all parameters except for the token embeddings in text encoder + text_encoder_1.text_model.encoder.requires_grad_(False) + text_encoder_1.text_model.final_layer_norm.requires_grad_(False) + text_encoder_1.text_model.embeddings.position_embedding.requires_grad_(False) + text_encoder_2.text_model.encoder.requires_grad_(False) + text_encoder_2.text_model.final_layer_norm.requires_grad_(False) + text_encoder_2.text_model.embeddings.position_embedding.requires_grad_(False) + + if args.gradient_checkpointing: + # Keep unet in train mode if we are using gradient checkpointing to save memory. + # The dropout cannot be != 0 so it doesn't matter if we are in eval or train mode. + unet.train() + text_encoder_1.gradient_checkpointing_enable() + text_encoder_2.gradient_checkpointing_enable() + unet.enable_gradient_checkpointing() + + if args.enable_xformers_memory_efficient_attention: + if is_xformers_available(): + import xformers + + xformers_version = version.parse(xformers.__version__) + if xformers_version == version.parse("0.0.16"): + logger.warn( + "xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details." + ) + unet.enable_xformers_memory_efficient_attention() + else: + raise ValueError("xformers is not available. Make sure it is installed correctly") + + # Enable TF32 for faster training on Ampere GPUs, + # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices + if args.allow_tf32: + torch.backends.cuda.matmul.allow_tf32 = True + + if args.scale_lr: + args.learning_rate = ( + args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes + ) + + # Initialize the optimizer + optimizer_1 = torch.optim.AdamW( + text_encoder_1.get_input_embeddings().parameters(), # only optimize the embeddings + lr=args.learning_rate, + betas=(args.adam_beta1, args.adam_beta2), + weight_decay=args.adam_weight_decay, + eps=args.adam_epsilon, + ) + + # Dataset and DataLoaders creation: + train_dataset = TextualInversionDataset( + data_root=args.train_data_dir, + tokenizer_1=tokenizer_1, + tokenizer_2=tokenizer_2, + size=args.resolution, + placeholder_token=(" ".join(tokenizer_1.convert_ids_to_tokens(placeholder_token_ids))), + repeats=args.repeats, + learnable_property=args.learnable_property, + center_crop=args.center_crop, + set="train", + ) + train_dataloader = torch.utils.data.DataLoader( + train_dataset, batch_size=args.train_batch_size, shuffle=True, num_workers=args.dataloader_num_workers + ) + if args.validation_epochs is not None: + warnings.warn( + f"FutureWarning: You are doing logging with validation_epochs={args.validation_epochs}." + " Deprecated validation_epochs in favor of `validation_steps`" + f"Setting `args.validation_steps` to {args.validation_epochs * len(train_dataset)}", + FutureWarning, + stacklevel=2, + ) + args.validation_steps = args.validation_epochs * len(train_dataset) + + # Scheduler and math around the number of training steps. + overrode_max_train_steps = False + num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) + if args.max_train_steps is None: + args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch + overrode_max_train_steps = True + + lr_scheduler_1 = get_scheduler( + args.lr_scheduler, + optimizer=optimizer_1, + num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes, + num_training_steps=args.max_train_steps * accelerator.num_processes, + num_cycles=args.lr_num_cycles, + ) + + text_encoder_1.train() + # Prepare everything with our `accelerator`. + text_encoder_1, optimizer_1, train_dataloader, lr_scheduler_1 = accelerator.prepare( + text_encoder_1, optimizer_1, train_dataloader, lr_scheduler_1 + ) + + # For mixed precision training we cast all non-trainable weigths (vae, non-lora text_encoder and non-lora unet) to half-precision + # as these weights are only used for inference, keeping weights in full precision is not required. + weight_dtype = torch.float32 + if accelerator.mixed_precision == "fp16": + weight_dtype = torch.float16 + elif accelerator.mixed_precision == "bf16": + weight_dtype = torch.bfloat16 + + # Move vae and unet and text_encoder_2 to device and cast to weight_dtype + unet.to(accelerator.device, dtype=weight_dtype) + vae.to(accelerator.device, dtype=weight_dtype) + text_encoder_2 = text_encoder_2.to(accelerator.device, dtype=weight_dtype) + + # We need to recalculate our total training steps as the size of the training dataloader may have changed. + num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) + if overrode_max_train_steps: + args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch + # Afterwards we recalculate our number of training epochs + args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) + + # We need to initialize the trackers we use, and also store our configuration. + # The trackers initializes automatically on the main process. + if accelerator.is_main_process: + accelerator.init_trackers("textual_inversion", config=vars(args)) + + # Train! + total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps + + logger.info("***** Running training *****") + logger.info(f" Num examples = {len(train_dataset)}") + logger.info(f" Num Epochs = {args.num_train_epochs}") + logger.info(f" Instantaneous batch size per device = {args.train_batch_size}") + logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}") + logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}") + logger.info(f" Total optimization steps = {args.max_train_steps}") + global_step = 0 + first_epoch = 0 + # Potentially load in the weights and states from a previous save + if args.resume_from_checkpoint: + if args.resume_from_checkpoint != "latest": + path = os.path.basename(args.resume_from_checkpoint) + else: + # Get the most recent checkpoint + dirs = os.listdir(args.output_dir) + dirs = [d for d in dirs if d.startswith("checkpoint")] + dirs = sorted(dirs, key=lambda x: int(x.split("-")[1])) + path = dirs[-1] if len(dirs) > 0 else None + + if path is None: + accelerator.print( + f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run." + ) + args.resume_from_checkpoint = None + initial_global_step = 0 + else: + accelerator.print(f"Resuming from checkpoint {path}") + accelerator.load_state(os.path.join(args.output_dir, path)) + global_step = int(path.split("-")[1]) + + initial_global_step = global_step + first_epoch = global_step // num_update_steps_per_epoch + + else: + initial_global_step = 0 + + progress_bar = tqdm( + range(0, args.max_train_steps), + initial=initial_global_step, + desc="Steps", + # Only show the progress bar once on each machine. + disable=not accelerator.is_local_main_process, + ) + + # keep original embeddings as reference + orig_embeds_params = accelerator.unwrap_model(text_encoder_1).get_input_embeddings().weight.data.clone() + + for epoch in range(first_epoch, args.num_train_epochs): + text_encoder_1.train() + for step, batch in enumerate(train_dataloader): + with accelerator.accumulate(text_encoder_1): + # Convert images to latent space + latents = vae.encode(batch["pixel_values"].to(dtype=weight_dtype)).latent_dist.sample().detach() + latents = latents * vae.config.scaling_factor + + # Sample noise that we'll add to the latents + noise = torch.randn_like(latents) + bsz = latents.shape[0] + # Sample a random timestep for each image + timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device) + timesteps = timesteps.long() + + # Add noise to the latents according to the noise magnitude at each timestep + # (this is the forward diffusion process) + noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) + + # Get the text embedding for conditioning + encoder_hidden_states_1 = text_encoder_1(batch["input_ids_1"], output_hidden_states=True).hidden_states[-2].to(dtype=weight_dtype) + encoder_output_2 = text_encoder_2(batch["input_ids_2"].reshape(batch["input_ids_1"].shape[0], -1), output_hidden_states=True) + encoder_hidden_states_2 = encoder_output_2.hidden_states[-2].to(dtype=weight_dtype) + sample_size = unet.config.sample_size * (2 ** (len(vae.config.block_out_channels) - 1)) + original_size = (sample_size, sample_size) + add_time_ids = torch.tensor([list(original_size + (0, 0) + original_size)], dtype=weight_dtype, device=accelerator.device) + added_cond_kwargs = {"text_embeds": encoder_output_2[0], "time_ids": add_time_ids} + encoder_hidden_states = torch.cat([encoder_hidden_states_1, encoder_hidden_states_2], dim=-1) + + # Predict the noise residual + model_pred = unet(noisy_latents, timesteps, encoder_hidden_states, added_cond_kwargs=added_cond_kwargs).sample + + # Get the target for loss depending on the prediction type + if noise_scheduler.config.prediction_type == "epsilon": + target = noise + elif noise_scheduler.config.prediction_type == "v_prediction": + target = noise_scheduler.get_velocity(latents, noise, timesteps) + else: + raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") + + loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") + + accelerator.backward(loss) + + optimizer_1.step() + lr_scheduler_1.step() + optimizer_1.zero_grad() + + # Let's make sure we don't update any embedding weights besides the newly added token + index_no_updates = torch.ones((len(tokenizer_1),), dtype=torch.bool) + index_no_updates[min(placeholder_token_ids) : max(placeholder_token_ids) + 1] = False + + with torch.no_grad(): + accelerator.unwrap_model(text_encoder_1).get_input_embeddings().weight[ + index_no_updates + ] = orig_embeds_params[index_no_updates] + + # Checks if the accelerator has performed an optimization step behind the scenes + if accelerator.sync_gradients: + images = [] + progress_bar.update(1) + global_step += 1 + if global_step % args.save_steps == 0: + weight_name = ( + f"learned_embeds-steps-{global_step}.bin" + if args.no_safe_serialization + else f"learned_embeds-steps-{global_step}.safetensors" + ) + save_path = os.path.join(args.output_dir, weight_name) + save_progress( + text_encoder_1, + placeholder_token_ids, + accelerator, + args, + save_path, + safe_serialization=not args.no_safe_serialization, + ) + + if accelerator.is_main_process: + if global_step % args.checkpointing_steps == 0: + # _before_ saving state, check if this save would set us over the `checkpoints_total_limit` + if args.checkpoints_total_limit is not None: + checkpoints = os.listdir(args.output_dir) + checkpoints = [d for d in checkpoints if d.startswith("checkpoint")] + checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1])) + + # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints + if len(checkpoints) >= args.checkpoints_total_limit: + num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1 + removing_checkpoints = checkpoints[0:num_to_remove] + + logger.info( + f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints" + ) + logger.info(f"removing checkpoints: {', '.join(removing_checkpoints)}") + + for removing_checkpoint in removing_checkpoints: + removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint) + shutil.rmtree(removing_checkpoint) + + save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}") + accelerator.save_state(save_path) + logger.info(f"Saved state to {save_path}") + + if args.validation_prompt is not None and global_step % args.validation_steps == 0: + images = log_validation( + text_encoder_1, text_encoder_2, tokenizer_1, tokenizer_2, unet, vae, args, accelerator, weight_dtype, epoch + ) + + logs = {"loss": loss.detach().item(), "lr": lr_scheduler_1.get_last_lr()[0]} + progress_bar.set_postfix(**logs) + accelerator.log(logs, step=global_step) + + if global_step >= args.max_train_steps: + break + # Create the pipeline using the trained modules and save it. + accelerator.wait_for_everyone() + if accelerator.is_main_process: + if args.push_to_hub and not args.save_as_full_pipeline: + logger.warn("Enabling full model saving because --push_to_hub=True was specified.") + save_full_model = True + else: + save_full_model = args.save_as_full_pipeline + if save_full_model: + pipeline = DiffusionPipeline.from_pretrained( + args.pretrained_model_name_or_path, + text_encoder=accelerator.unwrap_model(text_encoder_1), + text_encoder_2=accelerator.unwrap_model(text_encoder_2), + vae=vae, + unet=unet, + tokenizer=tokenizer_1, + tokenizer_2=tokenizer_2, + ) + pipeline.save_pretrained(args.output_dir) + # Save the newly trained embeddings + weight_name = "learned_embeds.bin" if args.no_safe_serialization else "learned_embeds.safetensors" + save_path = os.path.join(args.output_dir, weight_name) + save_progress( + text_encoder_1, + placeholder_token_ids, + accelerator, + args, + save_path, + safe_serialization=not args.no_safe_serialization, + ) + + if args.push_to_hub: + save_model_card( + repo_id, + images=images, + base_model=args.pretrained_model_name_or_path, + repo_folder=args.output_dir, + ) + upload_folder( + repo_id=repo_id, + folder_path=args.output_dir, + commit_message="End of training", + ignore_patterns=["step_*", "epoch_*"], + ) + + accelerator.end_training() + + +if __name__ == "__main__": + main() + From 21ab27ec7d54de0dbdb45a15ef1f7a4cd494f135 Mon Sep 17 00:00:00 2001 From: jiqing-feng Date: Fri, 5 Jan 2024 06:27:03 -0500 Subject: [PATCH 05/16] fix style --- examples/textual_inversion/textual_inversion.py | 2 +- examples/textual_inversion/textual_inversion_sdxl.py | 1 - 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/examples/textual_inversion/textual_inversion.py b/examples/textual_inversion/textual_inversion.py index 319e94895aae..50bcc992064d 100644 --- a/examples/textual_inversion/textual_inversion.py +++ b/examples/textual_inversion/textual_inversion.py @@ -997,4 +997,4 @@ def main(): if __name__ == "__main__": - main() \ No newline at end of file + main() diff --git a/examples/textual_inversion/textual_inversion_sdxl.py b/examples/textual_inversion/textual_inversion_sdxl.py index f725a59848e2..c51e490aafc7 100644 --- a/examples/textual_inversion/textual_inversion_sdxl.py +++ b/examples/textual_inversion/textual_inversion_sdxl.py @@ -48,7 +48,6 @@ DDPMScheduler, DiffusionPipeline, DPMSolverMultistepScheduler, - StableDiffusionPipeline, UNet2DConditionModel, ) from diffusers.optimization import get_scheduler From a2eda7b8d4c7286aee2ba0264b632c4fa44a3a78 Mon Sep 17 00:00:00 2001 From: jiqing-feng Date: Fri, 5 Jan 2024 06:30:18 -0500 Subject: [PATCH 06/16] fix example style --- .../textual_inversion_sdxl.py | 35 ++++++++++++++----- 1 file changed, 27 insertions(+), 8 deletions(-) diff --git a/examples/textual_inversion/textual_inversion_sdxl.py b/examples/textual_inversion/textual_inversion_sdxl.py index c51e490aafc7..9ecf14ee1d06 100644 --- a/examples/textual_inversion/textual_inversion_sdxl.py +++ b/examples/textual_inversion/textual_inversion_sdxl.py @@ -111,7 +111,9 @@ def save_model_card(repo_id: str, images=None, base_model=str, repo_folder=None) f.write(yaml + model_card) -def log_validation(text_encoder_1, text_encoder_2, tokenizer_1, tokenizer_2, unet, vae, args, accelerator, weight_dtype, epoch): +def log_validation( + text_encoder_1, text_encoder_2, tokenizer_1, tokenizer_2, unet, vae, args, accelerator, weight_dtype, epoch +): logger.info( f"Running validation... \n Generating {args.num_validation_images} images with prompt:" f" {args.validation_prompt}." @@ -644,7 +646,6 @@ def main(): args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision, variant=args.variant ) - # Add the placeholder token in tokenizer_1 placeholder_tokens = [args.placeholder_token] @@ -875,17 +876,27 @@ def main(): noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) # Get the text embedding for conditioning - encoder_hidden_states_1 = text_encoder_1(batch["input_ids_1"], output_hidden_states=True).hidden_states[-2].to(dtype=weight_dtype) - encoder_output_2 = text_encoder_2(batch["input_ids_2"].reshape(batch["input_ids_1"].shape[0], -1), output_hidden_states=True) + encoder_hidden_states_1 = ( + text_encoder_1(batch["input_ids_1"], output_hidden_states=True) + .hidden_states[-2] + .to(dtype=weight_dtype) + ) + encoder_output_2 = text_encoder_2( + batch["input_ids_2"].reshape(batch["input_ids_1"].shape[0], -1), output_hidden_states=True + ) encoder_hidden_states_2 = encoder_output_2.hidden_states[-2].to(dtype=weight_dtype) sample_size = unet.config.sample_size * (2 ** (len(vae.config.block_out_channels) - 1)) original_size = (sample_size, sample_size) - add_time_ids = torch.tensor([list(original_size + (0, 0) + original_size)], dtype=weight_dtype, device=accelerator.device) + add_time_ids = torch.tensor( + [list(original_size + (0, 0) + original_size)], dtype=weight_dtype, device=accelerator.device + ) added_cond_kwargs = {"text_embeds": encoder_output_2[0], "time_ids": add_time_ids} encoder_hidden_states = torch.cat([encoder_hidden_states_1, encoder_hidden_states_2], dim=-1) # Predict the noise residual - model_pred = unet(noisy_latents, timesteps, encoder_hidden_states, added_cond_kwargs=added_cond_kwargs).sample + model_pred = unet( + noisy_latents, timesteps, encoder_hidden_states, added_cond_kwargs=added_cond_kwargs + ).sample # Get the target for loss depending on the prediction type if noise_scheduler.config.prediction_type == "epsilon": @@ -961,7 +972,16 @@ def main(): if args.validation_prompt is not None and global_step % args.validation_steps == 0: images = log_validation( - text_encoder_1, text_encoder_2, tokenizer_1, tokenizer_2, unet, vae, args, accelerator, weight_dtype, epoch + text_encoder_1, + text_encoder_2, + tokenizer_1, + tokenizer_2, + unet, + vae, + args, + accelerator, + weight_dtype, + epoch, ) logs = {"loss": loss.detach().item(), "lr": lr_scheduler_1.get_last_lr()[0]} @@ -1020,4 +1040,3 @@ def main(): if __name__ == "__main__": main() - From 322ef19cfd8f34f5992b8d80965ace61f83f5852 Mon Sep 17 00:00:00 2001 From: jiqing-feng Date: Mon, 8 Jan 2024 04:45:41 -0500 Subject: [PATCH 07/16] reset for error changes --- .../textual_inversion_sdxl.py | 122 ++++++------------ 1 file changed, 41 insertions(+), 81 deletions(-) diff --git a/examples/textual_inversion/textual_inversion_sdxl.py b/examples/textual_inversion/textual_inversion_sdxl.py index 9ecf14ee1d06..a96bc90b3b1b 100644 --- a/examples/textual_inversion/textual_inversion_sdxl.py +++ b/examples/textual_inversion/textual_inversion_sdxl.py @@ -48,6 +48,7 @@ DDPMScheduler, DiffusionPipeline, DPMSolverMultistepScheduler, + StableDiffusionPipeline, UNet2DConditionModel, ) from diffusers.optimization import get_scheduler @@ -111,18 +112,15 @@ def save_model_card(repo_id: str, images=None, base_model=str, repo_folder=None) f.write(yaml + model_card) -def log_validation( - text_encoder_1, text_encoder_2, tokenizer_1, tokenizer_2, unet, vae, args, accelerator, weight_dtype, epoch -): +def log_validation(text_encoder_1, text_encoder_2, tokenizer_1, tokenizer_2, unet, vae, args, accelerator, weight_dtype, epoch): logger.info( f"Running validation... \n Generating {args.num_validation_images} images with prompt:" f" {args.validation_prompt}." ) - # create pipeline (note: unet and vae are loaded again in float32) pipeline = DiffusionPipeline.from_pretrained( args.pretrained_model_name_or_path, text_encoder=accelerator.unwrap_model(text_encoder_1), - text_encoder_2=accelerator.unwrap_model(text_encoder_2), + text_encoder_2=text_encoder_2, tokenizer=tokenizer_1, tokenizer_2=tokenizer_2, unet=unet, @@ -361,7 +359,7 @@ def parse_args(): parser.add_argument( "--validation_prompt", type=str, - default=None, + default="A backpack", help="A prompt that is used during validation to verify that the model is learning.", ) parser.add_argument( @@ -380,16 +378,6 @@ def parse_args(): " and logging the images." ), ) - parser.add_argument( - "--validation_epochs", - type=int, - default=None, - help=( - "Deprecated in favor of validation_steps. Run validation every X epochs. Validation consists of running the prompt" - " `args.validation_prompt` multiple times: `args.num_validation_images`" - " and logging the images." - ), - ) parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank") parser.add_argument( "--checkpointing_steps", @@ -418,11 +406,6 @@ def parse_args(): parser.add_argument( "--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers." ) - parser.add_argument( - "--no_safe_serialization", - action="store_true", - help="If specified save the checkpoint not in `safetensors` format, but in original PyTorch format instead.", - ) args = parser.parse_args() env_local_rank = int(os.environ.get("LOCAL_RANK", -1)) @@ -529,6 +512,7 @@ def __init__( self.templates = imagenet_style_templates_small if learnable_property == "style" else imagenet_templates_small self.flip_transform = transforms.RandomHorizontalFlip(p=self.flip_p) + self.crop = transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size) def __len__(self): return self._length @@ -543,6 +527,18 @@ def __getitem__(self, i): placeholder_string = self.placeholder_token text = random.choice(self.templates).format(placeholder_string) + example["original_size"] = (image.height, image.width) + + if self.center_crop: + y1 = max(0, int(round((image.height - self.size) / 2.0))) + x1 = max(0, int(round((image.width - self.size) / 2.0))) + image = self.crop(image) + else: + y1, x1, h, w = self.crop.get_params(image, (self.size, self.size)) + image = transforms.functional.crop(image, y1, x1, h, w) + + example["crop_top_left"] = (y1, x1) + example["input_ids_1"] = self.tokenizer_1( text, padding="max_length", @@ -564,13 +560,7 @@ def __getitem__(self, i): if self.center_crop: crop = min(img.shape[0], img.shape[1]) - ( - h, - w, - ) = ( - img.shape[0], - img.shape[1], - ) + (h, w,) = (img.shape[0], img.shape[1],) img = img[(h - crop) // 2 : (h + crop) // 2, (w - crop) // 2 : (w + crop) // 2] image = Image.fromarray(img) @@ -646,6 +636,7 @@ def main(): args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision, variant=args.variant ) + # Add the placeholder token in tokenizer_1 placeholder_tokens = [args.placeholder_token] @@ -686,21 +677,14 @@ def main(): # Freeze vae and unet vae.requires_grad_(False) unet.requires_grad_(False) + text_encoder_2.requires_grad_(False) # Freeze all parameters except for the token embeddings in text encoder text_encoder_1.text_model.encoder.requires_grad_(False) text_encoder_1.text_model.final_layer_norm.requires_grad_(False) text_encoder_1.text_model.embeddings.position_embedding.requires_grad_(False) - text_encoder_2.text_model.encoder.requires_grad_(False) - text_encoder_2.text_model.final_layer_norm.requires_grad_(False) - text_encoder_2.text_model.embeddings.position_embedding.requires_grad_(False) if args.gradient_checkpointing: - # Keep unet in train mode if we are using gradient checkpointing to save memory. - # The dropout cannot be != 0 so it doesn't matter if we are in eval or train mode. - unet.train() text_encoder_1.gradient_checkpointing_enable() - text_encoder_2.gradient_checkpointing_enable() - unet.enable_gradient_checkpointing() if args.enable_xformers_memory_efficient_attention: if is_xformers_available(): @@ -749,15 +733,6 @@ def main(): train_dataloader = torch.utils.data.DataLoader( train_dataset, batch_size=args.train_batch_size, shuffle=True, num_workers=args.dataloader_num_workers ) - if args.validation_epochs is not None: - warnings.warn( - f"FutureWarning: You are doing logging with validation_epochs={args.validation_epochs}." - " Deprecated validation_epochs in favor of `validation_steps`" - f"Setting `args.validation_steps` to {args.validation_epochs * len(train_dataset)}", - FutureWarning, - stacklevel=2, - ) - args.validation_steps = args.validation_epochs * len(train_dataset) # Scheduler and math around the number of training steps. overrode_max_train_steps = False @@ -791,7 +766,7 @@ def main(): # Move vae and unet and text_encoder_2 to device and cast to weight_dtype unet.to(accelerator.device, dtype=weight_dtype) vae.to(accelerator.device, dtype=weight_dtype) - text_encoder_2 = text_encoder_2.to(accelerator.device, dtype=weight_dtype) + text_encoder_2.to(accelerator.device, dtype=weight_dtype) # We need to recalculate our total training steps as the size of the training dataloader may have changed. num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) @@ -876,27 +851,18 @@ def main(): noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) # Get the text embedding for conditioning - encoder_hidden_states_1 = ( - text_encoder_1(batch["input_ids_1"], output_hidden_states=True) - .hidden_states[-2] - .to(dtype=weight_dtype) - ) - encoder_output_2 = text_encoder_2( - batch["input_ids_2"].reshape(batch["input_ids_1"].shape[0], -1), output_hidden_states=True - ) + encoder_hidden_states_1 = text_encoder_1(batch["input_ids_1"], output_hidden_states=True).hidden_states[-2].to(dtype=weight_dtype) + encoder_output_2 = text_encoder_2(batch["input_ids_2"].reshape(batch["input_ids_1"].shape[0], -1), output_hidden_states=True) encoder_hidden_states_2 = encoder_output_2.hidden_states[-2].to(dtype=weight_dtype) - sample_size = unet.config.sample_size * (2 ** (len(vae.config.block_out_channels) - 1)) - original_size = (sample_size, sample_size) - add_time_ids = torch.tensor( - [list(original_size + (0, 0) + original_size)], dtype=weight_dtype, device=accelerator.device - ) + original_size = [(batch["original_size"][0][i].item(), batch["original_size"][1][i].item()) for i in range(args.train_batch_size)] + crop_top_left = [(batch["crop_top_left"][0][i].item(), batch["crop_top_left"][1][i].item()) for i in range(args.train_batch_size)] + target_size = (args.resolution, args.resolution) + add_time_ids = torch.cat([torch.tensor(original_size[i] + crop_top_left[i] + target_size) for i in range(args.train_batch_size)]).to(accelerator.device, dtype=weight_dtype) added_cond_kwargs = {"text_embeds": encoder_output_2[0], "time_ids": add_time_ids} encoder_hidden_states = torch.cat([encoder_hidden_states_1, encoder_hidden_states_2], dim=-1) # Predict the noise residual - model_pred = unet( - noisy_latents, timesteps, encoder_hidden_states, added_cond_kwargs=added_cond_kwargs - ).sample + model_pred = unet(noisy_latents, timesteps, encoder_hidden_states, added_cond_kwargs=added_cond_kwargs).sample # Get the target for loss depending on the prediction type if noise_scheduler.config.prediction_type == "epsilon": @@ -929,11 +895,7 @@ def main(): progress_bar.update(1) global_step += 1 if global_step % args.save_steps == 0: - weight_name = ( - f"learned_embeds-steps-{global_step}.bin" - if args.no_safe_serialization - else f"learned_embeds-steps-{global_step}.safetensors" - ) + weight_name = (f"learned_embeds-steps-{global_step}.safetensors") save_path = os.path.join(args.output_dir, weight_name) save_progress( text_encoder_1, @@ -941,7 +903,7 @@ def main(): accelerator, args, save_path, - safe_serialization=not args.no_safe_serialization, + safe_serialization=True, ) if accelerator.is_main_process: @@ -972,16 +934,7 @@ def main(): if args.validation_prompt is not None and global_step % args.validation_steps == 0: images = log_validation( - text_encoder_1, - text_encoder_2, - tokenizer_1, - tokenizer_2, - unet, - vae, - args, - accelerator, - weight_dtype, - epoch, + text_encoder_1, text_encoder_2, tokenizer_1, tokenizer_2, unet, vae, args, accelerator, weight_dtype, epoch ) logs = {"loss": loss.detach().item(), "lr": lr_scheduler_1.get_last_lr()[0]} @@ -993,6 +946,10 @@ def main(): # Create the pipeline using the trained modules and save it. accelerator.wait_for_everyone() if accelerator.is_main_process: + images = log_validation( + text_encoder_1, text_encoder_2, tokenizer_1, tokenizer_2, unet, vae, args, accelerator, weight_dtype, epoch + ) + if args.push_to_hub and not args.save_as_full_pipeline: logger.warn("Enabling full model saving because --push_to_hub=True was specified.") save_full_model = True @@ -1002,7 +959,7 @@ def main(): pipeline = DiffusionPipeline.from_pretrained( args.pretrained_model_name_or_path, text_encoder=accelerator.unwrap_model(text_encoder_1), - text_encoder_2=accelerator.unwrap_model(text_encoder_2), + text_encoder_2=text_encoder_2, vae=vae, unet=unet, tokenizer=tokenizer_1, @@ -1010,7 +967,7 @@ def main(): ) pipeline.save_pretrained(args.output_dir) # Save the newly trained embeddings - weight_name = "learned_embeds.bin" if args.no_safe_serialization else "learned_embeds.safetensors" + weight_name = "learned_embeds.safetensors" save_path = os.path.join(args.output_dir, weight_name) save_progress( text_encoder_1, @@ -1018,7 +975,7 @@ def main(): accelerator, args, save_path, - safe_serialization=not args.no_safe_serialization, + safe_serialization=True, ) if args.push_to_hub: @@ -1035,6 +992,9 @@ def main(): ignore_patterns=["step_*", "epoch_*"], ) + for i in range(len(images)): + images[i].save(f"cat-backpack_sdxl_test_{i}.png") + accelerator.end_training() From e796492fedbfb410cff9f5da5aefa23eee32f461 Mon Sep 17 00:00:00 2001 From: jiqing-feng Date: Mon, 8 Jan 2024 04:52:04 -0500 Subject: [PATCH 08/16] add readme for sdxl --- examples/textual_inversion/README_sdxl.md | 27 +++++++++++++++++++++++ 1 file changed, 27 insertions(+) create mode 100644 examples/textual_inversion/README_sdxl.md diff --git a/examples/textual_inversion/README_sdxl.md b/examples/textual_inversion/README_sdxl.md new file mode 100644 index 000000000000..385311efb216 --- /dev/null +++ b/examples/textual_inversion/README_sdxl.md @@ -0,0 +1,27 @@ +## Textual Inversion fine-tuning example for SDXL + +The `textual_inversion.py` do not support training stable-diffusion-XL as it has two text encoders, you can training SDXL by the following command: +``` +export MODEL_NAME="stabilityai/stable-diffusion-xl-base-1.0" +export DATA_DIR="./cat" + +accelerate launch textual_inversion_sdxl.py \ + --pretrained_model_name_or_path=$MODEL_NAME \ + --train_data_dir=$DATA_DIR \ + --learnable_property="object" \ + --placeholder_token="" \ + --initializer_token="toy" \ + --mixed_precision="bf16" \ + --resolution=1024 \ + --train_batch_size=1 \ + --gradient_accumulation_steps=4 \ + --max_train_steps=500 \ + --learning_rate=5.0e-04 \ + --scale_lr \ + --lr_scheduler="constant" \ + --lr_warmup_steps=0 \ + --save_as_full_pipeline \ + --output_dir="./textual_inversion_cat_sdxl" +``` + +We only enbled training the first text encoder because of the precision issue, we will enable training the second text encoder once we fixed the problem. \ No newline at end of file From 1aad9db0774773151a31aad0bbc23ea9fbbe2323 Mon Sep 17 00:00:00 2001 From: jiqing-feng Date: Mon, 8 Jan 2024 04:54:25 -0500 Subject: [PATCH 09/16] fix style --- examples/textual_inversion/textual_inversion_sdxl.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/examples/textual_inversion/textual_inversion_sdxl.py b/examples/textual_inversion/textual_inversion_sdxl.py index a96bc90b3b1b..2b3a0a5bcca7 100644 --- a/examples/textual_inversion/textual_inversion_sdxl.py +++ b/examples/textual_inversion/textual_inversion_sdxl.py @@ -19,7 +19,6 @@ import os import random import shutil -import warnings from pathlib import Path import numpy as np @@ -48,7 +47,6 @@ DDPMScheduler, DiffusionPipeline, DPMSolverMultistepScheduler, - StableDiffusionPipeline, UNet2DConditionModel, ) from diffusers.optimization import get_scheduler From 73de30a80a6da130819b791953e0e75b4865d4af Mon Sep 17 00:00:00 2001 From: jiqing-feng Date: Mon, 8 Jan 2024 05:21:33 -0500 Subject: [PATCH 10/16] disable autocast as it will cause cast error when weight_dtype=bf16 --- examples/textual_inversion/textual_inversion_sdxl.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/examples/textual_inversion/textual_inversion_sdxl.py b/examples/textual_inversion/textual_inversion_sdxl.py index 2b3a0a5bcca7..bf039616dcdc 100644 --- a/examples/textual_inversion/textual_inversion_sdxl.py +++ b/examples/textual_inversion/textual_inversion_sdxl.py @@ -136,8 +136,7 @@ def log_validation(text_encoder_1, text_encoder_2, tokenizer_1, tokenizer_2, une generator = None if args.seed is None else torch.Generator(device=accelerator.device).manual_seed(args.seed) images = [] for _ in range(args.num_validation_images): - with torch.autocast("cuda"): - image = pipeline(args.validation_prompt, num_inference_steps=25, generator=generator).images[0] + image = pipeline(args.validation_prompt, num_inference_steps=25, generator=generator).images[0] images.append(image) for tracker in accelerator.trackers: From 1576fdf19d29d702260ebb69acefa15aa641dbd4 Mon Sep 17 00:00:00 2001 From: jiqing-feng Date: Mon, 8 Jan 2024 05:38:10 -0500 Subject: [PATCH 11/16] fix spelling error --- examples/textual_inversion/README_sdxl.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/textual_inversion/README_sdxl.md b/examples/textual_inversion/README_sdxl.md index 385311efb216..1cd4d17fe191 100644 --- a/examples/textual_inversion/README_sdxl.md +++ b/examples/textual_inversion/README_sdxl.md @@ -24,4 +24,4 @@ accelerate launch textual_inversion_sdxl.py \ --output_dir="./textual_inversion_cat_sdxl" ``` -We only enbled training the first text encoder because of the precision issue, we will enable training the second text encoder once we fixed the problem. \ No newline at end of file +We only enabled training the first text encoder because of the precision issue, we will enable training the second text encoder once we fixed the problem. \ No newline at end of file From c1619918d377cff10614f90b804a7c79bf17bdbe Mon Sep 17 00:00:00 2001 From: jiqing-feng Date: Tue, 9 Jan 2024 08:23:51 -0500 Subject: [PATCH 12/16] fix style and readme and 8bit optimizer --- examples/textual_inversion/README.md | 2 ++ examples/textual_inversion/README_sdxl.md | 5 ++- .../textual_inversion_sdxl.py | 36 +++++++++++++------ 3 files changed, 30 insertions(+), 13 deletions(-) diff --git a/examples/textual_inversion/README.md b/examples/textual_inversion/README.md index 0a2723f0982f..3831e7cd1165 100644 --- a/examples/textual_inversion/README.md +++ b/examples/textual_inversion/README.md @@ -60,6 +60,8 @@ Now we can launch the training using: **___Note: Change the `resolution` to 768 if you are using the [stable-diffusion-2](https://huggingface.co/stabilityai/stable-diffusion-2) 768x768 model.___** +**___Note: Please follow the README_sdxl.md if you are using the [stable-diffusion-xl](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0).___** + ```bash export MODEL_NAME="runwayml/stable-diffusion-v1-5" export DATA_DIR="./cat" diff --git a/examples/textual_inversion/README_sdxl.md b/examples/textual_inversion/README_sdxl.md index 1cd4d17fe191..2c1c80f7f286 100644 --- a/examples/textual_inversion/README_sdxl.md +++ b/examples/textual_inversion/README_sdxl.md @@ -1,6 +1,5 @@ ## Textual Inversion fine-tuning example for SDXL -The `textual_inversion.py` do not support training stable-diffusion-XL as it has two text encoders, you can training SDXL by the following command: ``` export MODEL_NAME="stabilityai/stable-diffusion-xl-base-1.0" export DATA_DIR="./cat" @@ -12,7 +11,7 @@ accelerate launch textual_inversion_sdxl.py \ --placeholder_token="" \ --initializer_token="toy" \ --mixed_precision="bf16" \ - --resolution=1024 \ + --resolution=768 \ --train_batch_size=1 \ --gradient_accumulation_steps=4 \ --max_train_steps=500 \ @@ -24,4 +23,4 @@ accelerate launch textual_inversion_sdxl.py \ --output_dir="./textual_inversion_cat_sdxl" ``` -We only enabled training the first text encoder because of the precision issue, we will enable training the second text encoder once we fixed the problem. \ No newline at end of file +For now, only training of the first text encoder is supported. \ No newline at end of file diff --git a/examples/textual_inversion/textual_inversion_sdxl.py b/examples/textual_inversion/textual_inversion_sdxl.py index bf039616dcdc..261c378e9f61 100644 --- a/examples/textual_inversion/textual_inversion_sdxl.py +++ b/examples/textual_inversion/textual_inversion_sdxl.py @@ -304,6 +304,9 @@ def parse_args(): "Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process." ), ) + parser.add_argument( + "--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes." + ) parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.") parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.") parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.") @@ -707,7 +710,19 @@ def main(): ) # Initialize the optimizer - optimizer_1 = torch.optim.AdamW( + if args.use_8bit_adam: + try: + import bitsandbytes as bnb + except ImportError: + raise ImportError( + "To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`." + ) + + optimizer_class = bnb.optim.AdamW8bit + else: + optimizer_class = torch.optim.AdamW + + optimizer = optimizer_class( text_encoder_1.get_input_embeddings().parameters(), # only optimize the embeddings lr=args.learning_rate, betas=(args.adam_beta1, args.adam_beta2), @@ -715,13 +730,14 @@ def main(): eps=args.adam_epsilon, ) + placeholder_token = " ".join(tokenizer_1.convert_ids_to_tokens(placeholder_token_ids)) # Dataset and DataLoaders creation: train_dataset = TextualInversionDataset( data_root=args.train_data_dir, tokenizer_1=tokenizer_1, tokenizer_2=tokenizer_2, size=args.resolution, - placeholder_token=(" ".join(tokenizer_1.convert_ids_to_tokens(placeholder_token_ids))), + placeholder_token=placeholder_token, repeats=args.repeats, learnable_property=args.learnable_property, center_crop=args.center_crop, @@ -738,9 +754,9 @@ def main(): args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch overrode_max_train_steps = True - lr_scheduler_1 = get_scheduler( + lr_scheduler = get_scheduler( args.lr_scheduler, - optimizer=optimizer_1, + optimizer=optimizer, num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes, num_training_steps=args.max_train_steps * accelerator.num_processes, num_cycles=args.lr_num_cycles, @@ -748,8 +764,8 @@ def main(): text_encoder_1.train() # Prepare everything with our `accelerator`. - text_encoder_1, optimizer_1, train_dataloader, lr_scheduler_1 = accelerator.prepare( - text_encoder_1, optimizer_1, train_dataloader, lr_scheduler_1 + text_encoder_1, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( + text_encoder_1, optimizer, train_dataloader, lr_scheduler ) # For mixed precision training we cast all non-trainable weigths (vae, non-lora text_encoder and non-lora unet) to half-precision @@ -873,9 +889,9 @@ def main(): accelerator.backward(loss) - optimizer_1.step() - lr_scheduler_1.step() - optimizer_1.zero_grad() + optimizer.step() + lr_scheduler.step() + optimizer.zero_grad() # Let's make sure we don't update any embedding weights besides the newly added token index_no_updates = torch.ones((len(tokenizer_1),), dtype=torch.bool) @@ -934,7 +950,7 @@ def main(): text_encoder_1, text_encoder_2, tokenizer_1, tokenizer_2, unet, vae, args, accelerator, weight_dtype, epoch ) - logs = {"loss": loss.detach().item(), "lr": lr_scheduler_1.get_last_lr()[0]} + logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]} progress_bar.set_postfix(**logs) accelerator.log(logs, step=global_step) From 5d439df0ba697a3756b19db953f256443c5ce16a Mon Sep 17 00:00:00 2001 From: jiqing-feng Date: Tue, 9 Jan 2024 09:08:18 -0500 Subject: [PATCH 13/16] add README_sdxl.md link --- examples/textual_inversion/README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/textual_inversion/README.md b/examples/textual_inversion/README.md index 3831e7cd1165..9e3a622943a1 100644 --- a/examples/textual_inversion/README.md +++ b/examples/textual_inversion/README.md @@ -60,7 +60,7 @@ Now we can launch the training using: **___Note: Change the `resolution` to 768 if you are using the [stable-diffusion-2](https://huggingface.co/stabilityai/stable-diffusion-2) 768x768 model.___** -**___Note: Please follow the README_sdxl.md if you are using the [stable-diffusion-xl](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0).___** +**___Note: Please follow the [README_sdxl.md](./README_sdxl.md) if you are using the [stable-diffusion-xl](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0).___** ```bash export MODEL_NAME="runwayml/stable-diffusion-v1-5" From 0ed88f03fefc43ad5fc8da3028975fb5c139c73c Mon Sep 17 00:00:00 2001 From: jiqing-feng Date: Tue, 9 Jan 2024 10:13:41 -0500 Subject: [PATCH 14/16] add tracker key on log_validation --- .../textual_inversion/textual_inversion_sdxl.py | 17 ++++++++--------- 1 file changed, 8 insertions(+), 9 deletions(-) diff --git a/examples/textual_inversion/textual_inversion_sdxl.py b/examples/textual_inversion/textual_inversion_sdxl.py index 261c378e9f61..acdfae7926f5 100644 --- a/examples/textual_inversion/textual_inversion_sdxl.py +++ b/examples/textual_inversion/textual_inversion_sdxl.py @@ -110,7 +110,7 @@ def save_model_card(repo_id: str, images=None, base_model=str, repo_folder=None) f.write(yaml + model_card) -def log_validation(text_encoder_1, text_encoder_2, tokenizer_1, tokenizer_2, unet, vae, args, accelerator, weight_dtype, epoch): +def log_validation(text_encoder_1, text_encoder_2, tokenizer_1, tokenizer_2, unet, vae, args, accelerator, weight_dtype, epoch, is_final_validation=False): logger.info( f"Running validation... \n Generating {args.num_validation_images} images with prompt:" f" {args.validation_prompt}." @@ -139,14 +139,15 @@ def log_validation(text_encoder_1, text_encoder_2, tokenizer_1, tokenizer_2, une image = pipeline(args.validation_prompt, num_inference_steps=25, generator=generator).images[0] images.append(image) + tracker_key = "test" if is_final_validation else "validation" for tracker in accelerator.trackers: if tracker.name == "tensorboard": np_images = np.stack([np.asarray(img) for img in images]) - tracker.writer.add_images("validation", np_images, epoch, dataformats="NHWC") + tracker.writer.add_images(tracker_key, np_images, epoch, dataformats="NHWC") if tracker.name == "wandb": tracker.log( { - "validation": [ + tracker_key: [ wandb.Image(image, caption=f"{i}: {args.validation_prompt}") for i, image in enumerate(images) ] } @@ -959,9 +960,10 @@ def main(): # Create the pipeline using the trained modules and save it. accelerator.wait_for_everyone() if accelerator.is_main_process: - images = log_validation( - text_encoder_1, text_encoder_2, tokenizer_1, tokenizer_2, unet, vae, args, accelerator, weight_dtype, epoch - ) + if args.validation_prompt: + images = log_validation( + text_encoder_1, text_encoder_2, tokenizer_1, tokenizer_2, unet, vae, args, accelerator, weight_dtype, epoch, is_final_validation=True + ) if args.push_to_hub and not args.save_as_full_pipeline: logger.warn("Enabling full model saving because --push_to_hub=True was specified.") @@ -1005,9 +1007,6 @@ def main(): ignore_patterns=["step_*", "epoch_*"], ) - for i in range(len(images)): - images[i].save(f"cat-backpack_sdxl_test_{i}.png") - accelerator.end_training() From 49fa6ab1baf7266eb3fa4678c94ef56fe0b7e8e5 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Tue, 9 Jan 2024 13:17:04 +0530 Subject: [PATCH 15/16] run style --- .../textual_inversion_sdxl.py | 81 ++++++++++++++++--- 1 file changed, 68 insertions(+), 13 deletions(-) diff --git a/examples/textual_inversion/textual_inversion_sdxl.py b/examples/textual_inversion/textual_inversion_sdxl.py index acdfae7926f5..b36effb04c51 100644 --- a/examples/textual_inversion/textual_inversion_sdxl.py +++ b/examples/textual_inversion/textual_inversion_sdxl.py @@ -110,7 +110,19 @@ def save_model_card(repo_id: str, images=None, base_model=str, repo_folder=None) f.write(yaml + model_card) -def log_validation(text_encoder_1, text_encoder_2, tokenizer_1, tokenizer_2, unet, vae, args, accelerator, weight_dtype, epoch, is_final_validation=False): +def log_validation( + text_encoder_1, + text_encoder_2, + tokenizer_1, + tokenizer_2, + unet, + vae, + args, + accelerator, + weight_dtype, + epoch, + is_final_validation=False, +): logger.info( f"Running validation... \n Generating {args.num_validation_images} images with prompt:" f" {args.validation_prompt}." @@ -561,7 +573,13 @@ def __getitem__(self, i): if self.center_crop: crop = min(img.shape[0], img.shape[1]) - (h, w,) = (img.shape[0], img.shape[1],) + ( + h, + w, + ) = ( + img.shape[0], + img.shape[1], + ) img = img[(h - crop) // 2 : (h + crop) // 2, (w - crop) // 2 : (w + crop) // 2] image = Image.fromarray(img) @@ -637,7 +655,6 @@ def main(): args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision, variant=args.variant ) - # Add the placeholder token in tokenizer_1 placeholder_tokens = [args.placeholder_token] @@ -865,18 +882,37 @@ def main(): noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) # Get the text embedding for conditioning - encoder_hidden_states_1 = text_encoder_1(batch["input_ids_1"], output_hidden_states=True).hidden_states[-2].to(dtype=weight_dtype) - encoder_output_2 = text_encoder_2(batch["input_ids_2"].reshape(batch["input_ids_1"].shape[0], -1), output_hidden_states=True) + encoder_hidden_states_1 = ( + text_encoder_1(batch["input_ids_1"], output_hidden_states=True) + .hidden_states[-2] + .to(dtype=weight_dtype) + ) + encoder_output_2 = text_encoder_2( + batch["input_ids_2"].reshape(batch["input_ids_1"].shape[0], -1), output_hidden_states=True + ) encoder_hidden_states_2 = encoder_output_2.hidden_states[-2].to(dtype=weight_dtype) - original_size = [(batch["original_size"][0][i].item(), batch["original_size"][1][i].item()) for i in range(args.train_batch_size)] - crop_top_left = [(batch["crop_top_left"][0][i].item(), batch["crop_top_left"][1][i].item()) for i in range(args.train_batch_size)] + original_size = [ + (batch["original_size"][0][i].item(), batch["original_size"][1][i].item()) + for i in range(args.train_batch_size) + ] + crop_top_left = [ + (batch["crop_top_left"][0][i].item(), batch["crop_top_left"][1][i].item()) + for i in range(args.train_batch_size) + ] target_size = (args.resolution, args.resolution) - add_time_ids = torch.cat([torch.tensor(original_size[i] + crop_top_left[i] + target_size) for i in range(args.train_batch_size)]).to(accelerator.device, dtype=weight_dtype) + add_time_ids = torch.cat( + [ + torch.tensor(original_size[i] + crop_top_left[i] + target_size) + for i in range(args.train_batch_size) + ] + ).to(accelerator.device, dtype=weight_dtype) added_cond_kwargs = {"text_embeds": encoder_output_2[0], "time_ids": add_time_ids} encoder_hidden_states = torch.cat([encoder_hidden_states_1, encoder_hidden_states_2], dim=-1) # Predict the noise residual - model_pred = unet(noisy_latents, timesteps, encoder_hidden_states, added_cond_kwargs=added_cond_kwargs).sample + model_pred = unet( + noisy_latents, timesteps, encoder_hidden_states, added_cond_kwargs=added_cond_kwargs + ).sample # Get the target for loss depending on the prediction type if noise_scheduler.config.prediction_type == "epsilon": @@ -909,7 +945,7 @@ def main(): progress_bar.update(1) global_step += 1 if global_step % args.save_steps == 0: - weight_name = (f"learned_embeds-steps-{global_step}.safetensors") + weight_name = f"learned_embeds-steps-{global_step}.safetensors" save_path = os.path.join(args.output_dir, weight_name) save_progress( text_encoder_1, @@ -948,7 +984,16 @@ def main(): if args.validation_prompt is not None and global_step % args.validation_steps == 0: images = log_validation( - text_encoder_1, text_encoder_2, tokenizer_1, tokenizer_2, unet, vae, args, accelerator, weight_dtype, epoch + text_encoder_1, + text_encoder_2, + tokenizer_1, + tokenizer_2, + unet, + vae, + args, + accelerator, + weight_dtype, + epoch, ) logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]} @@ -962,8 +1007,18 @@ def main(): if accelerator.is_main_process: if args.validation_prompt: images = log_validation( - text_encoder_1, text_encoder_2, tokenizer_1, tokenizer_2, unet, vae, args, accelerator, weight_dtype, epoch, is_final_validation=True - ) + text_encoder_1, + text_encoder_2, + tokenizer_1, + tokenizer_2, + unet, + vae, + args, + accelerator, + weight_dtype, + epoch, + is_final_validation=True, + ) if args.push_to_hub and not args.save_as_full_pipeline: logger.warn("Enabling full model saving because --push_to_hub=True was specified.") From d6bb1d9c829fbfa71078286d8eea29850e7b6e6b Mon Sep 17 00:00:00 2001 From: jiqing-feng Date: Tue, 9 Jan 2024 11:25:53 -0500 Subject: [PATCH 16/16] rm the second center crop --- examples/textual_inversion/textual_inversion_sdxl.py | 11 ----------- 1 file changed, 11 deletions(-) diff --git a/examples/textual_inversion/textual_inversion_sdxl.py b/examples/textual_inversion/textual_inversion_sdxl.py index b36effb04c51..6eba281e7165 100644 --- a/examples/textual_inversion/textual_inversion_sdxl.py +++ b/examples/textual_inversion/textual_inversion_sdxl.py @@ -571,17 +571,6 @@ def __getitem__(self, i): # default to score-sde preprocessing img = np.array(image).astype(np.uint8) - if self.center_crop: - crop = min(img.shape[0], img.shape[1]) - ( - h, - w, - ) = ( - img.shape[0], - img.shape[1], - ) - img = img[(h - crop) // 2 : (h + crop) // 2, (w - crop) // 2 : (w + crop) // 2] - image = Image.fromarray(img) image = image.resize((self.size, self.size), resample=self.interpolation)