Skip to content

[docs] Memory optims #11385

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
May 1, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 1 addition & 3 deletions docs/source/en/_toctree.yml
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,6 @@
title: Load LoRAs for inference
- local: tutorials/fast_diffusion
title: Accelerate inference of text-to-image diffusion models
- local: tutorials/inference_with_big_models
title: Working with big models
title: Tutorials
- sections:
- local: using-diffusers/loading
Expand Down Expand Up @@ -180,7 +178,7 @@
title: Quantization Methods
- sections:
- local: optimization/fp16
title: Speed up inference
title: Accelerate inference
- local: optimization/memory
title: Reduce memory usage
- local: optimization/torch2.0
Expand Down
239 changes: 165 additions & 74 deletions docs/source/en/optimization/fp16.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,120 +10,211 @@ an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express o
specific language governing permissions and limitations under the License.
-->

# Speed up inference
# Accelerate inference

There are several ways to optimize Diffusers for inference speed, such as reducing the computational burden by lowering the data precision or using a lightweight distilled model. There are also memory-efficient attention implementations, [xFormers](xformers) and [scaled dot product attention](https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html) in PyTorch 2.0, that reduce memory usage which also indirectly speeds up inference. Different speed optimizations can be stacked together to get the fastest inference times.
Diffusion models are slow at inference because generation is an iterative process where noise is gradually refined into an image or video over a certain number of "steps". To speedup this process, you can try experimenting with different [schedulers](../api/schedulers/overview), reduce the precision of the model weights for faster computations, use more memory-efficient attention mechanisms, and more.

> [!TIP]
> Optimizing for inference speed or reduced memory usage can lead to improved performance in the other category, so you should try to optimize for both whenever you can. This guide focuses on inference speed, but you can learn more about lowering memory usage in the [Reduce memory usage](memory) guide.
Combine and use these techniques together to make inference faster than using any single technique on its own.

This guide will go over how to accelerate inference.

The inference times below are obtained from generating a single 512x512 image from the prompt "a photo of an astronaut riding a horse on mars" with 50 DDIM steps on a NVIDIA A100.
## Model data type

| setup | latency | speed-up |
|----------|---------|----------|
| baseline | 5.27s | x1 |
| tf32 | 4.14s | x1.27 |
| fp16 | 3.51s | x1.50 |
| combined | 3.41s | x1.54 |
The precision and data type of the model weights affect inference speed because a higher precision requires more memory to load and more time to perform the computations. PyTorch loads model weights in float32 or full precision by default, so changing the data type is a simple way to quickly get faster inference.

## TensorFloat-32
<hfoptions id="dtypes">
<hfoption id="bfloat16">

