Skip to content

Commit c11de13

Browse files
authored
[training] fix training resuming problem for fp16 (SD LoRA DreamBooth) (#6554)
* fix training resume * update * update
1 parent 357855f commit c11de13

File tree

1 file changed

+44
-6
lines changed

1 file changed

+44
-6
lines changed

examples/dreambooth/train_dreambooth_lora.py

Lines changed: 44 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@
3535
from huggingface_hub.utils import insecure_hashlib
3636
from packaging import version
3737
from peft import LoraConfig
38-
from peft.utils import get_peft_model_state_dict
38+
from peft.utils import get_peft_model_state_dict, set_peft_model_state_dict
3939
from PIL import Image
4040
from PIL.ImageOps import exif_transpose
4141
from torch.utils.data import Dataset
@@ -54,7 +54,13 @@
5454
)
5555
from diffusers.loaders import LoraLoaderMixin
5656
from diffusers.optimization import get_scheduler
57-
from diffusers.utils import check_min_version, convert_state_dict_to_diffusers, is_wandb_available
57+
from diffusers.training_utils import _set_state_dict_into_text_encoder, cast_training_params
58+
from diffusers.utils import (
59+
check_min_version,
60+
convert_state_dict_to_diffusers,
61+
convert_unet_state_dict_to_peft,
62+
is_wandb_available,
63+
)
5864
from diffusers.utils.import_utils import is_xformers_available
5965
from diffusers.utils.torch_utils import is_compiled_module
6066

@@ -892,10 +898,33 @@ def load_model_hook(models, input_dir):
892898
raise ValueError(f"unexpected save model: {model.__class__}")
893899

894900
lora_state_dict, network_alphas = LoraLoaderMixin.lora_state_dict(input_dir)
895-
LoraLoaderMixin.load_lora_into_unet(lora_state_dict, network_alphas=network_alphas, unet=unet_)
896-
LoraLoaderMixin.load_lora_into_text_encoder(
897-
lora_state_dict, network_alphas=network_alphas, text_encoder=text_encoder_
898-
)
901+
902+
unet_state_dict = {f'{k.replace("unet.", "")}': v for k, v in lora_state_dict.items() if k.startswith("unet.")}
903+
unet_state_dict = convert_unet_state_dict_to_peft(unet_state_dict)
904+
incompatible_keys = set_peft_model_state_dict(unet_, unet_state_dict, adapter_name="default")
905+
906+
if incompatible_keys is not None:
907+
# check only for unexpected keys
908+
unexpected_keys = getattr(incompatible_keys, "unexpected_keys", None)
909+
if unexpected_keys:
910+
logger.warning(
911+
f"Loading adapter weights from state_dict led to unexpected keys not found in the model: "
912+
f" {unexpected_keys}. "
913+
)
914+
915+
if args.train_text_encoder:
916+
_set_state_dict_into_text_encoder(lora_state_dict, prefix="text_encoder.", text_encoder=text_encoder_)
917+
918+
# Make sure the trainable params are in float32. This is again needed since the base models
919+
# are in `weight_dtype`. More details:
920+
# https://github.com/huggingface/diffusers/pull/6514#discussion_r1449796804
921+
if args.mixed_precision == "fp16":
922+
models = [unet_]
923+
if args.train_text_encoder:
924+
models.append(text_encoder_)
925+
926+
# only upcast trainable parameters (LoRA) into fp32
927+
cast_training_params(models, dtype=torch.float32)
899928

900929
accelerator.register_save_state_pre_hook(save_model_hook)
901930
accelerator.register_load_state_pre_hook(load_model_hook)
@@ -910,6 +939,15 @@ def load_model_hook(models, input_dir):
910939
args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes
911940
)
912941

942+
# Make sure the trainable params are in float32.
943+
if args.mixed_precision == "fp16":
944+
models = [unet]
945+
if args.train_text_encoder:
946+
models.append(text_encoder)
947+
948+
# only upcast trainable parameters (LoRA) into fp32
949+
cast_training_params(models, dtype=torch.float32)
950+
913951
# Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs
914952
if args.use_8bit_adam:
915953
try:

0 commit comments

Comments
 (0)