Skip to content

Commit 07860f9

Browse files
leisuzzJ石页sayakpaul
authored
NPU Adaption for Sanna (#10409)
* NPU Adaption for Sanna --------- Co-authored-by: J石页 <jiangshuo9@h-partners.com> Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
1 parent 87252d8 commit 07860f9

File tree

2 files changed

+18
-2
lines changed

2 files changed

+18
-2
lines changed

examples/dreambooth/train_dreambooth_lora_sana.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,7 @@
6363
is_wandb_available,
6464
)
6565
from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card
66+
from diffusers.utils.import_utils import is_torch_npu_available
6667
from diffusers.utils.torch_utils import is_compiled_module
6768

6869

@@ -74,6 +75,9 @@
7475

7576
logger = get_logger(__name__)
7677

78+
if is_torch_npu_available():
79+
torch.npu.config.allow_internal_format = False
80+
7781

7882
def save_model_card(
7983
repo_id: str,
@@ -601,6 +605,7 @@ def parse_args(input_args=None):
601605
)
602606
parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
603607
parser.add_argument("--enable_vae_tiling", action="store_true", help="Enabla vae tiling in log validation")
608+
parser.add_argument("--enable_npu_flash_attention", action="store_true", help="Enabla Flash Attention for NPU")
604609

605610
if input_args is not None:
606611
args = parser.parse_args(input_args)
@@ -924,8 +929,7 @@ def main(args):
924929
image.save(image_filename)
925930

926931
del pipeline
927-
if torch.cuda.is_available():
928-
torch.cuda.empty_cache()
932+
free_memory()
929933

930934
# Handle the repository creation
931935
if accelerator.is_main_process:
@@ -988,6 +992,13 @@ def main(args):
988992
# because Gemma2 is particularly suited for bfloat16.
989993
text_encoder.to(dtype=torch.bfloat16)
990994

995+
if args.enable_npu_flash_attention:
996+
if is_torch_npu_available():
997+
logger.info("npu flash attention enabled.")
998+
transformer.enable_npu_flash_attention()
999+
else:
1000+
raise ValueError("npu flash attention requires torch_npu extensions and is supported only on npu device ")
1001+
9911002
# Initialize a text encoding pipeline and keep it to CPU for now.
9921003
text_encoding_pipeline = SanaPipeline.from_pretrained(
9931004
args.pretrained_model_name_or_path,

src/diffusers/models/attention_processor.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3154,6 +3154,11 @@ def __call__(
31543154
# scaled_dot_product_attention expects attention_mask shape to be
31553155
# (batch, heads, source_length, target_length)
31563156
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
3157+
attention_mask = attention_mask.repeat(1, 1, hidden_states.shape[1], 1)
3158+
if attention_mask.dtype == torch.bool:
3159+
attention_mask = torch.logical_not(attention_mask.bool())
3160+
else:
3161+
attention_mask = attention_mask.bool()
31573162

31583163
if attn.group_norm is not None:
31593164
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)

0 commit comments

Comments
 (0)