On Ampere and later CUDA devices, matrix multiplications and convolutions can use the [TensorFloat-32 (tf32)](https://blogs.nvidia.com/blog/2020/05/14/tensorfloat-32-precision-format/) mode for faster, but slightly less accurate computations. By default, PyTorch enables tf32 mode for convolutions but not matrix multiplications. Unless your network requires full float32 precision, we recommend enabling tf32 for matrix multiplications. It can significantly speed up computations with typically negligible loss in numerical accuracy.
bfloat16 is similar to float16 but it is more robust to numerical errors. Hardware support for bfloat16 varies, but most modern GPUs are capable of supporting bfloat16.

```python
```py
import torch
from diffusers import StableDiffusionXLPipeline

torch.backends.cuda.matmul.allow_tf32 = True
pipeline = StableDiffusionXLPipeline.from_pretrained(
"stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.bfloat16
).to("cuda")

prompt = "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k"
pipeline(prompt, num_inference_steps=30).images[0]
```

</hfoption>
<hfoption id="float16">

float16 is similar to bfloat16 but may be more prone to numerical errors.

```py
import torch
from diffusers import StableDiffusionXLPipeline

pipeline = StableDiffusionXLPipeline.from_pretrained(
"stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16
).to("cuda")

prompt = "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k"
pipeline(prompt, num_inference_steps=30).images[0]
```

Learn more about tf32 in the [Mixed precision training](https://huggingface.co/docs/transformers/en/perf_train_gpu_one#tf32) guide.
</hfoption>
<hfoption id="TensorFloat-32">

## Half-precision weights
[TensorFloat-32 (tf32)](https://blogs.nvidia.com/blog/2020/05/14/tensorfloat-32-precision-format/) mode is supported on NVIDIA Ampere GPUs and it computes the convolution and matrix multiplication operations in tf32. Storage and other operations are kept in float32. This enables significantly faster computations when combined with bfloat16 or float16.

To save GPU memory and get more speed, set `torch_dtype=torch.float16` to load and run the model weights directly with half-precision weights.
PyTorch only enables tf32 mode for convolutions by default and you'll need to explicitly enable it for matrix multiplications.

```Python
```py
import torch
from diffusers import DiffusionPipeline
from diffusers import StableDiffusionXLPipeline

pipe = DiffusionPipeline.from_pretrained(
"stable-diffusion-v1-5/stable-diffusion-v1-5",
torch_dtype=torch.float16,
use_safetensors=True,
)
pipe = pipe.to("cuda")
torch.backends.cuda.matmul.allow_tf32 = True

pipeline = StableDiffusionXLPipeline.from_pretrained(
"stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.bfloat16
).to("cuda")

prompt = "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k"
pipeline(prompt, num_inference_steps=30).images[0]
```

> [!WARNING]
> Don't use [torch.autocast](https://pytorch.org/docs/stable/amp.html#torch.autocast) in any of the pipelines as it can lead to black images and is always slower than pure float16 precision.
Refer to the [mixed precision training](https://huggingface.co/docs/transformers/en/perf_train_gpu_one#mixed-precision) docs for more details.

## Distilled model
</hfoption>
</hfoptions>

You could also use a distilled Stable Diffusion model and autoencoder to speed up inference. During distillation, many of the UNet's residual and attention blocks are shed to reduce the model size by 51% and improve latency on CPU/GPU by 43%. The distilled model is faster and uses less memory while generating images of comparable quality to the full Stable Diffusion model.
## Scaled dot product attention

> [!TIP]
> Read the [Open-sourcing Knowledge Distillation Code and Weights of SD-Small and SD-Tiny](https://huggingface.co/blog/sd_distillation) blog post to learn more about how knowledge distillation training works to produce a faster, smaller, and cheaper generative model.
> Memory-efficient attention optimizes for inference speed *and* [memory usage](./memory#memory-efficient-attention)!

The inference times below are obtained from generating 4 images from the prompt "a photo of an astronaut riding a horse on mars" with 25 PNDM steps on a NVIDIA A100. Each generation is repeated 3 times with the distilled Stable Diffusion v1.4 model by [Nota AI](https://hf.co/nota-ai).
[Scaled dot product attention (SDPA)](https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html) implements several attention backends, [FlashAttention](https://github.com/Dao-AILab/flash-attention), [xFormers](https://github.com/facebookresearch/xformers), and a native C++ implementation. It automatically selects the most optimal backend for your hardware.

| setup | latency | speed-up |
|------------------------------|---------|----------|
| baseline | 6.37s | x1 |
| distilled | 4.18s | x1.52 |
| distilled + tiny autoencoder | 3.83s | x1.66 |

Let's load the distilled Stable Diffusion model and compare it against the original Stable Diffusion model.
SDPA is enabled by default if you're using PyTorch >= 2.0 and no additional changes are required to your code. You could try experimenting with other attention backends though if you'd like to choose your own. The example below uses the [torch.nn.attention.sdpa_kernel](https://pytorch.org/docs/stable/generated/torch.nn.attention.sdpa_kernel.html) context manager to enable efficient attention.

```py
from diffusers import StableDiffusionPipeline
from torch.nn.attention import SDPBackend, sdpa_kernel
import torch
from diffusers import StableDiffusionXLPipeline

distilled = StableDiffusionPipeline.from_pretrained(
"nota-ai/bk-sdm-small", torch_dtype=torch.float16, use_safetensors=True,
pipeline = StableDiffusionXLPipeline.from_pretrained(
"stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.bfloat16
).to("cuda")
prompt = "a golden vase with different flowers"
generator = torch.manual_seed(2023)
image = distilled("a golden vase with different flowers", num_inference_steps=25, generator=generator).images[0]
image

prompt = "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k"

with sdpa_kernel(SDPBackend.EFFICIENT_ATTENTION):
image = pipeline(prompt, num_inference_steps=30).images[0]
```

<div class="flex gap-4">
<div>
<img class="rounded-xl" src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/original_sd.png"/>
<figcaption class="mt-2 text-center text-sm text-gray-500">original Stable Diffusion</figcaption>
</div>
<div>
<img class="rounded-xl" src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/distilled_sd.png"/>
<figcaption class="mt-2 text-center text-sm text-gray-500">distilled Stable Diffusion</figcaption>
</div>
</div>
## torch.compile

### Tiny AutoEncoder
[torch.compile](https://pytorch.org/tutorials/intermediate/torch_compile_tutorial.html) accelerates inference by compiling PyTorch code and operations into optimized kernels. Diffusers typically compiles the more compute-intensive models like the UNet, transformer, or VAE.

To speed inference up even more, replace the autoencoder with a [distilled version](https://huggingface.co/sayakpaul/taesdxl-diffusers) of it.
Enable the following compiler settings for maximum speed (refer to the [full list](https://github.com/pytorch/pytorch/blob/main/torch/_inductor/config.py) for more options).

```py
import torch
from diffusers import AutoencoderTiny, StableDiffusionPipeline
from diffusers import StableDiffusionXLPipeline

torch._inductor.config.conv_1x1_as_mm = True
torch._inductor.config.coordinate_descent_tuning = True
torch._inductor.config.epilogue_fusion = False
torch._inductor.config.coordinate_descent_check_all_directions = True
```

Load and compile the UNet and VAE. There are several different modes you can choose from, but `"max-autotune"` optimizes for the fastest speed by compiling to a CUDA graph. CUDA graphs effectively reduces the overhead by launching multiple GPU operations through a single CPU operation.

distilled = StableDiffusionPipeline.from_pretrained(
"nota-ai/bk-sdm-small", torch_dtype=torch.float16, use_safetensors=True,
> [!TIP]
> With PyTorch 2.3.1, you can control the caching behavior of torch.compile. This is particularly beneficial for compilation modes like `"max-autotune"` which performs a grid-search over several compilation flags to find the optimal configuration. Learn more in the [Compile Time Caching in torch.compile](https://pytorch.org/tutorials/recipes/torch_compile_caching_tutorial.html) tutorial.

Changing the memory layout to [channels_last](./memory#torchchannels_last) also optimizes memory and inference speed.

```py
pipeline = StableDiffusionXLPipeline.from_pretrained(
"stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16
).to("cuda")
distilled.vae = AutoencoderTiny.from_pretrained(
"sayakpaul/taesd-diffusers", torch_dtype=torch.float16, use_safetensors=True,
pipeline.unet.to(memory_format=torch.channels_last)
pipeline.vae.to(memory_format=torch.channels_last)
pipeline.unet = torch.compile(
pipeline.unet, mode="max-autotune", fullgraph=True
)
pipeline.vae.decode = torch.compile(
pipeline.vae.decode,
mode="max-autotune",
fullgraph=True
)

prompt = "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k"
pipeline(prompt, num_inference_steps=30).images[0]
```

Compilation is slow the first time, but once compiled, it is significantly faster. Try to only use the compiled pipeline on the same type of inference operations. Calling the compiled pipeline on a different image size retriggers compilation which is slow and inefficient.

### Graph breaks

It is important to specify `fullgraph=True` in torch.compile to ensure there are no graph breaks in the underlying model. This allows you to take advantage of torch.compile without any performance degradation. For the UNet and VAE, this changes how you access the return variables.

```diff
- latents = unet(
- latents, timestep=timestep, encoder_hidden_states=prompt_embeds
-).sample

+ latents = unet(
+ latents, timestep=timestep, encoder_hidden_states=prompt_embeds, return_dict=False
+)[0]
```

### GPU sync

The `step()` function is [called](https://github.com/huggingface/diffusers/blob/1d686bac8146037e97f3fd8c56e4063230f71751/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py#L1228) on the scheduler each time after the denoiser makes a prediction, and the `sigmas` variable is [indexed](https://github.com/huggingface/diffusers/blob/1d686bac8146037e97f3fd8c56e4063230f71751/src/diffusers/schedulers/scheduling_euler_discrete.py#L476). When placed on the GPU, it introduces latency because of the communication sync between the CPU and GPU. It becomes more evident when the denoiser has already been compiled.

In general, the `sigmas` should [stay on the CPU](https://github.com/huggingface/diffusers/blob/35a969d297cba69110d175ee79c59312b9f49e1e/src/diffusers/schedulers/scheduling_euler_discrete.py#L240) to avoid the communication sync and latency.

## Dynamic quantization

[Dynamic quantization](https://pytorch.org/tutorials/recipes/recipes/dynamic_quantization.html) improves inference speed by reducing precision to enable faster math operations. This particular type of quantization determines how to scale the activations based on the data at runtime rather than using a fixed scaling factor. As a result, the scaling factor is more accurately aligned with the data.

The example below applies [dynamic int8 quantization](https://pytorch.org/tutorials/recipes/recipes/dynamic_quantization.html) to the UNet and VAE with the [torchao](../quantization/torchao) library.

> [!TIP]
> Refer to our [torchao](../quantization/torchao) docs to learn more about how to use the Diffusers torchao integration.

Configure the compiler tags for maximum speed.

```py
import torch
from torchao import apply_dynamic_quant
from diffusers import StableDiffusionXLPipeline

torch._inductor.config.conv_1x1_as_mm = True
torch._inductor.config.coordinate_descent_tuning = True
torch._inductor.config.epilogue_fusion = False
torch._inductor.config.coordinate_descent_check_all_directions = True
torch._inductor.config.force_fuse_int_mm_with_mul = True
torch._inductor.config.use_mixed_mm = True
```

Filter out some linear layers in the UNet and VAE which don't benefit from dynamic quantization with the [dynamic_quant_filter_fn](https://github.com/huggingface/diffusion-fast/blob/0f169640b1db106fe6a479f78c1ed3bfaeba3386/utils/pipeline_utils.py#L16).

```py
pipeline = StableDiffusionXLPipeline.from_pretrained(
"stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.bfloat16
).to("cuda")

prompt = "a golden vase with different flowers"
generator = torch.manual_seed(2023)
image = distilled("a golden vase with different flowers", num_inference_steps=25, generator=generator).images[0]
image
apply_dynamic_quant(pipeline.unet, dynamic_quant_filter_fn)
apply_dynamic_quant(pipeline.vae, dynamic_quant_filter_fn)

prompt = "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k"
pipeline(prompt, num_inference_steps=30).images[0]
```

<div class="flex justify-center">
<div>
<img class="rounded-xl" src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/distilled_sd_vae.png" />
<figcaption class="mt-2 text-center text-sm text-gray-500">distilled Stable Diffusion + Tiny AutoEncoder</figcaption>
</div>
</div>
## Fused projection matrices

More tiny autoencoder models for other Stable Diffusion models, like Stable Diffusion 3, are available from [madebyollin](https://huggingface.co/madebyollin).
> [!WARNING]
> The [fuse_qkv_projections](https://github.com/huggingface/diffusers/blob/58431f102cf39c3c8a569f32d71b2ea8caa461e1/src/diffusers/pipelines/pipeline_utils.py#L2034) method is experimental and support is limited to mostly Stable Diffusion pipelines. Take a look at this [PR](https://github.com/huggingface/diffusers/pull/6179) to learn more about how to enable it for other pipelines

An input is projected into three subspaces, represented by the projection matrices Q, K, and V, in an attention block. These projections are typically calculated separately, but you can horizontally combine these into a single matrix and perform the projection in a single step. It increases the size of the matrix multiplications of the input projections and also improves the impact of quantization.

```py
pipeline.fuse_qkv_projections()
```
Loading