55
55
from diffusers .training_utils import compute_snr
56
56
from diffusers .utils import check_min_version , is_wandb_available
57
57
from diffusers .utils .import_utils import is_xformers_available
58
+ from diffusers .utils .torch_utils import is_compiled_module
58
59
59
60
60
61
if is_wandb_available ():
@@ -129,15 +130,12 @@ def log_validation(
129
130
if vae is not None :
130
131
pipeline_args ["vae" ] = vae
131
132
132
- if text_encoder is not None :
133
- text_encoder = accelerator .unwrap_model (text_encoder )
134
-
135
133
# create pipeline (note: unet and vae are loaded again in float32)
136
134
pipeline = DiffusionPipeline .from_pretrained (
137
135
args .pretrained_model_name_or_path ,
138
136
tokenizer = tokenizer ,
139
137
text_encoder = text_encoder ,
140
- unet = accelerator . unwrap_model ( unet ) ,
138
+ unet = unet ,
141
139
revision = args .revision ,
142
140
variant = args .variant ,
143
141
torch_dtype = weight_dtype ,
@@ -794,6 +792,7 @@ def encode_prompt(text_encoder, input_ids, attention_mask, text_encoder_use_atte
794
792
prompt_embeds = text_encoder (
795
793
text_input_ids ,
796
794
attention_mask = attention_mask ,
795
+ return_dict = False ,
797
796
)
798
797
prompt_embeds = prompt_embeds [0 ]
799
798
@@ -931,11 +930,16 @@ def main(args):
931
930
args .pretrained_model_name_or_path , subfolder = "unet" , revision = args .revision , variant = args .variant
932
931
)
933
932
933
+ def unwrap_model (model ):
934
+ model = accelerator .unwrap_model (model )
935
+ model = model ._orig_mod if is_compiled_module (model ) else model
936
+ return model
937
+
934
938
# create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
935
939
def save_model_hook (models , weights , output_dir ):
936
940
if accelerator .is_main_process :
937
941
for model in models :
938
- sub_dir = "unet" if isinstance (model , type (accelerator . unwrap_model (unet ))) else "text_encoder"
942
+ sub_dir = "unet" if isinstance (model , type (unwrap_model (unet ))) else "text_encoder"
939
943
model .save_pretrained (os .path .join (output_dir , sub_dir ))
940
944
941
945
# make sure to pop weight so that corresponding model is not saved again
@@ -946,7 +950,7 @@ def load_model_hook(models, input_dir):
946
950
# pop models so that they are not loaded again
947
951
model = models .pop ()
948
952
949
- if isinstance (model , type (accelerator . unwrap_model (text_encoder ))):
953
+ if isinstance (model , type (unwrap_model (text_encoder ))):
950
954
# load transformers style into model
951
955
load_model = text_encoder_cls .from_pretrained (input_dir , subfolder = "text_encoder" )
952
956
model .config = load_model .config
@@ -991,15 +995,12 @@ def load_model_hook(models, input_dir):
991
995
" doing mixed precision training. copy of the weights should still be float32."
992
996
)
993
997
994
- if accelerator .unwrap_model (unet ).dtype != torch .float32 :
995
- raise ValueError (
996
- f"Unet loaded as datatype { accelerator .unwrap_model (unet ).dtype } . { low_precision_error_string } "
997
- )
998
+ if unwrap_model (unet ).dtype != torch .float32 :
999
+ raise ValueError (f"Unet loaded as datatype { unwrap_model (unet ).dtype } . { low_precision_error_string } " )
998
1000
999
- if args .train_text_encoder and accelerator . unwrap_model (text_encoder ).dtype != torch .float32 :
1001
+ if args .train_text_encoder and unwrap_model (text_encoder ).dtype != torch .float32 :
1000
1002
raise ValueError (
1001
- f"Text encoder loaded as datatype { accelerator .unwrap_model (text_encoder ).dtype } ."
1002
- f" { low_precision_error_string } "
1003
+ f"Text encoder loaded as datatype { unwrap_model (text_encoder ).dtype } ." f" { low_precision_error_string } "
1003
1004
)
1004
1005
1005
1006
# Enable TF32 for faster training on Ampere GPUs,
@@ -1246,7 +1247,7 @@ def compute_text_embeddings(prompt):
1246
1247
text_encoder_use_attention_mask = args .text_encoder_use_attention_mask ,
1247
1248
)
1248
1249
1249
- if accelerator . unwrap_model (unet ).config .in_channels == channels * 2 :
1250
+ if unwrap_model (unet ).config .in_channels == channels * 2 :
1250
1251
noisy_model_input = torch .cat ([noisy_model_input , noisy_model_input ], dim = 1 )
1251
1252
1252
1253
if args .class_labels_conditioning == "timesteps" :
@@ -1256,8 +1257,8 @@ def compute_text_embeddings(prompt):
1256
1257
1257
1258
# Predict the noise residual
1258
1259
model_pred = unet (
1259
- noisy_model_input , timesteps , encoder_hidden_states , class_labels = class_labels
1260
- ). sample
1260
+ noisy_model_input , timesteps , encoder_hidden_states , class_labels = class_labels , return_dict = False
1261
+ )[ 0 ]
1261
1262
1262
1263
if model_pred .shape [1 ] == 6 :
1263
1264
model_pred , _ = torch .chunk (model_pred , 2 , dim = 1 )
@@ -1350,9 +1351,9 @@ def compute_text_embeddings(prompt):
1350
1351
1351
1352
if args .validation_prompt is not None and global_step % args .validation_steps == 0 :
1352
1353
images = log_validation (
1353
- text_encoder ,
1354
+ unwrap_model ( text_encoder ) if text_encoder is not None else text_encoder ,
1354
1355
tokenizer ,
1355
- unet ,
1356
+ unwrap_model ( unet ) ,
1356
1357
vae ,
1357
1358
args ,
1358
1359
accelerator ,
@@ -1375,14 +1376,14 @@ def compute_text_embeddings(prompt):
1375
1376
pipeline_args = {}
1376
1377
1377
1378
if text_encoder is not None :
1378
- pipeline_args ["text_encoder" ] = accelerator . unwrap_model (text_encoder )
1379
+ pipeline_args ["text_encoder" ] = unwrap_model (text_encoder )
1379
1380
1380
1381
if args .skip_save_text_encoder :
1381
1382
pipeline_args ["text_encoder" ] = None
1382
1383
1383
1384
pipeline = DiffusionPipeline .from_pretrained (
1384
1385
args .pretrained_model_name_or_path ,
1385
- unet = accelerator . unwrap_model (unet ),
1386
+ unet = unwrap_model (unet ),
1386
1387
revision = args .revision ,
1387
1388
variant = args .variant ,
1388
1389
** pipeline_args ,
0 commit comments