From cc4abfdc6ce603d0b81c7e14643e645d20aabc98 Mon Sep 17 00:00:00 2001 From: a-r-r-o-w Date: Fri, 12 Jan 2024 18:22:32 +0530 Subject: [PATCH 1/3] fix training resume --- examples/dreambooth/train_dreambooth_lora.py | 54 +++++++++++++++++--- 1 file changed, 48 insertions(+), 6 deletions(-) diff --git a/examples/dreambooth/train_dreambooth_lora.py b/examples/dreambooth/train_dreambooth_lora.py index 2d2629b2fd87..4125126792d1 100644 --- a/examples/dreambooth/train_dreambooth_lora.py +++ b/examples/dreambooth/train_dreambooth_lora.py @@ -35,7 +35,7 @@ from huggingface_hub.utils import insecure_hashlib from packaging import version from peft import LoraConfig -from peft.utils import get_peft_model_state_dict +from peft.utils import get_peft_model_state_dict, set_peft_model_state_dict from PIL import Image from PIL.ImageOps import exif_transpose from torch.utils.data import Dataset @@ -54,7 +54,13 @@ ) from diffusers.loaders import LoraLoaderMixin from diffusers.optimization import get_scheduler -from diffusers.utils import check_min_version, convert_state_dict_to_diffusers, is_wandb_available +from diffusers.training_utils import _set_state_dict_into_text_encoder +from diffusers.utils import ( + check_min_version, + convert_state_dict_to_diffusers, + convert_unet_state_dict_to_peft, + is_wandb_available, +) from diffusers.utils.import_utils import is_xformers_available from diffusers.utils.torch_utils import is_compiled_module @@ -892,10 +898,35 @@ def load_model_hook(models, input_dir): raise ValueError(f"unexpected save model: {model.__class__}") lora_state_dict, network_alphas = LoraLoaderMixin.lora_state_dict(input_dir) - LoraLoaderMixin.load_lora_into_unet(lora_state_dict, network_alphas=network_alphas, unet=unet_) - LoraLoaderMixin.load_lora_into_text_encoder( - lora_state_dict, network_alphas=network_alphas, text_encoder=text_encoder_ - ) + + unet_state_dict = {f'{k.replace("unet.", "")}': v for k, v in lora_state_dict.items() if k.startswith("unet.")} + unet_state_dict = convert_unet_state_dict_to_peft(unet_state_dict) + incompatible_keys = set_peft_model_state_dict(unet_, unet_state_dict, adapter_name="default") + + if incompatible_keys is not None: + # check only for unexpected keys + unexpected_keys = getattr(incompatible_keys, "unexpected_keys", None) + if unexpected_keys: + logger.warning( + f"Loading adapter weights from state_dict led to unexpected keys not found in the model: " + f" {unexpected_keys}. " + ) + + if args.train_text_encoder: + _set_state_dict_into_text_encoder(lora_state_dict, prefix="text_encoder.", text_encoder=text_encoder_) + + # Make sure the trainable params are in float32. This is again needed since the base models + # are in `weight_dtype`. More details: + # https://github.com/huggingface/diffusers/pull/6514#discussion_r1449796804 + if args.mixed_precision == "fp16": + models = [unet_] + if args.train_text_encoder: + models.append(text_encoder_) + for model in models: + for param in model.parameters(): + # only upcast trainable parameters (LoRA) into fp32 + if param.requires_grad: + param.data = param.to(torch.float32) accelerator.register_save_state_pre_hook(save_model_hook) accelerator.register_load_state_pre_hook(load_model_hook) @@ -910,6 +941,17 @@ def load_model_hook(models, input_dir): args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes ) + # Make sure the trainable params are in float32. + if args.mixed_precision == "fp16": + models = [unet] + if args.train_text_encoder: + models.append(text_encoder) + for model in models: + for param in model.parameters(): + # only upcast trainable parameters (LoRA) into fp32 + if param.requires_grad: + param.data = param.to(torch.float32) + # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs if args.use_8bit_adam: try: From 9bb5760f201d852740ff628231eb8f4679fdb922 Mon Sep 17 00:00:00 2001 From: a-r-r-o-w Date: Tue, 16 Jan 2024 00:49:11 +0530 Subject: [PATCH 2/3] update --- examples/dreambooth/train_dreambooth_lora.py | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/examples/dreambooth/train_dreambooth_lora.py b/examples/dreambooth/train_dreambooth_lora.py index 4125126792d1..80282a6e42e9 100644 --- a/examples/dreambooth/train_dreambooth_lora.py +++ b/examples/dreambooth/train_dreambooth_lora.py @@ -54,7 +54,7 @@ ) from diffusers.loaders import LoraLoaderMixin from diffusers.optimization import get_scheduler -from diffusers.training_utils import _set_state_dict_into_text_encoder +from diffusers.training_utils import _set_state_dict_into_text_encoder, cast_training_params from diffusers.utils import ( check_min_version, convert_state_dict_to_diffusers, @@ -922,11 +922,7 @@ def load_model_hook(models, input_dir): models = [unet_] if args.train_text_encoder: models.append(text_encoder_) - for model in models: - for param in model.parameters(): - # only upcast trainable parameters (LoRA) into fp32 - if param.requires_grad: - param.data = param.to(torch.float32) + cast_training_params(models, dtype=torch.float32) accelerator.register_save_state_pre_hook(save_model_hook) accelerator.register_load_state_pre_hook(load_model_hook) From 9f76b50541cf27cc0479fe86838af37828a4b15b Mon Sep 17 00:00:00 2001 From: a-r-r-o-w Date: Tue, 16 Jan 2024 00:54:18 +0530 Subject: [PATCH 3/3] update --- examples/dreambooth/train_dreambooth_lora.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/examples/dreambooth/train_dreambooth_lora.py b/examples/dreambooth/train_dreambooth_lora.py index 80282a6e42e9..3724e3d140d9 100644 --- a/examples/dreambooth/train_dreambooth_lora.py +++ b/examples/dreambooth/train_dreambooth_lora.py @@ -922,6 +922,8 @@ def load_model_hook(models, input_dir): models = [unet_] if args.train_text_encoder: models.append(text_encoder_) + + # only upcast trainable parameters (LoRA) into fp32 cast_training_params(models, dtype=torch.float32) accelerator.register_save_state_pre_hook(save_model_hook) @@ -942,11 +944,9 @@ def load_model_hook(models, input_dir): models = [unet] if args.train_text_encoder: models.append(text_encoder) - for model in models: - for param in model.parameters(): - # only upcast trainable parameters (LoRA) into fp32 - if param.requires_grad: - param.data = param.to(torch.float32) + + # only upcast trainable parameters (LoRA) into fp32 + cast_training_params(models, dtype=torch.float32) # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs if args.use_8bit_adam: