Skip to content

[Training] make DreamBooth SDXL LoRA training script compatible with torch.compile #6483

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Jan 9, 2024

Conversation

sayakpaul
Copy link
Member

What does this PR do?

Makes the DreamBooth SDXL LoRA compatible with torch.compile(). This has been requested by the community for a while.

My accelerate config:

compute_environment: LOCAL_MACHINE
debug: false
distributed_type: 'NO'
downcast_bf16: 'no'
dynamo_config:
  dynamo_backend: INDUCTOR
gpu_ids: '0'
machine_rank: 0
main_training_function: main
mixed_precision: fp16
num_machines: 1
num_processes: 1
rdzv_backend: static
same_network: true
tpu_env: []
tpu_use_cluster: false
tpu_use_sudo: false
use_cpu: false

I think this PR can definitely serve as a reference for contributors willing to do this for the other scripts (I will create an issue for that later).

@@ -780,13 +780,12 @@ def encode_prompt(text_encoders, tokenizers, prompt, text_input_ids_list=None):
text_input_ids = text_input_ids_list[i]

prompt_embeds = text_encoder(
text_input_ids.to(text_encoder.device),
output_hidden_states=True,
text_input_ids.to(text_encoder.device), output_hidden_states=True, return_dict=False
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we need to add output_hidden_states=True here?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It was already there. Look again, please.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

True here too:

prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah sorry yeah indeed we need it! Thanks for double-checking

Copy link
Contributor

@patrickvonplaten patrickvonplaten left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very good idea to add return_dict=False everywhere, I don't fully understand why we set output_hidden_states=True here

@sayakpaul
Copy link
Member Author

@patrickvonplaten mergeable now?

@sayakpaul sayakpaul merged commit 4497b3e into main Jan 9, 2024
@sayakpaul sayakpaul deleted the training/torch-compile branch January 9, 2024 14:42
AmericanPresidentJimmyCarter pushed a commit to AmericanPresidentJimmyCarter/diffusers that referenced this pull request Apr 26, 2024
…torch.compile (huggingface#6483)

* make it torch.compile comaptible

* make the text encoder compatible too.

* style
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants