Skip to content

Commit b238c1e

Browse files
committed
Fixes #6418 Advanced Dreambooth LoRa Training
1 parent 79df503 commit b238c1e

File tree

1 file changed

+56
-27
lines changed

1 file changed

+56
-27
lines changed

examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py

Lines changed: 56 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@
3838
from accelerate.utils import DistributedDataParallelKwargs, ProjectConfiguration, set_seed
3939
from huggingface_hub import create_repo, upload_folder
4040
from packaging import version
41-
from peft import LoraConfig
41+
from peft import LoraConfig, set_peft_model_state_dict
4242
from peft.utils import get_peft_model_state_dict
4343
from PIL import Image
4444
from PIL.ImageOps import exif_transpose
@@ -58,15 +58,17 @@
5858
)
5959
from diffusers.loaders import LoraLoaderMixin
6060
from diffusers.optimization import get_scheduler
61-
from diffusers.training_utils import compute_snr
61+
from diffusers.training_utils import _set_state_dict_into_text_encoder, compute_snr
6262
from diffusers.utils import (
6363
check_min_version,
6464
convert_all_state_dict_to_peft,
6565
convert_state_dict_to_diffusers,
66+
convert_unet_state_dict_to_peft,
6667
convert_state_dict_to_kohya,
6768
is_wandb_available,
6869
)
6970
from diffusers.utils.import_utils import is_xformers_available
71+
from diffusers.utils.torch_utils import is_compiled_module
7072

7173

7274
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
@@ -1292,16 +1294,10 @@ def main(args):
12921294
else:
12931295
param.requires_grad = False
12941296

1295-
# Make sure the trainable params are in float32.
1296-
if args.mixed_precision == "fp16":
1297-
models = [unet]
1298-
if args.train_text_encoder:
1299-
models.extend([text_encoder_one, text_encoder_two])
1300-
for model in models:
1301-
for param in model.parameters():
1302-
# only upcast trainable parameters (LoRA) into fp32
1303-
if param.requires_grad:
1304-
param.data = param.to(torch.float32)
1297+
def unwrap_model(model):
1298+
model = accelerator.unwrap_model(model)
1299+
model = model._orig_mod if is_compiled_module(model) else model
1300+
return model
13051301

13061302
# create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
13071303
def save_model_hook(models, weights, output_dir):
@@ -1313,14 +1309,14 @@ def save_model_hook(models, weights, output_dir):
13131309
text_encoder_two_lora_layers_to_save = None
13141310

