Skip to content

Commit cecada5

Browse files
leisuzzJ石页
and
J石页
authored
NPU adaption for RMSNorm (#10534)
* NPU adaption for RMSNorm * NPU adaption for RMSNorm --------- Co-authored-by: J石页 <jiangshuo9@h-partners.com>
1 parent 17d99c4 commit cecada5

File tree

1 file changed

+22
-11
lines changed

1 file changed

+22
-11
lines changed

src/diffusers/models/normalization.py

Lines changed: 22 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
import torch.nn as nn
2121
import torch.nn.functional as F
2222

23-
from ..utils import is_torch_version
23+
from ..utils import is_torch_npu_available, is_torch_version
2424
from .activations import get_activation
2525
from .embeddings import CombinedTimestepLabelEmbeddings, PixArtAlphaCombinedTimestepSizeEmbeddings
2626

@@ -505,19 +505,30 @@ def __init__(self, dim, eps: float, elementwise_affine: bool = True, bias: bool
505505
self.bias = nn.Parameter(torch.zeros(dim))
506506

507507
def forward(self, hidden_states):
508-
input_dtype = hidden_states.dtype
509-
variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
510-
hidden_states = hidden_states * torch.rsqrt(variance + self.eps)
511-
512-
if self.weight is not None:
513-
# convert into half-precision if necessary
514-
if self.weight.dtype in [torch.float16, torch.bfloat16]:
515-
hidden_states = hidden_states.to(self.weight.dtype)
516-
hidden_states = hidden_states * self.weight
508+
if is_torch_npu_available():
509+
import torch_npu
510+
511+
if self.weight is not None:
512+
# convert into half-precision if necessary
513+
if self.weight.dtype in [torch.float16, torch.bfloat16]:
514+
hidden_states = hidden_states.to(self.weight.dtype)
515+
hidden_states = torch_npu.npu_rms_norm(hidden_states, self.weight, epsilon=self.eps)[0]
517516
if self.bias is not None:
518517
hidden_states = hidden_states + self.bias
519518
else:
520-
hidden_states = hidden_states.to(input_dtype)
519+
input_dtype = hidden_states.dtype
520+
variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
521+
hidden_states = hidden_states * torch.rsqrt(variance + self.eps)
522+
523+
if self.weight is not None:
524+
# convert into half-precision if necessary
525+
if self.weight.dtype in [torch.float16, torch.bfloat16]:
526+
hidden_states = hidden_states.to(self.weight.dtype)
527+
hidden_states = hidden_states * self.weight
528+
if self.bias is not None:
529+
hidden_states = hidden_states + self.bias
530+
else:
531+
hidden_states = hidden_states.to(input_dtype)
521532

522533
return hidden_states
523534

0 commit comments

Comments
 (0)