|
63 | 63 | is_wandb_available,
|
64 | 64 | )
|
65 | 65 | from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card
|
| 66 | +from diffusers.utils.import_utils import is_torch_npu_available |
66 | 67 | from diffusers.utils.torch_utils import is_compiled_module
|
67 | 68 |
|
68 | 69 |
|
|
74 | 75 |
|
75 | 76 | logger = get_logger(__name__)
|
76 | 77 |
|
| 78 | +if is_torch_npu_available(): |
| 79 | + torch.npu.config.allow_internal_format = False |
| 80 | + |
77 | 81 |
|
78 | 82 | def save_model_card(
|
79 | 83 | repo_id: str,
|
@@ -601,6 +605,7 @@ def parse_args(input_args=None):
|
601 | 605 | )
|
602 | 606 | parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
|
603 | 607 | parser.add_argument("--enable_vae_tiling", action="store_true", help="Enabla vae tiling in log validation")
|
| 608 | + parser.add_argument("--enable_npu_flash_attention", action="store_true", help="Enabla Flash Attention for NPU") |
604 | 609 |
|
605 | 610 | if input_args is not None:
|
606 | 611 | args = parser.parse_args(input_args)
|
@@ -924,8 +929,7 @@ def main(args):
|
924 | 929 | image.save(image_filename)
|
925 | 930 |
|
926 | 931 | del pipeline
|
927 |
| - if torch.cuda.is_available(): |
928 |
| - torch.cuda.empty_cache() |
| 932 | + free_memory() |
929 | 933 |
|
930 | 934 | # Handle the repository creation
|
931 | 935 | if accelerator.is_main_process:
|
@@ -988,6 +992,13 @@ def main(args):
|
988 | 992 | # because Gemma2 is particularly suited for bfloat16.
|
989 | 993 | text_encoder.to(dtype=torch.bfloat16)
|
990 | 994 |
|
| 995 | + if args.enable_npu_flash_attention: |
| 996 | + if is_torch_npu_available(): |
| 997 | + logger.info("npu flash attention enabled.") |
| 998 | + transformer.enable_npu_flash_attention() |
| 999 | + else: |
| 1000 | + raise ValueError("npu flash attention requires torch_npu extensions and is supported only on npu device ") |
| 1001 | + |
991 | 1002 | # Initialize a text encoding pipeline and keep it to CPU for now.
|
992 | 1003 | text_encoding_pipeline = SanaPipeline.from_pretrained(
|
993 | 1004 | args.pretrained_model_name_or_path,
|
|
0 commit comments