From 2816222f958280b979ca8533d3bc3b9ea685b944 Mon Sep 17 00:00:00 2001 From: Juan Acevedo Date: Thu, 9 Jan 2025 16:58:31 -0800 Subject: [PATCH 1/6] implementing flux on TPUs with ptxla --- .../pytorch_xla/inference/flux/README.md | 100 +++++++++++++++++ .../inference/flux/flux_inference.py | 102 ++++++++++++++++++ .../{ => training/text_to_image}/README.md | 0 .../text_to_image}/requirements.txt | 0 .../text_to_image}/train_text_to_image_xla.py | 0 src/diffusers/models/attention_processor.py | 16 ++- 6 files changed, 214 insertions(+), 4 deletions(-) create mode 100644 examples/research_projects/pytorch_xla/inference/flux/README.md create mode 100644 examples/research_projects/pytorch_xla/inference/flux/flux_inference.py rename examples/research_projects/pytorch_xla/{ => training/text_to_image}/README.md (100%) rename examples/research_projects/pytorch_xla/{ => training/text_to_image}/requirements.txt (100%) rename examples/research_projects/pytorch_xla/{ => training/text_to_image}/train_text_to_image_xla.py (100%) diff --git a/examples/research_projects/pytorch_xla/inference/flux/README.md b/examples/research_projects/pytorch_xla/inference/flux/README.md new file mode 100644 index 000000000000..dd7e23c57049 --- /dev/null +++ b/examples/research_projects/pytorch_xla/inference/flux/README.md @@ -0,0 +1,100 @@ +# Generating images using Flux and PyTorch/XLA + +The `flux_inference` script shows how to do image generation using Flux on TPU devices using PyTorch/XLA. It uses the pallas kernel for flash attention for faster generation. + +It has been tested on [Trillium](https://cloud.google.com/blog/products/compute/introducing-trillium-6th-gen-tpus) TPU versions. No other TPU types have been tested. + +## Create TPU + +To create a TPU on Google Cloud, follow [this guide](https://cloud.google.com/tpu/docs/v6e) + +## Setup TPU environment + +SSH into the VM and install Pytorch, Pytorch/XLA + +```bash +pip install torch~=2.5.0 torch_xla[tpu]~=2.5.0 -f https://storage.googleapis.com/libtpu-releases/index.html -f https://storage.googleapis.com/libtpu-wheels/index.html +pip install torch_xla[pallas] -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html -f https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html +``` + +Verify that PyTorch and PyTorch/XLA were installed correctly: + +```bash +python3 -c "import torch; import torch_xla;" +``` + +Install dependencies + +```bash +pip install transformers accelerate sentencepiece structlog +pushd ../../.. +pip install . +popd +``` + +## Run the inference job + +### Authenticate + +Run the following command to authenticate your token in order to download Flux weights. + +```bash +huggingface-cli login +``` + +Then run: + +```bash +python flux_inference.py +``` + +The script loads the text encoders onto the CPU and the Flux transformer and VAE models onto the TPU. The first time the script runs, the compilation time is longer, while the cache stores the compiled programs. On subsequent runs, compilation is much faster and the subsequent passes being the fastest. + +On a Trillium v6e-4, you should expect ~9 sec / 4 images or 2.25 sec / image (as devices run generation in parallel): + +```bash +WARNING:root:libtpu.so and TPU device found. Setting PJRT_DEVICE=TPU. +Loading checkpoint shards: 100%|███████████████████████████████| 2/2 [00:00<00:00, 7.01it/s] +Loading pipeline components...: 40%|██████████▍ | 2/5 [00:00<00:00, 3.78it/s]You set `add_prefix_space`. The tokenizer needs to be converted from the slow tokenizers +Loading pipeline components...: 100%|██████████████████████████| 5/5 [00:00<00:00, 6.72it/s] +2025-01-10 00:51:25 [info ] loading flux from black-forest-labs/FLUX.1-dev +2025-01-10 00:51:25 [info ] loading flux from black-forest-labs/FLUX.1-dev +2025-01-10 00:51:26 [info ] loading flux from black-forest-labs/FLUX.1-dev +2025-01-10 00:51:26 [info ] loading flux from black-forest-labs/FLUX.1-dev +Loading pipeline components...: 100%|██████████████████████████| 3/3 [00:00<00:00, 4.29it/s] +Loading pipeline components...: 100%|██████████████████████████| 3/3 [00:00<00:00, 3.26it/s] +Loading pipeline components...: 100%|██████████████████████████| 3/3 [00:00<00:00, 3.27it/s] +Loading pipeline components...: 100%|██████████████████████████| 3/3 [00:00<00:00, 3.25it/s] +2025-01-10 00:51:34 [info ] starting compilation run... +2025-01-10 00:51:35 [info ] starting compilation run... +2025-01-10 00:51:37 [info ] starting compilation run... +2025-01-10 00:51:37 [info ] starting compilation run... +2025-01-10 00:52:52 [info ] compilation took 78.5155531649998 sec. +2025-01-10 00:52:53 [info ] starting inference run... +2025-01-10 00:52:57 [info ] compilation took 79.52986721400157 sec. +2025-01-10 00:52:57 [info ] compilation took 81.91776501700042 sec. +2025-01-10 00:52:57 [info ] compilation took 80.24951512600092 sec. +2025-01-10 00:52:57 [info ] starting inference run... +2025-01-10 00:52:57 [info ] starting inference run... +2025-01-10 00:52:58 [info ] starting inference run... +2025-01-10 00:53:22 [info ] inference time: 25.112665320000815 +2025-01-10 00:53:30 [info ] inference time: 7.7019307739992655 +2025-01-10 00:53:38 [info ] inference time: 7.693858365000779 +2025-01-10 00:53:46 [info ] inference time: 7.690621814001133 +2025-01-10 00:53:53 [info ] inference time: 7.679490454000188 +2025-01-10 00:54:01 [info ] inference time: 7.68949568500102 +2025-01-10 00:54:09 [info ] inference time: 7.686633744000574 +2025-01-10 00:54:16 [info ] inference time: 7.696786873999372 +2025-01-10 00:54:24 [info ] inference time: 7.691988694999964 +2025-01-10 00:54:32 [info ] inference time: 7.700649563999832 +2025-01-10 00:54:39 [info ] inference time: 7.684993574001055 +2025-01-10 00:54:47 [info ] inference time: 7.68343457499941 +2025-01-10 00:54:55 [info ] inference time: 7.667921153999487 +2025-01-10 00:55:02 [info ] inference time: 7.683585194001353 +2025-01-10 00:55:06 [info ] avg. inference over 15 iterations took 8.61202360273334 sec. +2025-01-10 00:55:07 [info ] avg. inference over 15 iterations took 8.952725123600006 sec. +2025-01-10 00:55:10 [info ] inference time: 7.673799695001435 +2025-01-10 00:55:10 [info ] avg. inference over 15 iterations took 8.849190365400379 sec. +2025-01-10 00:55:10 [info ] saved metric information as /tmp/metrics_report.txt +2025-01-10 00:55:12 [info ] avg. inference over 15 iterations took 8.940161458400205 sec. +``` \ No newline at end of file diff --git a/examples/research_projects/pytorch_xla/inference/flux/flux_inference.py b/examples/research_projects/pytorch_xla/inference/flux/flux_inference.py new file mode 100644 index 000000000000..b3a14e222139 --- /dev/null +++ b/examples/research_projects/pytorch_xla/inference/flux/flux_inference.py @@ -0,0 +1,102 @@ +from time import perf_counter +from pathlib import Path +from argparse import ArgumentParser + +import structlog + +import torch +import torch_xla.core.xla_model as xm +import torch_xla.runtime as xr +import torch_xla.debug.profiler as xp +import torch_xla.debug.metrics as met +from diffusers import FluxPipeline +import torch_xla.distributed.xla_multiprocessing as xmp + +logger = structlog.get_logger() +metrics_filepath = '/tmp/metrics_report.txt' + +def _main(index, args, text_pipe, ckpt_id): + + cache_path = Path('/tmp/data/compiler_cache_tRiLlium_eXp') + cache_path.mkdir(parents=True, exist_ok=True) + xr.initialize_cache(str(cache_path), readonly=False) + + profile_path = Path('/tmp/data/profiler_out_tRiLlium_eXp') + profile_path.mkdir(parents=True, exist_ok=True) + profiler_port = 9012 + profile_duration = args.profile_duration + if args.profile: + logger.info(f'starting profiler on port {profiler_port}') + _ = xp.start_server(profiler_port) + device0 = xm.xla_device() + + logger.info(f'loading flux from {ckpt_id}') + flux_pipe = FluxPipeline.from_pretrained(ckpt_id, text_encoder=None, tokenizer=None, + text_encoder_2=None, tokenizer_2=None, torch_dtype=torch.bfloat16).to(device0) + + prompt = 'photograph of an electronics chip in the shape of a race car with trillium written on its side' + width = args.width + height = args.height + guidance = args.guidance + n_steps = 4 if args.schnell else 28 + + logger.info('starting compilation run...') + ts = perf_counter() + with torch.no_grad(): + prompt_embeds, pooled_prompt_embeds, text_ids = text_pipe.encode_prompt( + prompt=prompt, prompt_2=None, max_sequence_length=512) + prompt_embeds = prompt_embeds.to(device0) + pooled_prompt_embeds = pooled_prompt_embeds.to(device0) + + image = flux_pipe(prompt_embeds=prompt_embeds, pooled_prompt_embeds=pooled_prompt_embeds, + num_inference_steps=28, guidance_scale=guidance, height=height, width=width).images[0] + logger.info(f'compilation took {perf_counter() - ts} sec.') + image.save('/tmp/compile_out.png') + + base_seed = 4096 if args.seed is None else args.seed + seed_range = 1000 + unique_seed = base_seed + index * seed_range + xm.set_rng_state(seed=unique_seed, device=device0) + times = [] + logger.info('starting inference run...') + for _ in range(args.itters): + ts = perf_counter() + with torch.no_grad(): + prompt_embeds, pooled_prompt_embeds, text_ids = text_pipe.encode_prompt( + prompt=prompt, prompt_2=None, max_sequence_length=512) + prompt_embeds = prompt_embeds.to(device0) + pooled_prompt_embeds = pooled_prompt_embeds.to(device0) + + if args.profile: + xp.trace_detached(f"localhost:{profiler_port}", str(profile_path), duration_ms=profile_duration) + image = flux_pipe(prompt_embeds=prompt_embeds, pooled_prompt_embeds=pooled_prompt_embeds, + num_inference_steps=n_steps, guidance_scale=guidance, height=height, width=width).images[0] + inference_time = perf_counter() - ts + if index == 0: + logger.info(f"inference time: {inference_time}") + times.append(inference_time) + logger.info(f'avg. inference over {args.itters} iterations took {sum(times)/len(times)} sec.') + image.save(f'/tmp/inference_out-{index}.png') + if index == 0: + metrics_report = met.metrics_report() + with open(metrics_filepath, 'w+') as fout: + fout.write(metrics_report) + logger.info(f'saved metric information as {metrics_filepath}') + +if __name__ == '__main__': + parser = ArgumentParser() + parser.add_argument('--schnell', action='store_true', help='run flux schnell instead of dev') + parser.add_argument('--width', type=int, default=1024, help='width of the image to generate') + parser.add_argument('--height', type=int, default=1024, help='height of the image to generate') + parser.add_argument('--guidance', type=float, default=3.5, help='gauidance strentgh for dev') + parser.add_argument('--seed', type=int, default=None, help='seed for inference') + parser.add_argument('--profile', action='store_true', help='enable profiling') + parser.add_argument('--profile-duration', type=int, default=10000, help='duration for profiling in msec.') + parser.add_argument('--itters', type=int, default=15, help='tiems to run inference and get avg time in sec.') + args = parser.parse_args() + if args.schnell: + ckpt_id = "black-forest-labs/FLUX.1-schnell" + else: + ckpt_id = "black-forest-labs/FLUX.1-dev" + text_pipe = FluxPipeline.from_pretrained(ckpt_id, transformer=None, vae=None, torch_dtype=torch.bfloat16).to('cpu') + xmp.spawn(_main, args=(args, text_pipe, ckpt_id)) \ No newline at end of file diff --git a/examples/research_projects/pytorch_xla/README.md b/examples/research_projects/pytorch_xla/training/text_to_image/README.md similarity index 100% rename from examples/research_projects/pytorch_xla/README.md rename to examples/research_projects/pytorch_xla/training/text_to_image/README.md diff --git a/examples/research_projects/pytorch_xla/requirements.txt b/examples/research_projects/pytorch_xla/training/text_to_image/requirements.txt similarity index 100% rename from examples/research_projects/pytorch_xla/requirements.txt rename to examples/research_projects/pytorch_xla/training/text_to_image/requirements.txt diff --git a/examples/research_projects/pytorch_xla/train_text_to_image_xla.py b/examples/research_projects/pytorch_xla/training/text_to_image/train_text_to_image_xla.py similarity index 100% rename from examples/research_projects/pytorch_xla/train_text_to_image_xla.py rename to examples/research_projects/pytorch_xla/training/text_to_image/train_text_to_image_xla.py diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index 4d7ae6bef26e..5cbe234b0714 100644 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -2318,9 +2318,12 @@ def __call__( query = apply_rotary_emb(query, image_rotary_emb) key = apply_rotary_emb(key, image_rotary_emb) - hidden_states = F.scaled_dot_product_attention( - query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False - ) + if XLA_AVAILABLE: + query /= math.sqrt(head_dim) + hidden_states = flash_attention(query, key, value, causal=False) + else: + hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False) + hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) hidden_states = hidden_states.to(query.dtype) @@ -2521,7 +2524,12 @@ def __call__( query = apply_rotary_emb(query, image_rotary_emb) key = apply_rotary_emb(key, image_rotary_emb) - hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False) + if XLA_AVAILABLE: + query /= math.sqrt(head_dim) + hidden_states = flash_attention(query, key, value) + else: + hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False) + hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) hidden_states = hidden_states.to(query.dtype) From 378f7ed4fdf76b13411d26d78d2dc7e914113c55 Mon Sep 17 00:00:00 2001 From: Juan Acevedo Date: Mon, 13 Jan 2025 18:07:24 +0000 Subject: [PATCH 2/6] add xla flux attention class --- .../inference/flux/flux_inference.py | 1 + src/diffusers/models/attention_processor.py | 121 ++++++++++++++++-- src/diffusers/models/modeling_utils.py | 8 +- 3 files changed, 113 insertions(+), 17 deletions(-) diff --git a/examples/research_projects/pytorch_xla/inference/flux/flux_inference.py b/examples/research_projects/pytorch_xla/inference/flux/flux_inference.py index b3a14e222139..6799495c9068 100644 --- a/examples/research_projects/pytorch_xla/inference/flux/flux_inference.py +++ b/examples/research_projects/pytorch_xla/inference/flux/flux_inference.py @@ -33,6 +33,7 @@ def _main(index, args, text_pipe, ckpt_id): logger.info(f'loading flux from {ckpt_id}') flux_pipe = FluxPipeline.from_pretrained(ckpt_id, text_encoder=None, tokenizer=None, text_encoder_2=None, tokenizer_2=None, torch_dtype=torch.bfloat16).to(device0) + flux_pipe.transformer.enable_xla_flash_attention(partition_spec=("data", None, None, None), is_flux=True) prompt = 'photograph of an electronics chip in the shape of a race car with trillium written on its side' width = args.width diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index 5cbe234b0714..942eefc390be 100644 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -297,7 +297,7 @@ def __init__( self.set_processor(processor) def set_use_xla_flash_attention( - self, use_xla_flash_attention: bool, partition_spec: Optional[Tuple[Optional[str], ...]] = None + self, use_xla_flash_attention: bool, partition_spec: Optional[Tuple[Optional[str], ...]] = None, **kwargs ) -> None: r""" Set whether to use xla flash attention from `torch_xla` or not. @@ -316,7 +316,10 @@ def set_use_xla_flash_attention( elif is_spmd() and is_torch_xla_version("<", "2.4"): raise "flash attention pallas kernel using SPMD is supported from torch_xla version 2.4" else: - processor = XLAFlashAttnProcessor2_0(partition_spec) + if len(kwargs) > 0 and kwargs.get("is_flux", None): + processor = XLAFluxFlashAttnProcessor2_0(partition_spec) + else: + processor = XLAFlashAttnProcessor2_0(partition_spec) else: processor = ( AttnProcessor2_0() if hasattr(F, "scaled_dot_product_attention") and self.scale_qk else AttnProcessor() @@ -2318,11 +2321,7 @@ def __call__( query = apply_rotary_emb(query, image_rotary_emb) key = apply_rotary_emb(key, image_rotary_emb) - if XLA_AVAILABLE: - query /= math.sqrt(head_dim) - hidden_states = flash_attention(query, key, value, causal=False) - else: - hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False) + hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False) hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) hidden_states = hidden_states.to(query.dtype) @@ -2523,12 +2522,8 @@ def __call__( query = apply_rotary_emb(query, image_rotary_emb) key = apply_rotary_emb(key, image_rotary_emb) - - if XLA_AVAILABLE: - query /= math.sqrt(head_dim) - hidden_states = flash_attention(query, key, value) - else: - hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False) + + hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False) hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) hidden_states = hidden_states.to(query.dtype) @@ -3430,6 +3425,106 @@ def __call__( return hidden_states +class XLAFluxFlashAttnProcessor2_0: + r""" + Processor for implementing scaled dot-product attention with pallas flash attention kernel if using `torch_xla`. + """ + + def __init__(self, partition_spec: Optional[Tuple[Optional[str], ...]] = None): + if not hasattr(F, "scaled_dot_product_attention"): + raise ImportError( + "XLAFlashAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0." + ) + if is_torch_xla_version("<", "2.3"): + raise ImportError("XLA flash attention requires torch_xla version >= 2.3.") + if is_spmd() and is_torch_xla_version("<", "2.4"): + raise ImportError("SPMD support for XLA flash attention needs torch_xla version >= 2.4.") + self.partition_spec = partition_spec + + def __call__( + self, + attn: Attention, + hidden_states: torch.FloatTensor, + encoder_hidden_states: torch.FloatTensor = None, + attention_mask: Optional[torch.FloatTensor] = None, + image_rotary_emb: Optional[torch.Tensor] = None, + ) -> torch.FloatTensor: + batch_size, _, _ = hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape + + # `sample` projections. + query = attn.to_q(hidden_states) + key = attn.to_k(hidden_states) + value = attn.to_v(hidden_states) + + inner_dim = key.shape[-1] + head_dim = inner_dim // attn.heads + + query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + + if attn.norm_q is not None: + query = attn.norm_q(query) + if attn.norm_k is not None: + key = attn.norm_k(key) + + # the attention in FluxSingleTransformerBlock does not use `encoder_hidden_states` + if encoder_hidden_states is not None: + # `context` projections. + encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states) + encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states) + encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states) + + encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view( + batch_size, -1, attn.heads, head_dim + ).transpose(1, 2) + encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view( + batch_size, -1, attn.heads, head_dim + ).transpose(1, 2) + encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view( + batch_size, -1, attn.heads, head_dim + ).transpose(1, 2) + + if attn.norm_added_q is not None: + encoder_hidden_states_query_proj = attn.norm_added_q(encoder_hidden_states_query_proj) + if attn.norm_added_k is not None: + encoder_hidden_states_key_proj = attn.norm_added_k(encoder_hidden_states_key_proj) + + # attention + query = torch.cat([encoder_hidden_states_query_proj, query], dim=2) + key = torch.cat([encoder_hidden_states_key_proj, key], dim=2) + value = torch.cat([encoder_hidden_states_value_proj, value], dim=2) + + if image_rotary_emb is not None: + from .embeddings import apply_rotary_emb + + query = apply_rotary_emb(query, image_rotary_emb) + key = apply_rotary_emb(key, image_rotary_emb) + + query /= math.sqrt(head_dim) + hidden_states = flash_attention(query, key, value, causal=False) + + hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) + hidden_states = hidden_states.to(query.dtype) + + if encoder_hidden_states is not None: + encoder_hidden_states, hidden_states = ( + hidden_states[:, : encoder_hidden_states.shape[1]], + hidden_states[:, encoder_hidden_states.shape[1] :], + ) + + # linear proj + hidden_states = attn.to_out[0](hidden_states) + # dropout + hidden_states = attn.to_out[1](hidden_states) + + encoder_hidden_states = attn.to_add_out(encoder_hidden_states) + + return hidden_states, encoder_hidden_states + else: + return hidden_states + + class MochiVaeAttnProcessor2_0: r""" Attention processor used in Mochi VAE. diff --git a/src/diffusers/models/modeling_utils.py b/src/diffusers/models/modeling_utils.py index 17e9d2043150..f7c18c440619 100644 --- a/src/diffusers/models/modeling_utils.py +++ b/src/diffusers/models/modeling_utils.py @@ -227,14 +227,14 @@ def disable_npu_flash_attention(self) -> None: self.set_use_npu_flash_attention(False) def set_use_xla_flash_attention( - self, use_xla_flash_attention: bool, partition_spec: Optional[Callable] = None + self, use_xla_flash_attention: bool, partition_spec: Optional[Callable] = None, **kwargs ) -> None: # Recursively walk through all the children. # Any children which exposes the set_use_xla_flash_attention method # gets the message def fn_recursive_set_flash_attention(module: torch.nn.Module): if hasattr(module, "set_use_xla_flash_attention"): - module.set_use_xla_flash_attention(use_xla_flash_attention, partition_spec) + module.set_use_xla_flash_attention(use_xla_flash_attention, partition_spec, **kwargs) for child in module.children(): fn_recursive_set_flash_attention(child) @@ -243,11 +243,11 @@ def fn_recursive_set_flash_attention(module: torch.nn.Module): if isinstance(module, torch.nn.Module): fn_recursive_set_flash_attention(module) - def enable_xla_flash_attention(self, partition_spec: Optional[Callable] = None): + def enable_xla_flash_attention(self, partition_spec: Optional[Callable] = None, **kwargs): r""" Enable the flash attention pallals kernel for torch_xla. """ - self.set_use_xla_flash_attention(True, partition_spec) + self.set_use_xla_flash_attention(True, partition_spec, **kwargs) def disable_xla_flash_attention(self): r""" From ad1e1cf309d0bb1819560cd8132a3919dfdd5bed Mon Sep 17 00:00:00 2001 From: Juan Acevedo Date: Mon, 13 Jan 2025 18:10:43 +0000 Subject: [PATCH 3/6] run make style/quality --- .../inference/flux/flux_inference.py | 97 +++++++++++-------- src/diffusers/models/attention_processor.py | 10 +- 2 files changed, 62 insertions(+), 45 deletions(-) diff --git a/examples/research_projects/pytorch_xla/inference/flux/flux_inference.py b/examples/research_projects/pytorch_xla/inference/flux/flux_inference.py index 6799495c9068..1ab80a7ec664 100644 --- a/examples/research_projects/pytorch_xla/inference/flux/flux_inference.py +++ b/examples/research_projects/pytorch_xla/inference/flux/flux_inference.py @@ -1,103 +1,120 @@ -from time import perf_counter -from pathlib import Path from argparse import ArgumentParser +from pathlib import Path +from time import perf_counter import structlog - import torch import torch_xla.core.xla_model as xm -import torch_xla.runtime as xr -import torch_xla.debug.profiler as xp import torch_xla.debug.metrics as met -from diffusers import FluxPipeline +import torch_xla.debug.profiler as xp import torch_xla.distributed.xla_multiprocessing as xmp +import torch_xla.runtime as xr + +from diffusers import FluxPipeline + logger = structlog.get_logger() -metrics_filepath = '/tmp/metrics_report.txt' +metrics_filepath = "/tmp/metrics_report.txt" -def _main(index, args, text_pipe, ckpt_id): - cache_path = Path('/tmp/data/compiler_cache_tRiLlium_eXp') +def _main(index, args, text_pipe, ckpt_id): + cache_path = Path("/tmp/data/compiler_cache_tRiLlium_eXp") cache_path.mkdir(parents=True, exist_ok=True) xr.initialize_cache(str(cache_path), readonly=False) - profile_path = Path('/tmp/data/profiler_out_tRiLlium_eXp') + profile_path = Path("/tmp/data/profiler_out_tRiLlium_eXp") profile_path.mkdir(parents=True, exist_ok=True) profiler_port = 9012 profile_duration = args.profile_duration if args.profile: - logger.info(f'starting profiler on port {profiler_port}') + logger.info(f"starting profiler on port {profiler_port}") _ = xp.start_server(profiler_port) device0 = xm.xla_device() - logger.info(f'loading flux from {ckpt_id}') - flux_pipe = FluxPipeline.from_pretrained(ckpt_id, text_encoder=None, tokenizer=None, - text_encoder_2=None, tokenizer_2=None, torch_dtype=torch.bfloat16).to(device0) + logger.info(f"loading flux from {ckpt_id}") + flux_pipe = FluxPipeline.from_pretrained( + ckpt_id, text_encoder=None, tokenizer=None, text_encoder_2=None, tokenizer_2=None, torch_dtype=torch.bfloat16 + ).to(device0) flux_pipe.transformer.enable_xla_flash_attention(partition_spec=("data", None, None, None), is_flux=True) - prompt = 'photograph of an electronics chip in the shape of a race car with trillium written on its side' + prompt = "photograph of an electronics chip in the shape of a race car with trillium written on its side" width = args.width height = args.height guidance = args.guidance n_steps = 4 if args.schnell else 28 - logger.info('starting compilation run...') + logger.info("starting compilation run...") ts = perf_counter() with torch.no_grad(): prompt_embeds, pooled_prompt_embeds, text_ids = text_pipe.encode_prompt( - prompt=prompt, prompt_2=None, max_sequence_length=512) + prompt=prompt, prompt_2=None, max_sequence_length=512 + ) prompt_embeds = prompt_embeds.to(device0) pooled_prompt_embeds = pooled_prompt_embeds.to(device0) - image = flux_pipe(prompt_embeds=prompt_embeds, pooled_prompt_embeds=pooled_prompt_embeds, - num_inference_steps=28, guidance_scale=guidance, height=height, width=width).images[0] - logger.info(f'compilation took {perf_counter() - ts} sec.') - image.save('/tmp/compile_out.png') + image = flux_pipe( + prompt_embeds=prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds, + num_inference_steps=28, + guidance_scale=guidance, + height=height, + width=width, + ).images[0] + logger.info(f"compilation took {perf_counter() - ts} sec.") + image.save("/tmp/compile_out.png") base_seed = 4096 if args.seed is None else args.seed seed_range = 1000 unique_seed = base_seed + index * seed_range xm.set_rng_state(seed=unique_seed, device=device0) times = [] - logger.info('starting inference run...') + logger.info("starting inference run...") for _ in range(args.itters): ts = perf_counter() with torch.no_grad(): prompt_embeds, pooled_prompt_embeds, text_ids = text_pipe.encode_prompt( - prompt=prompt, prompt_2=None, max_sequence_length=512) + prompt=prompt, prompt_2=None, max_sequence_length=512 + ) prompt_embeds = prompt_embeds.to(device0) pooled_prompt_embeds = pooled_prompt_embeds.to(device0) if args.profile: xp.trace_detached(f"localhost:{profiler_port}", str(profile_path), duration_ms=profile_duration) - image = flux_pipe(prompt_embeds=prompt_embeds, pooled_prompt_embeds=pooled_prompt_embeds, - num_inference_steps=n_steps, guidance_scale=guidance, height=height, width=width).images[0] + image = flux_pipe( + prompt_embeds=prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds, + num_inference_steps=n_steps, + guidance_scale=guidance, + height=height, + width=width, + ).images[0] inference_time = perf_counter() - ts if index == 0: logger.info(f"inference time: {inference_time}") times.append(inference_time) - logger.info(f'avg. inference over {args.itters} iterations took {sum(times)/len(times)} sec.') - image.save(f'/tmp/inference_out-{index}.png') + logger.info(f"avg. inference over {args.itters} iterations took {sum(times)/len(times)} sec.") + image.save(f"/tmp/inference_out-{index}.png") if index == 0: metrics_report = met.metrics_report() - with open(metrics_filepath, 'w+') as fout: + with open(metrics_filepath, "w+") as fout: fout.write(metrics_report) - logger.info(f'saved metric information as {metrics_filepath}') + logger.info(f"saved metric information as {metrics_filepath}") + -if __name__ == '__main__': +if __name__ == "__main__": parser = ArgumentParser() - parser.add_argument('--schnell', action='store_true', help='run flux schnell instead of dev') - parser.add_argument('--width', type=int, default=1024, help='width of the image to generate') - parser.add_argument('--height', type=int, default=1024, help='height of the image to generate') - parser.add_argument('--guidance', type=float, default=3.5, help='gauidance strentgh for dev') - parser.add_argument('--seed', type=int, default=None, help='seed for inference') - parser.add_argument('--profile', action='store_true', help='enable profiling') - parser.add_argument('--profile-duration', type=int, default=10000, help='duration for profiling in msec.') - parser.add_argument('--itters', type=int, default=15, help='tiems to run inference and get avg time in sec.') + parser.add_argument("--schnell", action="store_true", help="run flux schnell instead of dev") + parser.add_argument("--width", type=int, default=1024, help="width of the image to generate") + parser.add_argument("--height", type=int, default=1024, help="height of the image to generate") + parser.add_argument("--guidance", type=float, default=3.5, help="gauidance strentgh for dev") + parser.add_argument("--seed", type=int, default=None, help="seed for inference") + parser.add_argument("--profile", action="store_true", help="enable profiling") + parser.add_argument("--profile-duration", type=int, default=10000, help="duration for profiling in msec.") + parser.add_argument("--itters", type=int, default=15, help="tiems to run inference and get avg time in sec.") args = parser.parse_args() if args.schnell: ckpt_id = "black-forest-labs/FLUX.1-schnell" else: ckpt_id = "black-forest-labs/FLUX.1-dev" - text_pipe = FluxPipeline.from_pretrained(ckpt_id, transformer=None, vae=None, torch_dtype=torch.bfloat16).to('cpu') - xmp.spawn(_main, args=(args, text_pipe, ckpt_id)) \ No newline at end of file + text_pipe = FluxPipeline.from_pretrained(ckpt_id, transformer=None, vae=None, torch_dtype=torch.bfloat16).to("cpu") + xmp.spawn(_main, args=(args, text_pipe, ckpt_id)) diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index 942eefc390be..f6265c177b08 100644 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -2322,7 +2322,7 @@ def __call__( key = apply_rotary_emb(key, image_rotary_emb) hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False) - + hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) hidden_states = hidden_states.to(query.dtype) @@ -2522,9 +2522,9 @@ def __call__( query = apply_rotary_emb(query, image_rotary_emb) key = apply_rotary_emb(key, image_rotary_emb) - + hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False) - + hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) hidden_states = hidden_states.to(query.dtype) @@ -3503,7 +3503,7 @@ def __call__( query /= math.sqrt(head_dim) hidden_states = flash_attention(query, key, value, causal=False) - + hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) hidden_states = hidden_states.to(query.dtype) @@ -3523,7 +3523,7 @@ def __call__( return hidden_states, encoder_hidden_states else: return hidden_states - + class MochiVaeAttnProcessor2_0: r""" From 2d7c19850640789078aa3199793e8b5b021bc2ba Mon Sep 17 00:00:00 2001 From: Juan Acevedo Date: Mon, 13 Jan 2025 11:12:57 -0800 Subject: [PATCH 4/6] Update src/diffusers/models/attention_processor.py Co-authored-by: YiYi Xu --- src/diffusers/models/attention_processor.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index f6265c177b08..0af10063c8de 100644 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -297,7 +297,7 @@ def __init__( self.set_processor(processor) def set_use_xla_flash_attention( - self, use_xla_flash_attention: bool, partition_spec: Optional[Tuple[Optional[str], ...]] = None, **kwargs + self, use_xla_flash_attention: bool, partition_spec: Optional[Tuple[Optional[str], ...]] = None, is_flux = False, ) -> None: r""" Set whether to use xla flash attention from `torch_xla` or not. From e35dd67f3676fa382569dd3a8c75b4822ef34e5c Mon Sep 17 00:00:00 2001 From: Juan Acevedo Date: Mon, 13 Jan 2025 12:02:31 -0800 Subject: [PATCH 5/6] Update src/diffusers/models/attention_processor.py Co-authored-by: YiYi Xu --- src/diffusers/models/attention_processor.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index 0af10063c8de..d5ae1737c545 100644 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -316,7 +316,7 @@ def set_use_xla_flash_attention( elif is_spmd() and is_torch_xla_version("<", "2.4"): raise "flash attention pallas kernel using SPMD is supported from torch_xla version 2.4" else: - if len(kwargs) > 0 and kwargs.get("is_flux", None): + if is_flux: processor = XLAFluxFlashAttnProcessor2_0(partition_spec) else: processor = XLAFlashAttnProcessor2_0(partition_spec) From 40b154ab5a90937bfcd74b63967297bd85b4260b Mon Sep 17 00:00:00 2001 From: Juan Acevedo Date: Wed, 15 Jan 2025 17:55:44 +0000 Subject: [PATCH 6/6] run style and quality --- src/diffusers/models/attention_processor.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index d5ae1737c545..549bf1d8f1ce 100644 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -297,7 +297,10 @@ def __init__( self.set_processor(processor) def set_use_xla_flash_attention( - self, use_xla_flash_attention: bool, partition_spec: Optional[Tuple[Optional[str], ...]] = None, is_flux = False, + self, + use_xla_flash_attention: bool, + partition_spec: Optional[Tuple[Optional[str], ...]] = None, + is_flux=False, ) -> None: r""" Set whether to use xla flash attention from `torch_xla` or not.