|
52 | 52 | from diffusers.training_utils import EMAModel
|
53 | 53 | from diffusers.utils import check_min_version, deprecate, is_wandb_available, load_image
|
54 | 54 | from diffusers.utils.import_utils import is_xformers_available
|
| 55 | +from diffusers.utils.torch_utils import is_compiled_module |
55 | 56 |
|
56 | 57 |
|
57 | 58 | # Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
@@ -531,6 +532,11 @@ def main():
|
531 | 532 | else:
|
532 | 533 | raise ValueError("xformers is not available. Make sure it is installed correctly")
|
533 | 534 |
|
| 535 | + def unwrap_model(model): |
| 536 | + model = accelerator.unwrap_model(model) |
| 537 | + model = model._orig_mod if is_compiled_module(model) else model |
| 538 | + return model |
| 539 | + |
534 | 540 | # `accelerate` 0.16.0 will have better support for customized saving
|
535 | 541 | if version.parse(accelerate.__version__) >= version.parse("0.16.0"):
|
536 | 542 | # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
|
@@ -1044,8 +1050,12 @@ def collate_fn(examples):
|
1044 | 1050 | added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids}
|
1045 | 1051 |
|
1046 | 1052 | model_pred = unet(
|
1047 |
| - concatenated_noisy_latents, timesteps, encoder_hidden_states, added_cond_kwargs=added_cond_kwargs |
1048 |
| - ).sample |
| 1053 | + concatenated_noisy_latents, |
| 1054 | + timesteps, |
| 1055 | + encoder_hidden_states, |
| 1056 | + added_cond_kwargs=added_cond_kwargs, |
| 1057 | + return_dict=False, |
| 1058 | + )[0] |
1049 | 1059 | loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
|
1050 | 1060 |
|
1051 | 1061 | # Gather the losses across all processes for logging (if we use distributed training).
|
@@ -1115,7 +1125,7 @@ def collate_fn(examples):
|
1115 | 1125 | # The models need unwrapping because for compatibility in distributed training mode.
|
1116 | 1126 | pipeline = StableDiffusionXLInstructPix2PixPipeline.from_pretrained(
|
1117 | 1127 | args.pretrained_model_name_or_path,
|
1118 |
| - unet=accelerator.unwrap_model(unet), |
| 1128 | + unet=unwrap_model(unet), |
1119 | 1129 | text_encoder=text_encoder_1,
|
1120 | 1130 | text_encoder_2=text_encoder_2,
|
1121 | 1131 | tokenizer=tokenizer_1,
|
@@ -1177,7 +1187,7 @@ def collate_fn(examples):
|
1177 | 1187 | # Create the pipeline using the trained modules and save it.
|
1178 | 1188 | accelerator.wait_for_everyone()
|
1179 | 1189 | if accelerator.is_main_process:
|
1180 |
| - unet = accelerator.unwrap_model(unet) |
| 1190 | + unet = unwrap_model(unet) |
1181 | 1191 | if args.use_ema:
|
1182 | 1192 | ema_unet.copy_to(unet.parameters())
|
1183 | 1193 |
|
|
0 commit comments