Skip to content

[Training] make checkpointing compatible when using torch.compile (part II) #6511

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Jan 11, 2024

Conversation

sayakpaul
Copy link
Member

What does this PR do?

Follow-up of: #6483

@sayakpaul sayakpaul changed the title [Training] make checkpointing compatible when using torch.compile. [Training] make checkpointing compatible when using torch.compile (part II) Jan 10, 2024
@@ -1621,16 +1627,16 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers):
# Save the lora layers
accelerator.wait_for_everyone()
if accelerator.is_main_process:
unet = accelerator.unwrap_model(unet)
unet = unwrap_model(unet)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For intermediate places (such as performing validation inference) where we do use accelerator.unwrap_model() -- it's not an issue as the models are directly used. But here, since we're obtaining the state dicts, we need to get the _orig_mod out in case torch.compile() was called. LMK if it's not clear.

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Copy link
Contributor

@patrickvonplaten patrickvonplaten left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice!

@sayakpaul sayakpaul merged commit be0b425 into main Jan 11, 2024
@sayakpaul sayakpaul deleted the torch-compile-compatible-training-ii branch January 11, 2024 13:08
AmericanPresidentJimmyCarter pushed a commit to AmericanPresidentJimmyCarter/diffusers that referenced this pull request Apr 26, 2024
…part II) (huggingface#6511)

make checkpointing compatible when using torch.compile.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants