46
46
from diffusers .training_utils import cast_training_params , compute_snr
47
47
from diffusers .utils import check_min_version , convert_state_dict_to_diffusers , is_wandb_available
48
48
from diffusers .utils .import_utils import is_xformers_available
49
+ from diffusers .utils .torch_utils import is_compiled_module
49
50
50
51
51
52
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
@@ -596,6 +597,11 @@ def tokenize_captions(examples, is_train=True):
596
597
]
597
598
)
598
599
600
+ def unwrap_model (model ):
601
+ model = accelerator .unwrap_model (model )
602
+ model = model ._orig_mod if is_compiled_module (model ) else model
603
+ return model
604
+
599
605
def preprocess_train (examples ):
600
606
images = [image .convert ("RGB" ) for image in examples [image_column ]]
601
607
examples ["pixel_values" ] = [train_transforms (image ) for image in images ]
@@ -729,7 +735,7 @@ def collate_fn(examples):
729
735
noisy_latents = noise_scheduler .add_noise (latents , noise , timesteps )
730
736
731
737
# Get the text embedding for conditioning
732
- encoder_hidden_states = text_encoder (batch ["input_ids" ])[0 ]
738
+ encoder_hidden_states = text_encoder (batch ["input_ids" ], return_dict = False )[0 ]
733
739
734
740
# Get the target for loss depending on the prediction type
735
741
if args .prediction_type is not None :
@@ -744,7 +750,7 @@ def collate_fn(examples):
744
750
raise ValueError (f"Unknown prediction type { noise_scheduler .config .prediction_type } " )
745
751
746
752
# Predict the noise residual and compute loss
747
- model_pred = unet (noisy_latents , timesteps , encoder_hidden_states ). sample
753
+ model_pred = unet (noisy_latents , timesteps , encoder_hidden_states , return_dict = False )[ 0 ]
748
754
749
755
if args .snr_gamma is None :
750
756
loss = F .mse_loss (model_pred .float (), target .float (), reduction = "mean" )
@@ -809,7 +815,7 @@ def collate_fn(examples):
809
815
save_path = os .path .join (args .output_dir , f"checkpoint-{ global_step } " )
810
816
accelerator .save_state (save_path )
811
817
812
- unwrapped_unet = accelerator . unwrap_model (unet )
818
+ unwrapped_unet = unwrap_model (unet )
813
819
unet_lora_state_dict = convert_state_dict_to_diffusers (
814
820
get_peft_model_state_dict (unwrapped_unet )
815
821
)
@@ -837,7 +843,7 @@ def collate_fn(examples):
837
843
# create pipeline
838
844
pipeline = DiffusionPipeline .from_pretrained (
839
845
args .pretrained_model_name_or_path ,
840
- unet = accelerator . unwrap_model (unet ),
846
+ unet = unwrap_model (unet ),
841
847
revision = args .revision ,
842
848
variant = args .variant ,
843
849
torch_dtype = weight_dtype ,
@@ -878,7 +884,7 @@ def collate_fn(examples):
878
884
if accelerator .is_main_process :
879
885
unet = unet .to (torch .float32 )
880
886
881
- unwrapped_unet = accelerator . unwrap_model (unet )
887
+ unwrapped_unet = unwrap_model (unet )
882
888
unet_lora_state_dict = convert_state_dict_to_diffusers (get_peft_model_state_dict (unwrapped_unet ))
883
889
StableDiffusionPipeline .save_lora_weights (
884
890
save_directory = args .output_dir ,
0 commit comments