35
35
from huggingface_hub .utils import insecure_hashlib
36
36
from packaging import version
37
37
from peft import LoraConfig
38
- from peft .utils import get_peft_model_state_dict
38
+ from peft .utils import get_peft_model_state_dict , set_peft_model_state_dict
39
39
from PIL import Image
40
40
from PIL .ImageOps import exif_transpose
41
41
from torch .utils .data import Dataset
54
54
)
55
55
from diffusers .loaders import LoraLoaderMixin
56
56
from diffusers .optimization import get_scheduler
57
- from diffusers .utils import check_min_version , convert_state_dict_to_diffusers , is_wandb_available
57
+ from diffusers .training_utils import _set_state_dict_into_text_encoder , cast_training_params
58
+ from diffusers .utils import (
59
+ check_min_version ,
60
+ convert_state_dict_to_diffusers ,
61
+ convert_unet_state_dict_to_peft ,
62
+ is_wandb_available ,
63
+ )
58
64
from diffusers .utils .import_utils import is_xformers_available
59
65
from diffusers .utils .torch_utils import is_compiled_module
60
66
@@ -892,10 +898,33 @@ def load_model_hook(models, input_dir):
892
898
raise ValueError (f"unexpected save model: { model .__class__ } " )
893
899
894
900
lora_state_dict , network_alphas = LoraLoaderMixin .lora_state_dict (input_dir )
895
- LoraLoaderMixin .load_lora_into_unet (lora_state_dict , network_alphas = network_alphas , unet = unet_ )
896
- LoraLoaderMixin .load_lora_into_text_encoder (
897
- lora_state_dict , network_alphas = network_alphas , text_encoder = text_encoder_
898
- )
901
+
902
+ unet_state_dict = {f'{ k .replace ("unet." , "" )} ' : v for k , v in lora_state_dict .items () if k .startswith ("unet." )}
903
+ unet_state_dict = convert_unet_state_dict_to_peft (unet_state_dict )
904
+ incompatible_keys = set_peft_model_state_dict (unet_ , unet_state_dict , adapter_name = "default" )
905
+
906
+ if incompatible_keys is not None :
907
+ # check only for unexpected keys
908
+ unexpected_keys = getattr (incompatible_keys , "unexpected_keys" , None )
909
+ if unexpected_keys :
910
+ logger .warning (
911
+ f"Loading adapter weights from state_dict led to unexpected keys not found in the model: "
912
+ f" { unexpected_keys } . "
913
+ )
914
+
915
+ if args .train_text_encoder :
916
+ _set_state_dict_into_text_encoder (lora_state_dict , prefix = "text_encoder." , text_encoder = text_encoder_ )
917
+
918
+ # Make sure the trainable params are in float32. This is again needed since the base models
919
+ # are in `weight_dtype`. More details:
920
+ # https://github.com/huggingface/diffusers/pull/6514#discussion_r1449796804
921
+ if args .mixed_precision == "fp16" :
922
+ models = [unet_ ]
923
+ if args .train_text_encoder :
924
+ models .append (text_encoder_ )
925
+
926
+ # only upcast trainable parameters (LoRA) into fp32
927
+ cast_training_params (models , dtype = torch .float32 )
899
928
900
929
accelerator .register_save_state_pre_hook (save_model_hook )
901
930
accelerator .register_load_state_pre_hook (load_model_hook )
@@ -910,6 +939,15 @@ def load_model_hook(models, input_dir):
910
939
args .learning_rate * args .gradient_accumulation_steps * args .train_batch_size * accelerator .num_processes
911
940
)
912
941
942
+ # Make sure the trainable params are in float32.
943
+ if args .mixed_precision == "fp16" :
944
+ models = [unet ]
945
+ if args .train_text_encoder :
946
+ models .append (text_encoder )
947
+
948
+ # only upcast trainable parameters (LoRA) into fp32
949
+ cast_training_params (models , dtype = torch .float32 )
950
+
913
951
# Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs
914
952
if args .use_8bit_adam :
915
953
try :
0 commit comments