Skip to content

Commit e3103e1

Browse files
authored
Make InstructPix2Pix SDXL Training Script torch.compile compatible (#6576)
* changes for pix2pix_sdxl * style fix
1 parent b053053 commit e3103e1

File tree

1 file changed

+14
-4
lines changed

1 file changed

+14
-4
lines changed

examples/instruct_pix2pix/train_instruct_pix2pix_sdxl.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@
5252
from diffusers.training_utils import EMAModel
5353
from diffusers.utils import check_min_version, deprecate, is_wandb_available, load_image
5454
from diffusers.utils.import_utils import is_xformers_available
55+
from diffusers.utils.torch_utils import is_compiled_module
5556

5657

5758
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
@@ -531,6 +532,11 @@ def main():
531532
else:
532533
raise ValueError("xformers is not available. Make sure it is installed correctly")
533534

535+
def unwrap_model(model):
536+
model = accelerator.unwrap_model(model)
537+
model = model._orig_mod if is_compiled_module(model) else model
538+
return model
539+
534540
# `accelerate` 0.16.0 will have better support for customized saving
535541
if version.parse(accelerate.__version__) >= version.parse("0.16.0"):
536542
# create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
@@ -1044,8 +1050,12 @@ def collate_fn(examples):
10441050
added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids}
10451051

10461052
model_pred = unet(
1047-
concatenated_noisy_latents, timesteps, encoder_hidden_states, added_cond_kwargs=added_cond_kwargs
1048-
).sample
1053+
concatenated_noisy_latents,
1054+
timesteps,
1055+
encoder_hidden_states,
1056+
added_cond_kwargs=added_cond_kwargs,
1057+
return_dict=False,
1058+
)[0]
10491059
loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
10501060

10511061
# Gather the losses across all processes for logging (if we use distributed training).
@@ -1115,7 +1125,7 @@ def collate_fn(examples):
11151125
# The models need unwrapping because for compatibility in distributed training mode.
11161126
pipeline = StableDiffusionXLInstructPix2PixPipeline.from_pretrained(
11171127
args.pretrained_model_name_or_path,
1118-
unet=accelerator.unwrap_model(unet),
1128+
unet=unwrap_model(unet),
11191129
text_encoder=text_encoder_1,
11201130
text_encoder_2=text_encoder_2,
11211131
tokenizer=tokenizer_1,
@@ -1177,7 +1187,7 @@ def collate_fn(examples):
11771187
# Create the pipeline using the trained modules and save it.
11781188
accelerator.wait_for_everyone()
11791189
if accelerator.is_main_process:
1180-
unet = accelerator.unwrap_model(unet)
1190+
unet = unwrap_model(unet)
11811191
if args.use_ema:
11821192
ema_unet.copy_to(unet.parameters())
11831193

0 commit comments

Comments
 (0)