Skip to content

[pipelines] allow models to run with a user-provided dtype map instead of a single dtype #10108

Closed
@sayakpaul

Description

@sayakpaul

The newer models like Mochi-1 run the text encoder and VAE decoding in FP32 while keeping the denoising process in torch.bfloat16 autocast.

Currently, it's not possible for our pipelines to run the different models involved as we set a global torch_dtype while initializing the pipeline.

We have some pipelines like SDXL where the VAE has a config attribute called force_upcast and it's handled within the pipeline implementation like so:

if not output_type == "latent":
# make sure the VAE is in float32 mode, as it overflows in float16
needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast
if needs_upcasting:
self.upcast_vae()
latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype)
elif latents.dtype != self.vae.dtype:
if torch.backends.mps.is_available():
# some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
self.vae = self.vae.to(latents.dtype)

Another way to achieve this could be to decouple the major computation stages of the pipeline and users can choose whatever supported torch_dtype they want. Here is an example.

But this an involved process and is a power-user thing, IMO. What if we could allow the users to pass a torch_dtype map like so:

{"unet": torch.bfloat16, "vae": torch.float32, "text_encoder": torch.float32}

This along with @a-r-r-o-w's idea of an upcast marker could really benefit the pipelines that are not resilient to precision changes.

Cc: @DN6 @yiyixuxu @hlky

Metadata

Metadata

Assignees

No one assigned

    Labels

    enhancementNew feature or request

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions