Skip to content

Commit 94bde8f

Browse files
Address all comments
1 parent 1340a86 commit 94bde8f

File tree

1 file changed

+64
-39
lines changed

1 file changed

+64
-39
lines changed

intermediate_source/transformer_building_blocks.py

Lines changed: 64 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -3,55 +3,69 @@
33
=============================================================================================================
44
**Author:** `Mikayla Gawarecki <https://github.com/mikaylagawarecki>`_
55
6-
.. note::
7-
This tutorial should be run with the latest nightly, or, when available, 2.6.
6+
.. grid:: 2
7+
8+
.. grid-item-card:: :octicon:`mortar-board;1em;` What you will learn
9+
:class-card: card-prerequisites
10+
11+
* Learn about the low-level building blocks PyTorch provides to build custom transformer layers (
12+
nested tensors, ``scaled_dot_product_attention``, ``torch.compile()``, and ``FlexAttention``)
13+
* Discover how the above improve memory usage and performance using MultiHeadAttention as an example
14+
* Explore advanced customizations using the aforementioned building blocks
15+
16+
.. grid-item-card:: :octicon:`list-unordered;1em;` Prerequisites
17+
:class-card: card-prerequisites
18+
19+
* PyTorch v.2.6.0 or later
20+
821
922
Over the past few years, the PyTorch team has developed various lower level
1023
features that, when composed, can create a variety of transformer variants. These
1124
include:
1225
13-
1. Nested Tensors with the ``torch.jagged`` layout (AKA NJTs)
14-
2. ``scaled_dot_product_attention``
15-
3. ``torch.compile()``
16-
4. ``FlexAttention``
26+
* Nested Tensors with the ``torch.jagged`` layout (AKA NJTs)
27+
* ``scaled_dot_product_attention``
28+
* ``torch.compile()``
29+
* ``FlexAttention``
1730
1831
This tutorial will give a brief overview of the above technologies and
1932
demonstrate how they can be composed to yield flexible and performant transformer \
2033
layers with improved user experience.
2134
2235
One may observe that the ``torch.nn`` module currently provides various ``Transformer``-related layers.
23-
In particular ``TransformerEncoderLayer``, ``TransformerEncoder``, ``TransformerDecoderLayer``,
36+
In particular, it includes ``TransformerEncoderLayer``, ``TransformerEncoder``, ``TransformerDecoderLayer``,
2437
``TransformerDecoder``, ``Transformer`` and ``MultiheadAttention``. This family
2538
of layers was initially implemented following the `Attention is All
2639
You Need <https://arxiv.org/abs/1706.03762>`_ paper. The components discussed in
2740
this tutorial provide improved user experience, flexibility and performance over
2841
the existing ``nn`` layers.
2942
43+
3044
Is this tutorial for me?
3145
========================
3246
3347
If you are wondering about what building blocks the ``torch`` library provides
3448
for writing your own transformer layers and best practices, you are in the
35-
right place, please keep reading!
49+
right place. Please keep reading!
3650
3751
If you are looking for an out-of-the-box implementation of a popular transformer
3852
architecture, note that there are many open-source libraries that provide them,
39-
with some examples being:
53+
including:
4054
4155
* `HuggingFace transformers <https://github.com/huggingface/transformers>`_
4256
* `xformers <https://github.com/facebookresearch/xformers>`_
4357
* `torchtune <https://github.com/pytorch/torchtune>`_
4458
4559
If you are only interested in performant attention score modifications, please
46-
head to the `FlexAttention blog <https://pytorch.org/blog/flexattention/>`_ that
60+
check out the `FlexAttention blog <https://pytorch.org/blog/flexattention/>`_ that
4761
contains a `gym of masks <https://github.com/pytorch-labs/attention-gym>`_.
4862
4963
"""
5064

5165
################################################################################
5266
# Introducing the Building Blocks
5367
# ===============================
54-
# First, we will briefly introduce the 4 technologies mentioned in the introduction
68+
# First, we will briefly introduce the four technologies mentioned in the introduction
5569
#
5670
# * `torch.nested <https://pytorch.org/tutorials/prototype/nestedtensor.html>`_
5771
#
@@ -79,8 +93,8 @@
7993
# and ``scaled_dot_product_attention`` work seamlessly with compile. In the
8094
# context of transformers, the value add of using compile with nested tensor
8195
# and SDPA is that compile can remove framework overhead ones sees in eager mode
82-
# and fuse sequences of ops in transformers together (e.g. projection and
83-
# activation).
96+
# and fuse sequences of ops in transformers together, such as projection and
97+
# activation.
8498
#
8599
# * `FlexAttention <https://pytorch.org/blog/flexattention/>`_
86100
#
@@ -97,48 +111,48 @@
97111
# Blocks and Feed Forward networks. If we were to try to classify the differences
98112
# in this space, we might land on something like:
99113
#
100-
# 1. Layer type (activation functions e.g. ``SwiGLU``, normalization functions
101-
# e.g. ``RMSNorm`` etc., positional encodings e.g. Sinusoidal, Rotary etc.)
102-
# 2. Layer ordering (where to apply norms, where to apply positional encoding etc.)
103-
# 3. Modifications to attention score (``ALiBi``, Relative Positional Bias etc.)
114+
# 1. Layer type (activation functions such as ``SwiGLU`` and others, normalization functions
115+
# such as ``RMSNorm`` and others, positional encodings, such as Sinusoidal, Rotary.)
116+
# 2. Layer ordering, such as where to apply norms and positional encoding.
117+
# 3. Modifications to attention score, such as ``ALiBi``, Relative Positional Bias and so on.
104118
#
105119
#
106-
# In a pre-compiler world, one might write their custom transformer and observe
107-
# that it works but is slow. Then, one might write a custom fused kernel for
108-
# the specific series of ops. In a compiler world, one can do the former, compile
109-
# and profit.
120+
# In a pre-compiler environment, you might write a custom transformer and notice
121+
# that it functions correctly but is slow. To address this, you might develop a
122+
# custom fused kernel for the specific series of operations. In a compiler environment,
123+
# you can simply perform the initial step and then compile and benefit from improved performance.
110124

111125

112126
###############################################################################
113127
# MultiheadAttention
114128
# ------------------
115-
# Recall that MultiheadAttention takes in a query, key and value and consists
129+
# Remember that MultiheadAttention takes in a query, key, and value, and consists
116130
# of an input projection, a ``scaled_dot_product_attention`` operator and an
117131
# output projection. The main takeaway we want to demonstrate here is the
118132
# improvement yielded when we replaced padded/masked inputs with nested tensors.
119133
# The improvements are threefold:
120134
#
121-
# * User Experience
122-
# Recall that ``nn.MultiheadAttention`` requires ``query``, ``key`` and
135+
# * **User Experience**
136+
# Remember that ``nn.MultiheadAttention`` requires ``query``, ``key``, and
123137
# ``value`` to be dense ``torch.Tensors``. It also provides a
124138
# ``key_padding_mask`` that is used to mask out padding tokens in the ``key``
125139
# that arise due to different sequence lengths within a batch. Since there is
126140
# no ``query_padding_mask`` in ``nn.MHA``, users have to take care to mask/slice
127-
# the outputs appropriately to account for query sequence lengths. Nested tensor
141+
# the outputs appropriately to account for query sequence lengths. ``NestedTensor``
128142
# cleanly removes the need for this sort of error-prone padding masks.
129143
#
130-
# * Memory
144+
# * **Memory**
131145
# Instead of materializing a dense ``[B, S, D]`` tensor with a ``[B, S]``
132146
# padding mask (where ``B`` is batch size, ``S`` is max sequence length in the
133147
# batch and ``D`` is embedding size), nested tensors allow you to cleanly
134148
# represent the batch of varying sequence lengths. As a result, the input and
135149
# intermediate activations will use less memory.
136150
#
137-
# * Performance
151+
# * **Performance**
138152
# Since padding is not materialized and unnecessary computation on padding is
139153
# skipped, performance and memory usage improve.
140154
#
141-
# We'll demonstrate the above by building off the ``MultiheadAttention`` layer in the
155+
# We'll demonstrate the above by building upon the ``MultiheadAttention`` layer in the
142156
# `Nested Tensor tutorial <https://pytorch.org/tutorials/prototype/nestedtensor.html>`_
143157
# and comparing it to the ``nn.MultiheadAttention`` layer.
144158

@@ -257,8 +271,8 @@ def forward(self,
257271
# Utilities
258272
# ~~~~~~~~~
259273
# In this section, we include a utility to generate semi-realistic data using
260-
# Zipf distribution for sentence lengths. This is used to generate the nested
261-
# query, key and value tensors. We also include a benchmark utility.
274+
# ``Zipf`` distribution for sentence lengths. This is used to generate the nested
275+
# query, key, and value tensors. We also include a benchmark utility.
262276

263277

264278
import numpy as np
@@ -393,7 +407,7 @@ def benchmark(func, *args, **kwargs):
393407
print(f"Nested peak memory reduction {((padded_peak_memory - nested_peak_memory)/1e9):.2f} GB")
394408

395409
######################################################################################
396-
# For reference some sample outputs on A100:
410+
# For reference, here are some sample outputs on A100:
397411
#
398412
# .. code::
399413
#
@@ -456,13 +470,13 @@ def benchmark(func, *args, **kwargs):
456470
# ----------------------
457471
# So far, we have demonstrated how to implement a performant ``MultiheadAttention``
458472
# layer that follows the traditional ``nn.MultiheadAttention``. Going back to our
459-
# classification of modifications to the transformer architecture, recall that we
473+
# classification of modifications to the transformer architecture, remember that we
460474
# classified the modifications into layer type, layer ordering, and modifications
461475
# to the attention score. We trust that changing layer type and layer ordering
462-
# (e.g. swapping ``LayerNorm`` for ``RMSNorm``) is fairly straightforward.
476+
# (such as swapping ``LayerNorm`` for ``RMSNorm``) is fairly straightforward.
463477
#
464478
# In this section, we will discuss various functionalities using the
465-
# aforementioned building blocks. In particular,
479+
# aforementioned building blocks, including the following:
466480
#
467481
# * Cross Attention
468482
# * Fully masked rows no longer cause NaNs
@@ -595,8 +609,10 @@ def alibi_mod(score, b, h, q_idx, kv_idx):
595609
# In addition, one can also use the ``block_mask`` utility of ``FlexAttention``
596610
# with NJTs via the ``create_nested_block_mask`` function. This is useful for
597611
# taking advantage of the sparsity of the mask to speed up the attention computation.
598-
# In the following example, we show how to create a causal block mask using this
599-
# utility.
612+
# In particular, the function creates a sparse block mask for a "stacked sequence" of all
613+
# the variable length sequences in the NJT combined into one, while properly masking out
614+
# inter-sequence attention. In the following example, we show how to create a
615+
# causal block mask using this utility.
600616

601617
from torch.nn.attention.flex_attention import create_nested_block_mask
602618

@@ -629,7 +645,7 @@ def causal_mask(b, h, q_idx, kv_idx):
629645
#
630646
# Input projection for MultiheadAttention
631647
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
632-
# Recall that when doing self-attention, the ``query``, ``key`` and ``value``
648+
# When doing self-attention, the ``query``, ``key``, and ``value``
633649
# are the same tensor. Each of these tensors is projected with a
634650
# ``Linear(E_q, E_total)`` layer. Instead, we can pack this into one layer,
635651
# which is what we do in the MultiheadAttention layer above.
@@ -677,8 +693,8 @@ def forward(self, query):
677693
##################################################
678694
# SwiGLU feed forward network of Transformer Layer
679695
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
680-
# SwiGLU is a non-linear activation function that is increasingly popular in the feed-forward
681-
# network of the transformer layer (e.g. Llama). A feed-forward network with SwiGLU activation is defined as
696+
# Swish-Gated Linear Unit (SwiGLU) is a non-linear activation function that is increasingly popular in the feed-forward
697+
# network of the transformer layer (e.g. Llama). A feed-forward network with SwiGLU activation is defined as:
682698

683699
class SwiGLUFFN(nn.Module):
684700
def __init__(self, dim, hidden_dim, multiple_of, ffn_dim_multiplier=None, device=None, dtype=None):
@@ -751,3 +767,12 @@ def forward(self, x):
751767
# * `segment-anything-fast <https://github.com/pytorch-labs/segment-anything-fast>`_
752768
# * `lucidrains implementation of NaViT with nested tensors <https://github.com/lucidrains/vit-pytorch/blob/73199ab486e0fad9eced2e3350a11681db08b61b/vit_pytorch/na_vit_nested_tensor.py>`_
753769
# * `torchtune's implementation of VisionTransformer <https://github.com/pytorch/torchtune/blob/a8a64ec6a99a6ea2be4fdaf0cd5797b03a2567cf/torchtune/modules/vision_transformer.py#L16>`_
770+
771+
################################################################################
772+
# Conclusion
773+
# ----------
774+
#
775+
# In this tutorial, we have introduced the low level building blocks PyTorch
776+
# provides for writing transformer layers and demonstrated examples how to compose
777+
# them. It is our hope that this tutorial has educated the reader on the ease with
778+
# which flexible and performant transformer layers can be implemented by users of PyTorch.

0 commit comments

Comments
 (0)