From 6ccc66cc810c0854c58acc08aa8447f6c95b77b2 Mon Sep 17 00:00:00 2001 From: Steve Rhoades Date: Sat, 13 Jan 2024 09:50:17 -0800 Subject: [PATCH 1/5] Fixes #6418 Advanced Dreambooth LoRa Training --- .../train_dreambooth_lora_sdxl_advanced.py | 83 +++++++++++++------ 1 file changed, 56 insertions(+), 27 deletions(-) diff --git a/examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py b/examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py index ddd8114ae4f2..1d34fc91df17 100644 --- a/examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py +++ b/examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py @@ -38,7 +38,7 @@ from accelerate.utils import DistributedDataParallelKwargs, ProjectConfiguration, set_seed from huggingface_hub import create_repo, upload_folder from packaging import version -from peft import LoraConfig +from peft import LoraConfig, set_peft_model_state_dict from peft.utils import get_peft_model_state_dict from PIL import Image from PIL.ImageOps import exif_transpose @@ -58,15 +58,17 @@ ) from diffusers.loaders import LoraLoaderMixin from diffusers.optimization import get_scheduler -from diffusers.training_utils import compute_snr +from diffusers.training_utils import _set_state_dict_into_text_encoder, compute_snr from diffusers.utils import ( check_min_version, convert_all_state_dict_to_peft, convert_state_dict_to_diffusers, + convert_unet_state_dict_to_peft, convert_state_dict_to_kohya, is_wandb_available, ) from diffusers.utils.import_utils import is_xformers_available +from diffusers.utils.torch_utils import is_compiled_module # Will error if the minimal version of diffusers is not installed. Remove at your own risks. @@ -1292,16 +1294,10 @@ def main(args): else: param.requires_grad = False - # Make sure the trainable params are in float32. - if args.mixed_precision == "fp16": - models = [unet] - if args.train_text_encoder: - models.extend([text_encoder_one, text_encoder_two]) - for model in models: - for param in model.parameters(): - # only upcast trainable parameters (LoRA) into fp32 - if param.requires_grad: - param.data = param.to(torch.float32) + def unwrap_model(model): + model = accelerator.unwrap_model(model) + model = model._orig_mod if is_compiled_module(model) else model + return model # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format def save_model_hook(models, weights, output_dir): @@ -1313,14 +1309,14 @@ def save_model_hook(models, weights, output_dir): text_encoder_two_lora_layers_to_save = None for model in models: - if isinstance(model, type(accelerator.unwrap_model(unet))): + if isinstance(model, type(unwrap_model(unet))): unet_lora_layers_to_save = convert_state_dict_to_diffusers(get_peft_model_state_dict(model)) - elif isinstance(model, type(accelerator.unwrap_model(text_encoder_one))): + elif isinstance(model, type(unwrap_model(text_encoder_one))): if args.train_text_encoder: text_encoder_one_lora_layers_to_save = convert_state_dict_to_diffusers( get_peft_model_state_dict(model) ) - elif isinstance(model, type(accelerator.unwrap_model(text_encoder_two))): + elif isinstance(model, type(unwrap_model(text_encoder_two))): if args.train_text_encoder: text_encoder_two_lora_layers_to_save = convert_state_dict_to_diffusers( get_peft_model_state_dict(model) @@ -1348,27 +1344,49 @@ def load_model_hook(models, input_dir): while len(models) > 0: model = models.pop() - if isinstance(model, type(accelerator.unwrap_model(unet))): + if isinstance(model, type(unwrap_model(unet))): unet_ = model - elif isinstance(model, type(accelerator.unwrap_model(text_encoder_one))): + elif isinstance(model, type(unwrap_model(text_encoder_one))): text_encoder_one_ = model - elif isinstance(model, type(accelerator.unwrap_model(text_encoder_two))): + elif isinstance(model, type(unwrap_model(text_encoder_two))): text_encoder_two_ = model else: raise ValueError(f"unexpected save model: {model.__class__}") lora_state_dict, network_alphas = LoraLoaderMixin.lora_state_dict(input_dir) - LoraLoaderMixin.load_lora_into_unet(lora_state_dict, network_alphas=network_alphas, unet=unet_) - text_encoder_state_dict = {k: v for k, v in lora_state_dict.items() if "text_encoder." in k} - LoraLoaderMixin.load_lora_into_text_encoder( - text_encoder_state_dict, network_alphas=network_alphas, text_encoder=text_encoder_one_ - ) + unet_state_dict = {f'{k.replace("unet.", "")}': v for k, v in lora_state_dict.items() if k.startswith("unet.")} + unet_state_dict = convert_unet_state_dict_to_peft(unet_state_dict) + incompatible_keys = set_peft_model_state_dict(unet_, unet_state_dict, adapter_name="default") + if incompatible_keys is not None: + # check only for unexpected keys + unexpected_keys = getattr(incompatible_keys, "unexpected_keys", None) + if unexpected_keys: + logger.warning( + f"Loading adapter weights from state_dict led to unexpected keys not found in the model: " + f" {unexpected_keys}. " + ) - text_encoder_2_state_dict = {k: v for k, v in lora_state_dict.items() if "text_encoder_2." in k} - LoraLoaderMixin.load_lora_into_text_encoder( - text_encoder_2_state_dict, network_alphas=network_alphas, text_encoder=text_encoder_two_ - ) + if args.train_text_encoder: + # Do we need to call `scale_lora_layers()` here? + _set_state_dict_into_text_encoder(lora_state_dict, prefix="text_encoder.", text_encoder=text_encoder_one_) + + _set_state_dict_into_text_encoder( + lora_state_dict, prefix="text_encoder_2.", text_encoder=text_encoder_two_ + ) + + # Make sure the trainable params are in float32. This is again needed since the base models + # are in `weight_dtype`. More details: + # https://github.com/huggingface/diffusers/pull/6514#discussion_r1449796804 + if args.mixed_precision == "fp16": + models = [unet_] + if args.train_text_encoder: + models.extend([text_encoder_one_, text_encoder_two_]) + for model in models: + for param in model.parameters(): + # only upcast trainable parameters (LoRA) into fp32 + if param.requires_grad: + param.data = param.to(torch.float32) accelerator.register_save_state_pre_hook(save_model_hook) accelerator.register_load_state_pre_hook(load_model_hook) @@ -1383,6 +1401,17 @@ def load_model_hook(models, input_dir): args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes ) + # Make sure the trainable params are in float32. + if args.mixed_precision == "fp16": + models = [unet] + if args.train_text_encoder: + models.extend([text_encoder_one, text_encoder_two]) + for model in models: + for param in model.parameters(): + # only upcast trainable parameters (LoRA) into fp32 + if param.requires_grad: + param.data = param.to(torch.float32) + unet_lora_parameters = list(filter(lambda p: p.requires_grad, unet.parameters())) if args.train_text_encoder: From 938a65302f0579b253419338882ec5d0a8dfdbb8 Mon Sep 17 00:00:00 2001 From: Steve Rhoades Date: Sat, 13 Jan 2024 10:08:49 -0800 Subject: [PATCH 2/5] change order of import to fix nit --- .../train_dreambooth_lora_sdxl_advanced.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py b/examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py index 1d34fc91df17..ddd531452e9e 100644 --- a/examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py +++ b/examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py @@ -63,8 +63,8 @@ check_min_version, convert_all_state_dict_to_peft, convert_state_dict_to_diffusers, - convert_unet_state_dict_to_peft, convert_state_dict_to_kohya, + convert_unet_state_dict_to_peft, is_wandb_available, ) from diffusers.utils.import_utils import is_xformers_available From ca1f7df965c77a9e4df20b58d28656c90eaa1023 Mon Sep 17 00:00:00 2001 From: Steve Rhoades Date: Mon, 15 Jan 2024 20:15:11 -0800 Subject: [PATCH 3/5] fix nit, use cast_training_params --- .../train_dreambooth_lora_sdxl_advanced.py | 15 +++------------ 1 file changed, 3 insertions(+), 12 deletions(-) diff --git a/examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py b/examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py index ddd531452e9e..568279d9be3e 100644 --- a/examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py +++ b/examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py @@ -58,7 +58,7 @@ ) from diffusers.loaders import LoraLoaderMixin from diffusers.optimization import get_scheduler -from diffusers.training_utils import _set_state_dict_into_text_encoder, compute_snr +from diffusers.training_utils import _set_state_dict_into_text_encoder, cast_training_params, compute_snr from diffusers.utils import ( check_min_version, convert_all_state_dict_to_peft, @@ -1368,7 +1368,6 @@ def load_model_hook(models, input_dir): ) if args.train_text_encoder: - # Do we need to call `scale_lora_layers()` here? _set_state_dict_into_text_encoder(lora_state_dict, prefix="text_encoder.", text_encoder=text_encoder_one_) _set_state_dict_into_text_encoder( @@ -1382,11 +1381,7 @@ def load_model_hook(models, input_dir): models = [unet_] if args.train_text_encoder: models.extend([text_encoder_one_, text_encoder_two_]) - for model in models: - for param in model.parameters(): - # only upcast trainable parameters (LoRA) into fp32 - if param.requires_grad: - param.data = param.to(torch.float32) + cast_training_params(models) accelerator.register_save_state_pre_hook(save_model_hook) accelerator.register_load_state_pre_hook(load_model_hook) @@ -1406,11 +1401,7 @@ def load_model_hook(models, input_dir): models = [unet] if args.train_text_encoder: models.extend([text_encoder_one, text_encoder_two]) - for model in models: - for param in model.parameters(): - # only upcast trainable parameters (LoRA) into fp32 - if param.requires_grad: - param.data = param.to(torch.float32) + cast_training_params(models, dtype=torch.float32) unet_lora_parameters = list(filter(lambda p: p.requires_grad, unet.parameters())) From 2a0176cd8ece1d1a369799b3dc1bb00589fe4b4c Mon Sep 17 00:00:00 2001 From: Steve Rhoades Date: Mon, 15 Jan 2024 20:32:19 -0800 Subject: [PATCH 4/5] remove torch.compile fix, will move to a new PR --- .../train_dreambooth_lora_sdxl_advanced.py | 17 ++++++----------- 1 file changed, 6 insertions(+), 11 deletions(-) diff --git a/examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py b/examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py index 568279d9be3e..a88f95e925af 100644 --- a/examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py +++ b/examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py @@ -1294,11 +1294,6 @@ def main(args): else: param.requires_grad = False - def unwrap_model(model): - model = accelerator.unwrap_model(model) - model = model._orig_mod if is_compiled_module(model) else model - return model - # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format def save_model_hook(models, weights, output_dir): if accelerator.is_main_process: @@ -1309,14 +1304,14 @@ def save_model_hook(models, weights, output_dir): text_encoder_two_lora_layers_to_save = None for model in models: - if isinstance(model, type(unwrap_model(unet))): + if isinstance(model, type(accelerator.unwrap_model(unet))): unet_lora_layers_to_save = convert_state_dict_to_diffusers(get_peft_model_state_dict(model)) - elif isinstance(model, type(unwrap_model(text_encoder_one))): + elif isinstance(model, type(accelerator.unwrap_model(text_encoder_one))): if args.train_text_encoder: text_encoder_one_lora_layers_to_save = convert_state_dict_to_diffusers( get_peft_model_state_dict(model) ) - elif isinstance(model, type(unwrap_model(text_encoder_two))): + elif isinstance(model, type(accelerator.unwrap_model(text_encoder_two))): if args.train_text_encoder: text_encoder_two_lora_layers_to_save = convert_state_dict_to_diffusers( get_peft_model_state_dict(model) @@ -1344,11 +1339,11 @@ def load_model_hook(models, input_dir): while len(models) > 0: model = models.pop() - if isinstance(model, type(unwrap_model(unet))): + if isinstance(model, type(accelerator.unwrap_model(unet))): unet_ = model - elif isinstance(model, type(unwrap_model(text_encoder_one))): + elif isinstance(model, type(accelerator.unwrap_model(text_encoder_one))): text_encoder_one_ = model - elif isinstance(model, type(unwrap_model(text_encoder_two))): + elif isinstance(model, type(accelerator.unwrap_model(text_encoder_two))): text_encoder_two_ = model else: raise ValueError(f"unexpected save model: {model.__class__}") From e0fb77c612694c2e37f4f420edb0e72889ffcf12 Mon Sep 17 00:00:00 2001 From: Steve Rhoades Date: Mon, 15 Jan 2024 20:37:04 -0800 Subject: [PATCH 5/5] remove unnecessary import --- .../train_dreambooth_lora_sdxl_advanced.py | 1 - 1 file changed, 1 deletion(-) diff --git a/examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py b/examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py index a88f95e925af..3db9ff65e441 100644 --- a/examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py +++ b/examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py @@ -68,7 +68,6 @@ is_wandb_available, ) from diffusers.utils.import_utils import is_xformers_available -from diffusers.utils.torch_utils import is_compiled_module # Will error if the minimal version of diffusers is not installed. Remove at your own risks.