Skip to content

Commit 08702fc

Browse files
authored
Make text-to-image SDXL LoRA Training Script torch.compile compatible (#6556)
make compile compatible
1 parent 7ce89e9 commit 08702fc

File tree

1 file changed

+26
-17
lines changed

1 file changed

+26
-17
lines changed

examples/text_to_image/train_text_to_image_lora_sdxl.py

Lines changed: 26 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@
5454
from diffusers.training_utils import cast_training_params, compute_snr
5555
from diffusers.utils import check_min_version, convert_state_dict_to_diffusers, is_wandb_available
5656
from diffusers.utils.import_utils import is_xformers_available
57+
from diffusers.utils.torch_utils import is_compiled_module
5758

5859

5960
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
@@ -460,13 +461,12 @@ def encode_prompt(text_encoders, tokenizers, prompt, text_input_ids_list=None):
460461
text_input_ids = text_input_ids_list[i]
461462

462463
prompt_embeds = text_encoder(
463-
text_input_ids.to(text_encoder.device),
464-
output_hidden_states=True,
464+
text_input_ids.to(text_encoder.device), output_hidden_states=True, return_dict=False
465465
)
466466

467467
# We are only ALWAYS interested in the pooled output of the final text encoder
468468
pooled_prompt_embeds = prompt_embeds[0]
469-
prompt_embeds = prompt_embeds.hidden_states[-2]
469+
prompt_embeds = prompt_embeds[-1][-2]
470470
bs_embed, seq_len, _ = prompt_embeds.shape
471471
prompt_embeds = prompt_embeds.view(bs_embed, seq_len, -1)
472472
prompt_embeds_list.append(prompt_embeds)
@@ -637,6 +637,11 @@ def main(args):
637637
# only upcast trainable parameters (LoRA) into fp32
638638
cast_training_params(models, dtype=torch.float32)
639639

640+
def unwrap_model(model):
641+
model = accelerator.unwrap_model(model)
642+
model = model._orig_mod if is_compiled_module(model) else model
643+
return model
644+
640645
# create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
641646
def save_model_hook(models, weights, output_dir):
642647
if accelerator.is_main_process:
@@ -647,13 +652,13 @@ def save_model_hook(models, weights, output_dir):
647652
text_encoder_two_lora_layers_to_save = None
648653

649654
for model in models:
650-
if isinstance(model, type(accelerator.unwrap_model(unet))):
655+
if isinstance(model, type(unwrap_model(unet))):
651656
unet_lora_layers_to_save = convert_state_dict_to_diffusers(get_peft_model_state_dict(model))
652-
elif isinstance(model, type(accelerator.unwrap_model(text_encoder_one))):
657+
elif isinstance(model, type(unwrap_model(text_encoder_one))):
653658
text_encoder_one_lora_layers_to_save = convert_state_dict_to_diffusers(
654659
get_peft_model_state_dict(model)
655660
)
656-
elif isinstance(model, type(accelerator.unwrap_model(text_encoder_two))):
661+
elif isinstance(model, type(unwrap_model(text_encoder_two))):
657662
text_encoder_two_lora_layers_to_save = convert_state_dict_to_diffusers(
658663
get_peft_model_state_dict(model)
659664
)
@@ -678,11 +683,11 @@ def load_model_hook(models, input_dir):
678683
while len(models) > 0:
679684
model = models.pop()
680685

681-
if isinstance(model, type(accelerator.unwrap_model(unet))):
686+
if isinstance(model, type(unwrap_model(unet))):
682687
unet_ = model
683-
elif isinstance(model, type(accelerator.unwrap_model(text_encoder_one))):
688+
elif isinstance(model, type(unwrap_model(text_encoder_one))):
684689
text_encoder_one_ = model
685-
elif isinstance(model, type(accelerator.unwrap_model(text_encoder_two))):
690+
elif isinstance(model, type(unwrap_model(text_encoder_two))):
686691
text_encoder_two_ = model
687692
else:
688693
raise ValueError(f"unexpected save model: {model.__class__}")
@@ -1031,8 +1036,12 @@ def compute_time_ids(original_size, crops_coords_top_left):
10311036
)
10321037
unet_added_conditions.update({"text_embeds": pooled_prompt_embeds})
10331038
model_pred = unet(
1034-
noisy_model_input, timesteps, prompt_embeds, added_cond_kwargs=unet_added_conditions
1035-
).sample
1039+
noisy_model_input,
1040+
timesteps,
1041+
prompt_embeds,
1042+
added_cond_kwargs=unet_added_conditions,
1043+
return_dict=False,
1044+
)[0]
10361045

10371046
# Get the target for loss depending on the prediction type
10381047
if args.prediction_type is not None:
@@ -1125,9 +1134,9 @@ def compute_time_ids(original_size, crops_coords_top_left):
11251134
pipeline = StableDiffusionXLPipeline.from_pretrained(
11261135
args.pretrained_model_name_or_path,
11271136
vae=vae,
1128-
text_encoder=accelerator.unwrap_model(text_encoder_one),
1129-
text_encoder_2=accelerator.unwrap_model(text_encoder_two),
1130-
unet=accelerator.unwrap_model(unet),
1137+
text_encoder=unwrap_model(text_encoder_one),
1138+
text_encoder_2=unwrap_model(text_encoder_two),
1139+
unet=unwrap_model(unet),
11311140
revision=args.revision,
11321141
variant=args.variant,
11331142
torch_dtype=weight_dtype,
@@ -1166,12 +1175,12 @@ def compute_time_ids(original_size, crops_coords_top_left):
11661175
# Save the lora layers
11671176
accelerator.wait_for_everyone()
11681177
if accelerator.is_main_process:
1169-
unet = accelerator.unwrap_model(unet)
1178+
unet = unwrap_model(unet)
11701179
unet_lora_state_dict = convert_state_dict_to_diffusers(get_peft_model_state_dict(unet))
11711180

11721181
if args.train_text_encoder:
1173-
text_encoder_one = accelerator.unwrap_model(text_encoder_one)
1174-
text_encoder_two = accelerator.unwrap_model(text_encoder_two)
1182+
text_encoder_one = unwrap_model(text_encoder_one)
1183+
text_encoder_two = unwrap_model(text_encoder_two)
11751184

11761185
text_encoder_lora_layers = convert_state_dict_to_diffusers(get_peft_model_state_dict(text_encoder_one))
11771186
text_encoder_2_lora_layers = convert_state_dict_to_diffusers(get_peft_model_state_dict(text_encoder_two))

0 commit comments

Comments
 (0)