You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
*Source: PaliGemma: [A versatile 3B VLM for transfer](https://arxiv.org/abs/2407.07726)*
232
234
233
235
The T5 architecture, proposed in [Exploring the Limits of Transfer Learning with a Unified Text-to-Text Transformer](https://arxiv.org/abs/1910.10683), describes an attention variant that performs full bidirectional attention on a “prefix”, and causal attention on the rest. We again compose two mask functions to accomplish this, one for causal masking and one that is based off of the prefix length.
234
236
@@ -262,7 +264,7 @@ Through `BlockMask`, we can support this efficiently in FlexAttention as well\!
262
264
document_id: [SEQ_LEN]
263
265
264
266
defdocument_masking(b, h, q_idx, kv_idx):
265
-
return document_id[q_idx] == document_id[kv_idx]
267
+
return document_id[q_idx] == document_id[kv_idx]
266
268
```
267
269
268
270
And that’s it\! In this case, we see that we end up with a blockdiagonal mask.
@@ -458,6 +460,7 @@ We look forward to leveraging the approach we used here to more applications in
458
460
459
461
### Limitations and Future Work
460
462
463
+
- FlexAttention is currently available in PyTorch nightly releases, we plan to release it as a prototype feature in 2.5.0
461
464
- We did not cover how to use FlexAttention for inference here (or how to implement PagedAttention) \- we will cover those in a later post.
462
465
- We are working to improve the performance of FlexAttention to match FlashAttention3 on H100 GPUs.
463
466
- FlexAttention requires that all sequence lengths be a multiple of 128 \- this will be addressed soon.
0 commit comments