Description
Describe the bug
This issue is a follow-up of this PR.
The idea was to fix an issue leading to the following error message:
RuntimeError: mat1 and mat2 must have the same dtype, but got Float and Half
This was due to a dtype mismatch between image input encoded with a VAE using float32 precision and unet configured with float16 precision.
Although this bug is fixed, I would like to mention how difficult it was to debug since it didn't occur every time. After debugging I found that the dtype
property of the diffusion pipeline does not return a value deterministically (it seems the same applies to the device). It uses the dtype of the first pipeline component it finds, but the list of pipeline components is not sorted.
I think this may lead to other mistakes in the future and may require changes.
Reproduction
To reproduce the deterministic issue, you can first modify the DiffusionPipeline
dtype
property in the src/diffusers/pipelines/pipeline_utils.py
file to add some logs:
@property
def dtype(self) -> torch.dtype:
r"""
Returns:
`torch.dtype`: The torch dtype on which the pipeline is located.
"""
module_names, _ = self._get_signature_keys(self)
modules = [getattr(self, n, None) for n in module_names]
modules = [m for m in modules if isinstance(m, torch.nn.Module)]
for module in modules:
print(f"module.dtype is {module.dtype} using {type(module).__name__} from {module_names}.") # <--- Add this line
return module.dtype
return torch.float32
Then create a new python script file with the following content:
from diffusers import StableDiffusionXLPipeline
from diffusers.schedulers import UniPCMultistepScheduler
import torch
# initialize the models and pipeline
pipe = StableDiffusionXLPipeline.from_pretrained(
"stabilityai/stable-diffusion-xl-base-1.0",
torch_dtype=torch.float16,
use_safetensors=True,
variant="fp16"
)
pipe.to(torch.device("cuda"))
pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)
# generate image
print("Pipeline dtype:", pipe.dtype)
image = pipe(prompt="a dog", num_inference_steps=20).images[0]
Run this script multiple times and observe the results (see "Logs" section below).
Logs
1st script run gives:
module.dtype is torch.float16 using CLIPTextModelWithProjection from {'tokenizer', 'text_encoder_2', 'text_encoder', 'vae', 'unet', 'scheduler', 'image_encoder', 'tokenizer_2', 'feature_extractor'}.
Pipeline dtype: torch.float16
2nd script run gives:
module.dtype is torch.float16 using CLIPTextModel from {'tokenizer_2', 'tokenizer', 'text_encoder', 'text_encoder_2', 'vae', 'feature_extractor', 'scheduler', 'image_encoder', 'unet'}.
Pipeline dtype: torch.float16
3rd script run gives:
module.dtype is torch.float16 using UNet2DConditionModel from {'unet', 'tokenizer', 'text_encoder', 'vae', 'text_encoder_2', 'tokenizer_2', 'scheduler', 'image_encoder', 'feature_extractor'}.
Pipeline dtype: torch.float16
4th script run gives:
module.dtype is torch.float16 using AutoencoderKL from {'vae', 'scheduler', 'text_encoder_2', 'unet', 'image_encoder', 'tokenizer_2', 'tokenizer', 'text_encoder', 'feature_extractor'}.
Pipeline dtype: torch.float16
System Info
- 🤗 Diffusers version: 0.32.2
- Platform: Linux-5.15.153.1-microsoft-standard-WSL2-x86_64-with-glibc2.35
- Running on Google Colab?: No
- Python version: 3.10.12
- PyTorch version (GPU?): 2.5.1+cu124 (True)
- Flax version (CPU?/GPU?/TPU?): not installed (NA)
- Jax version: not installed
- JaxLib version: not installed
- Huggingface_hub version: 0.26.5
- Transformers version: 4.47.1
- Accelerate version: 1.2.1
- PEFT version: 0.13.2
- Bitsandbytes version: not installed
- Safetensors version: 0.4.5
- xFormers version: 0.0.28.post3
- Accelerator: NVIDIA GeForce RTX 4090, 24564 MiB
- Using GPU in script?: Yes
- Using distributed or parallel set-up in script?: No
Who can help?
No response