68
68
is_wandb_available ,
69
69
)
70
70
from diffusers .utils .import_utils import is_xformers_available
71
+ from diffusers .utils .torch_utils import is_compiled_module
71
72
72
73
73
74
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
@@ -1293,6 +1294,11 @@ def main(args):
1293
1294
else :
1294
1295
param .requires_grad = False
1295
1296
1297
+ def unwrap_model (model ):
1298
+ model = accelerator .unwrap_model (model )
1299
+ model = model ._orig_mod if is_compiled_module (model ) else model
1300
+ return model
1301
+
1296
1302
# create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
1297
1303
def save_model_hook (models , weights , output_dir ):
1298
1304
if accelerator .is_main_process :
@@ -1303,14 +1309,14 @@ def save_model_hook(models, weights, output_dir):
1303
1309
text_encoder_two_lora_layers_to_save = None
1304
1310
1305
1311
for model in models :
1306
- if isinstance (model , type (accelerator . unwrap_model (unet ))):
1312
+ if isinstance (model , type (unwrap_model (unet ))):
1307
1313
unet_lora_layers_to_save = convert_state_dict_to_diffusers (get_peft_model_state_dict (model ))
1308
- elif isinstance (model , type (accelerator . unwrap_model (text_encoder_one ))):
1314
+ elif isinstance (model , type (unwrap_model (text_encoder_one ))):
1309
1315
if args .train_text_encoder :
1310
1316
text_encoder_one_lora_layers_to_save = convert_state_dict_to_diffusers (
1311
1317
get_peft_model_state_dict (model )
1312
1318
)
1313
- elif isinstance (model , type (accelerator . unwrap_model (text_encoder_two ))):
1319
+ elif isinstance (model , type (unwrap_model (text_encoder_two ))):
1314
1320
if args .train_text_encoder :
1315
1321
text_encoder_two_lora_layers_to_save = convert_state_dict_to_diffusers (
1316
1322
get_peft_model_state_dict (model )
@@ -1338,11 +1344,11 @@ def load_model_hook(models, input_dir):
1338
1344
while len (models ) > 0 :
1339
1345
model = models .pop ()
1340
1346
1341
- if isinstance (model , type (accelerator . unwrap_model (unet ))):
1347
+ if isinstance (model , type (unwrap_model (unet ))):
1342
1348
unet_ = model
1343
- elif isinstance (model , type (accelerator . unwrap_model (text_encoder_one ))):
1349
+ elif isinstance (model , type (unwrap_model (text_encoder_one ))):
1344
1350
text_encoder_one_ = model
1345
- elif isinstance (model , type (accelerator . unwrap_model (text_encoder_two ))):
1351
+ elif isinstance (model , type (unwrap_model (text_encoder_two ))):
1346
1352
text_encoder_two_ = model
1347
1353
else :
1348
1354
raise ValueError (f"unexpected save model: { model .__class__ } " )
0 commit comments