56
56
from diffusers .training_utils import compute_snr
57
57
from diffusers .utils import check_min_version , convert_state_dict_to_diffusers , is_wandb_available
58
58
from diffusers .utils .import_utils import is_xformers_available
59
+ from diffusers .utils .torch_utils import is_compiled_module
59
60
60
61
61
62
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
@@ -1007,6 +1008,11 @@ def main(args):
1007
1008
if param .requires_grad :
1008
1009
param .data = param .to (torch .float32 )
1009
1010
1011
+ def unwrap_model (model ):
1012
+ model = accelerator .unwrap_model (model )
1013
+ model = model ._orig_mod if is_compiled_module (model ) else model
1014
+ return model
1015
+
1010
1016
# create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
1011
1017
def save_model_hook (models , weights , output_dir ):
1012
1018
if accelerator .is_main_process :
@@ -1017,13 +1023,13 @@ def save_model_hook(models, weights, output_dir):
1017
1023
text_encoder_two_lora_layers_to_save = None
1018
1024
1019
1025
for model in models :
1020
- if isinstance (model , type (accelerator . unwrap_model (unet ))):
1026
+ if isinstance (model , type (unwrap_model (unet ))):
1021
1027
unet_lora_layers_to_save = convert_state_dict_to_diffusers (get_peft_model_state_dict (model ))
1022
- elif isinstance (model , type (accelerator . unwrap_model (text_encoder_one ))):
1028
+ elif isinstance (model , type (unwrap_model (text_encoder_one ))):
1023
1029
text_encoder_one_lora_layers_to_save = convert_state_dict_to_diffusers (
1024
1030
get_peft_model_state_dict (model )
1025
1031
)
1026
- elif isinstance (model , type (accelerator . unwrap_model (text_encoder_two ))):
1032
+ elif isinstance (model , type (unwrap_model (text_encoder_two ))):
1027
1033
text_encoder_two_lora_layers_to_save = convert_state_dict_to_diffusers (
1028
1034
get_peft_model_state_dict (model )
1029
1035
)
@@ -1048,11 +1054,11 @@ def load_model_hook(models, input_dir):
1048
1054
while len (models ) > 0 :
1049
1055
model = models .pop ()
1050
1056
1051
- if isinstance (model , type (accelerator . unwrap_model (unet ))):
1057
+ if isinstance (model , type (unwrap_model (unet ))):
1052
1058
unet_ = model
1053
- elif isinstance (model , type (accelerator . unwrap_model (text_encoder_one ))):
1059
+ elif isinstance (model , type (unwrap_model (text_encoder_one ))):
1054
1060
text_encoder_one_ = model
1055
- elif isinstance (model , type (accelerator . unwrap_model (text_encoder_two ))):
1061
+ elif isinstance (model , type (unwrap_model (text_encoder_two ))):
1056
1062
text_encoder_two_ = model
1057
1063
else :
1058
1064
raise ValueError (f"unexpected save model: { model .__class__ } " )
@@ -1621,16 +1627,16 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers):
1621
1627
# Save the lora layers
1622
1628
accelerator .wait_for_everyone ()
1623
1629
if accelerator .is_main_process :
1624
- unet = accelerator . unwrap_model (unet )
1630
+ unet = unwrap_model (unet )
1625
1631
unet = unet .to (torch .float32 )
1626
1632
unet_lora_layers = convert_state_dict_to_diffusers (get_peft_model_state_dict (unet ))
1627
1633
1628
1634
if args .train_text_encoder :
1629
- text_encoder_one = accelerator . unwrap_model (text_encoder_one )
1635
+ text_encoder_one = unwrap_model (text_encoder_one )
1630
1636
text_encoder_lora_layers = convert_state_dict_to_diffusers (
1631
1637
get_peft_model_state_dict (text_encoder_one .to (torch .float32 ))
1632
1638
)
1633
- text_encoder_two = accelerator . unwrap_model (text_encoder_two )
1639
+ text_encoder_two = unwrap_model (text_encoder_two )
1634
1640
text_encoder_2_lora_layers = convert_state_dict_to_diffusers (
1635
1641
get_peft_model_state_dict (text_encoder_two .to (torch .float32 ))
1636
1642
)
0 commit comments