Skip to content

Make text-to-image SDXL LoRA Training Script torch.compile compatible #6556

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jan 15, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 26 additions & 17 deletions examples/text_to_image/train_text_to_image_lora_sdxl.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@
from diffusers.training_utils import compute_snr
from diffusers.utils import check_min_version, convert_state_dict_to_diffusers, is_wandb_available
from diffusers.utils.import_utils import is_xformers_available
from diffusers.utils.torch_utils import is_compiled_module


# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
Expand Down Expand Up @@ -460,13 +461,12 @@ def encode_prompt(text_encoders, tokenizers, prompt, text_input_ids_list=None):
text_input_ids = text_input_ids_list[i]

prompt_embeds = text_encoder(
text_input_ids.to(text_encoder.device),
output_hidden_states=True,
text_input_ids.to(text_encoder.device), output_hidden_states=True, return_dict=False
)

# We are only ALWAYS interested in the pooled output of the final text encoder
pooled_prompt_embeds = prompt_embeds[0]
prompt_embeds = prompt_embeds.hidden_states[-2]
prompt_embeds = prompt_embeds[-1][-2]
bs_embed, seq_len, _ = prompt_embeds.shape
prompt_embeds = prompt_embeds.view(bs_embed, seq_len, -1)
prompt_embeds_list.append(prompt_embeds)
Expand Down Expand Up @@ -640,6 +640,11 @@ def main(args):
if param.requires_grad:
param.data = param.to(torch.float32)

def unwrap_model(model):
model = accelerator.unwrap_model(model)
model = model._orig_mod if is_compiled_module(model) else model
return model

# create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
def save_model_hook(models, weights, output_dir):
if accelerator.is_main_process:
Expand All @@ -650,13 +655,13 @@ def save_model_hook(models, weights, output_dir):
text_encoder_two_lora_layers_to_save = None

for model in models:
if isinstance(model, type(accelerator.unwrap_model(unet))):
if isinstance(model, type(unwrap_model(unet))):
unet_lora_layers_to_save = convert_state_dict_to_diffusers(get_peft_model_state_dict(model))
elif isinstance(model, type(accelerator.unwrap_model(text_encoder_one))):
elif isinstance(model, type(unwrap_model(text_encoder_one))):
text_encoder_one_lora_layers_to_save = convert_state_dict_to_diffusers(
get_peft_model_state_dict(model)
)
elif isinstance(model, type(accelerator.unwrap_model(text_encoder_two))):
elif isinstance(model, type(unwrap_model(text_encoder_two))):
text_encoder_two_lora_layers_to_save = convert_state_dict_to_diffusers(
get_peft_model_state_dict(model)
)
Expand All @@ -681,11 +686,11 @@ def load_model_hook(models, input_dir):
while len(models) > 0:
model = models.pop()

if isinstance(model, type(accelerator.unwrap_model(unet))):
if isinstance(model, type(unwrap_model(unet))):
unet_ = model
elif isinstance(model, type(accelerator.unwrap_model(text_encoder_one))):
elif isinstance(model, type(unwrap_model(text_encoder_one))):
text_encoder_one_ = model
elif isinstance(model, type(accelerator.unwrap_model(text_encoder_two))):
elif isinstance(model, type(unwrap_model(text_encoder_two))):
text_encoder_two_ = model
else:
raise ValueError(f"unexpected save model: {model.__class__}")
Expand Down Expand Up @@ -1034,8 +1039,12 @@ def compute_time_ids(original_size, crops_coords_top_left):
)
unet_added_conditions.update({"text_embeds": pooled_prompt_embeds})
model_pred = unet(
noisy_model_input, timesteps, prompt_embeds, added_cond_kwargs=unet_added_conditions
).sample
noisy_model_input,
timesteps,
prompt_embeds,
added_cond_kwargs=unet_added_conditions,
return_dict=False,
)[0]

# Get the target for loss depending on the prediction type
if args.prediction_type is not None:
Expand Down Expand Up @@ -1128,9 +1137,9 @@ def compute_time_ids(original_size, crops_coords_top_left):
pipeline = StableDiffusionXLPipeline.from_pretrained(
args.pretrained_model_name_or_path,
vae=vae,
text_encoder=accelerator.unwrap_model(text_encoder_one),
text_encoder_2=accelerator.unwrap_model(text_encoder_two),
unet=accelerator.unwrap_model(unet),
text_encoder=unwrap_model(text_encoder_one),
text_encoder_2=unwrap_model(text_encoder_two),
unet=unwrap_model(unet),
revision=args.revision,
variant=args.variant,
torch_dtype=weight_dtype,
Expand Down Expand Up @@ -1169,12 +1178,12 @@ def compute_time_ids(original_size, crops_coords_top_left):
# Save the lora layers
accelerator.wait_for_everyone()
if accelerator.is_main_process:
unet = accelerator.unwrap_model(unet)
unet = unwrap_model(unet)
unet_lora_state_dict = convert_state_dict_to_diffusers(get_peft_model_state_dict(unet))

if args.train_text_encoder:
text_encoder_one = accelerator.unwrap_model(text_encoder_one)
text_encoder_two = accelerator.unwrap_model(text_encoder_two)
text_encoder_one = unwrap_model(text_encoder_one)
text_encoder_two = unwrap_model(text_encoder_two)

text_encoder_lora_layers = convert_state_dict_to_diffusers(get_peft_model_state_dict(text_encoder_one))
text_encoder_2_lora_layers = convert_state_dict_to_diffusers(get_peft_model_state_dict(text_encoder_two))
Expand Down