Description
The newer models like Mochi-1 run the text encoder and VAE decoding in FP32 while keeping the denoising process in torch.bfloat16
autocast.
Currently, it's not possible for our pipelines to run the different models involved as we set a global torch_dtype
while initializing the pipeline.
We have some pipelines like SDXL where the VAE has a config attribute called force_upcast
and it's handled within the pipeline implementation like so:
diffusers/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py
Lines 1264 to 1275 in cfdeebd
Another way to achieve this could be to decouple the major computation stages of the pipeline and users can choose whatever supported torch_dtype
they want. Here is an example.
But this an involved process and is a power-user thing, IMO. What if we could allow the users to pass a torch_dtype
map like so:
{"unet": torch.bfloat16, "vae": torch.float32, "text_encoder": torch.float32}
This along with @a-r-r-o-w's idea of an upcast marker could really benefit the pipelines that are not resilient to precision changes.