@@ -10,29 +10,29 @@ an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express o
10
10
specific language governing permissions and limitations under the License.
11
11
-->
12
12
13
- # Torch2 .0 support in Diffusers
13
+ # Accelerated PyTorch 2 .0 support in Diffusers
14
14
15
15
Starting from version `0.13.0`, Diffusers supports the latest optimization from the upcoming [PyTorch 2.0](https://pytorch.org/get-started/pytorch-2.0/) release. These include:
16
- 1. Support for native flash and memory-efficient attention without any extra dependencies.
17
- 2. [torch.compile](https://pytorch.org/tutorials/intermediate/torch_compile_tutorial.html) support for compiling individual models for extra performance boost .
16
+ 1. Support for accelerated transformers implementation with memory-efficient attention – no extra dependencies required .
17
+ 2. [torch.compile](https://pytorch.org/tutorials/intermediate/torch_compile_tutorial.html) support for extra performance boost when individual models are compiled .
18
18
19
19
20
20
## Installation
21
- To benefit from the native efficient attention and `torch.compile`, we will need to install the nightly version of PyTorch as the stable version is yet to be released. The first step is to install CUDA11 .7 or CUDA11 .8,
22
- as torch2 .0 does not support the previous versions. Once CUDA is installed, torch nightly can be installed using:
21
+ To benefit from the accelerated transformers implementation and `torch.compile`, we will need to install the nightly version of PyTorch, as the stable version is yet to be released. The first step is to install CUDA 11 .7 or CUDA 11 .8,
22
+ as PyTorch 2 .0 does not support the previous versions. Once CUDA is installed, torch nightly can be installed using:
23
23
24
24
```bash
25
25
pip install --pre torch torchvision --index-url https://download.pytorch.org/whl/nightly/cu117
26
26
```
27
27
28
- ## Using efficient attention and torch.compile.
28
+ ## Using accelerated transformers and torch.compile.
29
29
30
30
31
- 1. **Efficient Attention **
31
+ 1. **Accelerated Transformers implementation **
32
32
33
- Efficient attention is implemented via the [`torch.nn.functional.scaled_dot_product_attention`](https://pytorch.org/docs/master/generated/torch.nn.functional.scaled_dot_product_attention) function, which automatically enables flash/memory efficient attention, depending on the input and the GPU type. This is the same as the `memory_efficient_attention` from [xFormers](https://github.com/facebookresearch/xformers) but built natively into PyTorch.
33
+ PyTorch 2.0 includes an optimized and memory-efficient attention implementation through the [`torch.nn.functional.scaled_dot_product_attention`](https://pytorch.org/docs/master/generated/torch.nn.functional.scaled_dot_product_attention) function, which automatically enables several optimizations depending on the inputs and the GPU type. This is similar to the `memory_efficient_attention` from [xFormers](https://github.com/facebookresearch/xformers), but built natively into PyTorch.
34
34
35
- Efficient attention will be enabled by default in Diffusers if torch2 .0 is installed and if `torch.nn.functional.scaled_dot_product_attention` is available. To use it, you can install torch2.0 as suggested above and use the pipeline. For example:
35
+ These optimizations will be enabled by default in Diffusers if PyTorch 2 .0 is installed and if `torch.nn.functional.scaled_dot_product_attention` is available. To use it, just install `torch 2.0` as suggested above and simply use the pipeline. For example:
36
36
37
37
```Python
38
38
import torch
@@ -59,12 +59,12 @@ pip install --pre torch torchvision --index-url https://download.pytorch.org/whl
59
59
image = pipe(prompt).images[0]
60
60
```
61
61
62
- This should be as fast and memory efficient as `xFormers`.
62
+ This should be as fast and memory efficient as `xFormers`. More details [in our benchmark](#benchmark).
63
63
64
64
65
65
2. **torch.compile**
66
66
67
- To get an additional speedup, we can use the new `torch.compile` feature. To do so, we wrap our `unet` with `torch.compile`. For more information and different options, refer to the
67
+ To get an additional speedup, we can use the new `torch.compile` feature. To do so, we simply wrap our `unet` with `torch.compile`. For more information and different options, refer to the
68
68
[torch compile docs](https://pytorch.org/tutorials/intermediate/torch_compile_tutorial.html).
69
69
70
70
```python
@@ -81,22 +81,23 @@ pip install --pre torch torchvision --index-url https://download.pytorch.org/whl
81
81
images = pipe(prompt, num_inference_steps =steps, num_images_per_prompt =batch_size).images
82
82
```
83
83
84
- Depending on the type of GPU it can give between 2-9% speed-up over efficient attention. But note that as of now the speed-up is mostly noticeable on the more recent GPU architectures, such as in the A100.
84
+ Depending on the type of GPU, `compile()` can yield between 2-9% of _additional speed-up_ over the accelerated transformer optimizations. Note, however, that compilation is able to squeeze more performance improvements in more recent GPU architectures such as Ampere ( A100, 3090), Ada (4090) and Hopper (H100) .
85
85
86
- Note that compilation will also take some time to complete, so it is best suited for situations where you need to prepare your pipeline once and then perform the same type of inference operations multiple times.
86
+ Compilation takes some time to complete, so it is best suited for situations where you need to prepare your pipeline once and then perform the same type of inference operations multiple times.
87
87
88
88
89
89
## Benchmark
90
90
91
91
We conducted a simple benchmark on different GPUs to compare vanilla attention, xFormers, `torch.nn.functional.scaled_dot_product_attention` and `torch.compile+torch.nn.functional.scaled_dot_product_attention`.
92
- For the benchmark we used the the [stable-diffusion-v1-4](https://huggingface.co/CompVis/stable-diffusion-v1-4) model with 50 steps. `xFormers` benchmark is done using the `torch==1.13.1` version. The table below summarizes the result that we got.
93
- The `Speed over xformers` columns denotes the speed-up gained over `xFormers` using the `torch.compile+torch.nn.functional.scaled_dot_product_attention`.
92
+ For the benchmark we used the the [stable-diffusion-v1-4](https://huggingface.co/CompVis/stable-diffusion-v1-4) model with 50 steps. The `xFormers` benchmark is done using the `torch==1.13.1` version, while the accelerated transformers optimizations are tested using nightly versions of PyTorch 2.0. The tables below summarize the results we got.
93
+
94
+ The `Speed over xformers` columns denote the speed-up gained over `xFormers` using the `torch.compile+torch.nn.functional.scaled_dot_product_attention`.
94
95
95
96
96
97
### FP16 benchmark
97
98
98
99
The table below shows the benchmark results for inference using `fp16`. As we can see, `torch.nn.functional.scaled_dot_product_attention` is as fast as `xFormers` (sometimes slightly faster/slower) on all the GPUs we tested.
99
- And using `torch.compile` gives further speed-up up to 10% over `xFormers`, but it's mostly noticeable on the A100 GPU.
100
+ And using `torch.compile` gives further speed-up of up of 10% over `xFormers`, but it's mostly noticeable on the A100 GPU.
100
101
101
102
___The time reported is in seconds.___
102
103
@@ -105,7 +106,7 @@ ___The time reported is in seconds.___
105
106
| A100 | 10 | 12.02 | 8.7 | 8.79 | 7.89 | 9.31 |
106
107
| A100 | 16 | 18.95 | 13.57 | 13.67 | 12.25 | 9.73 |
107
108
| A100 | 32 (1) | OOM | 26.56 | 26.68 | 24.08 | 9.34 |
108
- | A100 | 64(2) | | 52.51 | 53.03 | 47.81 | 8.95 |
109
+ | A100 | 64 | | 52.51 | 53.03 | 47.81 | 8.95 |
109
110
| | | | | | | |
110
111
| A10 | 4 | 13.94 | 9.81 | 10.01 | 9.35 | 4.69 |
111
112
| A10 | 8 | 27.09 | 19 | 19.53 | 18.33 | 3.53 |
@@ -137,13 +138,20 @@ ___The time reported is in seconds.___
137
138
| 3090 Ti | 16 | OOM | 26.1 | 26.28 | 25.46 | 2.45 |
138
139
| 3090 Ti | 32 (1) | | 51.78 | 52.04 | 49.15 | 5.08 |
139
140
| 3090 Ti | 64 (1) | | 112.02 | 112.33 | 103.91 | 7.24 |
141
+ | | | | | | | |
142
+ | 4090 | 4 | 10.48 | 8.37 | 8.32 | 8.01 | 4.30 |
143
+ | 4090 | 8 | 14.33 | 10.22 | 10.42 | 9.78 | 4.31 |
144
+ | 4090 | 16 | | 17.07 | 17.46 | 17.15 | -0.47 |
145
+ | 4090 | 32 (1) | | 39.03 | 39.86 | 37.97 | 2.72 |
146
+ | 4090 | 64 (1) | | 77.29 | 79.44 | 77.67 | -0.49 |
140
147
141
148
142
149
143
150
### FP32 benchmark
144
151
145
- The table below shows the benchmark results for inference using `fp32`. As we can see, `torch.nn.functional.scaled_dot_product_attention` is as fast as `xFormers` (sometimes slightly faster/slower) on all the GPUs we tested.
146
- Using `torch.compile` with efficient attention gives up to 18% performance improvement over `xFormers` in Ampere cards, and up to 20% over vanilla attention.
152
+ The table below shows the benchmark results for inference using `fp32`. In this case, `torch.nn.functional.scaled_dot_product_attention` is faster than `xFormers` on all the GPUs we tested.
153
+
154
+ Using `torch.compile` in addition to the accelerated transformers implementation can yield up to 19% performance improvement over `xFormers` in Ampere and Ada cards, and up to 20% (Ampere) or 28% (Ada) over vanilla attention.
147
155
148
156
| GPU | Batch Size | Vanilla Attention | xFormers | PyTorch2.0 SDPA | SDPA + torch.compile | Speed over xformers (%) | Speed over vanilla (%) |
149
157
| --- | --- | --- | --- | --- | --- | --- | --- |
@@ -173,7 +181,7 @@ Using `torch.compile` with efficient attention gives up to 18% performance impro
173
181
| | | | | | | |
174
182
| 3090 | 1 | 7.09 | 6.78 | 6.11 | 6.03 | 11.06 | 14.95 |
175
183
| 3090 | 4 | 22.69 | 21.45 | 18.67 | 18.09 | 15.66 | 20.27 |
176
- | 3090 | 8 (2) | | 42.59 | 36.75 | 35.59 | 16.44 | |
184
+ | 3090 | 8 | | 42.59 | 36.75 | 35.59 | 16.44 | |
177
185
| 3090 | 16 | | 85.35 | 72.37 | 70.25 | 17.69 | |
178
186
| 3090 | 32 (1) | | 162.05 | 138.99 | 134.53 | 16.98 | |
179
187
| 3090 | 48 | | 241.91 | 207.75 | | 14.12 | |
@@ -185,12 +193,12 @@ Using `torch.compile` with efficient attention gives up to 18% performance impro
185
193
| 3090 Ti | 32 (1) | | 142.55 | 124.44 | 120.74 | 15.30 | |
186
194
| 3090 Ti | 48 | | 213.19 | 186.55 | | 12.50 | |
187
195
| | | | | | | |
188
- | 4090 | 1 | 5.54 | 4.99 | 4.51 | | | |
189
- | 4090 | 4 | 13.67 | 11.4 | 10.3 | | | |
190
- | 4090 | 8 (2) | | 19.79 | 17.13 | | | |
191
- | 4090 | 16 | | 38.62 | 33.14 | | | |
192
- | 4090 | 32 (1) | | 76.57 | 65.96 | | | |
193
- | 4090 | 48 | | 114.44 | 98.78 | | | |
196
+ | 4090 | 1 | 5.54 | 4.99 | 4.51 | 4.44 | 11.02 | 19.86 |
197
+ | 4090 | 4 | 13.67 | 11.4 | 10.3 | 9.84 | 13.68 | 28.02 |
198
+ | 4090 | 8 | | 19.79 | 17.13 | 16.19 | 18.19 | |
199
+ | 4090 | 16 | | 38.62 | 33.14 | 32.31 | 16.34 | |
200
+ | 4090 | 32 (1) | | 76.57 | 65.96 | 62.05 | 18.96 | |
201
+ | 4090 | 48 | | 114.44 | 98.78 | | 13.68 | |
194
202
195
203
196
204
0 commit comments