From 8e750052c3f8894efbe209f7c054740ad88d77e3 Mon Sep 17 00:00:00 2001 From: Kai zheng Date: Tue, 18 Feb 2025 14:59:06 +0800 Subject: [PATCH 1/2] get_1d_rotary_pos_embed support npu --- src/diffusers/models/embeddings.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/models/embeddings.py b/src/diffusers/models/embeddings.py index 390b752abe15..5e6d0e23c4ae 100644 --- a/src/diffusers/models/embeddings.py +++ b/src/diffusers/models/embeddings.py @@ -1151,7 +1151,7 @@ def get_1d_rotary_pos_embed( / (theta ** (torch.arange(0, dim, 2, dtype=freqs_dtype, device=pos.device)[: (dim // 2)] / dim)) / linear_factor ) # [D/2] - freqs = torch.outer(pos, freqs) # type: ignore # [S, D/2] + freqs = torch.outer(pos, freqs).float() # type: ignore # [S, D/2] if use_real and repeat_interleave_real: # flux, hunyuan-dit, cogvideox freqs_cos = freqs.cos().repeat_interleave(2, dim=1).float() # [S, D] From 5cdadaaf6fe2fd8e7f63032e59e4a871d9a31f43 Mon Sep 17 00:00:00 2001 From: hlky Date: Thu, 27 Feb 2025 06:41:18 +0000 Subject: [PATCH 2/2] Update src/diffusers/models/embeddings.py --- src/diffusers/models/embeddings.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/diffusers/models/embeddings.py b/src/diffusers/models/embeddings.py index 5e6d0e23c4ae..a40a39170da3 100644 --- a/src/diffusers/models/embeddings.py +++ b/src/diffusers/models/embeddings.py @@ -1151,7 +1151,10 @@ def get_1d_rotary_pos_embed( / (theta ** (torch.arange(0, dim, 2, dtype=freqs_dtype, device=pos.device)[: (dim // 2)] / dim)) / linear_factor ) # [D/2] - freqs = torch.outer(pos, freqs).float() # type: ignore # [S, D/2] + freqs = torch.outer(pos, freqs) # type: ignore # [S, D/2] + is_npu = freqs.device.type == "npu" + if is_npu: + freqs = freqs.float() if use_real and repeat_interleave_real: # flux, hunyuan-dit, cogvideox freqs_cos = freqs.cos().repeat_interleave(2, dim=1).float() # [S, D]