13151311
for model in models:
1316-
if isinstance(model, type(accelerator.unwrap_model(unet))):
1312+
if isinstance(model, type(unwrap_model(unet))):
13171313
unet_lora_layers_to_save = convert_state_dict_to_diffusers(get_peft_model_state_dict(model))
1318-
elif isinstance(model, type(accelerator.unwrap_model(text_encoder_one))):
1314+
elif isinstance(model, type(unwrap_model(text_encoder_one))):
13191315
if args.train_text_encoder:
13201316
text_encoder_one_lora_layers_to_save = convert_state_dict_to_diffusers(
13211317
get_peft_model_state_dict(model)
13221318
)
1323-
elif isinstance(model, type(accelerator.unwrap_model(text_encoder_two))):
1319+
elif isinstance(model, type(unwrap_model(text_encoder_two))):
13241320
if args.train_text_encoder:
13251321
text_encoder_two_lora_layers_to_save = convert_state_dict_to_diffusers(
13261322
get_peft_model_state_dict(model)
@@ -1348,27 +1344,49 @@ def load_model_hook(models, input_dir):
13481344
while len(models) > 0:
13491345
model = models.pop()
13501346

1351-
if isinstance(model, type(accelerator.unwrap_model(unet))):
1347+
if isinstance(model, type(unwrap_model(unet))):
13521348
unet_ = model
1353-
elif isinstance(model, type(accelerator.unwrap_model(text_encoder_one))):
1349+
elif isinstance(model, type(unwrap_model(text_encoder_one))):
13541350
text_encoder_one_ = model
1355-
elif isinstance(model, type(accelerator.unwrap_model(text_encoder_two))):
1351+
elif isinstance(model, type(unwrap_model(text_encoder_two))):
13561352
text_encoder_two_ = model
13571353
else:
13581354
raise ValueError(f"unexpected save model: {model.__class__}")
13591355

13601356
lora_state_dict, network_alphas = LoraLoaderMixin.lora_state_dict(input_dir)
1361-
LoraLoaderMixin.load_lora_into_unet(lora_state_dict, network_alphas=network_alphas, unet=unet_)
13621357

1363-
text_encoder_state_dict = {k: v for k, v in lora_state_dict.items() if "text_encoder." in k}
1364-
LoraLoaderMixin.load_lora_into_text_encoder(
1365-
text_encoder_state_dict, network_alphas=network_alphas, text_encoder=text_encoder_one_
1366-
)
1358+
unet_state_dict = {f'{k.replace("unet.", "")}': v for k, v in lora_state_dict.items() if k.startswith("unet.")}
1359+
unet_state_dict = convert_unet_state_dict_to_peft(unet_state_dict)
1360+
incompatible_keys = set_peft_model_state_dict(unet_, unet_state_dict, adapter_name="default")
1361+
if incompatible_keys is not None:
1362+
# check only for unexpected keys
1363+
unexpected_keys = getattr(incompatible_keys, "unexpected_keys", None)
1364+
if unexpected_keys:
1365+
logger.warning(
1366+
f"Loading adapter weights from state_dict led to unexpected keys not found in the model: "
1367+
f" {unexpected_keys}. "
1368+
)
13671369

1368-
text_encoder_2_state_dict = {k: v for k, v in lora_state_dict.items() if "text_encoder_2." in k}
1369-
LoraLoaderMixin.load_lora_into_text_encoder(
1370-
text_encoder_2_state_dict, network_alphas=network_alphas, text_encoder=text_encoder_two_
1371-
)
1370+
if args.train_text_encoder:
1371+
# Do we need to call `scale_lora_layers()` here?
1372+
_set_state_dict_into_text_encoder(lora_state_dict, prefix="text_encoder.", text_encoder=text_encoder_one_)
1373+
1374+
_set_state_dict_into_text_encoder(
1375+
lora_state_dict, prefix="text_encoder_2.", text_encoder=text_encoder_two_
1376+
)
1377+
1378+
# Make sure the trainable params are in float32. This is again needed since the base models
1379+
# are in `weight_dtype`. More details:
1380+
# https://github.com/huggingface/diffusers/pull/6514#discussion_r1449796804
1381+
if args.mixed_precision == "fp16":
1382+
models = [unet_]
1383+
if args.train_text_encoder:
1384+
models.extend([text_encoder_one_, text_encoder_two_])
1385+
for model in models:
1386+
for param in model.parameters():
1387+
# only upcast trainable parameters (LoRA) into fp32
1388+
if param.requires_grad:
1389+
param.data = param.to(torch.float32)
13721390

13731391
accelerator.register_save_state_pre_hook(save_model_hook)
13741392
accelerator.register_load_state_pre_hook(load_model_hook)
@@ -1383,6 +1401,17 @@ def load_model_hook(models, input_dir):
13831401
args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes
13841402
)
13851403

1404+
# Make sure the trainable params are in float32.
1405+
if args.mixed_precision == "fp16":
1406+
models = [unet]
1407+
if args.train_text_encoder:
1408+
models.extend([text_encoder_one, text_encoder_two])
1409+
for model in models:
1410+
for param in model.parameters():
1411+
# only upcast trainable parameters (LoRA) into fp32
1412+
if param.requires_grad:
1413+
param.data = param.to(torch.float32)
1414+
13861415
unet_lora_parameters = list(filter(lambda p: p.requires_grad, unet.parameters()))
13871416

13881417
if args.train_text_encoder:

0 commit comments

Comments
 (0)