Skip to content

Commit 181280b

Browse files
authored
Fixes training resuming: Advanced Dreambooth LoRa Training (#6566)
* Fixes #6418 Advanced Dreambooth LoRa Training * change order of import to fix nit * fix nit, use cast_training_params * remove torch.compile fix, will move to a new PR * remove unnecessary import
1 parent 53f498d commit 181280b

File tree

1 file changed

+36
-22
lines changed

1 file changed

+36
-22
lines changed

examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py

Lines changed: 36 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@
3838
from accelerate.utils import DistributedDataParallelKwargs, ProjectConfiguration, set_seed
3939
from huggingface_hub import create_repo, upload_folder
4040
from packaging import version
41-
from peft import LoraConfig
41+
from peft import LoraConfig, set_peft_model_state_dict
4242
from peft.utils import get_peft_model_state_dict
4343
from PIL import Image
4444
from PIL.ImageOps import exif_transpose
@@ -58,12 +58,13 @@
5858
)
5959
from diffusers.loaders import LoraLoaderMixin
6060
from diffusers.optimization import get_scheduler
61-
from diffusers.training_utils import compute_snr
61+
from diffusers.training_utils import _set_state_dict_into_text_encoder, cast_training_params, compute_snr
6262
from diffusers.utils import (
6363
check_min_version,
6464
convert_all_state_dict_to_peft,
6565
convert_state_dict_to_diffusers,
6666
convert_state_dict_to_kohya,
67+
convert_unet_state_dict_to_peft,
6768
is_wandb_available,
6869
)
6970
from diffusers.utils.import_utils import is_xformers_available
@@ -1292,17 +1293,6 @@ def main(args):
12921293
else:
12931294
param.requires_grad = False
12941295

1295-
# Make sure the trainable params are in float32.
1296-
if args.mixed_precision == "fp16":
1297-
models = [unet]
1298-
if args.train_text_encoder:
1299-
models.extend([text_encoder_one, text_encoder_two])
1300-
for model in models:
1301-
for param in model.parameters():
1302-
# only upcast trainable parameters (LoRA) into fp32
1303-
if param.requires_grad:
1304-
param.data = param.to(torch.float32)
1305-
13061296
# create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
13071297
def save_model_hook(models, weights, output_dir):
13081298
if accelerator.is_main_process:
@@ -1358,17 +1348,34 @@ def load_model_hook(models, input_dir):
13581348
raise ValueError(f"unexpected save model: {model.__class__}")
13591349

13601350
lora_state_dict, network_alphas = LoraLoaderMixin.lora_state_dict(input_dir)
1361-
LoraLoaderMixin.load_lora_into_unet(lora_state_dict, network_alphas=network_alphas, unet=unet_)
13621351

1363-
text_encoder_state_dict = {k: v for k, v in lora_state_dict.items() if "text_encoder." in k}
1364-
LoraLoaderMixin.load_lora_into_text_encoder(
1365-
text_encoder_state_dict, network_alphas=network_alphas, text_encoder=text_encoder_one_
1366-
)
1352+
unet_state_dict = {f'{k.replace("unet.", "")}': v for k, v in lora_state_dict.items() if k.startswith("unet.")}
1353+
unet_state_dict = convert_unet_state_dict_to_peft(unet_state_dict)
1354+
incompatible_keys = set_peft_model_state_dict(unet_, unet_state_dict, adapter_name="default")
1355+
if incompatible_keys is not None:
1356+
# check only for unexpected keys
1357+
unexpected_keys = getattr(incompatible_keys, "unexpected_keys", None)
1358+
if unexpected_keys:
1359+
logger.warning(
1360+
f"Loading adapter weights from state_dict led to unexpected keys not found in the model: "
1361+
f" {unexpected_keys}. "
1362+
)
13671363

1368-
text_encoder_2_state_dict = {k: v for k, v in lora_state_dict.items() if "text_encoder_2." in k}
1369-
LoraLoaderMixin.load_lora_into_text_encoder(
1370-
text_encoder_2_state_dict, network_alphas=network_alphas, text_encoder=text_encoder_two_
1371-
)
1364+
if args.train_text_encoder:
1365+
_set_state_dict_into_text_encoder(lora_state_dict, prefix="text_encoder.", text_encoder=text_encoder_one_)
1366+
1367+
_set_state_dict_into_text_encoder(
1368+
lora_state_dict, prefix="text_encoder_2.", text_encoder=text_encoder_two_
1369+
)
1370+
1371+
# Make sure the trainable params are in float32. This is again needed since the base models
1372+
# are in `weight_dtype`. More details:
1373+
# https://github.com/huggingface/diffusers/pull/6514#discussion_r1449796804
1374+
if args.mixed_precision == "fp16":
1375+
models = [unet_]
1376+
if args.train_text_encoder:
1377+
models.extend([text_encoder_one_, text_encoder_two_])
1378+
cast_training_params(models)
13721379

13731380
accelerator.register_save_state_pre_hook(save_model_hook)
13741381
accelerator.register_load_state_pre_hook(load_model_hook)
@@ -1383,6 +1390,13 @@ def load_model_hook(models, input_dir):
13831390
args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes
13841391
)
13851392

1393+
# Make sure the trainable params are in float32.
1394+
if args.mixed_precision == "fp16":
1395+
models = [unet]
1396+
if args.train_text_encoder:
1397+
models.extend([text_encoder_one, text_encoder_two])
1398+
cast_training_params(models, dtype=torch.float32)
1399+
13861400
unet_lora_parameters = list(filter(lambda p: p.requires_grad, unet.parameters()))
13871401

13881402
if args.train_text_encoder:

0 commit comments

Comments
 (0)