Skip to content

Fix pipeline dtype unexpected change when using SDXL reference community pipelines in float16 mode #10670

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Conversation

dimitribarbot
Copy link
Contributor

What does this PR do?

When using one of the SDXL reference pipelines, configuring the pipeline with the torch_dtype parameter set to torch.float16 and using a VAE with the force_upcast configuration, an upcast of the VAE to torch.float32 is performed during inference. However, it is never downcast to torch.float16 which leads to side effects.

Indeed, the dtype of the pipeline is defined in pipelines/pipeline_utils.py as:

@property
def dtype(self) -> torch.dtype:
    r"""
    Returns:
        `torch.dtype`: The torch dtype on which the pipeline is located.
    """
    module_names, _ = self._get_signature_keys(self)
    modules = [getattr(self, n, None) for n in module_names]
    modules = [m for m in modules if isinstance(m, torch.nn.Module)]

    for module in modules:
        return module.dtype

    return torch.float32

When the VAE comes first, the entire pipeline dtype is no longer set to float16 but to float32.

In my case, it causes the following error when inferring the unet and using an IP adapter at the same time:

RuntimeError: mat1 and mat2 must have the same dtype, but got Float and Half

This PR aims to address this issue by downcasting the VAE to float16 after computing the reference image latents.

Additionally, I'd like to mention that module_names in the above dtype property does not seam to be retrieved deterministically, despite using the same code to instantiate my pipeline. For instance, sometimes with the controlnet reference pipeline I have:

{'tokenizer_2', 'tokenizer', 'text_encoder_2', 'unet', 'controlnet', 'scheduler', 'vae', 'feature_extractor', 'text_encoder', 'image_encoder'}

And sometimes:

{'feature_extractor', 'scheduler', 'vae', 'tokenizer', 'unet', 'text_encoder_2', 'tokenizer_2', 'text_encoder', 'image_encoder', 'controlnet'}

In the first case, the pipeline dtype is not changed to float32. In the second case it is.
Bug reproducibility becomes complicated then.

Perhaps something could be done in this part of the code to remove this non-deterministic part?

Before submitting

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

@asomoza
Copy link
Member

asomoza commented Jan 28, 2025

Hi, thanks. For the deterministic problem, if it is in the community pipelines you're more than welcome to open a new PR to fix it, if you find the same problem in the core, you can open an issue about it, but let's keep this PR simple with the vae fix only.

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@asomoza asomoza merged commit 196aef5 into huggingface:main Jan 28, 2025
9 checks passed
@dimitribarbot
Copy link
Contributor Author

Hi, thanks. For the deterministic problem, if it is in the community pipelines you're more than welcome to open a new PR to fix it, if you find the same problem in the core, you can open an issue about it, but let's keep this PR simple with the vae fix only.

Thank you for the PR review and merge.

I totally agree to keep this PR as simple as possible. The deterministic issue is in the core (DiffusionPipeline), I managed to reproduce it only using a StableDiffusionXLPipeline. I will follow your advice and open a new issue with the reproducible scenario.

@dimitribarbot dimitribarbot deleted the fix-dtype-issue-in-sdxl-reference-pipelines branch January 28, 2025 14:14
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants