diff --git a/examples/text_to_image/train_text_to_image.py b/examples/text_to_image/train_text_to_image.py index 27d7bc54432f..fe62c8133c31 100644 --- a/examples/text_to_image/train_text_to_image.py +++ b/examples/text_to_image/train_text_to_image.py @@ -46,6 +46,7 @@ from diffusers.training_utils import EMAModel, compute_snr from diffusers.utils import check_min_version, deprecate, is_wandb_available, make_image_grid from diffusers.utils.import_utils import is_xformers_available +from diffusers.utils.torch_utils import is_compiled_module if is_wandb_available(): @@ -833,6 +834,12 @@ def collate_fn(examples): tracker_config.pop("validation_prompts") accelerator.init_trackers(args.tracker_project_name, tracker_config) + # Function for unwrapping if model was compiled with `torch.compile`. + def unwrap_model(model): + model = accelerator.unwrap_model(model) + model = model._orig_mod if is_compiled_module(model) else model + return model + # Train! total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps @@ -912,7 +919,7 @@ def collate_fn(examples): noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) # Get the text embedding for conditioning - encoder_hidden_states = text_encoder(batch["input_ids"])[0] + encoder_hidden_states = text_encoder(batch["input_ids"], return_dict=False)[0] # Get the target for loss depending on the prediction type if args.prediction_type is not None: @@ -927,7 +934,7 @@ def collate_fn(examples): raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") # Predict the noise residual and compute loss - model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample + model_pred = unet(noisy_latents, timesteps, encoder_hidden_states, return_dict=False)[0] if args.snr_gamma is None: loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") @@ -1023,7 +1030,7 @@ def collate_fn(examples): # Create the pipeline using the trained modules and save it. accelerator.wait_for_everyone() if accelerator.is_main_process: - unet = accelerator.unwrap_model(unet) + unet = unwrap_model(unet) if args.use_ema: ema_unet.copy_to(unet.parameters())