Skip to content

Commit 7ce89e9

Browse files
authored
Make text-to-image SD LoRA Training Script torch.compile compatible (#6555)
make compile compatible
1 parent 05faf32 commit 7ce89e9

File tree

1 file changed

+11
-5
lines changed

1 file changed

+11
-5
lines changed

examples/text_to_image/train_text_to_image_lora.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@
4646
from diffusers.training_utils import cast_training_params, compute_snr
4747
from diffusers.utils import check_min_version, convert_state_dict_to_diffusers, is_wandb_available
4848
from diffusers.utils.import_utils import is_xformers_available
49+
from diffusers.utils.torch_utils import is_compiled_module
4950

5051

5152
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
@@ -596,6 +597,11 @@ def tokenize_captions(examples, is_train=True):
596597
]
597598
)
598599

600+
def unwrap_model(model):
601+
model = accelerator.unwrap_model(model)
602+
model = model._orig_mod if is_compiled_module(model) else model
603+
return model
604+
599605
def preprocess_train(examples):
600606
images = [image.convert("RGB") for image in examples[image_column]]
601607
examples["pixel_values"] = [train_transforms(image) for image in images]
@@ -729,7 +735,7 @@ def collate_fn(examples):
729735
noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
730736

731737
# Get the text embedding for conditioning
732-
encoder_hidden_states = text_encoder(batch["input_ids"])[0]
738+
encoder_hidden_states = text_encoder(batch["input_ids"], return_dict=False)[0]
733739

734740
# Get the target for loss depending on the prediction type
735741
if args.prediction_type is not None:
@@ -744,7 +750,7 @@ def collate_fn(examples):
744750
raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
745751

746752
# Predict the noise residual and compute loss
747-
model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
753+
model_pred = unet(noisy_latents, timesteps, encoder_hidden_states, return_dict=False)[0]
748754

749755
if args.snr_gamma is None:
750756
loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
@@ -809,7 +815,7 @@ def collate_fn(examples):
809815
save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
810816
accelerator.save_state(save_path)
811817

812-
unwrapped_unet = accelerator.unwrap_model(unet)
818+
unwrapped_unet = unwrap_model(unet)
813819
unet_lora_state_dict = convert_state_dict_to_diffusers(
814820
get_peft_model_state_dict(unwrapped_unet)
815821
)
@@ -837,7 +843,7 @@ def collate_fn(examples):
837843
# create pipeline
838844
pipeline = DiffusionPipeline.from_pretrained(
839845
args.pretrained_model_name_or_path,
840-
unet=accelerator.unwrap_model(unet),
846+
unet=unwrap_model(unet),
841847
revision=args.revision,
842848
variant=args.variant,
843849
torch_dtype=weight_dtype,
@@ -878,7 +884,7 @@ def collate_fn(examples):
878884
if accelerator.is_main_process:
879885
unet = unet.to(torch.float32)
880886

881-
unwrapped_unet = accelerator.unwrap_model(unet)
887+
unwrapped_unet = unwrap_model(unet)
882888
unet_lora_state_dict = convert_state_dict_to_diffusers(get_peft_model_state_dict(unwrapped_unet))
883889
StableDiffusionPipeline.save_lora_weights(
884890
save_directory=args.output_dir,

0 commit comments

Comments
 (0)