Skip to content

Commit b053053

Browse files
authored
Make InstructPix2Pix Training Script torch.compile compatible (#6558)
* added torch.compile for pix2pix * required changes
1 parent 08702fc commit b053053

File tree

1 file changed

+13
-7
lines changed

1 file changed

+13
-7
lines changed

examples/instruct_pix2pix/train_instruct_pix2pix.py

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@
4949
from diffusers.training_utils import EMAModel
5050
from diffusers.utils import check_min_version, deprecate, is_wandb_available
5151
from diffusers.utils.import_utils import is_xformers_available
52+
from diffusers.utils.torch_utils import is_compiled_module
5253

5354

5455
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
@@ -489,6 +490,11 @@ def main():
489490
else:
490491
raise ValueError("xformers is not available. Make sure it is installed correctly")
491492

493+
def unwrap_model(model):
494+
model = accelerator.unwrap_model(model)
495+
model = model._orig_mod if is_compiled_module(model) else model
496+
return model
497+
492498
# `accelerate` 0.16.0 will have better support for customized saving
493499
if version.parse(accelerate.__version__) >= version.parse("0.16.0"):
494500
# create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
@@ -845,7 +851,7 @@ def collate_fn(examples):
845851
raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
846852

847853
# Predict the noise residual and compute loss
848-
model_pred = unet(concatenated_noisy_latents, timesteps, encoder_hidden_states).sample
854+
model_pred = unet(concatenated_noisy_latents, timesteps, encoder_hidden_states, return_dict=False)[0]
849855
loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
850856

851857
# Gather the losses across all processes for logging (if we use distributed training).
@@ -919,9 +925,9 @@ def collate_fn(examples):
919925
# The models need unwrapping because for compatibility in distributed training mode.
920926
pipeline = StableDiffusionInstructPix2PixPipeline.from_pretrained(
921927
args.pretrained_model_name_or_path,
922-
unet=accelerator.unwrap_model(unet),
923-
text_encoder=accelerator.unwrap_model(text_encoder),
924-
vae=accelerator.unwrap_model(vae),
928+
unet=unwrap_model(unet),
929+
text_encoder=unwrap_model(text_encoder),
930+
vae=unwrap_model(vae),
925931
revision=args.revision,
926932
variant=args.variant,
927933
torch_dtype=weight_dtype,
@@ -965,14 +971,14 @@ def collate_fn(examples):
965971
# Create the pipeline using the trained modules and save it.
966972
accelerator.wait_for_everyone()
967973
if accelerator.is_main_process:
968-
unet = accelerator.unwrap_model(unet)
974+
unet = unwrap_model(unet)
969975
if args.use_ema:
970976
ema_unet.copy_to(unet.parameters())
971977

972978
pipeline = StableDiffusionInstructPix2PixPipeline.from_pretrained(
973979
args.pretrained_model_name_or_path,
974-
text_encoder=accelerator.unwrap_model(text_encoder),
975-
vae=accelerator.unwrap_model(vae),
980+
text_encoder=unwrap_model(text_encoder),
981+
vae=unwrap_model(vae),
976982
unet=unet,
977983
revision=args.revision,
978984
variant=args.variant,

0 commit comments

Comments
 (0)