Skip to content

[Training] make DreamBooth SDXL LoRA training script compatible with torch.compile #6483

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Jan 9, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 10 additions & 6 deletions examples/dreambooth/train_dreambooth_lora_sdxl.py
Original file line number Diff line number Diff line change
Expand Up @@ -780,13 +780,12 @@ def encode_prompt(text_encoders, tokenizers, prompt, text_input_ids_list=None):
text_input_ids = text_input_ids_list[i]

prompt_embeds = text_encoder(
text_input_ids.to(text_encoder.device),
output_hidden_states=True,
text_input_ids.to(text_encoder.device), output_hidden_states=True, return_dict=False
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we need to add output_hidden_states=True here?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It was already there. Look again, please.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

True here too:

prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah sorry yeah indeed we need it! Thanks for double-checking

)

# We are only ALWAYS interested in the pooled output of the final text encoder
pooled_prompt_embeds = prompt_embeds[0]
prompt_embeds = prompt_embeds.hidden_states[-2]
prompt_embeds = prompt_embeds[-1][-2]
bs_embed, seq_len, _ = prompt_embeds.shape
prompt_embeds = prompt_embeds.view(bs_embed, seq_len, -1)
prompt_embeds_list.append(prompt_embeds)
Expand Down Expand Up @@ -1429,7 +1428,8 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers):
timesteps,
prompt_embeds_input,
added_cond_kwargs=unet_added_conditions,
).sample
return_dict=False,
)[0]
else:
unet_added_conditions = {"time_ids": add_time_ids.repeat(elems_to_repeat_time_ids, 1)}
prompt_embeds, pooled_prompt_embeds = encode_prompt(
Expand All @@ -1443,8 +1443,12 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers):
)
prompt_embeds_input = prompt_embeds.repeat(elems_to_repeat_text_embeds, 1, 1)
model_pred = unet(
noisy_model_input, timesteps, prompt_embeds_input, added_cond_kwargs=unet_added_conditions
).sample
noisy_model_input,
timesteps,
prompt_embeds_input,
added_cond_kwargs=unet_added_conditions,
return_dict=False,
)[0]

# Get the target for loss depending on the prediction type
if noise_scheduler.config.prediction_type == "epsilon":
Expand Down