You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Copy file name to clipboardExpand all lines: _posts/2024-08-07-flexattention.md
+3-3Lines changed: 3 additions & 3 deletions
Original file line number
Diff line number
Diff line change
@@ -139,7 +139,7 @@ This demonstrates one interesting piece of flexibility `torch.compile` provides
139
139
140
140
### Soft-capping
141
141
142
-
Soft-capping is a technique used in [Gemma2](https://huggingface.co/blog/gemma2\#soft-capping-and-attention-implementations) and Grok-1 that prevents logits from growing excessively large. In FlexAttention, it looks like:
142
+
Soft-capping is a technique used in [Gemma2](https://huggingface.co/blog/gemma2#soft-capping-and-attention-implementations) and Grok-1 that prevents logits from growing excessively large. In FlexAttention, it looks like:
143
143
144
144
```py
145
145
softcap =20
@@ -167,7 +167,7 @@ However, masking is special compared to other modifications \- if something is m
167
167
168
168
## Mask Mods
169
169
170
-
To take advantage of sparsity from masking, we need to do some more work. Specifically, by passing a `mask_mod` to [`create_block_mask`](https://github.com/pytorch/pytorch/blob/e49c0acc396e89baf8c6450e1fa0571d4ce2d4ed/torch/nn/attention/flex_attention.py\#L594), we can create a `BlockMask`. FlexAttention can then use `BlockMask` to take advantage of the sparsity\!
170
+
To take advantage of sparsity from masking, we need to do some more work. Specifically, by passing a `mask_mod` to [`create_block_mask`](https://github.com/pytorch/pytorch/blob/e49c0acc396e89baf8c6450e1fa0571d4ce2d4ed/torch/nn/attention/flex_attention.py#L594), we can create a `BlockMask`. FlexAttention can then use `BlockMask` to take advantage of the sparsity\!
171
171
172
172
The signature of `mask_mod` is very similar to `score_mod`\- just without the `score`. In particular
173
173
@@ -429,7 +429,7 @@ Although the results are not bitwise identical, we are confident that FlexAttent
429
429
430
430
### Performance
431
431
432
-
Generally speaking, FlexAttention is nearly as performant as a handwritten Triton kernel, which is unsurprising, as we heavily leverage a handwritten Triton kernel. However, due to its generality, we do incur a small performance penalty. For example, we must incur some additional latency to determine which block to compute next. In some cases, we provide some kernel options that can affect the performance of the kernel while changing its behavior. They can be found here: [performance knobs](https://github.com/pytorch/pytorch/blob/ee09d066d35d7e17cf7e9479c0b8bfc70cffc264/torch/_inductor/kernel/flex_attention.py\#L146-L155)
432
+
Generally speaking, FlexAttention is nearly as performant as a handwritten Triton kernel, which is unsurprising, as we heavily leverage a handwritten Triton kernel. However, due to its generality, we do incur a small performance penalty. For example, we must incur some additional latency to determine which block to compute next. In some cases, we provide some kernel options that can affect the performance of the kernel while changing its behavior. They can be found here: [performance knobs](https://github.com/pytorch/pytorch/blob/ee09d066d35d7e17cf7e9479c0b8bfc70cffc264/torch/_inductor/kernel/flex_attention.py#L146-L155)
433
433
434
434
As a case study, let's explore how the knobs affect the performance of causal attention. We will compare performance of the triton kernel versus FlashAttentionv2 on A100. The script can be found [here](https://github.com/pytorch/pytorch/blob/main/benchmarks/transformer/score_mod.py).
0 commit comments