You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
The `flux_inference` script shows how to do image generation using Flux on TPU devices using PyTorch/XLA. It uses the pallas kernel for flash attention for faster generation.
4
+
5
+
It has been tested on [Trillium](https://cloud.google.com/blog/products/compute/introducing-trillium-6th-gen-tpus) TPU versions. No other TPU types have been tested.
6
+
7
+
## Create TPU
8
+
9
+
To create a TPU on Google Cloud, follow [this guide](https://cloud.google.com/tpu/docs/v6e)
Run the following command to authenticate your token in order to download Flux weights.
40
+
41
+
```bash
42
+
huggingface-cli login
43
+
```
44
+
45
+
Then run:
46
+
47
+
```bash
48
+
python flux_inference.py
49
+
```
50
+
51
+
The script loads the text encoders onto the CPU and the Flux transformer and VAE models onto the TPU. The first time the script runs, the compilation time is longer, while the cache stores the compiled programs. On subsequent runs, compilation is much faster and the subsequent passes being the fastest.
52
+
53
+
On a Trillium v6e-4, you should expect ~9 sec / 4 images or 2.25 sec / image (as devices run generation in parallel):
54
+
55
+
```bash
56
+
WARNING:root:libtpu.so and TPU device found. Setting PJRT_DEVICE=TPU.
Loading pipeline components...: 40%|██████████▍ | 2/5 [00:00<00:00, 3.78it/s]You set`add_prefix_space`. The tokenizer needs to be converted from the slow tokenizers
0 commit comments