Skip to content

Commit aeac0a0

Browse files
entrpnjfacevedo-googlesayakpaulyiyixuxu
authored
implementing flux on TPUs with ptxla (#10515)
* implementing flux on TPUs with ptxla * add xla flux attention class * run make style/quality * Update src/diffusers/models/attention_processor.py Co-authored-by: YiYi Xu <yixu310@gmail.com> * Update src/diffusers/models/attention_processor.py Co-authored-by: YiYi Xu <yixu310@gmail.com> * run style and quality --------- Co-authored-by: Juan Acevedo <jfacevedo@google.com> Co-authored-by: Sayak Paul <spsayakpaul@gmail.com> Co-authored-by: YiYi Xu <yixu310@gmail.com>
1 parent cecada5 commit aeac0a0

File tree

7 files changed

+335
-9
lines changed

7 files changed

+335
-9
lines changed
Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,100 @@
1+
# Generating images using Flux and PyTorch/XLA
2+
3+
The `flux_inference` script shows how to do image generation using Flux on TPU devices using PyTorch/XLA. It uses the pallas kernel for flash attention for faster generation.
4+
5+
It has been tested on [Trillium](https://cloud.google.com/blog/products/compute/introducing-trillium-6th-gen-tpus) TPU versions. No other TPU types have been tested.
6+
7+
## Create TPU
8+
9+
To create a TPU on Google Cloud, follow [this guide](https://cloud.google.com/tpu/docs/v6e)
10+
11+
## Setup TPU environment
12+
13+
SSH into the VM and install Pytorch, Pytorch/XLA
14+
15+
```bash
16+
pip install torch~=2.5.0 torch_xla[tpu]~=2.5.0 -f https://storage.googleapis.com/libtpu-releases/index.html -f https://storage.googleapis.com/libtpu-wheels/index.html
17+
pip install torch_xla[pallas] -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html -f https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html
18+
```
19+
20+
Verify that PyTorch and PyTorch/XLA were installed correctly:
21+
22+
```bash
23+
python3 -c "import torch; import torch_xla;"
24+
```
25+
26+
Install dependencies
27+
28+
```bash
29+
pip install transformers accelerate sentencepiece structlog
30+
pushd ../../..
31+
pip install .
32+
popd
33+
```
34+
35+
## Run the inference job
36+
37+
### Authenticate
38+
39+
Run the following command to authenticate your token in order to download Flux weights.
40+
41+
```bash
42+
huggingface-cli login
43+
```
44+
45+
Then run:
46+
47+
```bash
48+
python flux_inference.py
49+
```
50+
51+
The script loads the text encoders onto the CPU and the Flux transformer and VAE models onto the TPU. The first time the script runs, the compilation time is longer, while the cache stores the compiled programs. On subsequent runs, compilation is much faster and the subsequent passes being the fastest.
52+
53+
On a Trillium v6e-4, you should expect ~9 sec / 4 images or 2.25 sec / image (as devices run generation in parallel):
54+
55+
```bash
56+
WARNING:root:libtpu.so and TPU device found. Setting PJRT_DEVICE=TPU.
57+
Loading checkpoint shards: 100%|███████████████████████████████| 2/2 [00:00<00:00, 7.01it/s]
58+
Loading pipeline components...: 40%|██████████▍ | 2/5 [00:00<00:00, 3.78it/s]You set `add_prefix_space`. The tokenizer needs to be converted from the slow tokenizers
59+
Loading pipeline components...: 100%|██████████████████████████| 5/5 [00:00<00:00, 6.72it/s]
60+
2025-01-10 00:51:25 [info ] loading flux from black-forest-labs/FLUX.1-dev
61+
2025-01-10 00:51:25 [info ] loading flux from black-forest-labs/FLUX.1-dev
62+
2025-01-10 00:51:26 [info ] loading flux from black-forest-labs/FLUX.1-dev
63+
2025-01-10 00:51:26 [info ] loading flux from black-forest-labs/FLUX.1-dev
64+
Loading pipeline components...: 100%|██████████████████████████| 3/3 [00:00<00:00, 4.29it/s]
65+
Loading pipeline components...: 100%|██████████████████████████| 3/3 [00:00<00:00, 3.26it/s]
66+
Loading pipeline components...: 100%|██████████████████████████| 3/3 [00:00<00:00, 3.27it/s]
67+
Loading pipeline components...: 100%|██████████████████████████| 3/3 [00:00<00:00, 3.25it/s]
68+
2025-01-10 00:51:34 [info ] starting compilation run...
69+
2025-01-10 00:51:35 [info ] starting compilation run...
70+
2025-01-10 00:51:37 [info ] starting compilation run...
71+
2025-01-10 00:51:37 [info ] starting compilation run...
72+
2025-01-10 00:52:52 [info ] compilation took 78.5155531649998 sec.
73+
2025-01-10 00:52:53 [info ] starting inference run...
74+
2025-01-10 00:52:57 [info ] compilation took 79.52986721400157 sec.
75+
2025-01-10 00:52:57 [info ] compilation took 81.91776501700042 sec.
76+
2025-01-10 00:52:57 [info ] compilation took 80.24951512600092 sec.
77+
2025-01-10 00:52:57 [info ] starting inference run...
78+
2025-01-10 00:52:57 [info ] starting inference run...
79+
2025-01-10 00:52:58 [info ] starting inference run...
80+
2025-01-10 00:53:22 [info ] inference time: 25.112665320000815
81+
2025-01-10 00:53:30 [info ] inference time: 7.7019307739992655
82+
2025-01-10 00:53:38 [info ] inference time: 7.693858365000779
83+
2025-01-10 00:53:46 [info ] inference time: 7.690621814001133
84+
2025-01-10 00:53:53 [info ] inference time: 7.679490454000188
85+
2025-01-10 00:54:01 [info ] inference time: 7.68949568500102
86+
2025-01-10 00:54:09 [info ] inference time: 7.686633744000574
87+
2025-01-10 00:54:16 [info ] inference time: 7.696786873999372
88+
2025-01-10 00:54:24 [info ] inference time: 7.691988694999964
89+
2025-01-10 00:54:32 [info ] inference time: 7.700649563999832
90+
2025-01-10 00:54:39 [info ] inference time: 7.684993574001055
91+
2025-01-10 00:54:47 [info ] inference time: 7.68343457499941
92+
2025-01-10 00:54:55 [info ] inference time: 7.667921153999487
93+
2025-01-10 00:55:02 [info ] inference time: 7.683585194001353
94+
2025-01-10 00:55:06 [info ] avg. inference over 15 iterations took 8.61202360273334 sec.
95+
2025-01-10 00:55:07 [info ] avg. inference over 15 iterations took 8.952725123600006 sec.
96+
2025-01-10 00:55:10 [info ] inference time: 7.673799695001435
97+
2025-01-10 00:55:10 [info ] avg. inference over 15 iterations took 8.849190365400379 sec.
98+
2025-01-10 00:55:10 [info ] saved metric information as /tmp/metrics_report.txt
99+
2025-01-10 00:55:12 [info ] avg. inference over 15 iterations took 8.940161458400205 sec.
100+
```
Lines changed: 120 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,120 @@
1+
from argparse import ArgumentParser
2+
from pathlib import Path
3+
from time import perf_counter
4+
5+
import structlog
6+
import torch
7+
import torch_xla.core.xla_model as xm
8+
import torch_xla.debug.metrics as met
9+
import torch_xla.debug.profiler as xp
10+
import torch_xla.distributed.xla_multiprocessing as xmp
11+
import torch_xla.runtime as xr
12+
13+
from diffusers import FluxPipeline
14+
15+
16+
logger = structlog.get_logger()
17+
metrics_filepath = "/tmp/metrics_report.txt"
18+
19+
20+
def _main(index, args, text_pipe, ckpt_id):
21+
cache_path = Path("/tmp/data/compiler_cache_tRiLlium_eXp")
22+
cache_path.mkdir(parents=True, exist_ok=True)
23+
xr.initialize_cache(str(cache_path), readonly=False)
24+
25+
profile_path = Path("/tmp/data/profiler_out_tRiLlium_eXp")
26+
profile_path.mkdir(parents=True, exist_ok=True)
27+
profiler_port = 9012
28+
profile_duration = args.profile_duration
29+
if args.profile:
30+
logger.info(f"starting profiler on port {profiler_port}")
31+
_ = xp.start_server(profiler_port)
32+
device0 = xm.xla_device()
33+
34+
logger.info(f"loading flux from {ckpt_id}")
35+
flux_pipe = FluxPipeline.from_pretrained(
36+
ckpt_id, text_encoder=None, tokenizer=None, text_encoder_2=None, tokenizer_2=None, torch_dtype=torch.bfloat16
37+
).to(device0)
38+
flux_pipe.transformer.enable_xla_flash_attention(partition_spec=("data", None, None, None), is_flux=True)
39+
40+
prompt = "photograph of an electronics chip in the shape of a race car with trillium written on its side"
41+
width = args.width
42+
height = args.height
43+
guidance = args.guidance
44+
n_steps = 4 if args.schnell else 28
45+
46+
logger.info("starting compilation run...")
47+
ts = perf_counter()
48+
with torch.no_grad():
49+
prompt_embeds, pooled_prompt_embeds, text_ids = text_pipe.encode_prompt(
50+
prompt=prompt, prompt_2=None, max_sequence_length=512
51+
)
52+
prompt_embeds = prompt_embeds.to(device0)
53+
pooled_prompt_embeds = pooled_prompt_embeds.to(device0)
54+
55+
image = flux_pipe(
56+
prompt_embeds=prompt_embeds,
57+
pooled_prompt_embeds=pooled_prompt_embeds,
58+
num_inference_steps=28,
59+
guidance_scale=guidance,
60+
height=height,
61+
width=width,
62+
).images[0]
63+
logger.info(f"compilation took {perf_counter() - ts} sec.")
64+
image.save("/tmp/compile_out.png")
65+
66+
base_seed = 4096 if args.seed is None else args.seed
67+
seed_range = 1000
68+
unique_seed = base_seed + index * seed_range
69+
xm.set_rng_state(seed=unique_seed, device=device0)
70+
times = []
71+
logger.info("starting inference run...")
72+
for _ in range(args.itters):
73+
ts = perf_counter()
74+
with torch.no_grad():
75+
prompt_embeds, pooled_prompt_embeds, text_ids = text_pipe.encode_prompt(
76+
prompt=prompt, prompt_2=None, max_sequence_length=512
77+
)
78+
prompt_embeds = prompt_embeds.to(device0)
79+
pooled_prompt_embeds = pooled_prompt_embeds.to(device0)
80+
81+
if args.profile:
82+
xp.trace_detached(f"localhost:{profiler_port}", str(profile_path), duration_ms=profile_duration)
83+
image = flux_pipe(
84+
prompt_embeds=prompt_embeds,
85+
pooled_prompt_embeds=pooled_prompt_embeds,
86+
num_inference_steps=n_steps,
87+
guidance_scale=guidance,
88+
height=height,
89+
width=width,
90+
).images[0]
91+
inference_time = perf_counter() - ts
92+
if index == 0:
93+
logger.info(f"inference time: {inference_time}")
94+
times.append(inference_time)
95+
logger.info(f"avg. inference over {args.itters} iterations took {sum(times)/len(times)} sec.")
96+
image.save(f"/tmp/inference_out-{index}.png")
97+
if index == 0:
98+
metrics_report = met.metrics_report()
99+
with open(metrics_filepath, "w+") as fout:
100+
fout.write(metrics_report)
101+
logger.info(f"saved metric information as {metrics_filepath}")
102+
103+
104+
if __name__ == "__main__":
105+
parser = ArgumentParser()
106+
parser.add_argument("--schnell", action="store_true", help="run flux schnell instead of dev")
107+
parser.add_argument("--width", type=int, default=1024, help="width of the image to generate")
108+
parser.add_argument("--height", type=int, default=1024, help="height of the image to generate")
109+
parser.add_argument("--guidance", type=float, default=3.5, help="gauidance strentgh for dev")
110+
parser.add_argument("--seed", type=int, default=None, help="seed for inference")
111+
parser.add_argument("--profile", action="store_true", help="enable profiling")
112+
parser.add_argument("--profile-duration", type=int, default=10000, help="duration for profiling in msec.")
113+
parser.add_argument("--itters", type=int, default=15, help="tiems to run inference and get avg time in sec.")
114+
args = parser.parse_args()
115+
if args.schnell:
116+
ckpt_id = "black-forest-labs/FLUX.1-schnell"
117+
else:
118+
ckpt_id = "black-forest-labs/FLUX.1-dev"
119+
text_pipe = FluxPipeline.from_pretrained(ckpt_id, transformer=None, vae=None, torch_dtype=torch.bfloat16).to("cpu")
120+
xmp.spawn(_main, args=(args, text_pipe, ckpt_id))

src/diffusers/models/attention_processor.py

Lines changed: 111 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -297,7 +297,10 @@ def __init__(
297297
self.set_processor(processor)
298298

299299
def set_use_xla_flash_attention(
300-
self, use_xla_flash_attention: bool, partition_spec: Optional[Tuple[Optional[str], ...]] = None
300+
self,
301+
use_xla_flash_attention: bool,
302+
partition_spec: Optional[Tuple[Optional[str], ...]] = None,
303+
is_flux=False,
301304
) -> None:
302305
r"""
303306
Set whether to use xla flash attention from `torch_xla` or not.
@@ -316,7 +319,10 @@ def set_use_xla_flash_attention(
316319
elif is_spmd() and is_torch_xla_version("<", "2.4"):
317320
raise "flash attention pallas kernel using SPMD is supported from torch_xla version 2.4"
318321
else:
319-
processor = XLAFlashAttnProcessor2_0(partition_spec)
322+
if is_flux:
323+
processor = XLAFluxFlashAttnProcessor2_0(partition_spec)
324+
else:
325+
processor = XLAFlashAttnProcessor2_0(partition_spec)
320326
else:
321327
processor = (
322328
AttnProcessor2_0() if hasattr(F, "scaled_dot_product_attention") and self.scale_qk else AttnProcessor()
@@ -2318,9 +2324,8 @@ def __call__(
23182324
query = apply_rotary_emb(query, image_rotary_emb)
23192325
key = apply_rotary_emb(key, image_rotary_emb)
23202326

2321-
hidden_states = F.scaled_dot_product_attention(
2322-
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
2323-
)
2327+
hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False)
2328+
23242329
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
23252330
hidden_states = hidden_states.to(query.dtype)
23262331

@@ -2522,6 +2527,7 @@ def __call__(
25222527
key = apply_rotary_emb(key, image_rotary_emb)
25232528

25242529
hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False)
2530+
25252531
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
25262532
hidden_states = hidden_states.to(query.dtype)
25272533

@@ -3422,6 +3428,106 @@ def __call__(
34223428
return hidden_states
34233429

34243430

3431+
class XLAFluxFlashAttnProcessor2_0:
3432+
r"""
3433+
Processor for implementing scaled dot-product attention with pallas flash attention kernel if using `torch_xla`.
3434+
"""
3435+
3436+
def __init__(self, partition_spec: Optional[Tuple[Optional[str], ...]] = None):
3437+
if not hasattr(F, "scaled_dot_product_attention"):
3438+
raise ImportError(
3439+
"XLAFlashAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
3440+
)
3441+
if is_torch_xla_version("<", "2.3"):
3442+
raise ImportError("XLA flash attention requires torch_xla version >= 2.3.")
3443+
if is_spmd() and is_torch_xla_version("<", "2.4"):
3444+
raise ImportError("SPMD support for XLA flash attention needs torch_xla version >= 2.4.")
3445+
self.partition_spec = partition_spec
3446+
3447+
def __call__(
3448+
self,
3449+
attn: Attention,
3450+
hidden_states: torch.FloatTensor,
3451+
encoder_hidden_states: torch.FloatTensor = None,
3452+
attention_mask: Optional[torch.FloatTensor] = None,
3453+
image_rotary_emb: Optional[torch.Tensor] = None,
3454+
) -> torch.FloatTensor:
3455+
batch_size, _, _ = hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
3456+
3457+
# `sample` projections.
3458+
query = attn.to_q(hidden_states)
3459+
key = attn.to_k(hidden_states)
3460+
value = attn.to_v(hidden_states)
3461+
3462+
inner_dim = key.shape[-1]
3463+
head_dim = inner_dim // attn.heads
3464+
3465+
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
3466+
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
3467+
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
3468+
3469+
if attn.norm_q is not None:
3470+
query = attn.norm_q(query)
3471+
if attn.norm_k is not None:
3472+
key = attn.norm_k(key)
3473+
3474+
# the attention in FluxSingleTransformerBlock does not use `encoder_hidden_states`
3475+
if encoder_hidden_states is not None:
3476+
# `context` projections.
3477+
encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states)
3478+
encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
3479+
encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
3480+
3481+
encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view(
3482+
batch_size, -1, attn.heads, head_dim
3483+
).transpose(1, 2)
3484+
encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view(
3485+
batch_size, -1, attn.heads, head_dim
3486+
).transpose(1, 2)
3487+
encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view(
3488+
batch_size, -1, attn.heads, head_dim
3489+
).transpose(1, 2)
3490+
3491+
if attn.norm_added_q is not None:
3492+
encoder_hidden_states_query_proj = attn.norm_added_q(encoder_hidden_states_query_proj)
3493+
if attn.norm_added_k is not None:
3494+
encoder_hidden_states_key_proj = attn.norm_added_k(encoder_hidden_states_key_proj)
3495+
3496+
# attention
3497+
query = torch.cat([encoder_hidden_states_query_proj, query], dim=2)
3498+
key = torch.cat([encoder_hidden_states_key_proj, key], dim=2)
3499+
value = torch.cat([encoder_hidden_states_value_proj, value], dim=2)
3500+
3501+
if image_rotary_emb is not None:
3502+
from .embeddings import apply_rotary_emb
3503+
3504+
query = apply_rotary_emb(query, image_rotary_emb)
3505+
key = apply_rotary_emb(key, image_rotary_emb)
3506+
3507+
query /= math.sqrt(head_dim)
3508+
hidden_states = flash_attention(query, key, value, causal=False)
3509+
3510+
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
3511+
hidden_states = hidden_states.to(query.dtype)
3512+
3513+
if encoder_hidden_states is not None:
3514+
encoder_hidden_states, hidden_states = (
3515+
hidden_states[:, : encoder_hidden_states.shape[1]],
3516+
hidden_states[:, encoder_hidden_states.shape[1] :],
3517+
)
3518+
3519+
# linear proj
3520+
hidden_states = attn.to_out[0](hidden_states)
3521+
# dropout
3522+
hidden_states = attn.to_out[1](hidden_states)
3523+
3524+
encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
3525+
3526+
return hidden_states, encoder_hidden_states
3527+
else:
3528+
return hidden_states
3529+
3530+
34253531
class MochiVaeAttnProcessor2_0:
34263532
r"""
34273533
Attention processor used in Mochi VAE.

0 commit comments

Comments
 (0)