40
40
Please head there instead!
41
41
42
42
If you are only interested in performant attention score modifications, please
43
- head to the `FlexAttention blog <https://flexattention.com /blog/>`_ that
43
+ head to the `FlexAttention blog <https://pytorch.org /blog/flexattention />`_ that
44
44
contains a `gym of masks <https://github.com/pytorch-labs/attention-gym>`_ .
45
45
46
46
If you are wondering about what building blocks the ``torch`` library provides
63
63
*`scaled_dot_product_attention <https://pytorch.org/tutorials/intermediate/scaled_dot_product_attention_tutorial.html>`_
64
64
65
65
``scaled_dot_product_attention`` is a primitive for
66
- $ \t ext{softmax}(\f rac{QK^T}{\sqrt{E}} + B)V$ that dispatches into either fused
66
+ :math:` \t ext{softmax}(\f rac{QK^T}{\sqrt{E}} + B)V` that dispatches into either fused
67
67
implementations of the operator or a fallback implementation. It works out of
68
68
the box in eager mode (i.e. the default mode of using PyTorch where operations
69
69
are executed on the fly as they are encountered) and also integrates seamlessly
118
118
# The improvements are threefold:
119
119
#
120
120
# * User Experience
121
- # Recall that `nn.MultiheadAttention` requires ``query```, ``key`` and
121
+ # Recall that `` nn.MultiheadAttention` ` requires ``query```, ``key`` and
122
122
# ``value`` to be dense ``torch.Tensor``s. It also provides a
123
123
# ``key_padding_mask`` that is used to mask out padding tokens in the ``key``
124
124
# that arise due to different sequence lengths within a batch. Since there is
@@ -202,10 +202,10 @@ def forward(self,
202
202
4. Apply output projection
203
203
204
204
Args:
205
- query (torch.Tensor): query of shape (N, L_q, E_qk)
206
- key (torch.Tensor): key of shape (N, L_kv, E_qk)
207
- value (torch.Tensor): value of shape (N, L_kv, E_v)
208
- attn_mask (torch.Tensor, optional): attention mask of shape (N, L_q, L_kv) to pass to sdpa . Default: None
205
+ query (torch.Tensor): query of shape (``N``, `` L_q``, `` E_qk`` )
206
+ key (torch.Tensor): key of shape (``N``, `` L_kv``, `` E_qk`` )
207
+ value (torch.Tensor): value of shape (``N``, `` L_kv``, `` E_v`` )
208
+ attn_mask (torch.Tensor, optional): attention mask of shape (``N``, `` L_q``, `` L_kv`` ) to pass to SDPA . Default: None
209
209
is_causal (bool, optional): Whether to apply causal mask. Default: False
210
210
211
211
Returns:
@@ -251,11 +251,10 @@ def forward(self,
251
251
252
252
return attn_output
253
253
254
- # TODO: Check whether there is a way to collapse this section by default
255
- # sphinx.collapse?
254
+
256
255
###############################################################################
257
256
# Utilities
258
- # =========
257
+ # ========================
259
258
# In this section, we include a utility to generate semi-realistic data using
260
259
# Zipf distribution for sentence lengths. This is used to generate the nested
261
260
# query, key and value tensors. We also include a benchmark utility.
@@ -343,7 +342,7 @@ def benchmark(func, *args, **kwargs):
343
342
torch .manual_seed (6 )
344
343
vanilla_mha_layer = nn .MultiheadAttention (E_q , nheads , dropout = dropout , batch_first = True , bias = bias , device = 'cuda' )
345
344
346
- # nn.MultiheadAttention uses a non conventional init for layers, so do this for exact parity :(
345
+ # nn.MultiheadAttention uses a non conventional initialization for layers, so do this for exact parity :(
347
346
mha_layer .out_proj .weight = nn .Parameter (vanilla_mha_layer .out_proj .weight .clone ().detach ())
348
347
mha_layer .packed_proj .weight = nn .Parameter (vanilla_mha_layer .in_proj_weight .clone ().detach ())
349
348
mha_layer .out_proj .bias = nn .Parameter (vanilla_mha_layer .out_proj .bias .clone ().detach ())
@@ -357,8 +356,8 @@ def benchmark(func, *args, **kwargs):
357
356
nested_result , nested_time , nested_peak_memory = benchmark (new_mha_layer , query , query , query , is_causal = True )
358
357
padded_nested_result = nested_result .to_padded_tensor (0.0 )
359
358
360
- # For the vanilla nn.MHA , we need to construct the key_padding_mask
361
- # Further, nn.MultiheadAttention forces one to materialize the attn_mask even if using is_causal
359
+ # For the vanilla `` nn.MultiheadAttention`` , we need to construct the `` key_padding_mask``
360
+ # Further, `` nn.MultiheadAttention`` forces one to materialize the `` attn_mask`` even if using `` is_causal``
362
361
src_key_padding_mask = torch .where (padded_query == 0.0 , - math .inf , 0 )[:, :, 0 ]
363
362
attn_mask = torch .empty ((N , S , S ), device = device ).fill_ (float ('-inf' ))
364
363
for i , s in enumerate (sentence_lengths ):
@@ -431,14 +430,14 @@ def benchmark(func, *args, **kwargs):
431
430
# classification of modifications to the transformer architecture, recall that we
432
431
# classified the modifications into layer type, layer ordering, and modifications
433
432
# to the attention score. We trust that changing layer type and layer ordering
434
- # (e.g. swapping LayerNorm for RMSNorm) is fairly straightforward.
433
+ # (e.g. swapping`` LayerNorm`` for `` RMSNorm`` ) is fairly straightforward.
435
434
#
436
435
# In this section, we will discuss various functionalities using the
437
436
# aforementioned building blocks. In particular,
438
437
#
439
438
# * Packed Projection
440
439
# * Cross Attention
441
- # * Fully masked rows no longer cause NaNs
440
+ # * Fully masked rows no longer cause ``NaN``s
442
441
# * [TODO] Modifying attention score: Relative Positional Embedding with NJT
443
442
# * [TODO] KV-Caching with NJT
444
443
# * [TODO] Grouped Query Attention with NJT
@@ -448,13 +447,13 @@ def benchmark(func, *args, **kwargs):
448
447
# -----------------
449
448
#
450
449
# Packed projection is a technique that makes use of the fact that when the input
451
- # for projection (matmul ) are the same (self-attention), we can pack the projection
450
+ # for projection (matrix multiplications ) are the same (self-attention), we can pack the projection
452
451
# weights and biases into single tensors. It is especially useful when the individual
453
- # projections (matmuls) are memory bound rather than compute bound. There are
452
+ # projections are memory bound rather than compute bound. There are
454
453
# two examples that we will demonstrate here:
455
454
#
456
455
# * Input projection for MultiheadAttention
457
- # * SwiGLU activation in FFN of Transformer Layer
456
+ # * SwiGLU activation in feed-forward network of Transformer Layer
458
457
#
459
458
# Input projection for MultiheadAttention
460
459
# ----------------------------------------
@@ -505,7 +504,7 @@ def forward(self, query):
505
504
# SwiGLU feed forward network of Transformer Layer
506
505
# ------------------------------------------------
507
506
# SwiGLU is a non-linear activation function that is increasingly popular in the feed-forward
508
- # network of the transformer layer (e.g. Llama). A FFN with SwiGLU activation is defined as
507
+ # network of the transformer layer (e.g. Llama). A feed-forward network with SwiGLU activation is defined as
509
508
510
509
class SwiGLUFFN (nn .Module ):
511
510
def __init__ (self , dim , hidden_dim , multiple_of , ffn_dim_multiplier = None , device = None , dtype = None ):
@@ -601,45 +600,47 @@ def forward(self, x):
601
600
602
601
603
602
################################################################################
604
- # [PENDING] KV-Caching with NJT
605
- # ----------------------------
606
- # During decoding in inference, the query comprises of the current token. However,
607
- # the key and value comprises of all the previous keys and values in addition to
608
- # the current token.
609
- #
610
- # When we do batched inference, each batch item will be at a different stage of
611
- # decoding, so we expect the keys and values to have different sequence lengths.
612
- # The query is a dense tensor of shape ``[B, 1, E_qk]`` and the keys and values
613
- # will be of shapes ``[B, *, E_qk]`` and ``[B, *, E_v]`` where ``B`` represents
614
- # batch size, ``*`` represents varying sequence lengths and ``E_qk`` and ``E_v``
615
- # are embedding dimensions for query/key and value respectively.
616
-
617
- # Directly related to the above point is the idea of KV-Caching. This is a technique
618
- # that is used in inference to reduce the latency of decoding. The idea is to cache
619
- # the key and value tensors for the previous tokens and use them for the current
620
- # token. This is especially useful when the sequence length is long.
621
-
622
- # FIXME: Pending https://github.com/pytorch/pytorch/pull/135722
623
-
624
-
625
- ################################################################################
626
- # [PENDING] Relative Positional Embedding with NJT (FlexAttention + NJT)
603
+ # ALiBi with NJT (FlexAttention + NJT)
627
604
# ---------------------------------------------------------------------
628
- #
629
- # FIXME: Pending https://github.com/pytorch/pytorch/pull/136792
605
+ # NJT also composes with the ``FlexAttention`` module. This is a generalization
606
+ # of the ``MultiheadAttention`` layer that allows for arbitrary modifications
607
+ # to the attention score. The example below takes the ``alibi_mod`` from
608
+ # attention gym and uses it with nested input tensors.
630
609
610
+ from torch .nn .attention .flex_attention import flex_attention
611
+
612
+ def generate_alibi_bias (H : int ):
613
+ """Returns an alibi bias score_mod given the number of heads H
614
+ Args:
615
+ H: number of heads
616
+ Returns:
617
+ alibi_bias: alibi bias score_mod
618
+ """
619
+ def alibi_mod (score , b , h , q_idx , kv_idx ):
620
+ scale = torch .exp2 (- ((h + 1 ) * 8.0 / H ))
621
+ bias = (q_idx - kv_idx ) * scale
622
+ return score + bias
623
+ return alibi_mod
624
+
625
+ query , key , value , _ = gen_batch (N , E_q , E_k , E_v , device )
626
+ n_heads , D = 8 , E_q // 8
627
+ alibi_score_mod = generate_alibi_bias (n_heads )
628
+ query = (
629
+ query .unflatten (- 1 , [n_heads , D ]).transpose (1 , 2 ).detach ().requires_grad_ ()
630
+ )
631
+ key = key .unflatten (- 1 , [n_heads , D ]).transpose (1 , 2 ).detach ().requires_grad_ ()
632
+ value = (
633
+ value .unflatten (- 1 , [n_heads , D ]).transpose (1 , 2 ).detach ().requires_grad_ ()
634
+ )
635
+ out_flex2 = flex_attention (query , key , value , score_mod = alibi_score_mod )
631
636
632
637
################################################################################
633
- # [PENDING] Grouped Query Attention with NJT
634
- # ------------------------------------------
635
- #
636
- # Grouped Query Attention refers to using a number of key/value heads that is
637
- # less than the number of query heads. Compared to MultiheadAttention, this
638
- # decreases the size of the kv-cache during inference.
639
- #
640
- # We can implement this using nested tensors as follows
638
+ # And more
639
+ # --------
641
640
#
642
- # FIXME: Pending FlexAttention/testing for NJT with grouped query attention
641
+ # We intend to update this tutorial to demonstrate more examples of how to use
642
+ # the various performant building blocks such as KV-Caching, Grouped Query Attention
643
+ # etc.
643
644
644
645
645
646
################################################################################
@@ -649,7 +650,7 @@ def forward(self, x):
649
650
# There are several good examples of using various performant building blocks to
650
651
# implement various transformer architectures. Some examples include
651
652
#
652
- # * `gpt_fast <https://github.com/pytorch-labs/gpt-fast>`_
653
- # * `sam_fast <https://github.com/pytorch-labs/sam -fast>`_
654
- # * `lucidrains implementation of ViT with nested tensors <https://github.com/lucidrains/vit-pytorch/blob/main /vit_pytorch/nested_tensor .py>`_
653
+ # * `gpt-fast <https://github.com/pytorch-labs/gpt-fast>`_
654
+ # * `segment-anything-fast <https://github.com/pytorch-labs/segment-anything -fast>`_
655
+ # * `lucidrains implementation of NaViT with nested tensors <https://github.com/lucidrains/vit-pytorch/blob/73199ab486e0fad9eced2e3350a11681db08b61b /vit_pytorch/na_vit_nested_tensor .py>`_
655
656
# * `torchtune's implementation of VisionTransformer <https://github.com/pytorch/torchtune/blob/a8a64ec6a99a6ea2be4fdaf0cd5797b03a2567cf/torchtune/modules/vision_transformer.py#L16>`_
0 commit comments