46
46
from diffusers .training_utils import EMAModel , compute_snr
47
47
from diffusers .utils import check_min_version , deprecate , is_wandb_available , make_image_grid
48
48
from diffusers .utils .import_utils import is_xformers_available
49
+ from diffusers .utils .torch_utils import is_compiled_module
49
50
50
51
51
52
if is_wandb_available ():
@@ -833,6 +834,12 @@ def collate_fn(examples):
833
834
tracker_config .pop ("validation_prompts" )
834
835
accelerator .init_trackers (args .tracker_project_name , tracker_config )
835
836
837
+ # Function for unwrapping if model was compiled with `torch.compile`.
838
+ def unwrap_model (model ):
839
+ model = accelerator .unwrap_model (model )
840
+ model = model ._orig_mod if is_compiled_module (model ) else model
841
+ return model
842
+
836
843
# Train!
837
844
total_batch_size = args .train_batch_size * accelerator .num_processes * args .gradient_accumulation_steps
838
845
@@ -912,7 +919,7 @@ def collate_fn(examples):
912
919
noisy_latents = noise_scheduler .add_noise (latents , noise , timesteps )
913
920
914
921
# Get the text embedding for conditioning
915
- encoder_hidden_states = text_encoder (batch ["input_ids" ])[0 ]
922
+ encoder_hidden_states = text_encoder (batch ["input_ids" ], return_dict = False )[0 ]
916
923
917
924
# Get the target for loss depending on the prediction type
918
925
if args .prediction_type is not None :
@@ -927,7 +934,7 @@ def collate_fn(examples):
927
934
raise ValueError (f"Unknown prediction type { noise_scheduler .config .prediction_type } " )
928
935
929
936
# Predict the noise residual and compute loss
930
- model_pred = unet (noisy_latents , timesteps , encoder_hidden_states ). sample
937
+ model_pred = unet (noisy_latents , timesteps , encoder_hidden_states , return_dict = False )[ 0 ]
931
938
932
939
if args .snr_gamma is None :
933
940
loss = F .mse_loss (model_pred .float (), target .float (), reduction = "mean" )
@@ -1023,7 +1030,7 @@ def collate_fn(examples):
1023
1030
# Create the pipeline using the trained modules and save it.
1024
1031
accelerator .wait_for_everyone ()
1025
1032
if accelerator .is_main_process :
1026
- unet = accelerator . unwrap_model (unet )
1033
+ unet = unwrap_model (unet )
1027
1034
if args .use_ema :
1028
1035
ema_unet .copy_to (unet .parameters ())
1029
1036
0 commit comments