54
54
from diffusers .training_utils import cast_training_params , compute_snr
55
55
from diffusers .utils import check_min_version , convert_state_dict_to_diffusers , is_wandb_available
56
56
from diffusers .utils .import_utils import is_xformers_available
57
+ from diffusers .utils .torch_utils import is_compiled_module
57
58
58
59
59
60
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
@@ -460,13 +461,12 @@ def encode_prompt(text_encoders, tokenizers, prompt, text_input_ids_list=None):
460
461
text_input_ids = text_input_ids_list [i ]
461
462
462
463
prompt_embeds = text_encoder (
463
- text_input_ids .to (text_encoder .device ),
464
- output_hidden_states = True ,
464
+ text_input_ids .to (text_encoder .device ), output_hidden_states = True , return_dict = False
465
465
)
466
466
467
467
# We are only ALWAYS interested in the pooled output of the final text encoder
468
468
pooled_prompt_embeds = prompt_embeds [0 ]
469
- prompt_embeds = prompt_embeds . hidden_states [- 2 ]
469
+ prompt_embeds = prompt_embeds [ - 1 ] [- 2 ]
470
470
bs_embed , seq_len , _ = prompt_embeds .shape
471
471
prompt_embeds = prompt_embeds .view (bs_embed , seq_len , - 1 )
472
472
prompt_embeds_list .append (prompt_embeds )
@@ -637,6 +637,11 @@ def main(args):
637
637
# only upcast trainable parameters (LoRA) into fp32
638
638
cast_training_params (models , dtype = torch .float32 )
639
639
640
+ def unwrap_model (model ):
641
+ model = accelerator .unwrap_model (model )
642
+ model = model ._orig_mod if is_compiled_module (model ) else model
643
+ return model
644
+
640
645
# create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
641
646
def save_model_hook (models , weights , output_dir ):
642
647
if accelerator .is_main_process :
@@ -647,13 +652,13 @@ def save_model_hook(models, weights, output_dir):
647
652
text_encoder_two_lora_layers_to_save = None
648
653
649
654
for model in models :
650
- if isinstance (model , type (accelerator . unwrap_model (unet ))):
655
+ if isinstance (model , type (unwrap_model (unet ))):
651
656
unet_lora_layers_to_save = convert_state_dict_to_diffusers (get_peft_model_state_dict (model ))
652
- elif isinstance (model , type (accelerator . unwrap_model (text_encoder_one ))):
657
+ elif isinstance (model , type (unwrap_model (text_encoder_one ))):
653
658
text_encoder_one_lora_layers_to_save = convert_state_dict_to_diffusers (
654
659
get_peft_model_state_dict (model )
655
660
)
656
- elif isinstance (model , type (accelerator . unwrap_model (text_encoder_two ))):
661
+ elif isinstance (model , type (unwrap_model (text_encoder_two ))):
657
662
text_encoder_two_lora_layers_to_save = convert_state_dict_to_diffusers (
658
663
get_peft_model_state_dict (model )
659
664
)
@@ -678,11 +683,11 @@ def load_model_hook(models, input_dir):
678
683
while len (models ) > 0 :
679
684
model = models .pop ()
680
685
681
- if isinstance (model , type (accelerator . unwrap_model (unet ))):
686
+ if isinstance (model , type (unwrap_model (unet ))):
682
687
unet_ = model
683
- elif isinstance (model , type (accelerator . unwrap_model (text_encoder_one ))):
688
+ elif isinstance (model , type (unwrap_model (text_encoder_one ))):
684
689
text_encoder_one_ = model
685
- elif isinstance (model , type (accelerator . unwrap_model (text_encoder_two ))):
690
+ elif isinstance (model , type (unwrap_model (text_encoder_two ))):
686
691
text_encoder_two_ = model
687
692
else :
688
693
raise ValueError (f"unexpected save model: { model .__class__ } " )
@@ -1031,8 +1036,12 @@ def compute_time_ids(original_size, crops_coords_top_left):
1031
1036
)
1032
1037
unet_added_conditions .update ({"text_embeds" : pooled_prompt_embeds })
1033
1038
model_pred = unet (
1034
- noisy_model_input , timesteps , prompt_embeds , added_cond_kwargs = unet_added_conditions
1035
- ).sample
1039
+ noisy_model_input ,
1040
+ timesteps ,
1041
+ prompt_embeds ,
1042
+ added_cond_kwargs = unet_added_conditions ,
1043
+ return_dict = False ,
1044
+ )[0 ]
1036
1045
1037
1046
# Get the target for loss depending on the prediction type
1038
1047
if args .prediction_type is not None :
@@ -1125,9 +1134,9 @@ def compute_time_ids(original_size, crops_coords_top_left):
1125
1134
pipeline = StableDiffusionXLPipeline .from_pretrained (
1126
1135
args .pretrained_model_name_or_path ,
1127
1136
vae = vae ,
1128
- text_encoder = accelerator . unwrap_model (text_encoder_one ),
1129
- text_encoder_2 = accelerator . unwrap_model (text_encoder_two ),
1130
- unet = accelerator . unwrap_model (unet ),
1137
+ text_encoder = unwrap_model (text_encoder_one ),
1138
+ text_encoder_2 = unwrap_model (text_encoder_two ),
1139
+ unet = unwrap_model (unet ),
1131
1140
revision = args .revision ,
1132
1141
variant = args .variant ,
1133
1142
torch_dtype = weight_dtype ,
@@ -1166,12 +1175,12 @@ def compute_time_ids(original_size, crops_coords_top_left):
1166
1175
# Save the lora layers
1167
1176
accelerator .wait_for_everyone ()
1168
1177
if accelerator .is_main_process :
1169
- unet = accelerator . unwrap_model (unet )
1178
+ unet = unwrap_model (unet )
1170
1179
unet_lora_state_dict = convert_state_dict_to_diffusers (get_peft_model_state_dict (unet ))
1171
1180
1172
1181
if args .train_text_encoder :
1173
- text_encoder_one = accelerator . unwrap_model (text_encoder_one )
1174
- text_encoder_two = accelerator . unwrap_model (text_encoder_two )
1182
+ text_encoder_one = unwrap_model (text_encoder_one )
1183
+ text_encoder_two = unwrap_model (text_encoder_two )
1175
1184
1176
1185
text_encoder_lora_layers = convert_state_dict_to_diffusers (get_peft_model_state_dict (text_encoder_one ))
1177
1186
text_encoder_2_lora_layers = convert_state_dict_to_diffusers (get_peft_model_state_dict (text_encoder_two ))
0 commit comments