Fix pipeline dtype unexpected change when using SDXL reference community pipelines in float16 mode #10670
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
What does this PR do?
When using one of the SDXL reference pipelines, configuring the pipeline with the
torch_dtype
parameter set totorch.float16
and using a VAE with theforce_upcast
configuration, an upcast of the VAE totorch.float32
is performed during inference. However, it is never downcast totorch.float16
which leads to side effects.Indeed, the dtype of the pipeline is defined in
pipelines/pipeline_utils.py
as:When the VAE comes first, the entire pipeline dtype is no longer set to float16 but to float32.
In my case, it causes the following error when inferring the unet and using an IP adapter at the same time:
This PR aims to address this issue by downcasting the VAE to
float16
after computing the reference image latents.Additionally, I'd like to mention that
module_names
in the abovedtype
property does not seam to be retrieved deterministically, despite using the same code to instantiate my pipeline. For instance, sometimes with the controlnet reference pipeline I have:And sometimes:
In the first case, the pipeline
dtype
is not changed tofloat32
. In the second case it is.Bug reproducibility becomes complicated then.
Perhaps something could be done in this part of the code to remove this non-deterministic part?
Before submitting
documentation guidelines, and
here are tips on formatting docstrings.
Who can review?
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.