diff --git a/src/diffusers/models/transformers/latte_transformer_3d.py b/src/diffusers/models/transformers/latte_transformer_3d.py index 4b359021f29d..132c258455ea 100644 --- a/src/diffusers/models/transformers/latte_transformer_3d.py +++ b/src/diffusers/models/transformers/latte_transformer_3d.py @@ -273,7 +273,7 @@ def forward( hidden_states = hidden_states.reshape(-1, hidden_states.shape[-2], hidden_states.shape[-1]) if i == 0 and num_frame > 1: - hidden_states = hidden_states + self.temp_pos_embed + hidden_states = hidden_states + self.temp_pos_embed.to(hidden_states.dtype) if torch.is_grad_enabled() and self.gradient_checkpointing: hidden_states = self._gradient_checkpointing_func(