-
Notifications
You must be signed in to change notification settings - Fork 6k
[Training] make DreamBooth SDXL LoRA training script compatible with torch.compile #6483
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
@@ -780,13 +780,12 @@ def encode_prompt(text_encoders, tokenizers, prompt, text_input_ids_list=None): | |||
text_input_ids = text_input_ids_list[i] | |||
|
|||
prompt_embeds = text_encoder( | |||
text_input_ids.to(text_encoder.device), | |||
output_hidden_states=True, | |||
text_input_ids.to(text_encoder.device), output_hidden_states=True, return_dict=False |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why do we need to add output_hidden_states=True
here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It was already there. Look again, please.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
True here too:
diffusers/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py
Line 415 in aa1797e
prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah sorry yeah indeed we need it! Thanks for double-checking
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Very good idea to add return_dict=False
everywhere, I don't fully understand why we set output_hidden_states=True
here
@patrickvonplaten mergeable now? |
…torch.compile (huggingface#6483) * make it torch.compile comaptible * make the text encoder compatible too. * style
What does this PR do?
Makes the DreamBooth SDXL LoRA compatible with
torch.compile()
. This has been requested by the community for a while.My
accelerate
config:I think this PR can definitely serve as a reference for contributors willing to do this for the other scripts (I will create an issue for that later).