Skip to content

Commit 79df503

Browse files
authored
[Training] fix training resuming problem when using FP16 (SDXL LoRA DreamBooth) (#6514)
* fix: training resume from fp16. * add: comment * remove residue from another branch. * remove more residues. * thanks to Younes; no hacks. * style. * clean things a bit and modularize _set_state_dict_into_text_encoder * add comment about the fix detailed.
1 parent 7d63182 commit 79df503

File tree

3 files changed

+80
-25
lines changed

3 files changed

+80
-25
lines changed

examples/dreambooth/train_dreambooth_lora_sdxl.py

Lines changed: 51 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@
3434
from huggingface_hub import create_repo, upload_folder
3535
from huggingface_hub.utils import insecure_hashlib
3636
from packaging import version
37-
from peft import LoraConfig
37+
from peft import LoraConfig, set_peft_model_state_dict
3838
from peft.utils import get_peft_model_state_dict
3939
from PIL import Image
4040
from PIL.ImageOps import exif_transpose
@@ -53,8 +53,13 @@
5353
)
5454
from diffusers.loaders import LoraLoaderMixin
5555
from diffusers.optimization import get_scheduler
56-
from diffusers.training_utils import compute_snr
57-
from diffusers.utils import check_min_version, convert_state_dict_to_diffusers, is_wandb_available
56+
from diffusers.training_utils import _set_state_dict_into_text_encoder, compute_snr
57+
from diffusers.utils import (
58+
check_min_version,
59+
convert_state_dict_to_diffusers,
60+
convert_unet_state_dict_to_peft,
61+
is_wandb_available,
62+
)
5863
from diffusers.utils.import_utils import is_xformers_available
5964
from diffusers.utils.torch_utils import is_compiled_module
6065

@@ -997,17 +1002,6 @@ def main(args):
9971002
text_encoder_one.add_adapter(text_lora_config)
9981003
text_encoder_two.add_adapter(text_lora_config)
9991004

1000-
# Make sure the trainable params are in float32.
1001-
if args.mixed_precision == "fp16":
1002-
models = [unet]
1003-
if args.train_text_encoder:
1004-
models.extend([text_encoder_one, text_encoder_two])
1005-
for model in models:
1006-
for param in model.parameters():
1007-
# only upcast trainable parameters (LoRA) into fp32
1008-
if param.requires_grad:
1009-
param.data = param.to(torch.float32)
1010-
10111005
def unwrap_model(model):
10121006
model = accelerator.unwrap_model(model)
10131007
model = model._orig_mod if is_compiled_module(model) else model
@@ -1064,17 +1058,39 @@ def load_model_hook(models, input_dir):
10641058
raise ValueError(f"unexpected save model: {model.__class__}")
10651059

10661060
lora_state_dict, network_alphas = LoraLoaderMixin.lora_state_dict(input_dir)
1067-
LoraLoaderMixin.load_lora_into_unet(lora_state_dict, network_alphas=network_alphas, unet=unet_)
10681061

1069-
text_encoder_state_dict = {k: v for k, v in lora_state_dict.items() if "text_encoder." in k}
1070-
LoraLoaderMixin.load_lora_into_text_encoder(
1071-
text_encoder_state_dict, network_alphas=network_alphas, text_encoder=text_encoder_one_
1072-
)
1062+
unet_state_dict = {f'{k.replace("unet.", "")}': v for k, v in lora_state_dict.items() if k.startswith("unet.")}
1063+
unet_state_dict = convert_unet_state_dict_to_peft(unet_state_dict)
1064+
incompatible_keys = set_peft_model_state_dict(unet_, unet_state_dict, adapter_name="default")
1065+
if incompatible_keys is not None:
1066+
# check only for unexpected keys
1067+
unexpected_keys = getattr(incompatible_keys, "unexpected_keys", None)
1068+
if unexpected_keys:
1069+
logger.warning(
1070+
f"Loading adapter weights from state_dict led to unexpected keys not found in the model: "
1071+
f" {unexpected_keys}. "
1072+
)
10731073

