Closed
Description
Describe the bug
LattePipeline fails with dtype mismatch in transformer block during attention phase if enable_temporal_attentions
is enabled (which is the default). Disabling temporal attention skips the affected block of code, so model works.
Reproduction
load diffusers.LattePipeline
with specific torch_dtype
, it fails for both torch.float16
and torch.bfloat16
Logs
│ /home/vlado/dev/sdnext/venv/lib/python3.12/site-packages/diffusers/pipelines/latte/pipeline_latte.py:819 in __call__ │
│ │
│ 818 │ │ │ │ # predict noise model_output │
│ ❱ 819 │ │ │ │ noise_pred = self.transformer( │
│ 820 │ │ │ │ │ hidden_states=latent_model_input, │
│ │
│ /home/vlado/dev/sdnext/venv/lib/python3.12/site-packages/torch/nn/modules/module.py:1739 in _wrapped_call_impl │
│ │
│ 1738 │ │ else: │
│ ❱ 1739 │ │ │ return self._call_impl(*args, **kwargs) │
│ 1740 │
│ │
│ /home/vlado/dev/sdnext/venv/lib/python3.12/site-packages/torch/nn/modules/module.py:1750 in _call_impl │
│ │
│ 1749 │ │ │ │ or _global_forward_hooks or _global_forward_pre_hooks): │
│ ❱ 1750 │ │ │ return forward_call(*args, **kwargs) │
│ 1751 │
│ │
│ /home/vlado/dev/sdnext/venv/lib/python3.12/site-packages/accelerate/hooks.py:176 in new_forward │
│ │
│ 175 │ │ else: │
│ ❱ 176 │ │ │ output = module._old_forward(*args, **kwargs) │
│ 177 │ │ return module._hf_hook.post_forward(module, output) │
│ │
│ /home/vlado/dev/sdnext/venv/lib/python3.12/site-packages/diffusers/models/transformers/latte_transformer_3d.py:291 in forward │
│ │
│ 290 │ │ │ │ else: │
│ ❱ 291 │ │ │ │ │ hidden_states = temp_block( │
│ 292 │ │ │ │ │ │ hidden_states,
...
│ /home/vlado/dev/sdnext/venv/lib/python3.12/site-packages/torch/nn/modules/linear.py:125 in forward │
│ │
│ 124 │ def forward(self, input: Tensor) -> Tensor: │
│ ❱ 125 │ │ return F.linear(input, self.weight, self.bias) │
│ 126 │
╰──────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────╯
RuntimeError: mat1 and mat2 must have the same dtype, but got Float and Half
System Info
diffusers==main