|
3 | 3 | =============================================================================================================
|
4 | 4 | **Author:** `Mikayla Gawarecki <https://github.com/mikaylagawarecki>`_
|
5 | 5 |
|
6 |
| -.. note:: |
7 |
| - This tutorial should be run with the latest nightly, or, when available, 2.6. |
| 6 | +.. grid:: 2 |
| 7 | +
|
| 8 | + .. grid-item-card:: :octicon:`mortar-board;1em;` What you will learn |
| 9 | + :class-card: card-prerequisites |
| 10 | +
|
| 11 | + * Learn about the low-level building blocks PyTorch provides to build custom transformer layers ( |
| 12 | + nested tensors, ``scaled_dot_product_attention``, ``torch.compile()``, and ``FlexAttention``) |
| 13 | + * Discover how the above improve memory usage and performance using MultiHeadAttention as an example |
| 14 | + * Explore advanced customizations using the aforementioned building blocks |
| 15 | + |
| 16 | + .. grid-item-card:: :octicon:`list-unordered;1em;` Prerequisites |
| 17 | + :class-card: card-prerequisites |
| 18 | +
|
| 19 | + * PyTorch v.2.6.0 or later |
| 20 | +
|
8 | 21 |
|
9 | 22 | Over the past few years, the PyTorch team has developed various lower level
|
10 | 23 | features that, when composed, can create a variety of transformer variants. These
|
11 | 24 | include:
|
12 | 25 |
|
13 |
| -1. Nested Tensors with the ``torch.jagged`` layout (AKA NJTs) |
14 |
| -2. ``scaled_dot_product_attention`` |
15 |
| -3. ``torch.compile()`` |
16 |
| -4. ``FlexAttention`` |
| 26 | +* Nested Tensors with the ``torch.jagged`` layout (AKA NJTs) |
| 27 | +* ``scaled_dot_product_attention`` |
| 28 | +* ``torch.compile()`` |
| 29 | +* ``FlexAttention`` |
17 | 30 |
|
18 | 31 | This tutorial will give a brief overview of the above technologies and
|
19 | 32 | demonstrate how they can be composed to yield flexible and performant transformer \
|
20 | 33 | layers with improved user experience.
|
21 | 34 |
|
22 | 35 | One may observe that the ``torch.nn`` module currently provides various ``Transformer``-related layers.
|
23 |
| -In particular ``TransformerEncoderLayer``, ``TransformerEncoder``, ``TransformerDecoderLayer``, |
| 36 | +In particular, it includes ``TransformerEncoderLayer``, ``TransformerEncoder``, ``TransformerDecoderLayer``, |
24 | 37 | ``TransformerDecoder``, ``Transformer`` and ``MultiheadAttention``. This family
|
25 | 38 | of layers was initially implemented following the `Attention is All
|
26 | 39 | You Need <https://arxiv.org/abs/1706.03762>`_ paper. The components discussed in
|
27 | 40 | this tutorial provide improved user experience, flexibility and performance over
|
28 | 41 | the existing ``nn`` layers.
|
29 | 42 |
|
| 43 | +
|
30 | 44 | Is this tutorial for me?
|
31 | 45 | ========================
|
32 | 46 |
|
33 | 47 | If you are wondering about what building blocks the ``torch`` library provides
|
34 | 48 | for writing your own transformer layers and best practices, you are in the
|
35 |
| -right place, please keep reading! |
| 49 | +right place. Please keep reading! |
36 | 50 |
|
37 | 51 | If you are looking for an out-of-the-box implementation of a popular transformer
|
38 | 52 | architecture, note that there are many open-source libraries that provide them,
|
39 |
| -with some examples being: |
| 53 | +including: |
40 | 54 |
|
41 | 55 | * `HuggingFace transformers <https://github.com/huggingface/transformers>`_
|
42 | 56 | * `xformers <https://github.com/facebookresearch/xformers>`_
|
43 | 57 | * `torchtune <https://github.com/pytorch/torchtune>`_
|
44 | 58 |
|
45 | 59 | If you are only interested in performant attention score modifications, please
|
46 |
| -head to the `FlexAttention blog <https://pytorch.org/blog/flexattention/>`_ that |
| 60 | +check out the `FlexAttention blog <https://pytorch.org/blog/flexattention/>`_ that |
47 | 61 | contains a `gym of masks <https://github.com/pytorch-labs/attention-gym>`_.
|
48 | 62 |
|
49 | 63 | """
|
50 | 64 |
|
51 | 65 | ################################################################################
|
52 | 66 | # Introducing the Building Blocks
|
53 | 67 | # ===============================
|
54 |
| -# First, we will briefly introduce the 4 technologies mentioned in the introduction |
| 68 | +# First, we will briefly introduce the four technologies mentioned in the introduction |
55 | 69 | #
|
56 | 70 | # * `torch.nested <https://pytorch.org/tutorials/prototype/nestedtensor.html>`_
|
57 | 71 | #
|
|
79 | 93 | # and ``scaled_dot_product_attention`` work seamlessly with compile. In the
|
80 | 94 | # context of transformers, the value add of using compile with nested tensor
|
81 | 95 | # and SDPA is that compile can remove framework overhead ones sees in eager mode
|
82 |
| -# and fuse sequences of ops in transformers together (e.g. projection and |
83 |
| -# activation). |
| 96 | +# and fuse sequences of ops in transformers together, such as projection and |
| 97 | +# activation. |
84 | 98 | #
|
85 | 99 | # * `FlexAttention <https://pytorch.org/blog/flexattention/>`_
|
86 | 100 | #
|
|
97 | 111 | # Blocks and Feed Forward networks. If we were to try to classify the differences
|
98 | 112 | # in this space, we might land on something like:
|
99 | 113 | #
|
100 |
| -# 1. Layer type (activation functions e.g. ``SwiGLU``, normalization functions |
101 |
| -# e.g. ``RMSNorm`` etc., positional encodings e.g. Sinusoidal, Rotary etc.) |
102 |
| -# 2. Layer ordering (where to apply norms, where to apply positional encoding etc.) |
103 |
| -# 3. Modifications to attention score (``ALiBi``, Relative Positional Bias etc.) |
| 114 | +# 1. Layer type (activation functions such as ``SwiGLU`` and others, normalization functions |
| 115 | +# such as ``RMSNorm`` and others, positional encodings, such as Sinusoidal, Rotary.) |
| 116 | +# 2. Layer ordering, such as where to apply norms and positional encoding. |
| 117 | +# 3. Modifications to attention score, such as ``ALiBi``, Relative Positional Bias and so on. |
104 | 118 | #
|
105 | 119 | #
|
106 |
| -# In a pre-compiler world, one might write their custom transformer and observe |
107 |
| -# that it works but is slow. Then, one might write a custom fused kernel for |
108 |
| -# the specific series of ops. In a compiler world, one can do the former, compile |
109 |
| -# and profit. |
| 120 | +# In a pre-compiler environment, you might write a custom transformer and notice |
| 121 | +# that it functions correctly but is slow. To address this, you might develop a |
| 122 | +# custom fused kernel for the specific series of operations. In a compiler environment, |
| 123 | +# you can simply perform the initial step and then compile and benefit from improved performance. |
110 | 124 |
|
111 | 125 |
|
112 | 126 | ###############################################################################
|
113 | 127 | # MultiheadAttention
|
114 | 128 | # ------------------
|
115 |
| -# Recall that MultiheadAttention takes in a query, key and value and consists |
| 129 | +# Remember that MultiheadAttention takes in a query, key, and value, and consists |
116 | 130 | # of an input projection, a ``scaled_dot_product_attention`` operator and an
|
117 | 131 | # output projection. The main takeaway we want to demonstrate here is the
|
118 | 132 | # improvement yielded when we replaced padded/masked inputs with nested tensors.
|
119 | 133 | # The improvements are threefold:
|
120 | 134 | #
|
121 |
| -# * User Experience |
122 |
| -# Recall that ``nn.MultiheadAttention`` requires ``query``, ``key`` and |
| 135 | +# * **User Experience** |
| 136 | +# Remember that ``nn.MultiheadAttention`` requires ``query``, ``key``, and |
123 | 137 | # ``value`` to be dense ``torch.Tensors``. It also provides a
|
124 | 138 | # ``key_padding_mask`` that is used to mask out padding tokens in the ``key``
|
125 | 139 | # that arise due to different sequence lengths within a batch. Since there is
|
126 | 140 | # no ``query_padding_mask`` in ``nn.MHA``, users have to take care to mask/slice
|
127 |
| -# the outputs appropriately to account for query sequence lengths. Nested tensor |
| 141 | +# the outputs appropriately to account for query sequence lengths. ``NestedTensor`` |
128 | 142 | # cleanly removes the need for this sort of error-prone padding masks.
|
129 | 143 | #
|
130 |
| -# * Memory |
| 144 | +# * **Memory** |
131 | 145 | # Instead of materializing a dense ``[B, S, D]`` tensor with a ``[B, S]``
|
132 | 146 | # padding mask (where ``B`` is batch size, ``S`` is max sequence length in the
|
133 | 147 | # batch and ``D`` is embedding size), nested tensors allow you to cleanly
|
134 | 148 | # represent the batch of varying sequence lengths. As a result, the input and
|
135 | 149 | # intermediate activations will use less memory.
|
136 | 150 | #
|
137 |
| -# * Performance |
| 151 | +# * **Performance** |
138 | 152 | # Since padding is not materialized and unnecessary computation on padding is
|
139 | 153 | # skipped, performance and memory usage improve.
|
140 | 154 | #
|
141 |
| -# We'll demonstrate the above by building off the ``MultiheadAttention`` layer in the |
| 155 | +# We'll demonstrate the above by building upon the ``MultiheadAttention`` layer in the |
142 | 156 | # `Nested Tensor tutorial <https://pytorch.org/tutorials/prototype/nestedtensor.html>`_
|
143 | 157 | # and comparing it to the ``nn.MultiheadAttention`` layer.
|
144 | 158 |
|
@@ -257,8 +271,8 @@ def forward(self,
|
257 | 271 | # Utilities
|
258 | 272 | # ~~~~~~~~~
|
259 | 273 | # In this section, we include a utility to generate semi-realistic data using
|
260 |
| -# Zipf distribution for sentence lengths. This is used to generate the nested |
261 |
| -# query, key and value tensors. We also include a benchmark utility. |
| 274 | +# ``Zipf`` distribution for sentence lengths. This is used to generate the nested |
| 275 | +# query, key, and value tensors. We also include a benchmark utility. |
262 | 276 |
|
263 | 277 |
|
264 | 278 | import numpy as np
|
@@ -393,7 +407,7 @@ def benchmark(func, *args, **kwargs):
|
393 | 407 | print(f"Nested peak memory reduction {((padded_peak_memory - nested_peak_memory)/1e9):.2f} GB")
|
394 | 408 |
|
395 | 409 | ######################################################################################
|
396 |
| -# For reference some sample outputs on A100: |
| 410 | +# For reference, here are some sample outputs on A100: |
397 | 411 | #
|
398 | 412 | # .. code::
|
399 | 413 | #
|
@@ -456,13 +470,13 @@ def benchmark(func, *args, **kwargs):
|
456 | 470 | # ----------------------
|
457 | 471 | # So far, we have demonstrated how to implement a performant ``MultiheadAttention``
|
458 | 472 | # layer that follows the traditional ``nn.MultiheadAttention``. Going back to our
|
459 |
| -# classification of modifications to the transformer architecture, recall that we |
| 473 | +# classification of modifications to the transformer architecture, remember that we |
460 | 474 | # classified the modifications into layer type, layer ordering, and modifications
|
461 | 475 | # to the attention score. We trust that changing layer type and layer ordering
|
462 |
| -# (e.g. swapping ``LayerNorm`` for ``RMSNorm``) is fairly straightforward. |
| 476 | +# (such as swapping ``LayerNorm`` for ``RMSNorm``) is fairly straightforward. |
463 | 477 | #
|
464 | 478 | # In this section, we will discuss various functionalities using the
|
465 |
| -# aforementioned building blocks. In particular, |
| 479 | +# aforementioned building blocks, including the following: |
466 | 480 | #
|
467 | 481 | # * Cross Attention
|
468 | 482 | # * Fully masked rows no longer cause NaNs
|
@@ -595,8 +609,10 @@ def alibi_mod(score, b, h, q_idx, kv_idx):
|
595 | 609 | # In addition, one can also use the ``block_mask`` utility of ``FlexAttention``
|
596 | 610 | # with NJTs via the ``create_nested_block_mask`` function. This is useful for
|
597 | 611 | # taking advantage of the sparsity of the mask to speed up the attention computation.
|
598 |
| -# In the following example, we show how to create a causal block mask using this |
599 |
| -# utility. |
| 612 | +# In particular, the function creates a sparse block mask for a "stacked sequence" of all |
| 613 | +# the variable length sequences in the NJT combined into one, while properly masking out |
| 614 | +# inter-sequence attention. In the following example, we show how to create a |
| 615 | +# causal block mask using this utility. |
600 | 616 |
|
601 | 617 | from torch.nn.attention.flex_attention import create_nested_block_mask
|
602 | 618 |
|
@@ -629,7 +645,7 @@ def causal_mask(b, h, q_idx, kv_idx):
|
629 | 645 | #
|
630 | 646 | # Input projection for MultiheadAttention
|
631 | 647 | # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
632 |
| -# Recall that when doing self-attention, the ``query``, ``key`` and ``value`` |
| 648 | +# When doing self-attention, the ``query``, ``key``, and ``value`` |
633 | 649 | # are the same tensor. Each of these tensors is projected with a
|
634 | 650 | # ``Linear(E_q, E_total)`` layer. Instead, we can pack this into one layer,
|
635 | 651 | # which is what we do in the MultiheadAttention layer above.
|
@@ -677,8 +693,8 @@ def forward(self, query):
|
677 | 693 | ##################################################
|
678 | 694 | # SwiGLU feed forward network of Transformer Layer
|
679 | 695 | # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
680 |
| -# SwiGLU is a non-linear activation function that is increasingly popular in the feed-forward |
681 |
| -# network of the transformer layer (e.g. Llama). A feed-forward network with SwiGLU activation is defined as |
| 696 | +# Swish-Gated Linear Unit (SwiGLU) is a non-linear activation function that is increasingly popular in the feed-forward |
| 697 | +# network of the transformer layer (e.g. Llama). A feed-forward network with SwiGLU activation is defined as: |
682 | 698 |
|
683 | 699 | class SwiGLUFFN(nn.Module):
|
684 | 700 | def __init__(self, dim, hidden_dim, multiple_of, ffn_dim_multiplier=None, device=None, dtype=None):
|
@@ -751,3 +767,12 @@ def forward(self, x):
|
751 | 767 | # * `segment-anything-fast <https://github.com/pytorch-labs/segment-anything-fast>`_
|
752 | 768 | # * `lucidrains implementation of NaViT with nested tensors <https://github.com/lucidrains/vit-pytorch/blob/73199ab486e0fad9eced2e3350a11681db08b61b/vit_pytorch/na_vit_nested_tensor.py>`_
|
753 | 769 | # * `torchtune's implementation of VisionTransformer <https://github.com/pytorch/torchtune/blob/a8a64ec6a99a6ea2be4fdaf0cd5797b03a2567cf/torchtune/modules/vision_transformer.py#L16>`_
|
| 770 | + |
| 771 | +################################################################################ |
| 772 | +# Conclusion |
| 773 | +# ---------- |
| 774 | +# |
| 775 | +# In this tutorial, we have introduced the low level building blocks PyTorch |
| 776 | +# provides for writing transformer layers and demonstrated examples how to compose |
| 777 | +# them. It is our hope that this tutorial has educated the reader on the ease with |
| 778 | +# which flexible and performant transformer layers can be implemented by users of PyTorch. |
0 commit comments