1074-
text_encoder_2_state_dict = {k: v for k, v in lora_state_dict.items() if "text_encoder_2." in k}
1075-
LoraLoaderMixin.load_lora_into_text_encoder(
1076-
text_encoder_2_state_dict, network_alphas=network_alphas, text_encoder=text_encoder_two_
1077-
)
1074+
if args.train_text_encoder:
1075+
# Do we need to call `scale_lora_layers()` here?
1076+
_set_state_dict_into_text_encoder(lora_state_dict, prefix="text_encoder.", text_encoder=text_encoder_one_)
1077+
1078+
_set_state_dict_into_text_encoder(
1079+
lora_state_dict, prefix="text_encoder_2.", text_encoder=text_encoder_one_
1080+
)
1081+
1082+
# Make sure the trainable params are in float32. This is again needed since the base models
1083+
# are in `weight_dtype`. More details:
1084+
# https://github.com/huggingface/diffusers/pull/6514#discussion_r1449796804
1085+
if args.mixed_precision == "fp16":
1086+
models = [unet_]
1087+
if args.train_text_encoder:
1088+
models.extend([text_encoder_one_, text_encoder_two_])
1089+
for model in models:
1090+
for param in model.parameters():
1091+
# only upcast trainable parameters (LoRA) into fp32
1092+
if param.requires_grad:
1093+
param.data = param.to(torch.float32)
10781094

10791095
accelerator.register_save_state_pre_hook(save_model_hook)
10801096
accelerator.register_load_state_pre_hook(load_model_hook)
@@ -1089,6 +1105,17 @@ def load_model_hook(models, input_dir):
10891105
args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes
10901106
)
10911107

1108+
# Make sure the trainable params are in float32.
1109+
if args.mixed_precision == "fp16":
1110+
models = [unet]
1111+
if args.train_text_encoder:
1112+
models.extend([text_encoder_one, text_encoder_two])
1113+
for model in models:
1114+
for param in model.parameters():
1115+
# only upcast trainable parameters (LoRA) into fp32
1116+
if param.requires_grad:
1117+
param.data = param.to(torch.float32)
1118+
10921119
unet_lora_parameters = list(filter(lambda p: p.requires_grad, unet.parameters()))
10931120

10941121
if args.train_text_encoder:
@@ -1506,6 +1533,7 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers):
15061533
else unet_lora_parameters
15071534
)
15081535
accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
1536+
15091537
optimizer.step()
15101538
lr_scheduler.step()
15111539
optimizer.zero_grad()

src/diffusers/loaders/lora.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -581,7 +581,6 @@ def load_lora_into_text_encoder(
581581
lora_config_kwargs = get_peft_kwargs(
582582
rank, network_alphas, text_encoder_lora_state_dict, is_unet=False
583583
)
584-
585584
lora_config = LoraConfig(**lora_config_kwargs)
586585

587586
# adapter_name

src/diffusers/training_utils.py

Lines changed: 29 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,12 +8,21 @@
88
from torchvision import transforms
99

1010
from .models import UNet2DConditionModel
11-
from .utils import deprecate, is_transformers_available
11+
from .utils import (
12+
convert_state_dict_to_diffusers,
13+
convert_state_dict_to_peft,
14+
deprecate,
15+
is_peft_available,
16+
is_transformers_available,
17+
)
1218

1319

1420
if is_transformers_available():
1521
import transformers
1622

23+
if is_peft_available():
24+
from peft import set_peft_model_state_dict
25+
1726

1827
def set_seed(seed: int):
1928
"""
@@ -112,6 +121,25 @@ def unet_lora_state_dict(unet: UNet2DConditionModel) -> Dict[str, torch.Tensor]:
112121
return lora_state_dict
113122

114123

124+
def _set_state_dict_into_text_encoder(
125+
lora_state_dict: Dict[str, torch.Tensor], prefix: str, text_encoder: torch.nn.Module
126+
):
127+
"""
128+
Sets the `lora_state_dict` into `text_encoder` coming from `transformers`.
129+
130+
Args:
131+
lora_state_dict: The state dictionary to be set.
132+
prefix: String identifier to retrieve the portion of the state dict that belongs to `text_encoder`.
133+
text_encoder: Where the `lora_state_dict` is to be set.
134+
"""
135+
136+
text_encoder_state_dict = {
137+
f'{k.replace(prefix, "")}': v for k, v in lora_state_dict.items() if k.startswith(prefix)
138+
}
139+
text_encoder_state_dict = convert_state_dict_to_peft(convert_state_dict_to_diffusers(text_encoder_state_dict))
140+
set_peft_model_state_dict(text_encoder, text_encoder_state_dict, adapter_name="default")
141+
142+
115143
# Adapted from torch-ema https://github.com/fadel/pytorch_ema/blob/master/torch_ema/ema.py#L14
116144
class EMAModel:
117145
"""

0 commit comments

Comments
 (0)