56
56
from diffusers .optimization import get_scheduler
57
57
from diffusers .utils import check_min_version , convert_state_dict_to_diffusers , is_wandb_available
58
58
from diffusers .utils .import_utils import is_xformers_available
59
+ from diffusers .utils .torch_utils import is_compiled_module
59
60
60
61
61
62
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
@@ -647,6 +648,7 @@ def encode_prompt(text_encoder, input_ids, attention_mask, text_encoder_use_atte
647
648
prompt_embeds = text_encoder (
648
649
text_input_ids ,
649
650
attention_mask = attention_mask ,
651
+ return_dict = False ,
650
652
)
651
653
prompt_embeds = prompt_embeds [0 ]
652
654
@@ -843,6 +845,11 @@ def main(args):
843
845
)
844
846
text_encoder .add_adapter (text_lora_config )
845
847
848
+ def unwrap_model (model ):
849
+ model = accelerator .unwrap_model (model )
850
+ model = model ._orig_mod if is_compiled_module (model ) else model
851
+ return model
852
+
846
853
# create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
847
854
def save_model_hook (models , weights , output_dir ):
848
855
if accelerator .is_main_process :
@@ -852,9 +859,9 @@ def save_model_hook(models, weights, output_dir):
852
859
text_encoder_lora_layers_to_save = None
853
860
854
861
for model in models :
855
- if isinstance (model , type (accelerator . unwrap_model (unet ))):
862
+ if isinstance (model , type (unwrap_model (unet ))):
856
863
unet_lora_layers_to_save = convert_state_dict_to_diffusers (get_peft_model_state_dict (model ))
857
- elif isinstance (model , type (accelerator . unwrap_model (text_encoder ))):
864
+ elif isinstance (model , type (unwrap_model (text_encoder ))):
858
865
text_encoder_lora_layers_to_save = convert_state_dict_to_diffusers (
859
866
get_peft_model_state_dict (model )
860
867
)
@@ -877,9 +884,9 @@ def load_model_hook(models, input_dir):
877
884
while len (models ) > 0 :
878
885
model = models .pop ()
879
886
880
- if isinstance (model , type (accelerator . unwrap_model (unet ))):
887
+ if isinstance (model , type (unwrap_model (unet ))):
881
888
unet_ = model
882
- elif isinstance (model , type (accelerator . unwrap_model (text_encoder ))):
889
+ elif isinstance (model , type (unwrap_model (text_encoder ))):
883
890
text_encoder_ = model
884
891
else :
885
892
raise ValueError (f"unexpected save model: { model .__class__ } " )
@@ -1118,7 +1125,7 @@ def compute_text_embeddings(prompt):
1118
1125
text_encoder_use_attention_mask = args .text_encoder_use_attention_mask ,
1119
1126
)
1120
1127
1121
- if accelerator . unwrap_model (unet ).config .in_channels == channels * 2 :
1128
+ if unwrap_model (unet ).config .in_channels == channels * 2 :
1122
1129
noisy_model_input = torch .cat ([noisy_model_input , noisy_model_input ], dim = 1 )
1123
1130
1124
1131
if args .class_labels_conditioning == "timesteps" :
@@ -1128,8 +1135,12 @@ def compute_text_embeddings(prompt):
1128
1135
1129
1136
# Predict the noise residual
1130
1137
model_pred = unet (
1131
- noisy_model_input , timesteps , encoder_hidden_states , class_labels = class_labels
1132
- ).sample
1138
+ noisy_model_input ,
1139
+ timesteps ,
1140
+ encoder_hidden_states ,
1141
+ class_labels = class_labels ,
1142
+ return_dict = False ,
1143
+ )[0 ]
1133
1144
1134
1145
# if model predicts variance, throw away the prediction. we will only train on the
1135
1146
# simplified training objective. This means that all schedulers using the fine tuned
@@ -1215,8 +1226,8 @@ def compute_text_embeddings(prompt):
1215
1226
# create pipeline
1216
1227
pipeline = DiffusionPipeline .from_pretrained (
1217
1228
args .pretrained_model_name_or_path ,
1218
- unet = accelerator . unwrap_model (unet ),
1219
- text_encoder = None if args .pre_compute_text_embeddings else accelerator . unwrap_model (text_encoder ),
1229
+ unet = unwrap_model (unet ),
1230
+ text_encoder = None if args .pre_compute_text_embeddings else unwrap_model (text_encoder ),
1220
1231
revision = args .revision ,
1221
1232
variant = args .variant ,
1222
1233
torch_dtype = weight_dtype ,
@@ -1284,13 +1295,13 @@ def compute_text_embeddings(prompt):
1284
1295
# Save the lora layers
1285
1296
accelerator .wait_for_everyone ()
1286
1297
if accelerator .is_main_process :
1287
- unet = accelerator . unwrap_model (unet )
1298
+ unet = unwrap_model (unet )
1288
1299
unet = unet .to (torch .float32 )
1289
1300
1290
1301
unet_lora_state_dict = convert_state_dict_to_diffusers (get_peft_model_state_dict (unet ))
1291
1302
1292
1303
if args .train_text_encoder :
1293
- text_encoder = accelerator . unwrap_model (text_encoder )
1304
+ text_encoder = unwrap_model (text_encoder )
1294
1305
text_encoder_state_dict = convert_state_dict_to_diffusers (get_peft_model_state_dict (text_encoder ))
1295
1306
else :
1296
1307
text_encoder_state_dict = None
0 commit comments