Skip to content

LattePipeline fails with dtype mismatch #11137

Closed
@vladmandic

Description

@vladmandic

Describe the bug

LattePipeline fails with dtype mismatch in transformer block during attention phase if enable_temporal_attentions is enabled (which is the default). Disabling temporal attention skips the affected block of code, so model works.

Reproduction

load diffusers.LattePipeline with specific torch_dtype, it fails for both torch.float16 and torch.bfloat16

Logs

│ /home/vlado/dev/sdnext/venv/lib/python3.12/site-packages/diffusers/pipelines/latte/pipeline_latte.py:819 in __call__                                                                                                                                                                                                                                                                                                             │
│                                                                                                                                                                                                                                                                                                                                                                                                                                  │
│   818 │   │   │   │   # predict noise model_output                                                                                                                                                                                                                                                                                                                                                                               │
│ ❱ 819 │   │   │   │   noise_pred = self.transformer(                                                                                                                                                                                                                                                                                                                                                                             │
│   820 │   │   │   │   │   hidden_states=latent_model_input,                                                                                                                                                                                                                                                                                                                                                                      │
│                                                                                                                                                                                                                                                                                                                                                                                                                                  │
│ /home/vlado/dev/sdnext/venv/lib/python3.12/site-packages/torch/nn/modules/module.py:1739 in _wrapped_call_impl                                                                                                                                                                                                                                                                                                                   │
│                                                                                                                                                                                                                                                                                                                                                                                                                                  │
│   1738 │   │   else:                                                                                                                                                                                                                                                                                                                                                                                                             │
│ ❱ 1739 │   │   │   return self._call_impl(*args, **kwargs)                                                                                                                                                                                                                                                                                                                                                                       │
│   1740                                                                                                                                                                                                                                                                                                                                                                                                                           │
│                                                                                                                                                                                                                                                                                                                                                                                                                                  │
│ /home/vlado/dev/sdnext/venv/lib/python3.12/site-packages/torch/nn/modules/module.py:1750 in _call_impl                                                                                                                                                                                                                                                                                                                           │
│                                                                                                                                                                                                                                                                                                                                                                                                                                  │
│   1749 │   │   │   │   or _global_forward_hooks or _global_forward_pre_hooks):                                                                                                                                                                                                                                                                                                                                                   │
│ ❱ 1750 │   │   │   return forward_call(*args, **kwargs)                                                                                                                                                                                                                                                                                                                                                                          │
│   1751                                                                                                                                                                                                                                                                                                                                                                                                                           │
│                                                                                                                                                                                                                                                                                                                                                                                                                                  │
│ /home/vlado/dev/sdnext/venv/lib/python3.12/site-packages/accelerate/hooks.py:176 in new_forward                                                                                                                                                                                                                                                                                                                                  │
│                                                                                                                                                                                                                                                                                                                                                                                                                                  │
│   175 │   │   else:                                                                                                                                                                                                                                                                                                                                                                                                              │
│ ❱ 176 │   │   │   output = module._old_forward(*args, **kwargs)                                                                                                                                                                                                                                                                                                                                                                  │
│   177 │   │   return module._hf_hook.post_forward(module, output)                                                                                                                                                                                                                                                                                                                                                                │
│                                                                                                                                                                                                                                                                                                                                                                                                                                  │
│ /home/vlado/dev/sdnext/venv/lib/python3.12/site-packages/diffusers/models/transformers/latte_transformer_3d.py:291 in forward                                                                                                                                                                                                                                                                                                    │
│                                                                                                                                                                                                                                                                                                                                                                                                                                  │
│   290 │   │   │   │   else:                                                                                                                                                                                                                                                                                                                                                                                                      │
│ ❱ 291 │   │   │   │   │   hidden_states = temp_block(                                                                                                                                                                                                                                                                                                                                                                            │
│   292 │   │   │   │   │   │   hidden_states,                                                                                                           
...
│ /home/vlado/dev/sdnext/venv/lib/python3.12/site-packages/torch/nn/modules/linear.py:125 in forward                                                                                                                                                                                                                                                                                                                               │
│                                                                                                                                                                                                                                                                                                                                                                                                                                  │
│   124 │   def forward(self, input: Tensor) -> Tensor:                                                                                                                                                                                                                                                                                                                                                                            │
│ ❱ 125 │   │   return F.linear(input, self.weight, self.bias)                                                                                                                                                                                                                                                                                                                                                                     │
│   126                                                                                                                                                                                                                                                                                                                                                                                                                            │
╰──────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────╯
RuntimeError: mat1 and mat2 must have the same dtype, but got Float and Half

System Info

diffusers==main

Who can help?

@a-r-r-o-w @DN6

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions