diff --git a/examples/text_to_image/train_text_to_image_lora.py b/examples/text_to_image/train_text_to_image_lora.py index 27dedc8f7fd1..9aa66499cf5e 100644 --- a/examples/text_to_image/train_text_to_image_lora.py +++ b/examples/text_to_image/train_text_to_image_lora.py @@ -46,6 +46,7 @@ from diffusers.training_utils import compute_snr from diffusers.utils import check_min_version, convert_state_dict_to_diffusers, is_wandb_available from diffusers.utils.import_utils import is_xformers_available +from diffusers.utils.torch_utils import is_compiled_module # Will error if the minimal version of diffusers is not installed. Remove at your own risks. @@ -598,6 +599,11 @@ def tokenize_captions(examples, is_train=True): ] ) + def unwrap_model(model): + model = accelerator.unwrap_model(model) + model = model._orig_mod if is_compiled_module(model) else model + return model + def preprocess_train(examples): images = [image.convert("RGB") for image in examples[image_column]] examples["pixel_values"] = [train_transforms(image) for image in images] @@ -731,7 +737,7 @@ def collate_fn(examples): noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) # Get the text embedding for conditioning - encoder_hidden_states = text_encoder(batch["input_ids"])[0] + encoder_hidden_states = text_encoder(batch["input_ids"], return_dict=False)[0] # Get the target for loss depending on the prediction type if args.prediction_type is not None: @@ -746,7 +752,7 @@ def collate_fn(examples): raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") # Predict the noise residual and compute loss - model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample + model_pred = unet(noisy_latents, timesteps, encoder_hidden_states, return_dict=False)[0] if args.snr_gamma is None: loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") @@ -811,7 +817,7 @@ def collate_fn(examples): save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}") accelerator.save_state(save_path) - unwrapped_unet = accelerator.unwrap_model(unet) + unwrapped_unet = unwrap_model(unet) unet_lora_state_dict = convert_state_dict_to_diffusers( get_peft_model_state_dict(unwrapped_unet) ) @@ -839,7 +845,7 @@ def collate_fn(examples): # create pipeline pipeline = DiffusionPipeline.from_pretrained( args.pretrained_model_name_or_path, - unet=accelerator.unwrap_model(unet), + unet=unwrap_model(unet), revision=args.revision, variant=args.variant, torch_dtype=weight_dtype, @@ -880,7 +886,7 @@ def collate_fn(examples): if accelerator.is_main_process: unet = unet.to(torch.float32) - unwrapped_unet = accelerator.unwrap_model(unet) + unwrapped_unet = unwrap_model(unet) unet_lora_state_dict = convert_state_dict_to_diffusers(get_peft_model_state_dict(unwrapped_unet)) StableDiffusionPipeline.save_lora_weights( save_directory=args.output_dir,