From 2f53bc597830524e1c205e0e99355894aa7eeff5 Mon Sep 17 00:00:00 2001 From: chengzeyi Date: Mon, 13 Jan 2025 10:55:03 +0800 Subject: [PATCH 01/30] add para_attn_flux.md and para_attn_hunyuan_video.md --- docs/source/en/_toctree.yml | 4 + docs/source/en/optimization/para_attn_flux.md | 285 +++++++++++++++ .../optimization/para_attn_hunyuan_video.md | 326 ++++++++++++++++++ 3 files changed, 615 insertions(+) create mode 100644 docs/source/en/optimization/para_attn_flux.md create mode 100644 docs/source/en/optimization/para_attn_hunyuan_video.md diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml index a2b411c8fcb0..502b6307404a 100644 --- a/docs/source/en/_toctree.yml +++ b/docs/source/en/_toctree.yml @@ -179,6 +179,10 @@ title: TGATE - local: optimization/xdit title: xDiT + - local: optimization/para_attn_flux + title: Fast FLUX Inference with ParaAttention + - local: optimization/para_attn_hunyuan_video + title: Fast HunyuanVideo Inference with ParaAttention - sections: - local: using-diffusers/stable_diffusion_jax_how_to title: JAX/Flax diff --git a/docs/source/en/optimization/para_attn_flux.md b/docs/source/en/optimization/para_attn_flux.md new file mode 100644 index 000000000000..402b543885db --- /dev/null +++ b/docs/source/en/optimization/para_attn_flux.md @@ -0,0 +1,285 @@ +# Fastest FLUX Inference with ParaAttention + +[![](https://mermaid.ink/img/pako:eNqNUu9r2zAQ_VcOQUjDbMeWEycxbNAmKxS6kTH2g1X9oFgXW2BLQVbaZMH_-85J6WDsQ_VB0j096U733okVViHLWRiGwhTWbHWZCwM0DsdlJZ1_ifrxrJWvckiSOP4LVqjLyufApwSeXxkMTtpogk5DX2GDwxyGW-uw9cMOusFAmMOx6J8ON-glVNbp39Z4WQvjta8RBPsh6xq8bhDoIpRo0EmvTQnWIOhGlkjF-Apu77_9jJJQ4VMAScwnh34KgM-h9bhrA-LD5-93q7truOdxcLm0lk5ee4-UzRrBqJxQHnQLD4LdyBZrbVCwgKq4vVnKosIrp_z7OIrnoxd4PYfVl_9REj6Cd_C2c9os11d89DbehHiPwhwvlQrWImmlWsEghjD8ACk1X5iNdPDAsyjNqB2zKE5oSaMJfXwWTQmbRAseQBot4kcWsAZdI7Ui8U-9nIKd1RIsp-2GGtG3piOe3Hv79WgKlnu3x4A5uy8rlm9l3VK03ynpcaVl6WTziu6k-WVt8w_ro9LeulewtlIhhSfmj7vehKVuPSW82LDH964muPJ-1-bjcX8clSThfhMVthm3WvU2qp4W2Tjj2VzyFLNZKqdpqopNsphv-STZqlmccMm6LmB4zv_p4viz8bs_BMbpYw?type=png)](https://mermaid.live/edit#pako:eNqNUu9r2zAQ_VcOQUjDbMeWEycxbNAmKxS6kTH2g1X9oFgXW2BLQVbaZMH_-85J6WDsQ_VB0j096U733okVViHLWRiGwhTWbHWZCwM0DsdlJZ1_ifrxrJWvckiSOP4LVqjLyufApwSeXxkMTtpogk5DX2GDwxyGW-uw9cMOusFAmMOx6J8ON-glVNbp39Z4WQvjta8RBPsh6xq8bhDoIpRo0EmvTQnWIOhGlkjF-Apu77_9jJJQ4VMAScwnh34KgM-h9bhrA-LD5-93q7truOdxcLm0lk5ee4-UzRrBqJxQHnQLD4LdyBZrbVCwgKq4vVnKosIrp_z7OIrnoxd4PYfVl_9REj6Cd_C2c9os11d89DbehHiPwhwvlQrWImmlWsEghjD8ACk1X5iNdPDAsyjNqB2zKE5oSaMJfXwWTQmbRAseQBot4kcWsAZdI7Ui8U-9nIKd1RIsp-2GGtG3piOe3Hv79WgKlnu3x4A5uy8rlm9l3VK03ynpcaVl6WTziu6k-WVt8w_ro9LeulewtlIhhSfmj7vehKVuPSW82LDH964muPJ-1-bjcX8clSThfhMVthm3WvU2qp4W2Tjj2VzyFLNZKqdpqopNsphv-STZqlmccMm6LmB4zv_p4viz8bs_BMbpYw) + +## Introduction + +During the past year, we have seen the rapid development of image generation models with the release of several open-source models, such as [FLUX.1-dev](https://huggingface.co/black-forest-labs/FLUX.1-dev) and [SD3.5-large](https://huggingface.co/stabilityai/stable-diffusion-3.5-large). +It is very exciting to see that open source image generation models are going to beat closed source. +However, the inference speed of these models is still a bottleneck for real-time applications and deployment. + +In this article, we will use [ParaAttention](https://github.com/chengzeyi/ParaAttention), a library implements **Context Parallelism** and **First Block Cache**, as well as other techniques like `torch.compile` and **FP8 Dynamic Quantization**, to achieve the fastest inference speed for FLUX.1-dev. + +**We set up our experiments on NVIDIA L20 GPUs, which only have PCIe support.** +**If you have NVIDIA A100 or H100 GPUs with NVLink support, you can achieve a better speedup with context parallelism, especially when the number of GPUs is large.** + +## FLUX.1-dev Inference with `diffusers` + +Like many other generative AI models, FLUX.1-dev has its official code repository and is supported by other frameworks like `diffusers` and `ComfyUI`. +In this article, we will focus on optimizing the inference speed of FLUX.1-dev with `diffusers`. +To use FLUX.1-dev with `diffusers`, we need to install its latest version: + +```bash +pip3 install -U diffusers +``` + +Then, we can load the model and generate images with the following code: + +```python +import time +import torch +from diffusers import FluxPipeline + +pipe = FluxPipeline.from_pretrained( + "black-forest-labs/FLUX.1-dev", + torch_dtype=torch.bfloat16, +).to("cuda") + +# Enable memory savings +# pipe.enable_model_cpu_offload() +# pipe.enable_sequential_cpu_offload() + +begin = time.time() +image = pipe( + "A cat holding a sign that says hello world", + num_inference_steps=28, +).images[0] +end = time.time() +print(f"Time: {end - begin:.2f}s") + +print("Saving image to flux.png") +image.save("flux.png") +``` + +This is our baseline. +On one single NVIDIA L20 GPU, we can generate 1 image with 1024x1024 resolution in 28 inference steps in 26.36 seconds. + +## Apply First Block Cache on FLUX.1-dev + +By caching the output of the transformer blocks in the transformer model and resuing them in the next inference steps, we can reduce the computation cost and make the inference faster. +However, it is hard to decide when to reuse the cache to ensure the quality of the generated image. +Recently, [TeaCache](https://github.com/ali-vilab/TeaCache) suggests that we can use the timestep embedding to approximate the difference among model outputs. +And [AdaCache](https://adacache-dit.github.io) also shows that caching can contribute grant significant inference speedups without sacrificing the generation quality, across multiple image and video DiT baselines. +However, TeaCache is still a bit complex as it needs a rescaling strategy to ensure the accuracy of the cache. +In ParaAttention, we find that we can directly use **the residual difference of the first transformer block output** to approximate the difference among model outputs. +When the difference is small enough, we can reuse the residual difference of previous inference steps, meaning that we in fact skip this denoising step. +This has been proved to be effective in our experiments and we can achieve an up to 1.5x speedup on FLUX.1-dev inference with very good quality. + +
+ Cache in Diffusion Transformer +
How AdaCache works, First Block Cache is a variant of it
+
+ +To apply the first block cache on FLUX.1-dev, we can call `apply_cache_on_pipe` with `residual_diff_threshold=0.08`, which is the default value for FLUX models. + +```python +import time +import torch +from diffusers import FluxPipeline + +pipe = FluxPipeline.from_pretrained( + "black-forest-labs/FLUX.1-dev", + torch_dtype=torch.bfloat16, +).to("cuda") + +from para_attn.first_block_cache.diffusers_adapters import apply_cache_on_pipe + +apply_cache_on_pipe(pipe, residual_diff_threshold=0.08) + +# Enable memory savings +# pipe.enable_model_cpu_offload() +# pipe.enable_sequential_cpu_offload() + +begin = time.time() +image = pipe( + "A cat holding a sign that says hello world", + num_inference_steps=28, +).images[0] +end = time.time() +print(f"Time: {end - begin:.2f}s") + +print("Saving image to flux.png") +image.save("flux.png") +``` + +| Optimizations | Original | FBCache rdt=0.06 | FBCache rdt=0.08 | FBCache rdt=0.10 | FBCache rdt=0.12 | +| - | - | - | - | - | - | +| Preview | ![Original](https://github.com/chengzeyi/ParaAttention/blob/main/assets/flux_original.png) | ![FBCache rdt=0.06](https://github.com/chengzeyi/ParaAttention/blob/main/assets/flux_fbc_0.06.png) | ![FBCache rdt=0.08](https://github.com/chengzeyi/ParaAttention/blob/main/assets/flux_fbc_0.08.png) | ![FBCache rdt=0.10](https://github.com/chengzeyi/ParaAttention/blob/main/assets/flux_fbc_0.10.png) | ![FBCache rdt=0.12](https://github.com/chengzeyi/ParaAttention/raw/main/assets/flux_fbc_0.12.png) | +| Wall Time (s) | 26.36 | 21.83 | 17.01 | 16.00 | 13.78 | + +We observe that the first block cache is very effective in speeding up the inference, and maintaining nearly no quality loss in the generated image. +Now, on one single NVIDIA L20 GPU, we can generate 1 image with 1024x1024 resolution in 28 inference steps in 17.01 seconds. This is a 1.55x speedup compared to the baseline. + +## Quantize the model into FP8 + +To further speed up the inference and reduce memory usage, we can quantize the model into FP8 with dynamic quantization. +We must quantize both the activation and weight of the transformer model to utilize the 8-bit **Tensor Cores** on NVIDIA GPUs. +Here, we use `float8_weight_only` and `float8_dynamic_activation_float8_weight`to quantize the text encoder and transformer model respectively. +The default quantization method is per tensor quantization. If your GPU supports row-wise quantization, you can also try it for better accuracy. +[diffusers-torchao](https://github.com/sayakpaul/diffusers-torchao) provides a really good tutorial on how to quantize models in `diffusers` and achieve a good speedup. +Here, we simply install the latest `torchao` that is capable of quantizing FLUX.1-dev. +If you are not familiar with `torchao` quantization, you can refer to this [documentation](https://github.com/pytorch/ao/blob/main/torchao/quantization/README.md). + +```bash +pip3 install -U torch torchao +``` + +We also need to pass the model to `torch.compile` to gain actual speedup. +`torch.compile` with `mode="max-autotune-no-cudagraphs"` or `mode="max-autotune"` can help us to achieve the best performance by generating and selecting the best kernel for the model inference. +The compilation process could take a long time, but it is worth it. +If you are not familiar with `torch.compile`, you can refer to the [official tutorial](https://pytorch.org/tutorials/intermediate/torch_compile_tutorial.html). +In this example, we only quantize the transformer model, but you can also quantize the text encoder to reduce more memory usage. +We also need to notice that the actually compilation process is done on the first time the model is called, so we need to warm up the model to measure the speedup correctly. + +**Note**: we find that dynamic quantization can significantly change the distribution of the model output, so we need to change the `residual_diff_threshold` to a larger value to make it take effect. + +```python +import time +import torch +from diffusers import FluxPipeline + +pipe = FluxPipeline.from_pretrained( + "black-forest-labs/FLUX.1-dev", + torch_dtype=torch.bfloat16, +).to("cuda") + +from para_attn.first_block_cache.diffusers_adapters import apply_cache_on_pipe + +apply_cache_on_pipe( + pipe, + residual_diff_threshold=0.12, # Use a larger value to make the cache take effect +) + +from torchao.quantization import quantize_, float8_dynamic_activation_float8_weight, float8_weight_only + +quantize_(pipe.text_encoder, float8_weight_only()) +quantize_(pipe.transformer, float8_dynamic_activation_float8_weight()) +pipe.transformer = torch.compile( + pipe.transformer, mode="max-autotune-no-cudagraphs", +) + +# Enable memory savings +# pipe.enable_model_cpu_offload() +# pipe.enable_sequential_cpu_offload() + +for i in range(2): + begin = time.time() + image = pipe( + "A cat holding a sign that says hello world", + num_inference_steps=28, + ).images[0] + end = time.time() + if i == 0: + print(f"Warm up time: {end - begin:.2f}s") + else: + print(f"Time: {end - begin:.2f}s") + +print("Saving image to flux.png") +image.save("flux.png") +``` + +We can see that the quantization and compilation process can further speed up the inference. +Now, on one single NVIDIA L20 GPU, we can generate 1 image with 1024x1024 resolution in 28 inference steps in 7.56s, which is a 3.48x speedup compared to the baseline. + +## Parallelize the inference with Context Parallelism + +A lot faster than before, right? But we are not satisfied with the speedup we have achieved so far. +If we want to accelerate the inference further, we can use context parallelism to parallelize the inference. +Libraries like [xDit](https://github.com/xdit-project/xDiT) and our [ParaAttention](https://github.com/chengzeyi/ParaAttention) provide ways to scale up the inference with multiple GPUs. +In ParaAttention, we design our API in a compositional way so that we can combine context parallelism with first block cache and dynamic quantization all together. +We provide very detailed instructions and examples of how to scale up the inference with multiple GPUs in our ParaAttention repository. +Users can easily launch the inference with multiple GPUs by calling `torchrun`. +If there is a need to make the inference process persistent and serviceable, it is suggested to use `torch.multiprocessing` to write your own inference processor, which can eliminate the overhead of launching the process and loading and recompiling the model. + +Below is our ultimate code to achieve the fastest FLUX.1-dev inference: + +```python +import time +import torch +import torch.distributed as dist +from diffusers import FluxPipeline + +dist.init_process_group() + +torch.cuda.set_device(dist.get_rank()) + +pipe = FluxPipeline.from_pretrained( + "black-forest-labs/FLUX.1-dev", + torch_dtype=torch.bfloat16, +).to("cuda") + +from para_attn.context_parallel import init_context_parallel_mesh +from para_attn.context_parallel.diffusers_adapters import parallelize_pipe +from para_attn.parallel_vae.diffusers_adapters import parallelize_vae + +mesh = init_context_parallel_mesh( + pipe.device.type, + max_ring_dim_size=2, +) +parallelize_pipe( + pipe, + mesh=mesh, +) +parallelize_vae(pipe.vae, mesh=mesh._flatten()) + +from para_attn.first_block_cache.diffusers_adapters import apply_cache_on_pipe + +apply_cache_on_pipe( + pipe, + residual_diff_threshold=0.12, # Use a larger value to make the cache take effect +) + +from torchao.quantization import quantize_, float8_dynamic_activation_float8_weight, float8_weight_only + +quantize_(pipe.text_encoder, float8_weight_only()) +quantize_(pipe.transformer, float8_dynamic_activation_float8_weight()) +torch._inductor.config.reorder_for_compute_comm_overlap = True +pipe.transformer = torch.compile( + pipe.transformer, mode="max-autotune-no-cudagraphs", +) + +# Enable memory savings +# pipe.enable_model_cpu_offload(gpu_id=dist.get_rank()) +# pipe.enable_sequential_cpu_offload(gpu_id=dist.get_rank()) + +for i in range(2): + begin = time.time() + image = pipe( + "A cat holding a sign that says hello world", + num_inference_steps=28, + output_type="pil" if dist.get_rank() == 0 else "pt", + ).images[0] + end = time.time() + if dist.get_rank() == 0: + if i == 0: + print(f"Warm up time: {end - begin:.2f}s") + else: + print(f"Time: {end - begin:.2f}s") + +if dist.get_rank() == 0: + print("Saving image to flux.png") + image.save("flux.png") + +dist.destroy_process_group() +``` + +We save the above code to `run_flux.py` and run it with `torchrun`: + +```bash +# Use --nproc_per_node to specify the number of GPUs +torchrun --nproc_per_node=2 run_flux.py +``` + +With 2 NVIDIA L20 GPUs, we can generate 1 image with 1024x1024 resolution in 28 inference steps in 8.20 seconds, which is a 3.21x speedup compared to the baseline. +And with 4 NVIDIA L20 GPUs, we can generate 1 image with 1024x1024 resolution in 28 inference steps in 3.90 seconds, which is a 6.75x speedup compared to the baseline. + +## Conclusion + +| GPU Type | Number of GPUs | Optimizations | Wall Time (s) | Speedup | +| - | - | - | - | - | +| NVIDIA L20 | 1 | Baseline | 26.36 | 1.00x | +| NVIDIA L20 | 1 | FBCache (rdt=0.08) | 17.01 | 1.55x | +| NVIDIA L20 | 1 | FP8 DQ | 13.40 | 1.96x | +| NVIDIA L20 | 1 | FBCache (rdt=0.12) + FP8 DQ | 7.56 | 3.48x | +| NVIDIA L20 | 2 | FBCache (rdt=0.12) + FP8 DQ + CP | 4.92 | 5.35x | +| NVIDIA L20 | 4 | FBCache (rdt=0.12) + FP8 DQ + CP | 3.90 | 6.75x | diff --git a/docs/source/en/optimization/para_attn_hunyuan_video.md b/docs/source/en/optimization/para_attn_hunyuan_video.md new file mode 100644 index 000000000000..079982ec2e70 --- /dev/null +++ b/docs/source/en/optimization/para_attn_hunyuan_video.md @@ -0,0 +1,326 @@ +# Fastest HunyuanVideo Inference with ParaAttention + +[![](https://mermaid.ink/img/pako:eNptktuK2zAQhl9lEIS01HZsOXESXxT20NKFtgQKW-hqLxR7YgtsKcjjbbwh795x0m4PVCDQfNJoTv9RFK5EkYswDJUtnN2ZKlcWeB2Gm1p7-mmN67spqc4hSeL4N6zRVDXlIBcMz79MJkdjDaPjlGpscZrDdOc8djQ9wWkyUfYwFOPX4RZJQ-28eXaWdKMsGWoQlPiqmwbItAjsCBVa9JqMrcBZhCdTouNkqIYPvbYD7_sRBZDIVXxYyviQyHUAaQwd4b4L2As-39_d3l3BRxkHF9eN9vqKCDmms0pwUqE-mA4elLjWHTbGohIB5_L--kYX9d8GvIGbzSv5-j9w_gtu_oArho_KDpcQSnTIrS47JSCGMHwL83hsqbJb7eEhzWQWpWkAUi6TKM64riSV0ZozXyarKFkEkM3XkUwfRSBa9K02JU_wOM5EiXPLlcj5uOU6xspO_E735L4MthA5-R4D4V1f1SLf6aZjq9-XmvDW6Mrr9oXutf3mXPvPq3elIedfYON0iWweBQ37UUmV6YgDXrQ08t43jGuifZfPZuN1VPEE-m1UuHbWmXLUQv20zmZc-ErLFLNlqhdpWhbbZL3ayXmyK5dxIrU4nQKB5_ifLrI9q_f0AzkD3HY?type=png)](https://mermaid.live/edit#pako:eNptktuK2zAQhl9lEIS01HZsOXESXxT20NKFtgQKW-hqLxR7YgtsKcjjbbwh795x0m4PVCDQfNJoTv9RFK5EkYswDJUtnN2ZKlcWeB2Gm1p7-mmN67spqc4hSeL4N6zRVDXlIBcMz79MJkdjDaPjlGpscZrDdOc8djQ9wWkyUfYwFOPX4RZJQ-28eXaWdKMsGWoQlPiqmwbItAjsCBVa9JqMrcBZhCdTouNkqIYPvbYD7_sRBZDIVXxYyviQyHUAaQwd4b4L2As-39_d3l3BRxkHF9eN9vqKCDmms0pwUqE-mA4elLjWHTbGohIB5_L--kYX9d8GvIGbzSv5-j9w_gtu_oArho_KDpcQSnTIrS47JSCGMHwL83hsqbJb7eEhzWQWpWkAUi6TKM64riSV0ZozXyarKFkEkM3XkUwfRSBa9K02JU_wOM5EiXPLlcj5uOU6xspO_E735L4MthA5-R4D4V1f1SLf6aZjq9-XmvDW6Mrr9oXutf3mXPvPq3elIedfYON0iWweBQ37UUmV6YgDXrQ08t43jGuifZfPZuN1VPEE-m1UuHbWmXLUQv20zmZc-ErLFLNlqhdpWhbbZL3ayXmyK5dxIrU4nQKB5_ifLrI9q_f0AzkD3HY) + +## Introduction + +During the past year, we have seen the rapid development of video generation models with the release of several open-source models, such as [HunyuanVideo](https://huggingface.co/tencent/HunyuanVideo), [CogVideoX](https://huggingface.co/THUDM/CogVideoX-5b) and [Mochi](https://huggingface.co/genmo/mochi-1-preview). +It is very exciting to see that open source video models are going to beat closed source. +However, the inference speed of these models is still a bottleneck for real-time applications and deployment. + +In this article, we will use [ParaAttention](https://github.com/chengzeyi/ParaAttention), a library implements **Context Parallelism** and **First Block Cache**, as well as other techniques like `torch.compile` and **FP8 Dynamic Quantization**, to achieve the fastest inference speed for HunyuanVideo. +If you want to speed up other models like `CogVideoX`, `Mochi` or `FLUX`, you can also follow the same steps in this article. + +**We set up our experiments on NVIDIA L20 GPUs, which only have PCIe support.** +**If you have NVIDIA A100 or H100 GPUs with NVLink support, you can achieve a better speedup with context parallelism, especially when the number of GPUs is large.** + +## HunyuanVideo Inference with `diffusers` + +Like many other generative AI models, HunyuanVideo has its official code repository and is supported by other frameworks like `diffusers` and `ComfyUI`. +In this article, we will focus on optimizing the inference speed of HunyuanVideo with `diffusers`. +To use HunyuanVideo with `diffusers`, we need to install its latest version: + +```bash +pip3 install -U diffusers +``` + +Then, we can load the model and generate video frames with the following code: + +```python +import time +import torch +from diffusers import HunyuanVideoPipeline, HunyuanVideoTransformer3DModel +from diffusers.utils import export_to_video + +model_id = "tencent/HunyuanVideo" +transformer = HunyuanVideoTransformer3DModel.from_pretrained( + model_id, + subfolder="transformer", + torch_dtype=torch.bfloat16, + revision="refs/pr/18", +) +pipe = HunyuanVideoPipeline.from_pretrained( + model_id, + transformer=transformer, + torch_dtype=torch.float16, + revision="refs/pr/18", +).to("cuda") + +pipe.vae.enable_tiling() + +begin = time.time() +output = pipe( + prompt="A cat walks on the grass, realistic", + height=720, + width=1280, + num_frames=129, + num_inference_steps=30, +).frames[0] +end = time.time() +print(f"Time: {end - begin:.2f}s") + +print("Saving video to hunyuan_video.mp4") +export_to_video(output, "hunyuan_video.mp4", fps=15) +``` + +However, most people will experience OOM (Out of Memory) errors when running the above code. +This is because the HunyuanVideo transformer model is relatively large and it has a quite large text encoder. +Besides, HunyuanVideo requires a variable length of text conditions and the `diffusers` library implements this feature with a `attn_mask` in `scaled_dot_product_attention`. +The size of `attn_mask` is proportional to the square of the input sequence length, which is crazy when we increase the resolution and the number of frames of the inference! +Luckily, we can use ParaAttention to solve this problem. +In ParaAttention, we patch the original implementation in `diffusers` to cut the text conditions before calling `scaled_dot_product_attention`. +We implement this in our `apply_cache_on_pipe` function so we can call it after loading the model: + +```bash +pip3 install -U para-attn +``` + +```python +pipe = HunyuanVideoPipeline.from_pretrained( + model_id, + transformer=transformer, + torch_dtype=torch.float16, + revision="refs/pr/18", +).to("cuda") + +from para_attn.first_block_cache.diffusers_adapters import apply_cache_on_pipe + +apply_cache_on_pipe(pipe, residual_diff_threshold=0.0) +``` + +We pass `residual_diff_threshold=0.0` to `apply_cache_on_pipe` to disable the cache mechanism now, because we will enable it later. +Here, we only want it to cut the text conditions to avoid OOM errors. +If you still experience OOM errors, you can try calling `pipe.enable_model_cpu_offload()` or `pipe.enable_sequential_cpu_offload` after calling `apply_cache_on_pipe`. + +This is our baseline. +On one single NVIDIA L20 GPU, we can generate 129 frames with 720p resolution in 30 inference steps in 3675.71 seconds. + +## Apply First Block Cache on HunyuanVideo + +By caching the output of the transformer blocks in the transformer model and resuing them in the next inference steps, we can reduce the computation cost and make the inference faster. +However, it is hard to decide when to reuse the cache to ensure the quality of the generated video. +Recently, [TeaCache](https://github.com/ali-vilab/TeaCache) suggests that we can use the timestep embedding to approximate the difference among model outputs. +And [AdaCache](https://adacache-dit.github.io) also shows that caching can contribute grant significant inference speedups without sacrificing the generation quality, across multiple video DiT baselines. +However, TeaCache is still a bit complex as it needs a rescaling strategy to ensure the accuracy of the cache. +In ParaAttention, we find that we can directly use **the residual difference of the first transformer block output** to approximate the difference among model outputs. +When the difference is small enough, we can reuse the residual difference of previous inference steps, meaning that we in fact skip this denoising step. +This has been proved to be effective in our experiments and we can achieve an up to 2x speedup on HunyuanVideo inference with very good quality. + +
+ Cache in Diffusion Transformer +
How AdaCache works, First Block Cache is a variant of it
+
+ +To apply the first block cache on HunyuanVideo, we can call `apply_cache_on_pipe` with `residual_diff_threshold=0.06`, which is the default value for HunyuanVideo. + +```python +apply_cache_on_pipe(pipe, residual_diff_threshold=0.06) +``` + +### HunyuanVideo without FBCache + +https://github.com/user-attachments/assets/883d771a-e74e-4081-aa2a-416985d6c713 + +### HunyuanVideo with FBCache + +https://github.com/user-attachments/assets/f77c2f58-2b59-4dd1-a06a-a36974cb1e40 + +We observe that the first block cache is very effective in speeding up the inference, and maintaining nearly no quality loss in the generated video. +Now, on one single NVIDIA L20 GPU, we can generate 129 frames with 720p resolution in 30 inference steps in 2271.06 seconds. This is a 1.62x speedup compared to the baseline. + +## Quantize the model into FP8 + +To further speed up the inference and reduce memory usage, we can quantize the model into FP8 with dynamic quantization. +We must quantize both the activation and weight of the transformer model to utilize the 8-bit **Tensor Cores** on NVIDIA GPUs. +Here, we use `float8_weight_only` and `float8_dynamic_activation_float8_weight`to quantize the text encoder and transformer model respectively. +The default quantization method is per tensor quantization. If your GPU supports row-wise quantization, you can also try it for better accuracy. +[diffusers-torchao](https://github.com/sayakpaul/diffusers-torchao) provides a really good tutorial on how to quantize models in `diffusers` and achieve a good speedup. +Here, we simply install the latest `torchao` that is capable of quantizing HunyuanVideo. +If you are not familiar with `torchao` quantization, you can refer to this [documentation](https://github.com/pytorch/ao/blob/main/torchao/quantization/README.md). + +```bash +pip3 install -U torch torchao +``` + +We also need to pass the model to `torch.compile` to gain actual speedup. +`torch.compile` with `mode="max-autotune-no-cudagraphs"` or `mode="max-autotune"` can help us to achieve the best performance by generating and selecting the best kernel for the model inference. +The compilation process could take a long time, but it is worth it. +If you are not familiar with `torch.compile`, you can refer to the [official tutorial](https://pytorch.org/tutorials/intermediate/torch_compile_tutorial.html). +In this example, we only quantize the transformer model, but you can also quantize the text encoder to reduce more memory usage. +We also need to notice that the actually compilation process is done on the first time the model is called, so we need to warm up the model to measure the speedup correctly. + +**Note**: we find that dynamic quantization can significantly change the distribution of the model output, so you might need to tweak the `residual_diff_threshold` to a larger value to make it take effect. + +```python +import time +import torch +from diffusers import HunyuanVideoPipeline, HunyuanVideoTransformer3DModel +from diffusers.utils import export_to_video + +model_id = "tencent/HunyuanVideo" +transformer = HunyuanVideoTransformer3DModel.from_pretrained( + model_id, + subfolder="transformer", + torch_dtype=torch.bfloat16, + revision="refs/pr/18", +) +pipe = HunyuanVideoPipeline.from_pretrained( + model_id, + transformer=transformer, + torch_dtype=torch.float16, + revision="refs/pr/18", +).to("cuda") + +from para_attn.first_block_cache.diffusers_adapters import apply_cache_on_pipe + +apply_cache_on_pipe(pipe) + +from torchao.quantization import quantize_, float8_dynamic_activation_float8_weight, float8_weight_only + +quantize_(pipe.text_encoder, float8_weight_only()) +quantize_(pipe.transformer, float8_dynamic_activation_float8_weight()) +pipe.transformer = torch.compile( + pipe.transformer, mode="max-autotune-no-cudagraphs", +) + +pipe.vae.enable_tiling() +# pipe.enable_model_cpu_offload() + +for i in range(2): + begin = time.time() + output = pipe( + prompt="A cat walks on the grass, realistic", + height=720, + width=1280, + num_frames=129, + num_inference_steps=1 if i == 0 else 30, + ).frames[0] + end = time.time() + if i == 0: + print(f"Warm up time: {end - begin:.2f}s") + else: + print(f"Time: {end - begin:.2f}s") + +print("Saving video to hunyuan_video.mp4") +export_to_video(output, "hunyuan_video.mp4", fps=15) +``` + +The NVIDIA L20 GPU only has 48GB memory and could face OOM errors after compiling the model and not calling `enable_model_cpu_offload`, +because the HunyuanVideo has very large activation tensors when running with high resolution and large number of frames. +So here we skip measuring the speedup with quantization and compilation on one single NVIDIA L20 GPU and choose to use context parallelism to release the memory pressure. +If you want to run HunyuanVideo with `torch.compile` on GPUs with less than 80GB memory, you can try reducing the resolution and the number of frames to avoid OOM errors. + +Due to the fact that large video generation models usually have performance bottleneck on the attention computation rather than the fully connected layers, we don't observe a significant speedup with quantization and compilation. +However, models like `FLUX` and `SD3` can benefit a lot from quantization and compilation, it is suggested to try it for these models. + +## Parallelize the inference with Context Parallelism + +A lot faster than before, right? But we are not satisfied with the speedup we have achieved so far. +If we want to accelerate the inference further, we can use context parallelism to parallelize the inference. +Libraries like [xDit](https://github.com/xdit-project/xDiT) and our [ParaAttention](https://github.com/chengzeyi/ParaAttention) provide ways to scale up the inference with multiple GPUs. +In ParaAttention, we design our API in a compositional way so that we can combine context parallelism with first block cache and dynamic quantization all together. +We provide very detailed instructions and examples of how to scale up the inference with multiple GPUs in our ParaAttention repository. +Users can easily launch the inference with multiple GPUs by calling `torchrun`. +If there is a need to make the inference process persistent and serviceable, it is suggested to use `torch.multiprocessing` to write your own inference processor, which can eliminate the overhead of launching the process and loading and recompiling the model. + +Below is our ultimate code to achieve the fastest HunyuanVideo inference: + +```python +import time +import torch +import torch.distributed as dist +from diffusers import HunyuanVideoPipeline, HunyuanVideoTransformer3DModel +from diffusers.utils import export_to_video + +dist.init_process_group() + +torch.cuda.set_device(dist.get_rank()) + +# [rank1]: RuntimeError: Expected mha_graph->execute(handle, variant_pack, workspace_ptr.get()).is_good() to be true, but got false. (Could this error message be improved? If so, please report an enhancement request to PyTorch.) +torch.backends.cuda.enable_cudnn_sdp(False) + +model_id = "tencent/HunyuanVideo" +transformer = HunyuanVideoTransformer3DModel.from_pretrained( + model_id, + subfolder="transformer", + torch_dtype=torch.bfloat16, + revision="refs/pr/18", +) +pipe = HunyuanVideoPipeline.from_pretrained( + model_id, + transformer=transformer, + torch_dtype=torch.float16, + revision="refs/pr/18", +).to("cuda") + +from para_attn.context_parallel import init_context_parallel_mesh +from para_attn.context_parallel.diffusers_adapters import parallelize_pipe +from para_attn.parallel_vae.diffusers_adapters import parallelize_vae + +mesh = init_context_parallel_mesh( + pipe.device.type, +) +parallelize_pipe( + pipe, + mesh=mesh, +) +parallelize_vae(pipe.vae, mesh=mesh._flatten()) + +from para_attn.first_block_cache.diffusers_adapters import apply_cache_on_pipe + +apply_cache_on_pipe(pipe) + +# from torchao.quantization import quantize_, float8_dynamic_activation_float8_weight, float8_weight_only +# +# torch._inductor.config.reorder_for_compute_comm_overlap = True +# +# quantize_(pipe.text_encoder, float8_weight_only()) +# quantize_(pipe.transformer, float8_dynamic_activation_float8_weight()) +# pipe.transformer = torch.compile( +# pipe.transformer, mode="max-autotune-no-cudagraphs", +# ) + +pipe.vae.enable_tiling() +# pipe.enable_model_cpu_offload(gpu_id=dist.get_rank()) + +for i in range(2): + begin = time.time() + output = pipe( + prompt="A cat walks on the grass, realistic", + height=720, + width=1280, + num_frames=129, + num_inference_steps=1 if i == 0 else 30, + output_type="pil" if dist.get_rank() == 0 else "pt", + ).frames[0] + end = time.time() + if dist.get_rank() == 0: + if i == 0: + print(f"Warm up time: {end - begin:.2f}s") + else: + print(f"Time: {end - begin:.2f}s") + +if dist.get_rank() == 0: + print("Saving video to hunyuan_video.mp4") + export_to_video(output, "hunyuan_video.mp4", fps=15) + +dist.destroy_process_group() +``` + +We save the above code to `run_hunyuan_video.py` and run it with `torchrun`: + +```bash +torchrun --nproc_per_node=8 run_hunyuan_video.py +``` + +With 8 NVIDIA L20 GPUs, we can generate 129 frames with 720p resolution in 30 inference steps in 649.23 seconds. This is a 5.66x speedup compared to the baseline! + +## Conclusion + +| GPU Type | Number of GPUs | Optimizations | Wall Time (s) | Speedup | +| - | - | - | - | - | +| NVIDIA L20 | 1 | Baseline | 3675.71 | 1.00x | +| NVIDIA L20 | 1 | FBCache | 2271.06 | 1.62x | +| NVIDIA L20 | 2 | FBCache + CP | 1132.90 | 3.24x | +| NVIDIA L20 | 4 | FBCache + CP | 718.15 | 5.12x | +| NVIDIA L20 | 8 | FBCache + CP | 649.23 | 5.66x | From 578f418cf9698b6bacbd07ca45e0f7bc8e23909b Mon Sep 17 00:00:00 2001 From: chengzeyi Date: Mon, 13 Jan 2025 11:02:57 +0800 Subject: [PATCH 02/30] add enable_sequential_cpu_offload in para_attn_hunyuan_video.md --- docs/source/en/optimization/para_attn_hunyuan_video.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/docs/source/en/optimization/para_attn_hunyuan_video.md b/docs/source/en/optimization/para_attn_hunyuan_video.md index 079982ec2e70..2e96f72b8a5a 100644 --- a/docs/source/en/optimization/para_attn_hunyuan_video.md +++ b/docs/source/en/optimization/para_attn_hunyuan_video.md @@ -185,6 +185,7 @@ pipe.transformer = torch.compile( pipe.vae.enable_tiling() # pipe.enable_model_cpu_offload() +# pipe.enable_sequential_cpu_offload() for i in range(2): begin = time.time() @@ -282,6 +283,7 @@ apply_cache_on_pipe(pipe) pipe.vae.enable_tiling() # pipe.enable_model_cpu_offload(gpu_id=dist.get_rank()) +# pipe.enable_sequential_cpu_offload(gpu_id=dist.get_rank()) for i in range(2): begin = time.time() From c21cb085ebdc88a8313f3aea5435ffd2f40127e8 Mon Sep 17 00:00:00 2001 From: chengzeyi Date: Mon, 13 Jan 2025 11:06:18 +0800 Subject: [PATCH 03/30] add comment --- docs/source/en/optimization/para_attn_hunyuan_video.md | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/docs/source/en/optimization/para_attn_hunyuan_video.md b/docs/source/en/optimization/para_attn_hunyuan_video.md index 2e96f72b8a5a..e957dec3e569 100644 --- a/docs/source/en/optimization/para_attn_hunyuan_video.md +++ b/docs/source/en/optimization/para_attn_hunyuan_video.md @@ -90,7 +90,7 @@ apply_cache_on_pipe(pipe, residual_diff_threshold=0.0) We pass `residual_diff_threshold=0.0` to `apply_cache_on_pipe` to disable the cache mechanism now, because we will enable it later. Here, we only want it to cut the text conditions to avoid OOM errors. -If you still experience OOM errors, you can try calling `pipe.enable_model_cpu_offload()` or `pipe.enable_sequential_cpu_offload` after calling `apply_cache_on_pipe`. +If you still experience OOM errors, you can try calling `pipe.enable_model_cpu_offload` or `pipe.enable_sequential_cpu_offload` after calling `apply_cache_on_pipe`. This is our baseline. On one single NVIDIA L20 GPU, we can generate 129 frames with 720p resolution in 30 inference steps in 3675.71 seconds. @@ -183,6 +183,7 @@ pipe.transformer = torch.compile( pipe.transformer, mode="max-autotune-no-cudagraphs", ) +# Enable memory savings pipe.vae.enable_tiling() # pipe.enable_model_cpu_offload() # pipe.enable_sequential_cpu_offload() @@ -281,6 +282,7 @@ apply_cache_on_pipe(pipe) # pipe.transformer, mode="max-autotune-no-cudagraphs", # ) +# Enable memory savings pipe.vae.enable_tiling() # pipe.enable_model_cpu_offload(gpu_id=dist.get_rank()) # pipe.enable_sequential_cpu_offload(gpu_id=dist.get_rank()) From 6122df17de271463875e6848b9bbf8ac52e45f54 Mon Sep 17 00:00:00 2001 From: chengzeyi Date: Tue, 14 Jan 2025 16:46:11 +0800 Subject: [PATCH 04/30] refactor --- docs/source/en/_toctree.yml | 6 +- ...ara_attn_hunyuan_video.md => para_attn.md} | 356 +++++++++++++----- docs/source/en/optimization/para_attn_flux.md | 285 -------------- docs/source/en/optimization/xdit.md | 2 +- 4 files changed, 275 insertions(+), 374 deletions(-) rename docs/source/en/optimization/{para_attn_hunyuan_video.md => para_attn.md} (51%) delete mode 100644 docs/source/en/optimization/para_attn_flux.md diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml index 502b6307404a..3bd7f1987a00 100644 --- a/docs/source/en/_toctree.yml +++ b/docs/source/en/_toctree.yml @@ -179,10 +179,8 @@ title: TGATE - local: optimization/xdit title: xDiT - - local: optimization/para_attn_flux - title: Fast FLUX Inference with ParaAttention - - local: optimization/para_attn_hunyuan_video - title: Fast HunyuanVideo Inference with ParaAttention + - local: optimization/para_attn + title: ParaAttention - sections: - local: using-diffusers/stable_diffusion_jax_how_to title: JAX/Flax diff --git a/docs/source/en/optimization/para_attn_hunyuan_video.md b/docs/source/en/optimization/para_attn.md similarity index 51% rename from docs/source/en/optimization/para_attn_hunyuan_video.md rename to docs/source/en/optimization/para_attn.md index e957dec3e569..f9a46aadb288 100644 --- a/docs/source/en/optimization/para_attn_hunyuan_video.md +++ b/docs/source/en/optimization/para_attn.md @@ -1,30 +1,95 @@ -# Fastest HunyuanVideo Inference with ParaAttention +# ParaAttention -[![](https://mermaid.ink/img/pako:eNptktuK2zAQhl9lEIS01HZsOXESXxT20NKFtgQKW-hqLxR7YgtsKcjjbbwh795x0m4PVCDQfNJoTv9RFK5EkYswDJUtnN2ZKlcWeB2Gm1p7-mmN67spqc4hSeL4N6zRVDXlIBcMz79MJkdjDaPjlGpscZrDdOc8djQ9wWkyUfYwFOPX4RZJQ-28eXaWdKMsGWoQlPiqmwbItAjsCBVa9JqMrcBZhCdTouNkqIYPvbYD7_sRBZDIVXxYyviQyHUAaQwd4b4L2As-39_d3l3BRxkHF9eN9vqKCDmms0pwUqE-mA4elLjWHTbGohIB5_L--kYX9d8GvIGbzSv5-j9w_gtu_oArho_KDpcQSnTIrS47JSCGMHwL83hsqbJb7eEhzWQWpWkAUi6TKM64riSV0ZozXyarKFkEkM3XkUwfRSBa9K02JU_wOM5EiXPLlcj5uOU6xspO_E735L4MthA5-R4D4V1f1SLf6aZjq9-XmvDW6Mrr9oXutf3mXPvPq3elIedfYON0iWweBQ37UUmV6YgDXrQ08t43jGuifZfPZuN1VPEE-m1UuHbWmXLUQv20zmZc-ErLFLNlqhdpWhbbZL3ayXmyK5dxIrU4nQKB5_ifLrI9q_f0AzkD3HY?type=png)](https://mermaid.live/edit#pako:eNptktuK2zAQhl9lEIS01HZsOXESXxT20NKFtgQKW-hqLxR7YgtsKcjjbbwh795x0m4PVCDQfNJoTv9RFK5EkYswDJUtnN2ZKlcWeB2Gm1p7-mmN67spqc4hSeL4N6zRVDXlIBcMz79MJkdjDaPjlGpscZrDdOc8djQ9wWkyUfYwFOPX4RZJQ-28eXaWdKMsGWoQlPiqmwbItAjsCBVa9JqMrcBZhCdTouNkqIYPvbYD7_sRBZDIVXxYyviQyHUAaQwd4b4L2As-39_d3l3BRxkHF9eN9vqKCDmms0pwUqE-mA4elLjWHTbGohIB5_L--kYX9d8GvIGbzSv5-j9w_gtu_oArho_KDpcQSnTIrS47JSCGMHwL83hsqbJb7eEhzWQWpWkAUi6TKM64riSV0ZozXyarKFkEkM3XkUwfRSBa9K02JU_wOM5EiXPLlcj5uOU6xspO_E735L4MthA5-R4D4V1f1SLf6aZjq9-XmvDW6Mrr9oXutf3mXPvPq3elIedfYON0iWweBQ37UUmV6YgDXrQ08t43jGuifZfPZuN1VPEE-m1UuHbWmXLUQv20zmZc-ErLFLNlqhdpWhbbZL3ayXmyK5dxIrU4nQKB5_ifLrI9q_f0AzkD3HY) +
+ +
+
+ +
## Introduction -During the past year, we have seen the rapid development of video generation models with the release of several open-source models, such as [HunyuanVideo](https://huggingface.co/tencent/HunyuanVideo), [CogVideoX](https://huggingface.co/THUDM/CogVideoX-5b) and [Mochi](https://huggingface.co/genmo/mochi-1-preview). -It is very exciting to see that open source video models are going to beat closed source. -However, the inference speed of these models is still a bottleneck for real-time applications and deployment. +Large image and video generation models, such as [FLUX.1-dev](https://huggingface.co/black-forest-labs/FLUX.1-dev) and [HunyuanVideo](https://huggingface.co/tencent/HunyuanVideo), can be an inference challenge for real-time applications and deployment because of their size. -In this article, we will use [ParaAttention](https://github.com/chengzeyi/ParaAttention), a library implements **Context Parallelism** and **First Block Cache**, as well as other techniques like `torch.compile` and **FP8 Dynamic Quantization**, to achieve the fastest inference speed for HunyuanVideo. -If you want to speed up other models like `CogVideoX`, `Mochi` or `FLUX`, you can also follow the same steps in this article. +[ParaAttention](https://github.com/chengzeyi/ParaAttention) is a library that implements **context parallelism** and **first block cache**, and can be combined with other techniques (torch.compile, fp8 dynamic quantization), to accelerate inference. -**We set up our experiments on NVIDIA L20 GPUs, which only have PCIe support.** -**If you have NVIDIA A100 or H100 GPUs with NVLink support, you can achieve a better speedup with context parallelism, especially when the number of GPUs is large.** +This guide will show you how to apply ParaAttention to FLUX.1-dev and HunyuanVideo on NVIDIA L20 GPUs with only PCIe support. -## HunyuanVideo Inference with `diffusers` +> [!TIP] +> For even faster inference with context parallelism, try using NVIDIA A100 or H100 GPUs (if available) with NVLink support, especially when there is a large number of GPUs. -Like many other generative AI models, HunyuanVideo has its official code repository and is supported by other frameworks like `diffusers` and `ComfyUI`. -In this article, we will focus on optimizing the inference speed of HunyuanVideo with `diffusers`. -To use HunyuanVideo with `diffusers`, we need to install its latest version: +## Baseline Performance -```bash -pip3 install -U diffusers +Our baseline performance is measured on a single NVIDIA L20 GPU, with normal FP16/BF16 precision, For FLUX.1-dev, we don't apply any optimization, and for HunyuanVideo, we only apply the patch `apply_cache_on_pipe(pipe, residual_diff_threshold=0.0)` to make it run without a square `attn_mask` to avoid OOM errors. + +For FLUX.1-dev, we can generate 1 image with 1024x1024 resolution in 28 inference steps in 26.36 seconds. + +For HunyuanVideo, we can generate 129 frames with 720p resolution in 30 inference steps in 3675.71 seconds. + +### First Block Cache + +By caching the output of the transformer blocks in the transformer model and resuing them in the next inference steps, we can reduce the computation cost and make the inference faster. +However, it is hard to decide when to reuse the cache to ensure the quality of the generated image or video. +Recently, [TeaCache](https://github.com/ali-vilab/TeaCache) suggests that we can use the timestep embedding to approximate the difference among model outputs. +And [AdaCache](https://adacache-dit.github.io) also shows that caching can contribute grant significant inference speedups without sacrificing the generation quality, across multiple image and video DiT baselines. +However, TeaCache is still a bit complex as it needs a rescaling strategy to ensure the accuracy of the cache. +In ParaAttention, we find that we can directly use **the residual difference of the first transformer block output** to approximate the difference among model outputs. +When the difference is small enough, we can reuse the residual difference of previous inference steps, meaning that we in fact skip this denoising step. + +This has been proved to be effective in our experiments and we can achieve an 2.0x speedup on FLUX.1-dev and HunyuanVideo inference with very good quality. + +
+ Cache in Diffusion Transformer +
How AdaCache works, First Block Cache is a variant of it
+
+ + + + +To apply first block cache on FLUX.1-dev, we can call `apply_cache_on_pipe` with `residual_diff_threshold=0.08`, which is the default value for FLUX models. + +```python +import time +import torch +from diffusers import FluxPipeline + +pipe = FluxPipeline.from_pretrained( + "black-forest-labs/FLUX.1-dev", + torch_dtype=torch.bfloat16, +).to("cuda") + +from para_attn.first_block_cache.diffusers_adapters import apply_cache_on_pipe + +apply_cache_on_pipe(pipe, residual_diff_threshold=0.08) + +# Enable memory savings +# pipe.enable_model_cpu_offload() +# pipe.enable_sequential_cpu_offload() + +begin = time.time() +image = pipe( + "A cat holding a sign that says hello world", + num_inference_steps=28, +).images[0] +end = time.time() +print(f"Time: {end - begin:.2f}s") + +print("Saving image to flux.png") +image.save("flux.png") ``` -Then, we can load the model and generate video frames with the following code: +| Optimizations | Original | FBCache rdt=0.06 | FBCache rdt=0.08 | FBCache rdt=0.10 | FBCache rdt=0.12 | +| - | - | - | - | - | - | +| Preview | ![Original](https://huggingface.co/datasets/chengzeyi/documentation-images/resolve/main/diffusers/para-attn/flux-original.png) | ![FBCache rdt=0.06](https://huggingface.co/datasets/chengzeyi/documentation-images/resolve/main/diffusers/para-attn/flux-fbc-0.06.png) | ![FBCache rdt=0.08](https://huggingface.co/datasets/chengzeyi/documentation-images/resolve/main/diffusers/para-attn/flux-fbc-0.08.png) | ![FBCache rdt=0.10](https://huggingface.co/datasets/chengzeyi/documentation-images/resolve/main/diffusers/para-attn/flux-fbc-0.10.png) | ![FBCache rdt=0.12](https://huggingface.co/datasets/chengzeyi/documentation-images/resolve/main/diffusers/para-attn/flux-fbc-0.12.png) | +| Wall Time (s) | 26.36 | 21.83 | 17.01 | 16.00 | 13.78 | + +We observe that the first block cache is very effective in speeding up the inference, and maintaining nearly no quality loss in the generated image. +Now, on one single NVIDIA L20 GPU, we can generate 1 image with 1024x1024 resolution in 28 inference steps in 17.01 seconds. This is a 1.55x speedup compared to the baseline. + + + + +To apply first block cache on HunyuanVideo, we can call `apply_cache_on_pipe` with `residual_diff_threshold=0.06`, which is the default value for HunyuanVideo. ```python import time @@ -46,6 +111,10 @@ pipe = HunyuanVideoPipeline.from_pretrained( revision="refs/pr/18", ).to("cuda") +from para_attn.first_block_cache.diffusers_adapters import apply_cache_on_pipe + +apply_cache_on_pipe(pipe, residual_diff_threshold=0.6) + pipe.vae.enable_tiling() begin = time.time() @@ -63,79 +132,34 @@ print("Saving video to hunyuan_video.mp4") export_to_video(output, "hunyuan_video.mp4", fps=15) ``` -However, most people will experience OOM (Out of Memory) errors when running the above code. -This is because the HunyuanVideo transformer model is relatively large and it has a quite large text encoder. -Besides, HunyuanVideo requires a variable length of text conditions and the `diffusers` library implements this feature with a `attn_mask` in `scaled_dot_product_attention`. -The size of `attn_mask` is proportional to the square of the input sequence length, which is crazy when we increase the resolution and the number of frames of the inference! -Luckily, we can use ParaAttention to solve this problem. -In ParaAttention, we patch the original implementation in `diffusers` to cut the text conditions before calling `scaled_dot_product_attention`. -We implement this in our `apply_cache_on_pipe` function so we can call it after loading the model: - -```bash -pip3 install -U para-attn -``` - -```python -pipe = HunyuanVideoPipeline.from_pretrained( - model_id, - transformer=transformer, - torch_dtype=torch.float16, - revision="refs/pr/18", -).to("cuda") - -from para_attn.first_block_cache.diffusers_adapters import apply_cache_on_pipe - -apply_cache_on_pipe(pipe, residual_diff_threshold=0.0) -``` - -We pass `residual_diff_threshold=0.0` to `apply_cache_on_pipe` to disable the cache mechanism now, because we will enable it later. -Here, we only want it to cut the text conditions to avoid OOM errors. -If you still experience OOM errors, you can try calling `pipe.enable_model_cpu_offload` or `pipe.enable_sequential_cpu_offload` after calling `apply_cache_on_pipe`. - -This is our baseline. -On one single NVIDIA L20 GPU, we can generate 129 frames with 720p resolution in 30 inference steps in 3675.71 seconds. +#### HunyuanVideo without FBCache -## Apply First Block Cache on HunyuanVideo + -By caching the output of the transformer blocks in the transformer model and resuing them in the next inference steps, we can reduce the computation cost and make the inference faster. -However, it is hard to decide when to reuse the cache to ensure the quality of the generated video. -Recently, [TeaCache](https://github.com/ali-vilab/TeaCache) suggests that we can use the timestep embedding to approximate the difference among model outputs. -And [AdaCache](https://adacache-dit.github.io) also shows that caching can contribute grant significant inference speedups without sacrificing the generation quality, across multiple video DiT baselines. -However, TeaCache is still a bit complex as it needs a rescaling strategy to ensure the accuracy of the cache. -In ParaAttention, we find that we can directly use **the residual difference of the first transformer block output** to approximate the difference among model outputs. -When the difference is small enough, we can reuse the residual difference of previous inference steps, meaning that we in fact skip this denoising step. -This has been proved to be effective in our experiments and we can achieve an up to 2x speedup on HunyuanVideo inference with very good quality. - -
- Cache in Diffusion Transformer -
How AdaCache works, First Block Cache is a variant of it
-
+#### HunyuanVideo with FBCache -To apply the first block cache on HunyuanVideo, we can call `apply_cache_on_pipe` with `residual_diff_threshold=0.06`, which is the default value for HunyuanVideo. - -```python -apply_cache_on_pipe(pipe, residual_diff_threshold=0.06) -``` - -### HunyuanVideo without FBCache - -https://github.com/user-attachments/assets/883d771a-e74e-4081-aa2a-416985d6c713 - -### HunyuanVideo with FBCache - -https://github.com/user-attachments/assets/f77c2f58-2b59-4dd1-a06a-a36974cb1e40 + We observe that the first block cache is very effective in speeding up the inference, and maintaining nearly no quality loss in the generated video. Now, on one single NVIDIA L20 GPU, we can generate 129 frames with 720p resolution in 30 inference steps in 2271.06 seconds. This is a 1.62x speedup compared to the baseline. -## Quantize the model into FP8 +
+
+ +### FP8 Quantization To further speed up the inference and reduce memory usage, we can quantize the model into FP8 with dynamic quantization. We must quantize both the activation and weight of the transformer model to utilize the 8-bit **Tensor Cores** on NVIDIA GPUs. Here, we use `float8_weight_only` and `float8_dynamic_activation_float8_weight`to quantize the text encoder and transformer model respectively. The default quantization method is per tensor quantization. If your GPU supports row-wise quantization, you can also try it for better accuracy. [diffusers-torchao](https://github.com/sayakpaul/diffusers-torchao) provides a really good tutorial on how to quantize models in `diffusers` and achieve a good speedup. -Here, we simply install the latest `torchao` that is capable of quantizing HunyuanVideo. +Here, we simply install the latest `torchao` that is capable of quantizing FLUX.1-dev and HunyuanVideo. If you are not familiar with `torchao` quantization, you can refer to this [documentation](https://github.com/pytorch/ao/blob/main/torchao/quantization/README.md). ```bash @@ -149,7 +173,61 @@ If you are not familiar with `torch.compile`, you can refer to the [official tut In this example, we only quantize the transformer model, but you can also quantize the text encoder to reduce more memory usage. We also need to notice that the actually compilation process is done on the first time the model is called, so we need to warm up the model to measure the speedup correctly. -**Note**: we find that dynamic quantization can significantly change the distribution of the model output, so you might need to tweak the `residual_diff_threshold` to a larger value to make it take effect. +**Note**: we find that dynamic quantization can significantly change the distribution of the model output, so we need to change the `residual_diff_threshold` to a larger value to make it take effect. + + + + +```python +import time +import torch +from diffusers import FluxPipeline + +pipe = FluxPipeline.from_pretrained( + "black-forest-labs/FLUX.1-dev", + torch_dtype=torch.bfloat16, +).to("cuda") + +from para_attn.first_block_cache.diffusers_adapters import apply_cache_on_pipe + +apply_cache_on_pipe( + pipe, + residual_diff_threshold=0.12, # Use a larger value to make the cache take effect +) + +from torchao.quantization import quantize_, float8_dynamic_activation_float8_weight, float8_weight_only + +quantize_(pipe.text_encoder, float8_weight_only()) +quantize_(pipe.transformer, float8_dynamic_activation_float8_weight()) +pipe.transformer = torch.compile( + pipe.transformer, mode="max-autotune-no-cudagraphs", +) + +# Enable memory savings +# pipe.enable_model_cpu_offload() +# pipe.enable_sequential_cpu_offload() + +for i in range(2): + begin = time.time() + image = pipe( + "A cat holding a sign that says hello world", + num_inference_steps=28, + ).images[0] + end = time.time() + if i == 0: + print(f"Warm up time: {end - begin:.2f}s") + else: + print(f"Time: {end - begin:.2f}s") + +print("Saving image to flux.png") +image.save("flux.png") +``` + +We can see that the quantization and compilation process can further speed up the inference. +Now, on one single NVIDIA L20 GPU, we can generate 1 image with 1024x1024 resolution in 28 inference steps in 7.56s, which is a 3.48x speedup compared to the baseline. + + + ```python import time @@ -213,9 +291,12 @@ So here we skip measuring the speedup with quantization and compilation on one s If you want to run HunyuanVideo with `torch.compile` on GPUs with less than 80GB memory, you can try reducing the resolution and the number of frames to avoid OOM errors. Due to the fact that large video generation models usually have performance bottleneck on the attention computation rather than the fully connected layers, we don't observe a significant speedup with quantization and compilation. -However, models like `FLUX` and `SD3` can benefit a lot from quantization and compilation, it is suggested to try it for these models. +However, models like `FLUX.1-dev` can benefit a lot from quantization and compilation, it is suggested to try it for these models. + + + -## Parallelize the inference with Context Parallelism +### Context Parallelism A lot faster than before, right? But we are not satisfied with the speedup we have achieved so far. If we want to accelerate the inference further, we can use context parallelism to parallelize the inference. @@ -225,7 +306,95 @@ We provide very detailed instructions and examples of how to scale up the infere Users can easily launch the inference with multiple GPUs by calling `torchrun`. If there is a need to make the inference process persistent and serviceable, it is suggested to use `torch.multiprocessing` to write your own inference processor, which can eliminate the overhead of launching the process and loading and recompiling the model. -Below is our ultimate code to achieve the fastest HunyuanVideo inference: + + + +Below is our ultimate code to achieve a much faster FLUX.1-dev inference: + +```python +import time +import torch +import torch.distributed as dist +from diffusers import FluxPipeline + +dist.init_process_group() + +torch.cuda.set_device(dist.get_rank()) + +pipe = FluxPipeline.from_pretrained( + "black-forest-labs/FLUX.1-dev", + torch_dtype=torch.bfloat16, +).to("cuda") + +from para_attn.context_parallel import init_context_parallel_mesh +from para_attn.context_parallel.diffusers_adapters import parallelize_pipe +from para_attn.parallel_vae.diffusers_adapters import parallelize_vae + +mesh = init_context_parallel_mesh( + pipe.device.type, + max_ring_dim_size=2, +) +parallelize_pipe( + pipe, + mesh=mesh, +) +parallelize_vae(pipe.vae, mesh=mesh._flatten()) + +from para_attn.first_block_cache.diffusers_adapters import apply_cache_on_pipe + +apply_cache_on_pipe( + pipe, + residual_diff_threshold=0.12, # Use a larger value to make the cache take effect +) + +from torchao.quantization import quantize_, float8_dynamic_activation_float8_weight, float8_weight_only + +quantize_(pipe.text_encoder, float8_weight_only()) +quantize_(pipe.transformer, float8_dynamic_activation_float8_weight()) +torch._inductor.config.reorder_for_compute_comm_overlap = True +pipe.transformer = torch.compile( + pipe.transformer, mode="max-autotune-no-cudagraphs", +) + +# Enable memory savings +# pipe.enable_model_cpu_offload(gpu_id=dist.get_rank()) +# pipe.enable_sequential_cpu_offload(gpu_id=dist.get_rank()) + +for i in range(2): + begin = time.time() + image = pipe( + "A cat holding a sign that says hello world", + num_inference_steps=28, + output_type="pil" if dist.get_rank() == 0 else "pt", + ).images[0] + end = time.time() + if dist.get_rank() == 0: + if i == 0: + print(f"Warm up time: {end - begin:.2f}s") + else: + print(f"Time: {end - begin:.2f}s") + +if dist.get_rank() == 0: + print("Saving image to flux.png") + image.save("flux.png") + +dist.destroy_process_group() +``` + +We save the above code to `run_flux.py` and run it with `torchrun`: + +```bash +# Use --nproc_per_node to specify the number of GPUs +torchrun --nproc_per_node=2 run_flux.py +``` + +With 2 NVIDIA L20 GPUs, we can generate 1 image with 1024x1024 resolution in 28 inference steps in 8.20 seconds, which is a 3.21x speedup compared to the baseline. +And with 4 NVIDIA L20 GPUs, we can generate 1 image with 1024x1024 resolution in 28 inference steps in 3.90 seconds, which is a 6.75x speedup compared to the baseline. + + + + +Below is our ultimate code to achieve a much faster HunyuanVideo inference: ```python import time @@ -238,9 +407,6 @@ dist.init_process_group() torch.cuda.set_device(dist.get_rank()) -# [rank1]: RuntimeError: Expected mha_graph->execute(handle, variant_pack, workspace_ptr.get()).is_good() to be true, but got false. (Could this error message be improved? If so, please report an enhancement request to PyTorch.) -torch.backends.cuda.enable_cudnn_sdp(False) - model_id = "tencent/HunyuanVideo" transformer = HunyuanVideoTransformer3DModel.from_pretrained( model_id, @@ -314,12 +480,31 @@ dist.destroy_process_group() We save the above code to `run_hunyuan_video.py` and run it with `torchrun`: ```bash +# Use --nproc_per_node to specify the number of GPUs torchrun --nproc_per_node=8 run_hunyuan_video.py ``` With 8 NVIDIA L20 GPUs, we can generate 129 frames with 720p resolution in 30 inference steps in 649.23 seconds. This is a 5.66x speedup compared to the baseline! -## Conclusion + + + +### Conclusion of Inference Optimization with ParaAttention + + + + +| GPU Type | Number of GPUs | Optimizations | Wall Time (s) | Speedup | +| - | - | - | - | - | +| NVIDIA L20 | 1 | Baseline | 26.36 | 1.00x | +| NVIDIA L20 | 1 | FBCache (rdt=0.08) | 17.01 | 1.55x | +| NVIDIA L20 | 1 | FP8 DQ | 13.40 | 1.96x | +| NVIDIA L20 | 1 | FBCache (rdt=0.12) + FP8 DQ | 7.56 | 3.48x | +| NVIDIA L20 | 2 | FBCache (rdt=0.12) + FP8 DQ + CP | 4.92 | 5.35x | +| NVIDIA L20 | 4 | FBCache (rdt=0.12) + FP8 DQ + CP | 3.90 | 6.75x | + + + | GPU Type | Number of GPUs | Optimizations | Wall Time (s) | Speedup | | - | - | - | - | - | @@ -328,3 +513,6 @@ With 8 NVIDIA L20 GPUs, we can generate 129 frames with 720p resolution in 30 in | NVIDIA L20 | 2 | FBCache + CP | 1132.90 | 3.24x | | NVIDIA L20 | 4 | FBCache + CP | 718.15 | 5.12x | | NVIDIA L20 | 8 | FBCache + CP | 649.23 | 5.66x | + + + diff --git a/docs/source/en/optimization/para_attn_flux.md b/docs/source/en/optimization/para_attn_flux.md deleted file mode 100644 index 402b543885db..000000000000 --- a/docs/source/en/optimization/para_attn_flux.md +++ /dev/null @@ -1,285 +0,0 @@ -# Fastest FLUX Inference with ParaAttention - -[![](https://mermaid.ink/img/pako:eNqNUu9r2zAQ_VcOQUjDbMeWEycxbNAmKxS6kTH2g1X9oFgXW2BLQVbaZMH_-85J6WDsQ_VB0j096U733okVViHLWRiGwhTWbHWZCwM0DsdlJZ1_ifrxrJWvckiSOP4LVqjLyufApwSeXxkMTtpogk5DX2GDwxyGW-uw9cMOusFAmMOx6J8ON-glVNbp39Z4WQvjta8RBPsh6xq8bhDoIpRo0EmvTQnWIOhGlkjF-Apu77_9jJJQ4VMAScwnh34KgM-h9bhrA-LD5-93q7truOdxcLm0lk5ee4-UzRrBqJxQHnQLD4LdyBZrbVCwgKq4vVnKosIrp_z7OIrnoxd4PYfVl_9REj6Cd_C2c9os11d89DbehHiPwhwvlQrWImmlWsEghjD8ACk1X5iNdPDAsyjNqB2zKE5oSaMJfXwWTQmbRAseQBot4kcWsAZdI7Ui8U-9nIKd1RIsp-2GGtG3piOe3Hv79WgKlnu3x4A5uy8rlm9l3VK03ynpcaVl6WTziu6k-WVt8w_ro9LeulewtlIhhSfmj7vehKVuPSW82LDH964muPJ-1-bjcX8clSThfhMVthm3WvU2qp4W2Tjj2VzyFLNZKqdpqopNsphv-STZqlmccMm6LmB4zv_p4viz8bs_BMbpYw?type=png)](https://mermaid.live/edit#pako:eNqNUu9r2zAQ_VcOQUjDbMeWEycxbNAmKxS6kTH2g1X9oFgXW2BLQVbaZMH_-85J6WDsQ_VB0j096U733okVViHLWRiGwhTWbHWZCwM0DsdlJZ1_ifrxrJWvckiSOP4LVqjLyufApwSeXxkMTtpogk5DX2GDwxyGW-uw9cMOusFAmMOx6J8ON-glVNbp39Z4WQvjta8RBPsh6xq8bhDoIpRo0EmvTQnWIOhGlkjF-Apu77_9jJJQ4VMAScwnh34KgM-h9bhrA-LD5-93q7truOdxcLm0lk5ee4-UzRrBqJxQHnQLD4LdyBZrbVCwgKq4vVnKosIrp_z7OIrnoxd4PYfVl_9REj6Cd_C2c9os11d89DbehHiPwhwvlQrWImmlWsEghjD8ACk1X5iNdPDAsyjNqB2zKE5oSaMJfXwWTQmbRAseQBot4kcWsAZdI7Ui8U-9nIKd1RIsp-2GGtG3piOe3Hv79WgKlnu3x4A5uy8rlm9l3VK03ynpcaVl6WTziu6k-WVt8w_ro9LeulewtlIhhSfmj7vehKVuPSW82LDH964muPJ-1-bjcX8clSThfhMVthm3WvU2qp4W2Tjj2VzyFLNZKqdpqopNsphv-STZqlmccMm6LmB4zv_p4viz8bs_BMbpYw) - -## Introduction - -During the past year, we have seen the rapid development of image generation models with the release of several open-source models, such as [FLUX.1-dev](https://huggingface.co/black-forest-labs/FLUX.1-dev) and [SD3.5-large](https://huggingface.co/stabilityai/stable-diffusion-3.5-large). -It is very exciting to see that open source image generation models are going to beat closed source. -However, the inference speed of these models is still a bottleneck for real-time applications and deployment. - -In this article, we will use [ParaAttention](https://github.com/chengzeyi/ParaAttention), a library implements **Context Parallelism** and **First Block Cache**, as well as other techniques like `torch.compile` and **FP8 Dynamic Quantization**, to achieve the fastest inference speed for FLUX.1-dev. - -**We set up our experiments on NVIDIA L20 GPUs, which only have PCIe support.** -**If you have NVIDIA A100 or H100 GPUs with NVLink support, you can achieve a better speedup with context parallelism, especially when the number of GPUs is large.** - -## FLUX.1-dev Inference with `diffusers` - -Like many other generative AI models, FLUX.1-dev has its official code repository and is supported by other frameworks like `diffusers` and `ComfyUI`. -In this article, we will focus on optimizing the inference speed of FLUX.1-dev with `diffusers`. -To use FLUX.1-dev with `diffusers`, we need to install its latest version: - -```bash -pip3 install -U diffusers -``` - -Then, we can load the model and generate images with the following code: - -```python -import time -import torch -from diffusers import FluxPipeline - -pipe = FluxPipeline.from_pretrained( - "black-forest-labs/FLUX.1-dev", - torch_dtype=torch.bfloat16, -).to("cuda") - -# Enable memory savings -# pipe.enable_model_cpu_offload() -# pipe.enable_sequential_cpu_offload() - -begin = time.time() -image = pipe( - "A cat holding a sign that says hello world", - num_inference_steps=28, -).images[0] -end = time.time() -print(f"Time: {end - begin:.2f}s") - -print("Saving image to flux.png") -image.save("flux.png") -``` - -This is our baseline. -On one single NVIDIA L20 GPU, we can generate 1 image with 1024x1024 resolution in 28 inference steps in 26.36 seconds. - -## Apply First Block Cache on FLUX.1-dev - -By caching the output of the transformer blocks in the transformer model and resuing them in the next inference steps, we can reduce the computation cost and make the inference faster. -However, it is hard to decide when to reuse the cache to ensure the quality of the generated image. -Recently, [TeaCache](https://github.com/ali-vilab/TeaCache) suggests that we can use the timestep embedding to approximate the difference among model outputs. -And [AdaCache](https://adacache-dit.github.io) also shows that caching can contribute grant significant inference speedups without sacrificing the generation quality, across multiple image and video DiT baselines. -However, TeaCache is still a bit complex as it needs a rescaling strategy to ensure the accuracy of the cache. -In ParaAttention, we find that we can directly use **the residual difference of the first transformer block output** to approximate the difference among model outputs. -When the difference is small enough, we can reuse the residual difference of previous inference steps, meaning that we in fact skip this denoising step. -This has been proved to be effective in our experiments and we can achieve an up to 1.5x speedup on FLUX.1-dev inference with very good quality. - -
- Cache in Diffusion Transformer -
How AdaCache works, First Block Cache is a variant of it
-
- -To apply the first block cache on FLUX.1-dev, we can call `apply_cache_on_pipe` with `residual_diff_threshold=0.08`, which is the default value for FLUX models. - -```python -import time -import torch -from diffusers import FluxPipeline - -pipe = FluxPipeline.from_pretrained( - "black-forest-labs/FLUX.1-dev", - torch_dtype=torch.bfloat16, -).to("cuda") - -from para_attn.first_block_cache.diffusers_adapters import apply_cache_on_pipe - -apply_cache_on_pipe(pipe, residual_diff_threshold=0.08) - -# Enable memory savings -# pipe.enable_model_cpu_offload() -# pipe.enable_sequential_cpu_offload() - -begin = time.time() -image = pipe( - "A cat holding a sign that says hello world", - num_inference_steps=28, -).images[0] -end = time.time() -print(f"Time: {end - begin:.2f}s") - -print("Saving image to flux.png") -image.save("flux.png") -``` - -| Optimizations | Original | FBCache rdt=0.06 | FBCache rdt=0.08 | FBCache rdt=0.10 | FBCache rdt=0.12 | -| - | - | - | - | - | - | -| Preview | ![Original](https://github.com/chengzeyi/ParaAttention/blob/main/assets/flux_original.png) | ![FBCache rdt=0.06](https://github.com/chengzeyi/ParaAttention/blob/main/assets/flux_fbc_0.06.png) | ![FBCache rdt=0.08](https://github.com/chengzeyi/ParaAttention/blob/main/assets/flux_fbc_0.08.png) | ![FBCache rdt=0.10](https://github.com/chengzeyi/ParaAttention/blob/main/assets/flux_fbc_0.10.png) | ![FBCache rdt=0.12](https://github.com/chengzeyi/ParaAttention/raw/main/assets/flux_fbc_0.12.png) | -| Wall Time (s) | 26.36 | 21.83 | 17.01 | 16.00 | 13.78 | - -We observe that the first block cache is very effective in speeding up the inference, and maintaining nearly no quality loss in the generated image. -Now, on one single NVIDIA L20 GPU, we can generate 1 image with 1024x1024 resolution in 28 inference steps in 17.01 seconds. This is a 1.55x speedup compared to the baseline. - -## Quantize the model into FP8 - -To further speed up the inference and reduce memory usage, we can quantize the model into FP8 with dynamic quantization. -We must quantize both the activation and weight of the transformer model to utilize the 8-bit **Tensor Cores** on NVIDIA GPUs. -Here, we use `float8_weight_only` and `float8_dynamic_activation_float8_weight`to quantize the text encoder and transformer model respectively. -The default quantization method is per tensor quantization. If your GPU supports row-wise quantization, you can also try it for better accuracy. -[diffusers-torchao](https://github.com/sayakpaul/diffusers-torchao) provides a really good tutorial on how to quantize models in `diffusers` and achieve a good speedup. -Here, we simply install the latest `torchao` that is capable of quantizing FLUX.1-dev. -If you are not familiar with `torchao` quantization, you can refer to this [documentation](https://github.com/pytorch/ao/blob/main/torchao/quantization/README.md). - -```bash -pip3 install -U torch torchao -``` - -We also need to pass the model to `torch.compile` to gain actual speedup. -`torch.compile` with `mode="max-autotune-no-cudagraphs"` or `mode="max-autotune"` can help us to achieve the best performance by generating and selecting the best kernel for the model inference. -The compilation process could take a long time, but it is worth it. -If you are not familiar with `torch.compile`, you can refer to the [official tutorial](https://pytorch.org/tutorials/intermediate/torch_compile_tutorial.html). -In this example, we only quantize the transformer model, but you can also quantize the text encoder to reduce more memory usage. -We also need to notice that the actually compilation process is done on the first time the model is called, so we need to warm up the model to measure the speedup correctly. - -**Note**: we find that dynamic quantization can significantly change the distribution of the model output, so we need to change the `residual_diff_threshold` to a larger value to make it take effect. - -```python -import time -import torch -from diffusers import FluxPipeline - -pipe = FluxPipeline.from_pretrained( - "black-forest-labs/FLUX.1-dev", - torch_dtype=torch.bfloat16, -).to("cuda") - -from para_attn.first_block_cache.diffusers_adapters import apply_cache_on_pipe - -apply_cache_on_pipe( - pipe, - residual_diff_threshold=0.12, # Use a larger value to make the cache take effect -) - -from torchao.quantization import quantize_, float8_dynamic_activation_float8_weight, float8_weight_only - -quantize_(pipe.text_encoder, float8_weight_only()) -quantize_(pipe.transformer, float8_dynamic_activation_float8_weight()) -pipe.transformer = torch.compile( - pipe.transformer, mode="max-autotune-no-cudagraphs", -) - -# Enable memory savings -# pipe.enable_model_cpu_offload() -# pipe.enable_sequential_cpu_offload() - -for i in range(2): - begin = time.time() - image = pipe( - "A cat holding a sign that says hello world", - num_inference_steps=28, - ).images[0] - end = time.time() - if i == 0: - print(f"Warm up time: {end - begin:.2f}s") - else: - print(f"Time: {end - begin:.2f}s") - -print("Saving image to flux.png") -image.save("flux.png") -``` - -We can see that the quantization and compilation process can further speed up the inference. -Now, on one single NVIDIA L20 GPU, we can generate 1 image with 1024x1024 resolution in 28 inference steps in 7.56s, which is a 3.48x speedup compared to the baseline. - -## Parallelize the inference with Context Parallelism - -A lot faster than before, right? But we are not satisfied with the speedup we have achieved so far. -If we want to accelerate the inference further, we can use context parallelism to parallelize the inference. -Libraries like [xDit](https://github.com/xdit-project/xDiT) and our [ParaAttention](https://github.com/chengzeyi/ParaAttention) provide ways to scale up the inference with multiple GPUs. -In ParaAttention, we design our API in a compositional way so that we can combine context parallelism with first block cache and dynamic quantization all together. -We provide very detailed instructions and examples of how to scale up the inference with multiple GPUs in our ParaAttention repository. -Users can easily launch the inference with multiple GPUs by calling `torchrun`. -If there is a need to make the inference process persistent and serviceable, it is suggested to use `torch.multiprocessing` to write your own inference processor, which can eliminate the overhead of launching the process and loading and recompiling the model. - -Below is our ultimate code to achieve the fastest FLUX.1-dev inference: - -```python -import time -import torch -import torch.distributed as dist -from diffusers import FluxPipeline - -dist.init_process_group() - -torch.cuda.set_device(dist.get_rank()) - -pipe = FluxPipeline.from_pretrained( - "black-forest-labs/FLUX.1-dev", - torch_dtype=torch.bfloat16, -).to("cuda") - -from para_attn.context_parallel import init_context_parallel_mesh -from para_attn.context_parallel.diffusers_adapters import parallelize_pipe -from para_attn.parallel_vae.diffusers_adapters import parallelize_vae - -mesh = init_context_parallel_mesh( - pipe.device.type, - max_ring_dim_size=2, -) -parallelize_pipe( - pipe, - mesh=mesh, -) -parallelize_vae(pipe.vae, mesh=mesh._flatten()) - -from para_attn.first_block_cache.diffusers_adapters import apply_cache_on_pipe - -apply_cache_on_pipe( - pipe, - residual_diff_threshold=0.12, # Use a larger value to make the cache take effect -) - -from torchao.quantization import quantize_, float8_dynamic_activation_float8_weight, float8_weight_only - -quantize_(pipe.text_encoder, float8_weight_only()) -quantize_(pipe.transformer, float8_dynamic_activation_float8_weight()) -torch._inductor.config.reorder_for_compute_comm_overlap = True -pipe.transformer = torch.compile( - pipe.transformer, mode="max-autotune-no-cudagraphs", -) - -# Enable memory savings -# pipe.enable_model_cpu_offload(gpu_id=dist.get_rank()) -# pipe.enable_sequential_cpu_offload(gpu_id=dist.get_rank()) - -for i in range(2): - begin = time.time() - image = pipe( - "A cat holding a sign that says hello world", - num_inference_steps=28, - output_type="pil" if dist.get_rank() == 0 else "pt", - ).images[0] - end = time.time() - if dist.get_rank() == 0: - if i == 0: - print(f"Warm up time: {end - begin:.2f}s") - else: - print(f"Time: {end - begin:.2f}s") - -if dist.get_rank() == 0: - print("Saving image to flux.png") - image.save("flux.png") - -dist.destroy_process_group() -``` - -We save the above code to `run_flux.py` and run it with `torchrun`: - -```bash -# Use --nproc_per_node to specify the number of GPUs -torchrun --nproc_per_node=2 run_flux.py -``` - -With 2 NVIDIA L20 GPUs, we can generate 1 image with 1024x1024 resolution in 28 inference steps in 8.20 seconds, which is a 3.21x speedup compared to the baseline. -And with 4 NVIDIA L20 GPUs, we can generate 1 image with 1024x1024 resolution in 28 inference steps in 3.90 seconds, which is a 6.75x speedup compared to the baseline. - -## Conclusion - -| GPU Type | Number of GPUs | Optimizations | Wall Time (s) | Speedup | -| - | - | - | - | - | -| NVIDIA L20 | 1 | Baseline | 26.36 | 1.00x | -| NVIDIA L20 | 1 | FBCache (rdt=0.08) | 17.01 | 1.55x | -| NVIDIA L20 | 1 | FP8 DQ | 13.40 | 1.96x | -| NVIDIA L20 | 1 | FBCache (rdt=0.12) + FP8 DQ | 7.56 | 3.48x | -| NVIDIA L20 | 2 | FBCache (rdt=0.12) + FP8 DQ + CP | 4.92 | 5.35x | -| NVIDIA L20 | 4 | FBCache (rdt=0.12) + FP8 DQ + CP | 3.90 | 6.75x | diff --git a/docs/source/en/optimization/xdit.md b/docs/source/en/optimization/xdit.md index 33ff8dc255d0..6c7d25fcd3b3 100644 --- a/docs/source/en/optimization/xdit.md +++ b/docs/source/en/optimization/xdit.md @@ -118,4 +118,4 @@ More detailed performance metric can be found on our [github page](https://githu [USP: A Unified Sequence Parallelism Approach for Long Context Generative AI](https://arxiv.org/abs/2405.07719) -[PipeFusion: Displaced Patch Pipeline Parallelism for Inference of Diffusion Transformer Models](https://arxiv.org/abs/2405.14430) \ No newline at end of file +[PipeFusion: Displaced Patch Pipeline Parallelism for Inference of Diffusion Transformer Models](https://arxiv.org/abs/2405.14430) From d90851a08e90816651e02d9c5585c3ef83d7f6d7 Mon Sep 17 00:00:00 2001 From: chengzeyi Date: Tue, 14 Jan 2025 16:50:27 +0800 Subject: [PATCH 05/30] fix --- docs/source/en/optimization/xdit.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/en/optimization/xdit.md b/docs/source/en/optimization/xdit.md index 6c7d25fcd3b3..33ff8dc255d0 100644 --- a/docs/source/en/optimization/xdit.md +++ b/docs/source/en/optimization/xdit.md @@ -118,4 +118,4 @@ More detailed performance metric can be found on our [github page](https://githu [USP: A Unified Sequence Parallelism Approach for Long Context Generative AI](https://arxiv.org/abs/2405.07719) -[PipeFusion: Displaced Patch Pipeline Parallelism for Inference of Diffusion Transformer Models](https://arxiv.org/abs/2405.14430) +[PipeFusion: Displaced Patch Pipeline Parallelism for Inference of Diffusion Transformer Models](https://arxiv.org/abs/2405.14430) \ No newline at end of file From 83e651fa772233ae71a2f22fb2e8bada57914841 Mon Sep 17 00:00:00 2001 From: chengzeyi Date: Tue, 14 Jan 2025 16:57:59 +0800 Subject: [PATCH 06/30] fix --- docs/source/en/optimization/para_attn.md | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/docs/source/en/optimization/para_attn.md b/docs/source/en/optimization/para_attn.md index f9a46aadb288..42a8af04799b 100644 --- a/docs/source/en/optimization/para_attn.md +++ b/docs/source/en/optimization/para_attn.md @@ -173,7 +173,8 @@ If you are not familiar with `torch.compile`, you can refer to the [official tut In this example, we only quantize the transformer model, but you can also quantize the text encoder to reduce more memory usage. We also need to notice that the actually compilation process is done on the first time the model is called, so we need to warm up the model to measure the speedup correctly. -**Note**: we find that dynamic quantization can significantly change the distribution of the model output, so we need to change the `residual_diff_threshold` to a larger value to make it take effect. +> [!TIP] +> We find that dynamic quantization can significantly change the distribution of the model output, so we need to change the `residual_diff_threshold` to a larger value to make it take effect. From 646920cf24e1ea813784c1ceed15a083490b5290 Mon Sep 17 00:00:00 2001 From: C Date: Wed, 15 Jan 2025 14:26:59 +0800 Subject: [PATCH 07/30] Update docs/source/en/optimization/para_attn.md Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com> --- docs/source/en/optimization/para_attn.md | 1 - 1 file changed, 1 deletion(-) diff --git a/docs/source/en/optimization/para_attn.md b/docs/source/en/optimization/para_attn.md index 42a8af04799b..3b9294b0669a 100644 --- a/docs/source/en/optimization/para_attn.md +++ b/docs/source/en/optimization/para_attn.md @@ -7,7 +7,6 @@ -## Introduction Large image and video generation models, such as [FLUX.1-dev](https://huggingface.co/black-forest-labs/FLUX.1-dev) and [HunyuanVideo](https://huggingface.co/tencent/HunyuanVideo), can be an inference challenge for real-time applications and deployment because of their size. From 1b7ea1ac8898f5883c2ae0b83630f837c106d804 Mon Sep 17 00:00:00 2001 From: C Date: Wed, 15 Jan 2025 14:28:11 +0800 Subject: [PATCH 08/30] Update docs/source/en/optimization/para_attn.md Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com> --- docs/source/en/optimization/para_attn.md | 12 +++--------- 1 file changed, 3 insertions(+), 9 deletions(-) diff --git a/docs/source/en/optimization/para_attn.md b/docs/source/en/optimization/para_attn.md index 3b9294b0669a..d227604cc5f4 100644 --- a/docs/source/en/optimization/para_attn.md +++ b/docs/source/en/optimization/para_attn.md @@ -12,19 +12,13 @@ Large image and video generation models, such as [FLUX.1-dev](https://huggingfac [ParaAttention](https://github.com/chengzeyi/ParaAttention) is a library that implements **context parallelism** and **first block cache**, and can be combined with other techniques (torch.compile, fp8 dynamic quantization), to accelerate inference. -This guide will show you how to apply ParaAttention to FLUX.1-dev and HunyuanVideo on NVIDIA L20 GPUs with only PCIe support. +This guide will show you how to apply ParaAttention to FLUX.1-dev and HunyuanVideo on NVIDIA L20 GPUs in fp16/bf16 precision with only PCIe support. No optimizations are applied, except for HunyuanVideo to avoid out-of-memory errors. + +FLUX.1-dev is able to generate a 1024x1024 resolution image in 28 steps in 26.36 seconds. HunyuanVideo is able to generate 129 frames at 720p resolution in 30 steps in 3675.71 seconds. > [!TIP] > For even faster inference with context parallelism, try using NVIDIA A100 or H100 GPUs (if available) with NVLink support, especially when there is a large number of GPUs. -## Baseline Performance - -Our baseline performance is measured on a single NVIDIA L20 GPU, with normal FP16/BF16 precision, For FLUX.1-dev, we don't apply any optimization, and for HunyuanVideo, we only apply the patch `apply_cache_on_pipe(pipe, residual_diff_threshold=0.0)` to make it run without a square `attn_mask` to avoid OOM errors. - -For FLUX.1-dev, we can generate 1 image with 1024x1024 resolution in 28 inference steps in 26.36 seconds. - -For HunyuanVideo, we can generate 129 frames with 720p resolution in 30 inference steps in 3675.71 seconds. - ### First Block Cache By caching the output of the transformer blocks in the transformer model and resuing them in the next inference steps, we can reduce the computation cost and make the inference faster. From a2345d6a2966578a3c4e49baa971280c70b2da2f Mon Sep 17 00:00:00 2001 From: C Date: Wed, 15 Jan 2025 14:28:25 +0800 Subject: [PATCH 09/30] Update docs/source/en/optimization/para_attn.md Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com> --- docs/source/en/optimization/para_attn.md | 14 +++++--------- 1 file changed, 5 insertions(+), 9 deletions(-) diff --git a/docs/source/en/optimization/para_attn.md b/docs/source/en/optimization/para_attn.md index d227604cc5f4..0c231a849cfb 100644 --- a/docs/source/en/optimization/para_attn.md +++ b/docs/source/en/optimization/para_attn.md @@ -21,15 +21,11 @@ FLUX.1-dev is able to generate a 1024x1024 resolution image in 28 steps in 26.36 ### First Block Cache -By caching the output of the transformer blocks in the transformer model and resuing them in the next inference steps, we can reduce the computation cost and make the inference faster. -However, it is hard to decide when to reuse the cache to ensure the quality of the generated image or video. -Recently, [TeaCache](https://github.com/ali-vilab/TeaCache) suggests that we can use the timestep embedding to approximate the difference among model outputs. -And [AdaCache](https://adacache-dit.github.io) also shows that caching can contribute grant significant inference speedups without sacrificing the generation quality, across multiple image and video DiT baselines. -However, TeaCache is still a bit complex as it needs a rescaling strategy to ensure the accuracy of the cache. -In ParaAttention, we find that we can directly use **the residual difference of the first transformer block output** to approximate the difference among model outputs. -When the difference is small enough, we can reuse the residual difference of previous inference steps, meaning that we in fact skip this denoising step. - -This has been proved to be effective in our experiments and we can achieve an 2.0x speedup on FLUX.1-dev and HunyuanVideo inference with very good quality. +Caching the output of the transformers blocks in the model and reusing them in the next inference steps reduces the computation cost and makes inference faster. + +However, it is hard to decide when to reuse the cache to ensure quality generated images or videos. ParaAttention directly uses the **residual difference of the first transformer block output** to approximate the difference among model outputs. When the difference is small enough, the residual difference of previous inference steps is reused. In other words, the denoising step is skipped. + +This achieves a 2x speedup on FLUX.1-dev and HunyuanVideo inference with very good quality.
Cache in Diffusion Transformer From 8b8379b6b9dfbb2f801ce7c5b9de49a8e23d0450 Mon Sep 17 00:00:00 2001 From: C Date: Wed, 15 Jan 2025 14:28:47 +0800 Subject: [PATCH 10/30] Update docs/source/en/optimization/para_attn.md Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com> --- docs/source/en/optimization/para_attn.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/en/optimization/para_attn.md b/docs/source/en/optimization/para_attn.md index 0c231a849cfb..12a3f9728b93 100644 --- a/docs/source/en/optimization/para_attn.md +++ b/docs/source/en/optimization/para_attn.md @@ -19,7 +19,7 @@ FLUX.1-dev is able to generate a 1024x1024 resolution image in 28 steps in 26.36 > [!TIP] > For even faster inference with context parallelism, try using NVIDIA A100 or H100 GPUs (if available) with NVLink support, especially when there is a large number of GPUs. -### First Block Cache +## First Block Cache Caching the output of the transformers blocks in the model and reusing them in the next inference steps reduces the computation cost and makes inference faster. From 1722d07bf115534a184dfa18b52afd4f7372a4c4 Mon Sep 17 00:00:00 2001 From: C Date: Wed, 15 Jan 2025 14:28:57 +0800 Subject: [PATCH 11/30] Update docs/source/en/optimization/para_attn.md Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com> --- docs/source/en/optimization/para_attn.md | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/docs/source/en/optimization/para_attn.md b/docs/source/en/optimization/para_attn.md index 12a3f9728b93..8000914445c7 100644 --- a/docs/source/en/optimization/para_attn.md +++ b/docs/source/en/optimization/para_attn.md @@ -378,8 +378,7 @@ We save the above code to `run_flux.py` and run it with `torchrun`: torchrun --nproc_per_node=2 run_flux.py ``` -With 2 NVIDIA L20 GPUs, we can generate 1 image with 1024x1024 resolution in 28 inference steps in 8.20 seconds, which is a 3.21x speedup compared to the baseline. -And with 4 NVIDIA L20 GPUs, we can generate 1 image with 1024x1024 resolution in 28 inference steps in 3.90 seconds, which is a 6.75x speedup compared to the baseline. +Inference speed is reduced to 8.20 seconds compared to the baseline, or 3.21x faster, with 2 NVIDIA L20 GPUs. On 4 L20s, inference speed is 3.90 seconds, or 6.75x faster. From 51feb2e72bc961eb292fed727f01b54bb9bd8723 Mon Sep 17 00:00:00 2001 From: C Date: Wed, 15 Jan 2025 14:29:14 +0800 Subject: [PATCH 12/30] Update docs/source/en/optimization/para_attn.md Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com> --- docs/source/en/optimization/para_attn.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/en/optimization/para_attn.md b/docs/source/en/optimization/para_attn.md index 8000914445c7..26209ad0c1bc 100644 --- a/docs/source/en/optimization/para_attn.md +++ b/docs/source/en/optimization/para_attn.md @@ -383,7 +383,7 @@ Inference speed is reduced to 8.20 seconds compared to the baseline, or 3.21x fa -Below is our ultimate code to achieve a much faster HunyuanVideo inference: +The code sample below combines First Block Cache, fp8 dynamic quantization, torch.compile, and Context Parallelism for the fastest inference speed. ```python import time From 327dcedbe1895ee318e1917b3c2402a53634dd5a Mon Sep 17 00:00:00 2001 From: C Date: Wed, 15 Jan 2025 14:29:22 +0800 Subject: [PATCH 13/30] Update docs/source/en/optimization/para_attn.md Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com> --- docs/source/en/optimization/para_attn.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/en/optimization/para_attn.md b/docs/source/en/optimization/para_attn.md index 26209ad0c1bc..c9d59d16fa93 100644 --- a/docs/source/en/optimization/para_attn.md +++ b/docs/source/en/optimization/para_attn.md @@ -466,7 +466,7 @@ if dist.get_rank() == 0: dist.destroy_process_group() ``` -We save the above code to `run_hunyuan_video.py` and run it with `torchrun`: +Save to `run_hunyuan_video.py` and launch it with [torchrun](https://pytorch.org/docs/stable/elastic/run.html). ```bash # Use --nproc_per_node to specify the number of GPUs From 66e13a80b823589a00efa0dcb94497558f811f11 Mon Sep 17 00:00:00 2001 From: C Date: Wed, 15 Jan 2025 14:29:30 +0800 Subject: [PATCH 14/30] Update docs/source/en/optimization/para_attn.md Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com> --- docs/source/en/optimization/para_attn.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/en/optimization/para_attn.md b/docs/source/en/optimization/para_attn.md index c9d59d16fa93..6773238518b2 100644 --- a/docs/source/en/optimization/para_attn.md +++ b/docs/source/en/optimization/para_attn.md @@ -473,7 +473,7 @@ Save to `run_hunyuan_video.py` and launch it with [torchrun](https://pytorch.org torchrun --nproc_per_node=8 run_hunyuan_video.py ``` -With 8 NVIDIA L20 GPUs, we can generate 129 frames with 720p resolution in 30 inference steps in 649.23 seconds. This is a 5.66x speedup compared to the baseline! +Inference speed is reduced to 649.23 seconds compared to the baseline, or 5.66x faster, with 8 NVIDIA L20 GPUs. From 91d2ef53eae038f3cc3d7f626c319006b4c7999d Mon Sep 17 00:00:00 2001 From: C Date: Wed, 15 Jan 2025 14:29:38 +0800 Subject: [PATCH 15/30] Update docs/source/en/optimization/para_attn.md Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com> --- docs/source/en/optimization/para_attn.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/en/optimization/para_attn.md b/docs/source/en/optimization/para_attn.md index 6773238518b2..3700f4063965 100644 --- a/docs/source/en/optimization/para_attn.md +++ b/docs/source/en/optimization/para_attn.md @@ -478,7 +478,7 @@ Inference speed is reduced to 649.23 seconds compared to the baseline, or 5.66x -### Conclusion of Inference Optimization with ParaAttention +## Benchmarks From 76ae1cbe769723ab55a869bf6f66248beb3d70f9 Mon Sep 17 00:00:00 2001 From: C Date: Wed, 15 Jan 2025 14:29:55 +0800 Subject: [PATCH 16/30] Update docs/source/en/optimization/para_attn.md Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com> --- docs/source/en/optimization/para_attn.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/en/optimization/para_attn.md b/docs/source/en/optimization/para_attn.md index 3700f4063965..e5c2c5ea19ce 100644 --- a/docs/source/en/optimization/para_attn.md +++ b/docs/source/en/optimization/para_attn.md @@ -35,7 +35,7 @@ This achieves a 2x speedup on FLUX.1-dev and HunyuanVideo inference with very go -To apply first block cache on FLUX.1-dev, we can call `apply_cache_on_pipe` with `residual_diff_threshold=0.08`, which is the default value for FLUX models. +To apply first block cache on FLUX.1-dev, call `apply_cache_on_pipe` as shown below. 0.08 is the default residual difference value for FLUX models. ```python import time From 13474b52df3fea89483210ed1e9c4f9960870c03 Mon Sep 17 00:00:00 2001 From: chengzeyi Date: Wed, 15 Jan 2025 14:35:12 +0800 Subject: [PATCH 17/30] fix --- docs/source/en/optimization/para_attn.md | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/docs/source/en/optimization/para_attn.md b/docs/source/en/optimization/para_attn.md index e5c2c5ea19ce..8d544dec9e6b 100644 --- a/docs/source/en/optimization/para_attn.md +++ b/docs/source/en/optimization/para_attn.md @@ -12,9 +12,10 @@ Large image and video generation models, such as [FLUX.1-dev](https://huggingfac [ParaAttention](https://github.com/chengzeyi/ParaAttention) is a library that implements **context parallelism** and **first block cache**, and can be combined with other techniques (torch.compile, fp8 dynamic quantization), to accelerate inference. -This guide will show you how to apply ParaAttention to FLUX.1-dev and HunyuanVideo on NVIDIA L20 GPUs in fp16/bf16 precision with only PCIe support. No optimizations are applied, except for HunyuanVideo to avoid out-of-memory errors. +This guide will show you how to apply ParaAttention to FLUX.1-dev and HunyuanVideo on NVIDIA L20 GPUs. +No optimizations are applied for our baseline benchmark, except for HunyuanVideo to avoid out-of-memory errors. -FLUX.1-dev is able to generate a 1024x1024 resolution image in 28 steps in 26.36 seconds. HunyuanVideo is able to generate 129 frames at 720p resolution in 30 steps in 3675.71 seconds. +Our baseline benchmark shows that FLUX.1-dev is able to generate a 1024x1024 resolution image in 28 steps in 26.36 seconds, and HunyuanVideo is able to generate 129 frames at 720p resolution in 30 steps in 3675.71 seconds. > [!TIP] > For even faster inference with context parallelism, try using NVIDIA A100 or H100 GPUs (if available) with NVLink support, especially when there is a large number of GPUs. @@ -299,7 +300,7 @@ If there is a need to make the inference process persistent and serviceable, it -Below is our ultimate code to achieve a much faster FLUX.1-dev inference: +The code sample below combines First Block Cache, fp8 dynamic quantization, torch.compile, and Context Parallelism for the fastest inference speed. ```python import time @@ -383,7 +384,7 @@ Inference speed is reduced to 8.20 seconds compared to the baseline, or 3.21x fa -The code sample below combines First Block Cache, fp8 dynamic quantization, torch.compile, and Context Parallelism for the fastest inference speed. +The code sample below combines First Block Cache and Context Parallelism for the fastest inference speed. ```python import time From b098c3758e385eaee6b7f2c2f8d19ae8ebb340c5 Mon Sep 17 00:00:00 2001 From: chengzeyi Date: Thu, 16 Jan 2025 10:56:59 +0800 Subject: [PATCH 18/30] update links --- docs/source/en/optimization/para_attn.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/en/optimization/para_attn.md b/docs/source/en/optimization/para_attn.md index 8d544dec9e6b..d07d1582dbc1 100644 --- a/docs/source/en/optimization/para_attn.md +++ b/docs/source/en/optimization/para_attn.md @@ -70,7 +70,7 @@ image.save("flux.png") | Optimizations | Original | FBCache rdt=0.06 | FBCache rdt=0.08 | FBCache rdt=0.10 | FBCache rdt=0.12 | | - | - | - | - | - | - | -| Preview | ![Original](https://huggingface.co/datasets/chengzeyi/documentation-images/resolve/main/diffusers/para-attn/flux-original.png) | ![FBCache rdt=0.06](https://huggingface.co/datasets/chengzeyi/documentation-images/resolve/main/diffusers/para-attn/flux-fbc-0.06.png) | ![FBCache rdt=0.08](https://huggingface.co/datasets/chengzeyi/documentation-images/resolve/main/diffusers/para-attn/flux-fbc-0.08.png) | ![FBCache rdt=0.10](https://huggingface.co/datasets/chengzeyi/documentation-images/resolve/main/diffusers/para-attn/flux-fbc-0.10.png) | ![FBCache rdt=0.12](https://huggingface.co/datasets/chengzeyi/documentation-images/resolve/main/diffusers/para-attn/flux-fbc-0.12.png) | +| Preview | ![Original](https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/para-attn/flux-original.png) | ![FBCache rdt=0.06](https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/para-attn/flux-fbc-0.06.png) | ![FBCache rdt=0.08](https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/para-attn/flux-fbc-0.08.png) | ![FBCache rdt=0.10](https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/para-attn/flux-fbc-0.10.png) | ![FBCache rdt=0.12](https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/para-attn/flux-fbc-0.12.png) | | Wall Time (s) | 26.36 | 21.83 | 17.01 | 16.00 | 13.78 | We observe that the first block cache is very effective in speeding up the inference, and maintaining nearly no quality loss in the generated image. From cdfd39c7ca350a46164a467c63622f4b658a25d4 Mon Sep 17 00:00:00 2001 From: C Date: Thu, 16 Jan 2025 10:58:25 +0800 Subject: [PATCH 19/30] Update docs/source/en/optimization/para_attn.md Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com> --- docs/source/en/optimization/para_attn.md | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/docs/source/en/optimization/para_attn.md b/docs/source/en/optimization/para_attn.md index d07d1582dbc1..12b4e2504de4 100644 --- a/docs/source/en/optimization/para_attn.md +++ b/docs/source/en/optimization/para_attn.md @@ -73,8 +73,7 @@ image.save("flux.png") | Preview | ![Original](https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/para-attn/flux-original.png) | ![FBCache rdt=0.06](https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/para-attn/flux-fbc-0.06.png) | ![FBCache rdt=0.08](https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/para-attn/flux-fbc-0.08.png) | ![FBCache rdt=0.10](https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/para-attn/flux-fbc-0.10.png) | ![FBCache rdt=0.12](https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/para-attn/flux-fbc-0.12.png) | | Wall Time (s) | 26.36 | 21.83 | 17.01 | 16.00 | 13.78 | -We observe that the first block cache is very effective in speeding up the inference, and maintaining nearly no quality loss in the generated image. -Now, on one single NVIDIA L20 GPU, we can generate 1 image with 1024x1024 resolution in 28 inference steps in 17.01 seconds. This is a 1.55x speedup compared to the baseline. +First Block Cache reduced the inference speed to 17.01 seconds compared to the baseline, or 1.55x faster, while maintaining nearly zero quality loss. From d6fa4ea8375b49fd7d027b4c01c1031f877756b3 Mon Sep 17 00:00:00 2001 From: C Date: Thu, 16 Jan 2025 10:58:48 +0800 Subject: [PATCH 20/30] Update docs/source/en/optimization/para_attn.md Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com> --- docs/source/en/optimization/para_attn.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/en/optimization/para_attn.md b/docs/source/en/optimization/para_attn.md index 12b4e2504de4..86fe438cf091 100644 --- a/docs/source/en/optimization/para_attn.md +++ b/docs/source/en/optimization/para_attn.md @@ -371,7 +371,7 @@ if dist.get_rank() == 0: dist.destroy_process_group() ``` -We save the above code to `run_flux.py` and run it with `torchrun`: +Save to `run_flux.py` and launch it with [torchrun](https://pytorch.org/docs/stable/elastic/run.html). ```bash # Use --nproc_per_node to specify the number of GPUs From 98626bf328569816349ece99eac2afe541ffdb3b Mon Sep 17 00:00:00 2001 From: C Date: Thu, 16 Jan 2025 11:00:33 +0800 Subject: [PATCH 21/30] Update docs/source/en/optimization/para_attn.md Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com> --- docs/source/en/optimization/para_attn.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/en/optimization/para_attn.md b/docs/source/en/optimization/para_attn.md index 86fe438cf091..f350108f4012 100644 --- a/docs/source/en/optimization/para_attn.md +++ b/docs/source/en/optimization/para_attn.md @@ -78,7 +78,7 @@ First Block Cache reduced the inference speed to 17.01 seconds compared to the b -To apply first block cache on HunyuanVideo, we can call `apply_cache_on_pipe` with `residual_diff_threshold=0.06`, which is the default value for HunyuanVideo. +To apply First Block Cache on HunyuanVideo, `apply_cache_on_pipe` as shown below. 0.06 is the default residual difference value for HunyuanVideo models. ```python import time From 2d6a2e142c524f36f45464274c37a29a0da6fc14 Mon Sep 17 00:00:00 2001 From: chengzeyi Date: Thu, 16 Jan 2025 11:02:41 +0800 Subject: [PATCH 22/30] fix --- docs/source/en/optimization/para_attn.md | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/docs/source/en/optimization/para_attn.md b/docs/source/en/optimization/para_attn.md index f350108f4012..f4e5b480b123 100644 --- a/docs/source/en/optimization/para_attn.md +++ b/docs/source/en/optimization/para_attn.md @@ -121,20 +121,20 @@ print("Saving video to hunyuan_video.mp4") export_to_video(output, "hunyuan_video.mp4", fps=15) ``` -#### HunyuanVideo without FBCache - -#### HunyuanVideo with FBCache + HunyuanVideo without FBCache + HunyuanVideo with FBCache + We observe that the first block cache is very effective in speeding up the inference, and maintaining nearly no quality loss in the generated video. Now, on one single NVIDIA L20 GPU, we can generate 129 frames with 720p resolution in 30 inference steps in 2271.06 seconds. This is a 1.62x speedup compared to the baseline. From 03abeda8e65514793041d942c6bd36cebcc55aa4 Mon Sep 17 00:00:00 2001 From: C Date: Thu, 16 Jan 2025 11:03:16 +0800 Subject: [PATCH 23/30] Update docs/source/en/optimization/para_attn.md Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com> --- docs/source/en/optimization/para_attn.md | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/docs/source/en/optimization/para_attn.md b/docs/source/en/optimization/para_attn.md index f4e5b480b123..da8dbcb40213 100644 --- a/docs/source/en/optimization/para_attn.md +++ b/docs/source/en/optimization/para_attn.md @@ -135,8 +135,7 @@ export_to_video(output, "hunyuan_video.mp4", fps=15) HunyuanVideo with FBCache -We observe that the first block cache is very effective in speeding up the inference, and maintaining nearly no quality loss in the generated video. -Now, on one single NVIDIA L20 GPU, we can generate 129 frames with 720p resolution in 30 inference steps in 2271.06 seconds. This is a 1.62x speedup compared to the baseline. +First Block Cache reduced the inference speed to 2271.06 seconds compared to the baseline, or 1.62x faster, while maintaining nearly zero quality loss. From 3c04cb8df4ad5e15123b88d685c840086167e9a7 Mon Sep 17 00:00:00 2001 From: C Date: Thu, 16 Jan 2025 11:03:56 +0800 Subject: [PATCH 24/30] Update docs/source/en/optimization/para_attn.md Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com> --- docs/source/en/optimization/para_attn.md | 17 ++++++++--------- 1 file changed, 8 insertions(+), 9 deletions(-) diff --git a/docs/source/en/optimization/para_attn.md b/docs/source/en/optimization/para_attn.md index da8dbcb40213..1b71eec95d3b 100644 --- a/docs/source/en/optimization/para_attn.md +++ b/docs/source/en/optimization/para_attn.md @@ -285,15 +285,14 @@ However, models like `FLUX.1-dev` can benefit a lot from quantization and compil -### Context Parallelism - -A lot faster than before, right? But we are not satisfied with the speedup we have achieved so far. -If we want to accelerate the inference further, we can use context parallelism to parallelize the inference. -Libraries like [xDit](https://github.com/xdit-project/xDiT) and our [ParaAttention](https://github.com/chengzeyi/ParaAttention) provide ways to scale up the inference with multiple GPUs. -In ParaAttention, we design our API in a compositional way so that we can combine context parallelism with first block cache and dynamic quantization all together. -We provide very detailed instructions and examples of how to scale up the inference with multiple GPUs in our ParaAttention repository. -Users can easily launch the inference with multiple GPUs by calling `torchrun`. -If there is a need to make the inference process persistent and serviceable, it is suggested to use `torch.multiprocessing` to write your own inference processor, which can eliminate the overhead of launching the process and loading and recompiling the model. +## Context Parallelism + +Context Parallelism parallelizes inference and scales with multiple GPUs. The ParaAttention compositional design allows you to combine Context Parallelism with First Block Cache and dynamic quantization. + +> [!TIP] +> Refer to the [ParaAttention](https://github.com/chengzeyi/ParaAttention/tree/main) repository for detailed instructions and examples of how to scale inference with multiple GPUs. + +If the inference process needs to be persistent and serviceable, it is suggested to use [torch.multiprocessing](https://pytorch.org/docs/stable/multiprocessing.html) to write your own inference processor. This can eliminate the overhead of launching the process and loading and recompiling the model. From edc062439664e6008337fabaf61eef2029b01bd9 Mon Sep 17 00:00:00 2001 From: C Date: Thu, 16 Jan 2025 11:04:34 +0800 Subject: [PATCH 25/30] Update docs/source/en/optimization/para_attn.md Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com> --- docs/source/en/optimization/para_attn.md | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/docs/source/en/optimization/para_attn.md b/docs/source/en/optimization/para_attn.md index 1b71eec95d3b..4eab3f8637be 100644 --- a/docs/source/en/optimization/para_attn.md +++ b/docs/source/en/optimization/para_attn.md @@ -274,13 +274,9 @@ print("Saving video to hunyuan_video.mp4") export_to_video(output, "hunyuan_video.mp4", fps=15) ``` -The NVIDIA L20 GPU only has 48GB memory and could face OOM errors after compiling the model and not calling `enable_model_cpu_offload`, -because the HunyuanVideo has very large activation tensors when running with high resolution and large number of frames. -So here we skip measuring the speedup with quantization and compilation on one single NVIDIA L20 GPU and choose to use context parallelism to release the memory pressure. -If you want to run HunyuanVideo with `torch.compile` on GPUs with less than 80GB memory, you can try reducing the resolution and the number of frames to avoid OOM errors. +A NVIDIA L20 GPU only has 48GB memory and could face out-of-memory (OOM) errors after compilation and if `enable_model_cpu_offload` isn't called because HunyuanVideo has very large activation tensors when running with high resolution and large number of frames. For GPUs with less than 80GB of memory, you can try reducing the resolution and number of frames to avoid OOM errors. -Due to the fact that large video generation models usually have performance bottleneck on the attention computation rather than the fully connected layers, we don't observe a significant speedup with quantization and compilation. -However, models like `FLUX.1-dev` can benefit a lot from quantization and compilation, it is suggested to try it for these models. +Large video generation models are usually bottlenecked by the attention computations rather than the fully connected layers. These models don't significantly benefit from quantization and torch.compile. From 2c1edf3c1761cd93e5e5a6c5dd0c809d883e0220 Mon Sep 17 00:00:00 2001 From: C Date: Thu, 16 Jan 2025 11:04:55 +0800 Subject: [PATCH 26/30] Update docs/source/en/optimization/para_attn.md Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com> --- docs/source/en/optimization/para_attn.md | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/docs/source/en/optimization/para_attn.md b/docs/source/en/optimization/para_attn.md index 4eab3f8637be..2cefbd1557f3 100644 --- a/docs/source/en/optimization/para_attn.md +++ b/docs/source/en/optimization/para_attn.md @@ -212,8 +212,7 @@ print("Saving image to flux.png") image.save("flux.png") ``` -We can see that the quantization and compilation process can further speed up the inference. -Now, on one single NVIDIA L20 GPU, we can generate 1 image with 1024x1024 resolution in 28 inference steps in 7.56s, which is a 3.48x speedup compared to the baseline. +fp8 dynamic quantization and torch.compile reduced the inference speed to 7.56 seconds compared to the baseline, or 3.48x faster. From a525f05923ee68c3abd9260fa8d8bd3d664b31d1 Mon Sep 17 00:00:00 2001 From: C Date: Thu, 16 Jan 2025 11:05:16 +0800 Subject: [PATCH 27/30] Update docs/source/en/optimization/para_attn.md Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com> --- docs/source/en/optimization/para_attn.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/en/optimization/para_attn.md b/docs/source/en/optimization/para_attn.md index 2cefbd1557f3..11e5c9334c8e 100644 --- a/docs/source/en/optimization/para_attn.md +++ b/docs/source/en/optimization/para_attn.md @@ -162,7 +162,7 @@ In this example, we only quantize the transformer model, but you can also quanti We also need to notice that the actually compilation process is done on the first time the model is called, so we need to warm up the model to measure the speedup correctly. > [!TIP] -> We find that dynamic quantization can significantly change the distribution of the model output, so we need to change the `residual_diff_threshold` to a larger value to make it take effect. +> Dynamic quantization can significantly change the distribution of the model output, so you need to change the `residual_diff_threshold` to a larger value for it to take effect. From 6d30ba108df80e6c1802235536cc2b57c72ccd52 Mon Sep 17 00:00:00 2001 From: C Date: Thu, 16 Jan 2025 11:05:41 +0800 Subject: [PATCH 28/30] Update docs/source/en/optimization/para_attn.md Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com> --- docs/source/en/optimization/para_attn.md | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/docs/source/en/optimization/para_attn.md b/docs/source/en/optimization/para_attn.md index 11e5c9334c8e..6d1bb63cb212 100644 --- a/docs/source/en/optimization/para_attn.md +++ b/docs/source/en/optimization/para_attn.md @@ -154,12 +154,9 @@ If you are not familiar with `torchao` quantization, you can refer to this [docu pip3 install -U torch torchao ``` -We also need to pass the model to `torch.compile` to gain actual speedup. -`torch.compile` with `mode="max-autotune-no-cudagraphs"` or `mode="max-autotune"` can help us to achieve the best performance by generating and selecting the best kernel for the model inference. -The compilation process could take a long time, but it is worth it. -If you are not familiar with `torch.compile`, you can refer to the [official tutorial](https://pytorch.org/tutorials/intermediate/torch_compile_tutorial.html). -In this example, we only quantize the transformer model, but you can also quantize the text encoder to reduce more memory usage. -We also need to notice that the actually compilation process is done on the first time the model is called, so we need to warm up the model to measure the speedup correctly. +[torch.compile](https://pytorch.org/tutorials/intermediate/torch_compile_tutorial.html) with `mode="max-autotune-no-cudagraphs"` or `mode="max-autotune"` selects the best kernel for performance. Compilation can take a long time if it's the first time the model is called, but it is worth it once the model has been compiled. + +This example only quantizes the transformer model, but you can also quantize the text encoder to reduce memory usage even more. > [!TIP] > Dynamic quantization can significantly change the distribution of the model output, so you need to change the `residual_diff_threshold` to a larger value for it to take effect. From 873426d0fab0d58a2dc6b380ec8bfc28057e9970 Mon Sep 17 00:00:00 2001 From: C Date: Thu, 16 Jan 2025 11:06:17 +0800 Subject: [PATCH 29/30] Update docs/source/en/optimization/para_attn.md Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com> --- docs/source/en/optimization/para_attn.md | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/docs/source/en/optimization/para_attn.md b/docs/source/en/optimization/para_attn.md index 6d1bb63cb212..740922f3b9be 100644 --- a/docs/source/en/optimization/para_attn.md +++ b/docs/source/en/optimization/para_attn.md @@ -142,13 +142,13 @@ First Block Cache reduced the inference speed to 2271.06 seconds compared to the ### FP8 Quantization -To further speed up the inference and reduce memory usage, we can quantize the model into FP8 with dynamic quantization. -We must quantize both the activation and weight of the transformer model to utilize the 8-bit **Tensor Cores** on NVIDIA GPUs. -Here, we use `float8_weight_only` and `float8_dynamic_activation_float8_weight`to quantize the text encoder and transformer model respectively. -The default quantization method is per tensor quantization. If your GPU supports row-wise quantization, you can also try it for better accuracy. -[diffusers-torchao](https://github.com/sayakpaul/diffusers-torchao) provides a really good tutorial on how to quantize models in `diffusers` and achieve a good speedup. -Here, we simply install the latest `torchao` that is capable of quantizing FLUX.1-dev and HunyuanVideo. -If you are not familiar with `torchao` quantization, you can refer to this [documentation](https://github.com/pytorch/ao/blob/main/torchao/quantization/README.md). +fp8 with dynamic quantization further speeds up inference and reduces memory usage. Both the activations and weights must be quantized in order to use the 8-bit [NVIDIA Tensor Cores](https://www.nvidia.com/en-us/data-center/tensor-cores/). + +Use `float8_weight_only` and `float8_dynamic_activation_float8_weight` to quantize the text encoder and transformer model. + +The default quantization method is per tensor quantization, but if your GPU supports row-wise quantization, you can also try it for better accuracy. + +Install [torchao](https://github.com/pytorch/ao/tree/main) with the command below. ```bash pip3 install -U torch torchao From 8e9ba97c22138f6ed1c0bf6e8378aa25cfdb173b Mon Sep 17 00:00:00 2001 From: C Date: Thu, 16 Jan 2025 11:06:38 +0800 Subject: [PATCH 30/30] Update docs/source/en/optimization/para_attn.md Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com> --- docs/source/en/optimization/para_attn.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/en/optimization/para_attn.md b/docs/source/en/optimization/para_attn.md index 740922f3b9be..b1b111045590 100644 --- a/docs/source/en/optimization/para_attn.md +++ b/docs/source/en/optimization/para_attn.md @@ -140,7 +140,7 @@ First Block Cache reduced the inference speed to 2271.06 seconds compared to the -### FP8 Quantization +## fp8 quantization fp8 with dynamic quantization further speeds up inference and reduces memory usage. Both the activations and weights must be quantized in order to use the 8-bit [NVIDIA Tensor Cores](https://www.nvidia.com/en-us/data-center/tensor-cores/).