|
49 | 49 | from diffusers.training_utils import EMAModel
|
50 | 50 | from diffusers.utils import check_min_version, deprecate, is_wandb_available
|
51 | 51 | from diffusers.utils.import_utils import is_xformers_available
|
| 52 | +from diffusers.utils.torch_utils import is_compiled_module |
52 | 53 |
|
53 | 54 |
|
54 | 55 | # Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
@@ -489,6 +490,11 @@ def main():
|
489 | 490 | else:
|
490 | 491 | raise ValueError("xformers is not available. Make sure it is installed correctly")
|
491 | 492 |
|
| 493 | + def unwrap_model(model): |
| 494 | + model = accelerator.unwrap_model(model) |
| 495 | + model = model._orig_mod if is_compiled_module(model) else model |
| 496 | + return model |
| 497 | + |
492 | 498 | # `accelerate` 0.16.0 will have better support for customized saving
|
493 | 499 | if version.parse(accelerate.__version__) >= version.parse("0.16.0"):
|
494 | 500 | # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
|
@@ -845,7 +851,7 @@ def collate_fn(examples):
|
845 | 851 | raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
|
846 | 852 |
|
847 | 853 | # Predict the noise residual and compute loss
|
848 |
| - model_pred = unet(concatenated_noisy_latents, timesteps, encoder_hidden_states).sample |
| 854 | + model_pred = unet(concatenated_noisy_latents, timesteps, encoder_hidden_states, return_dict=False)[0] |
849 | 855 | loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
|
850 | 856 |
|
851 | 857 | # Gather the losses across all processes for logging (if we use distributed training).
|
@@ -919,9 +925,9 @@ def collate_fn(examples):
|
919 | 925 | # The models need unwrapping because for compatibility in distributed training mode.
|
920 | 926 | pipeline = StableDiffusionInstructPix2PixPipeline.from_pretrained(
|
921 | 927 | args.pretrained_model_name_or_path,
|
922 |
| - unet=accelerator.unwrap_model(unet), |
923 |
| - text_encoder=accelerator.unwrap_model(text_encoder), |
924 |
| - vae=accelerator.unwrap_model(vae), |
| 928 | + unet=unwrap_model(unet), |
| 929 | + text_encoder=unwrap_model(text_encoder), |
| 930 | + vae=unwrap_model(vae), |
925 | 931 | revision=args.revision,
|
926 | 932 | variant=args.variant,
|
927 | 933 | torch_dtype=weight_dtype,
|
@@ -965,14 +971,14 @@ def collate_fn(examples):
|
965 | 971 | # Create the pipeline using the trained modules and save it.
|
966 | 972 | accelerator.wait_for_everyone()
|
967 | 973 | if accelerator.is_main_process:
|
968 |
| - unet = accelerator.unwrap_model(unet) |
| 974 | + unet = unwrap_model(unet) |
969 | 975 | if args.use_ema:
|
970 | 976 | ema_unet.copy_to(unet.parameters())
|
971 | 977 |
|
972 | 978 | pipeline = StableDiffusionInstructPix2PixPipeline.from_pretrained(
|
973 | 979 | args.pretrained_model_name_or_path,
|
974 |
| - text_encoder=accelerator.unwrap_model(text_encoder), |
975 |
| - vae=accelerator.unwrap_model(vae), |
| 980 | + text_encoder=unwrap_model(text_encoder), |
| 981 | + vae=unwrap_model(vae), |
976 | 982 | unet=unet,
|
977 | 983 | revision=args.revision,
|
978 | 984 | variant=args.variant,
|
|
0 commit comments