diff --git a/src/diffusers/models/embeddings.py b/src/diffusers/models/embeddings.py index 6dce88826ba0..006ea8b4013f 100644 --- a/src/diffusers/models/embeddings.py +++ b/src/diffusers/models/embeddings.py @@ -1154,6 +1154,9 @@ def get_1d_rotary_pos_embed( / linear_factor ) # [D/2] freqs = torch.outer(pos, freqs) # type: ignore # [S, D/2] + is_npu = freqs.device.type == "npu" + if is_npu: + freqs = freqs.float() if use_real and repeat_interleave_real: # flux, hunyuan-dit, cogvideox freqs_cos = freqs.cos().repeat_interleave(2, dim=1, output_size=freqs.shape[1] * 2).float() # [S, D]