34
34
from huggingface_hub import create_repo , upload_folder
35
35
from huggingface_hub .utils import insecure_hashlib
36
36
from packaging import version
37
- from peft import LoraConfig
37
+ from peft import LoraConfig , set_peft_model_state_dict
38
38
from peft .utils import get_peft_model_state_dict
39
39
from PIL import Image
40
40
from PIL .ImageOps import exif_transpose
53
53
)
54
54
from diffusers .loaders import LoraLoaderMixin
55
55
from diffusers .optimization import get_scheduler
56
- from diffusers .training_utils import compute_snr
57
- from diffusers .utils import check_min_version , convert_state_dict_to_diffusers , is_wandb_available
56
+ from diffusers .training_utils import _set_state_dict_into_text_encoder , compute_snr
57
+ from diffusers .utils import (
58
+ check_min_version ,
59
+ convert_state_dict_to_diffusers ,
60
+ convert_unet_state_dict_to_peft ,
61
+ is_wandb_available ,
62
+ )
58
63
from diffusers .utils .import_utils import is_xformers_available
59
64
from diffusers .utils .torch_utils import is_compiled_module
60
65
@@ -997,17 +1002,6 @@ def main(args):
997
1002
text_encoder_one .add_adapter (text_lora_config )
998
1003
text_encoder_two .add_adapter (text_lora_config )
999
1004
1000
- # Make sure the trainable params are in float32.
1001
- if args .mixed_precision == "fp16" :
1002
- models = [unet ]
1003
- if args .train_text_encoder :
1004
- models .extend ([text_encoder_one , text_encoder_two ])
1005
- for model in models :
1006
- for param in model .parameters ():
1007
- # only upcast trainable parameters (LoRA) into fp32
1008
- if param .requires_grad :
1009
- param .data = param .to (torch .float32 )
1010
-
1011
1005
def unwrap_model (model ):
1012
1006
model = accelerator .unwrap_model (model )
1013
1007
model = model ._orig_mod if is_compiled_module (model ) else model
@@ -1064,17 +1058,39 @@ def load_model_hook(models, input_dir):
1064
1058
raise ValueError (f"unexpected save model: { model .__class__ } " )
1065
1059
1066
1060
lora_state_dict , network_alphas = LoraLoaderMixin .lora_state_dict (input_dir )
1067
- LoraLoaderMixin .load_lora_into_unet (lora_state_dict , network_alphas = network_alphas , unet = unet_ )
1068
1061
1069
- text_encoder_state_dict = {k : v for k , v in lora_state_dict .items () if "text_encoder." in k }
1070
- LoraLoaderMixin .load_lora_into_text_encoder (
1071
- text_encoder_state_dict , network_alphas = network_alphas , text_encoder = text_encoder_one_
1072
- )
1062
+ unet_state_dict = {f'{ k .replace ("unet." , "" )} ' : v for k , v in lora_state_dict .items () if k .startswith ("unet." )}
1063
+ unet_state_dict = convert_unet_state_dict_to_peft (unet_state_dict )
1064
+ incompatible_keys = set_peft_model_state_dict (unet_ , unet_state_dict , adapter_name = "default" )
1065
+ if incompatible_keys is not None :
1066
+ # check only for unexpected keys
1067
+ unexpected_keys = getattr (incompatible_keys , "unexpected_keys" , None )
1068
+ if unexpected_keys :
1069
+ logger .warning (
1070
+ f"Loading adapter weights from state_dict led to unexpected keys not found in the model: "
1071
+ f" { unexpected_keys } . "
1072
+ )
1073
1073
1074
- text_encoder_2_state_dict = {k : v for k , v in lora_state_dict .items () if "text_encoder_2." in k }
1075
- LoraLoaderMixin .load_lora_into_text_encoder (
1076
- text_encoder_2_state_dict , network_alphas = network_alphas , text_encoder = text_encoder_two_
1077
- )
1074
+ if args .train_text_encoder :
1075
+ # Do we need to call `scale_lora_layers()` here?
1076
+ _set_state_dict_into_text_encoder (lora_state_dict , prefix = "text_encoder." , text_encoder = text_encoder_one_ )
1077
+
1078
+ _set_state_dict_into_text_encoder (
1079
+ lora_state_dict , prefix = "text_encoder_2." , text_encoder = text_encoder_one_
1080
+ )
1081
+
1082
+ # Make sure the trainable params are in float32. This is again needed since the base models
1083
+ # are in `weight_dtype`. More details:
1084
+ # https://github.com/huggingface/diffusers/pull/6514#discussion_r1449796804
1085
+ if args .mixed_precision == "fp16" :
1086
+ models = [unet_ ]
1087
+ if args .train_text_encoder :
1088
+ models .extend ([text_encoder_one_ , text_encoder_two_ ])
1089
+ for model in models :
1090
+ for param in model .parameters ():
1091
+ # only upcast trainable parameters (LoRA) into fp32
1092
+ if param .requires_grad :
1093
+ param .data = param .to (torch .float32 )
1078
1094
1079
1095
accelerator .register_save_state_pre_hook (save_model_hook )
1080
1096
accelerator .register_load_state_pre_hook (load_model_hook )
@@ -1089,6 +1105,17 @@ def load_model_hook(models, input_dir):
1089
1105
args .learning_rate * args .gradient_accumulation_steps * args .train_batch_size * accelerator .num_processes
1090
1106
)
1091
1107
1108
+ # Make sure the trainable params are in float32.
1109
+ if args .mixed_precision == "fp16" :
1110
+ models = [unet ]
1111
+ if args .train_text_encoder :
1112
+ models .extend ([text_encoder_one , text_encoder_two ])
1113
+ for model in models :
1114
+ for param in model .parameters ():
1115
+ # only upcast trainable parameters (LoRA) into fp32
1116
+ if param .requires_grad :
1117
+ param .data = param .to (torch .float32 )
1118
+
1092
1119
unet_lora_parameters = list (filter (lambda p : p .requires_grad , unet .parameters ()))
1093
1120
1094
1121
if args .train_text_encoder :
@@ -1506,6 +1533,7 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers):
1506
1533
else unet_lora_parameters
1507
1534
)
1508
1535
accelerator .clip_grad_norm_ (params_to_clip , args .max_grad_norm )
1536
+
1509
1537
optimizer .step ()
1510
1538
lr_scheduler .step ()
1511
1539
optimizer .zero_grad ()
0 commit comments