diff --git a/_posts/2024-08-07-flexattention.md b/_posts/2024-08-07-flexattention.md index 7e15ccb38727..4c34879d33b6 100644 --- a/_posts/2024-08-07-flexattention.md +++ b/_posts/2024-08-07-flexattention.md @@ -1,6 +1,7 @@ --- layout: blog_detail title: "FlexAttention: The Flexibility of PyTorch with the Performance of FlashAttention" +author: "Team PyTorch: Horace He, Driss Guessous, Yanbo Liang, Joy Dong" --- ![a cartoon chart flexing his muscles](/assets/images/flexattention/fg1.jpg){:style="width:100%"} @@ -120,6 +121,7 @@ Note that unlike typical implementations, this does *not* need to materialize a ### ALiBi Bias ![alibi bias](/assets/images/flexattention/fg6.png){:style="max-width:600px; display:block; margin-left: auto; margin-right: auto; width:100%"} +

Source: Train Short, Test Long: Attention with Linear Biases Enables Input Length Extrapolation

ALiBi was introduced in [Train Short, Test Long: Attention with Linear Biases Enables Input Length Extrapolation](https://arxiv.org/abs/2108.12409), and claims to have beneficial properties for length extrapolation at inference. Notably, MosaicML has pointed to [“lack of kernel support”](https://twitter.com/jefrankle/status/1804567458092605736) as the main reason why they eventually switched from ALiBi to rotary embeddings. @@ -137,17 +139,18 @@ This demonstrates one interesting piece of flexibility `torch.compile` provides ### Soft-capping -Soft-capping is a technique used in [Gemma2](https://huggingface.co/blog/gemma2\#soft-capping-and-attention-implementations) and Grok-1 that prevents logits from growing excessively large. In FlexAttention, it looks like: +Soft-capping is a technique used in [Gemma2](https://huggingface.co/blog/gemma2#soft-capping-and-attention-implementations) and Grok-1 that prevents logits from growing excessively large. In FlexAttention, it looks like: ```py softcap = 20 -def soft_cap(score, b, h, q_idx, kv_idx): score = score / softcap +def soft_cap(score, b, h, q_idx, kv_idx): + score = score / softcap score = torch.tanh(score) score = score * softcap return score ``` -Note that we also automatically generate the backwards pass from the forwards pass here. Also, although this implementation is semantically correct, we likely want to use a tanh approximation in this case for performance reasons. See [attention-gym](https://github.com/pytorch-labs/attention-gym/blob/738268eae279c48dc8c4d1c6f40b3cfaec648831/attn\_gym/mods/softcapping.py\#L1) for more details. +Note that we also automatically generate the backwards pass from the forwards pass here. Also, although this implementation is semantically correct, we likely want to use a tanh approximation in this case for performance reasons. See [attention-gym](https://github.com/pytorch-labs/attention-gym/blob/main/attn_gym/mods/softcapping.py) for more details. ### Causal Mask @@ -164,7 +167,7 @@ However, masking is special compared to other modifications \- if something is m ## Mask Mods -To take advantage of sparsity from masking, we need to do some more work. Specifically, by passing a `mask_mod` to [`create_block_mask`](https://github.com/pytorch/pytorch/blob/e49c0acc396e89baf8c6450e1fa0571d4ce2d4ed/torch/nn/attention/flex_attention.py\#L594), we can create a `BlockMask`. FlexAttention can then use `BlockMask` to take advantage of the sparsity\! +To take advantage of sparsity from masking, we need to do some more work. Specifically, by passing a `mask_mod` to [`create_block_mask`](https://github.com/pytorch/pytorch/blob/e49c0acc396e89baf8c6450e1fa0571d4ce2d4ed/torch/nn/attention/flex_attention.py#L594), we can create a `BlockMask`. FlexAttention can then use `BlockMask` to take advantage of the sparsity\! The signature of `mask_mod` is very similar to `score_mod` \- just without the `score`. In particular @@ -201,6 +204,7 @@ While the TFlops are roughly the same, the execution time is 2x faster for the m ### Sliding Window \+ Causal ![Sliding Window Causal diagrams](/assets/images/flexattention/fg8.png){:style="width:100%"} +

Source: Mistral 7B

Popularized by [Mistral](https://arxiv.org/abs/2310.06825), sliding window attention (also known as local attention) takes advantage of the intuition that the most recent tokens are the most useful. In particular, it allows the query token to only attend to, say, the 1024 most recent tokens. This is often used together with causal attention. @@ -229,6 +233,7 @@ We benchmark it against `F.scaled_dot_product_attention` with a sliding window m ### PrefixLM ![PrefixLM diagram](/assets/images/flexattention/fg10.png){:style="max-width:600px; display:block; margin-left: auto; margin-right: auto; width:100%"} +

Source: PaliGemma: A versatile 3B VLM for transfer

The T5 architecture, proposed in [Exploring the Limits of Transfer Learning with a Unified Text-to-Text Transformer](https://arxiv.org/abs/1910.10683), describes an attention variant that performs full bidirectional attention on a “prefix”, and causal attention on the rest. We again compose two mask functions to accomplish this, one for causal masking and one that is based off of the prefix length. @@ -262,7 +267,7 @@ Through `BlockMask`, we can support this efficiently in FlexAttention as well\! document_id: [SEQ_LEN] def document_masking(b, h, q_idx, kv_idx): - return document_id[q_idx] == document_id[kv_idx] + return document_id[q_idx] == document_id[kv_idx] ``` And that’s it\! In this case, we see that we end up with a blockdiagonal mask. @@ -424,7 +429,7 @@ Although the results are not bitwise identical, we are confident that FlexAttent ### Performance -Generally speaking, FlexAttention is nearly as performant as a handwritten Triton kernel, which is unsurprising, as we heavily leverage a handwritten Triton kernel. However, due to its generality, we do incur a small performance penalty. For example, we must incur some additional latency to determine which block to compute next. In some cases, we provide some kernel options that can affect the performance of the kernel while changing its behavior. They can be found here: [performance knobs](https://github.com/pytorch/pytorch/blob/ee09d066d35d7e17cf7e9479c0b8bfc70cffc264/torch/_inductor/kernel/flex_attention.py\#L146-L155) +Generally speaking, FlexAttention is nearly as performant as a handwritten Triton kernel, which is unsurprising, as we heavily leverage a handwritten Triton kernel. However, due to its generality, we do incur a small performance penalty. For example, we must incur some additional latency to determine which block to compute next. In some cases, we provide some kernel options that can affect the performance of the kernel while changing its behavior. They can be found here: [performance knobs](https://github.com/pytorch/pytorch/blob/ee09d066d35d7e17cf7e9479c0b8bfc70cffc264/torch/_inductor/kernel/flex_attention.py#L146-L155) As a case study, let's explore how the knobs affect the performance of causal attention. We will compare performance of the triton kernel versus FlashAttentionv2 on A100. The script can be found [here](https://github.com/pytorch/pytorch/blob/main/benchmarks/transformer/score_mod.py). @@ -458,7 +463,20 @@ We look forward to leveraging the approach we used here to more applications in ### Limitations and Future Work +- FlexAttention is currently available in PyTorch nightly releases, we plan to release it as a prototype feature in 2.5.0 - We did not cover how to use FlexAttention for inference here (or how to implement PagedAttention) \- we will cover those in a later post. - We are working to improve the performance of FlexAttention to match FlashAttention3 on H100 GPUs. - FlexAttention requires that all sequence lengths be a multiple of 128 \- this will be addressed soon. - We plan on adding GQA support soon \- for now, you can just replicate the kv heads. + + +### Acknowledgements + +We want to highlight some prior work (and people) that have inspired FlexAttention. + +- Tri Dao's work on FlashAttention +- Francisco Massa and the Xformers team for BlockSparseAttention in Triton +- The Jax team's work on SplashAttention +- Philippe Tillet and Keren Zhou for helping us with Triton +- Ali Hassani for discussions on neighborhood attention +- Everybody who's complained about attention kernels not supporting their favorite attention variant :) \ No newline at end of file diff --git a/assets/images/flexattention/fg9.png b/assets/images/flexattention/fg9.png index b544dc9f86ea..3604c5beea91 100644 Binary files a/assets/images/flexattention/fg9.png and b/assets/images/flexattention/fg9.png differ