50
50
from diffusers .optimization import get_scheduler
51
51
from diffusers .utils import check_min_version , is_wandb_available
52
52
from diffusers .utils .import_utils import is_xformers_available
53
+ from diffusers .utils .torch_utils import is_compiled_module
53
54
54
55
55
56
MAX_SEQ_LENGTH = 77
@@ -926,6 +927,11 @@ def load_model_hook(models, input_dir):
926
927
else :
927
928
raise ValueError ("xformers is not available. Make sure it is installed correctly" )
928
929
930
+ def unwrap_model (model ):
931
+ model = accelerator .unwrap_model (model )
932
+ model = model ._orig_mod if is_compiled_module (model ) else model
933
+ return model
934
+
929
935
if args .gradient_checkpointing :
930
936
unet .enable_gradient_checkpointing ()
931
937
@@ -935,9 +941,9 @@ def load_model_hook(models, input_dir):
935
941
" doing mixed precision training, copy of the weights should still be float32."
936
942
)
937
943
938
- if accelerator . unwrap_model (t2iadapter ).dtype != torch .float32 :
944
+ if unwrap_model (t2iadapter ).dtype != torch .float32 :
939
945
raise ValueError (
940
- f"Controlnet loaded as datatype { accelerator . unwrap_model (t2iadapter ).dtype } . { low_precision_error_string } "
946
+ f"Controlnet loaded as datatype { unwrap_model (t2iadapter ).dtype } . { low_precision_error_string } "
941
947
)
942
948
943
949
# Enable TF32 for faster training on Ampere GPUs,
@@ -1198,7 +1204,8 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
1198
1204
encoder_hidden_states = batch ["prompt_ids" ],
1199
1205
added_cond_kwargs = batch ["unet_added_conditions" ],
1200
1206
down_block_additional_residuals = down_block_additional_residuals ,
1201
- ).sample
1207
+ return_dict = False ,
1208
+ )[0 ]
1202
1209
1203
1210
# Denoise the latents
1204
1211
denoised_latents = model_pred * (- sigmas ) + noisy_latents
@@ -1279,7 +1286,7 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
1279
1286
# Create the pipeline using using the trained modules and save it.
1280
1287
accelerator .wait_for_everyone ()
1281
1288
if accelerator .is_main_process :
1282
- t2iadapter = accelerator . unwrap_model (t2iadapter )
1289
+ t2iadapter = unwrap_model (t2iadapter )
1283
1290
t2iadapter .save_pretrained (args .output_dir )
1284
1291
1285
1292
if args .push_to_hub :
0 commit comments