Skip to content

Commit 6939adb

Browse files
authored
Fix misc blog things
1 parent eb898bd commit 6939adb

File tree

2 files changed

+5
-2
lines changed

2 files changed

+5
-2
lines changed

_posts/2024-08-07-flexattention.md

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -141,7 +141,8 @@ Soft-capping is a technique used in [Gemma2](https://huggingface.co/blog/gemma2\
141141

142142
```py
143143
softcap = 20
144-
def soft_cap(score, b, h, q_idx, kv_idx): score = score / softcap
144+
def soft_cap(score, b, h, q_idx, kv_idx):
145+
score = score / softcap
145146
score = torch.tanh(score)
146147
score = score * softcap
147148
return score
@@ -229,6 +230,7 @@ We benchmark it against `F.scaled_dot_product_attention` with a sliding window m
229230
### PrefixLM
230231

231232
![PrefixLM diagram](/assets/images/flexattention/fg10.png){:style="max-width:600px; display:block; margin-left: auto; margin-right: auto; width:100%"}
233+
*Source: PaliGemma: [A versatile 3B VLM for transfer](https://arxiv.org/abs/2407.07726)*
232234

233235
The T5 architecture, proposed in [Exploring the Limits of Transfer Learning with a Unified Text-to-Text Transformer](https://arxiv.org/abs/1910.10683), describes an attention variant that performs full bidirectional attention on a “prefix”, and causal attention on the rest. We again compose two mask functions to accomplish this, one for causal masking and one that is based off of the prefix length.
234236

@@ -262,7 +264,7 @@ Through `BlockMask`, we can support this efficiently in FlexAttention as well\!
262264
document_id: [SEQ_LEN]
263265

264266
def document_masking(b, h, q_idx, kv_idx):
265-
return document_id[q_idx] == document_id[kv_idx]
267+
return document_id[q_idx] == document_id[kv_idx]
266268
```
267269

268270
And that’s it\! In this case, we see that we end up with a blockdiagonal mask.
@@ -458,6 +460,7 @@ We look forward to leveraging the approach we used here to more applications in
458460

459461
### Limitations and Future Work
460462

463+
- FlexAttention is currently available in PyTorch nightly releases, we plan to release it as a prototype feature in 2.5.0
461464
- We did not cover how to use FlexAttention for inference here (or how to implement PagedAttention) \- we will cover those in a later post.
462465
- We are working to improve the performance of FlexAttention to match FlashAttention3 on H100 GPUs.
463466
- FlexAttention requires that all sequence lengths be a multiple of 128 \- this will be addressed soon.

assets/images/flexattention/fg9.png

182 KB
Loading

0 commit comments

Comments
 (0)