|
20 | 20 | import torch.nn as nn
|
21 | 21 | import torch.nn.functional as F
|
22 | 22 |
|
23 |
| -from ..utils import is_torch_version |
| 23 | +from ..utils import is_torch_npu_available, is_torch_version |
24 | 24 | from .activations import get_activation
|
25 | 25 | from .embeddings import CombinedTimestepLabelEmbeddings, PixArtAlphaCombinedTimestepSizeEmbeddings
|
26 | 26 |
|
@@ -505,19 +505,30 @@ def __init__(self, dim, eps: float, elementwise_affine: bool = True, bias: bool
|
505 | 505 | self.bias = nn.Parameter(torch.zeros(dim))
|
506 | 506 |
|
507 | 507 | def forward(self, hidden_states):
|
508 |
| - input_dtype = hidden_states.dtype |
509 |
| - variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True) |
510 |
| - hidden_states = hidden_states * torch.rsqrt(variance + self.eps) |
511 |
| - |
512 |
| - if self.weight is not None: |
513 |
| - # convert into half-precision if necessary |
514 |
| - if self.weight.dtype in [torch.float16, torch.bfloat16]: |
515 |
| - hidden_states = hidden_states.to(self.weight.dtype) |
516 |
| - hidden_states = hidden_states * self.weight |
| 508 | + if is_torch_npu_available(): |
| 509 | + import torch_npu |
| 510 | + |
| 511 | + if self.weight is not None: |
| 512 | + # convert into half-precision if necessary |
| 513 | + if self.weight.dtype in [torch.float16, torch.bfloat16]: |
| 514 | + hidden_states = hidden_states.to(self.weight.dtype) |
| 515 | + hidden_states = torch_npu.npu_rms_norm(hidden_states, self.weight, epsilon=self.eps)[0] |
517 | 516 | if self.bias is not None:
|
518 | 517 | hidden_states = hidden_states + self.bias
|
519 | 518 | else:
|
520 |
| - hidden_states = hidden_states.to(input_dtype) |
| 519 | + input_dtype = hidden_states.dtype |
| 520 | + variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True) |
| 521 | + hidden_states = hidden_states * torch.rsqrt(variance + self.eps) |
| 522 | + |
| 523 | + if self.weight is not None: |
| 524 | + # convert into half-precision if necessary |
| 525 | + if self.weight.dtype in [torch.float16, torch.bfloat16]: |
| 526 | + hidden_states = hidden_states.to(self.weight.dtype) |
| 527 | + hidden_states = hidden_states * self.weight |
| 528 | + if self.bias is not None: |
| 529 | + hidden_states = hidden_states + self.bias |
| 530 | + else: |
| 531 | + hidden_states = hidden_states.to(input_dtype) |
521 | 532 |
|
522 | 533 | return hidden_states
|
523 | 534 |
|
|
0 commit comments