Skip to content

Commit 33d2b5b

Browse files
authored
SD text-to-image torch compile compatible (#6519)
* added unwrapper * fiz typo
1 parent f486d34 commit 33d2b5b

File tree

1 file changed

+10
-3
lines changed

1 file changed

+10
-3
lines changed

examples/text_to_image/train_text_to_image.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@
4646
from diffusers.training_utils import EMAModel, compute_snr
4747
from diffusers.utils import check_min_version, deprecate, is_wandb_available, make_image_grid
4848
from diffusers.utils.import_utils import is_xformers_available
49+
from diffusers.utils.torch_utils import is_compiled_module
4950

5051

5152
if is_wandb_available():
@@ -833,6 +834,12 @@ def collate_fn(examples):
833834
tracker_config.pop("validation_prompts")
834835
accelerator.init_trackers(args.tracker_project_name, tracker_config)
835836

837+
# Function for unwrapping if model was compiled with `torch.compile`.
838+
def unwrap_model(model):
839+
model = accelerator.unwrap_model(model)
840+
model = model._orig_mod if is_compiled_module(model) else model
841+
return model
842+
836843
# Train!
837844
total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
838845

@@ -912,7 +919,7 @@ def collate_fn(examples):
912919
noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
913920

914921
# Get the text embedding for conditioning
915-
encoder_hidden_states = text_encoder(batch["input_ids"])[0]
922+
encoder_hidden_states = text_encoder(batch["input_ids"], return_dict=False)[0]
916923

917924
# Get the target for loss depending on the prediction type
918925
if args.prediction_type is not None:
@@ -927,7 +934,7 @@ def collate_fn(examples):
927934
raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
928935

929936
# Predict the noise residual and compute loss
930-
model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
937+
model_pred = unet(noisy_latents, timesteps, encoder_hidden_states, return_dict=False)[0]
931938

932939
if args.snr_gamma is None:
933940
loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
@@ -1023,7 +1030,7 @@ def collate_fn(examples):
10231030
# Create the pipeline using the trained modules and save it.
10241031
accelerator.wait_for_everyone()
10251032
if accelerator.is_main_process:
1026-
unet = accelerator.unwrap_model(unet)
1033+
unet = unwrap_model(unet)
10271034
if args.use_ema:
10281035
ema_unet.copy_to(unet.parameters())
10291036

0 commit comments

Comments
 (0)