Skip to content

Commit 09642b4

Browse files
Address most comments
1 parent 8c7ec76 commit 09642b4

File tree

3 files changed

+49
-32
lines changed

3 files changed

+49
-32
lines changed

.jenkins/metadata.json

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@
3333
},
3434
"recipes_source/torch_export_aoti_python.py": {
3535
"needs": "linux.g5.4xlarge.nvidia.gpu"
36-
},
36+
},
3737
"advanced_source/pendulum.py": {
3838
"needs": "linux.g5.4xlarge.nvidia.gpu",
3939
"_comment": "need to be here for the compiling_optimizer_lr_scheduler.py to run."
@@ -58,6 +58,9 @@
5858
"intermediate_source/scaled_dot_product_attention_tutorial.py": {
5959
"needs": "linux.g5.4xlarge.nvidia.gpu"
6060
},
61+
"intermediate_source/transformer_building_blocks.py": {
62+
"needs": "linux.g5.4xlarge.nvidia.gpu"
63+
},
6164
"recipes_source/torch_compile_user_defined_triton_kernel_tutorial.py": {
6265
"needs": "linux.g5.4xlarge.nvidia.gpu"
6366
},

index.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -667,7 +667,7 @@ Welcome to PyTorch Tutorials
667667

668668
.. customcarditem::
669669
:header: [Title TBD] Unbundling nn.Transformer modules for gains and profits
670-
:card_description: This tutorial goes over recommended best practices for implementing Transformers.
670+
:card_description: This tutorial goes over recommended best practices for implementing Transformers with native PyTorch.
671671
:image: _static/img/thumbnails/cropped/pytorch-logo.png
672672
:link: intermediate/transformer_building_blocks.html
673673
:tags: Transformer

intermediate_source/transformer_building_blocks.py

Lines changed: 44 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
"""
2-
[Title TBD] Unbundling nn.Transformer modules for gains and profits
3-
===================================================================
2+
Dismantling down the ``nn.Transformer`` modules for gains and profits
3+
======================================================================
44
**Author:** `Mikayla Gawarecki <https://github.com/mikaylagawarecki>`_
55
66
The ``torch.nn`` module currently provides various ``Transformer``-related layers.
@@ -11,7 +11,7 @@
1111
were made to try to make these layers more flexible.
1212
1313
While historically these layers intended to provide out-of-the-box, performant
14-
solutions. We make the observations that
14+
solutions, we make the observations that
1515
1616
1. People want to add slight customizations to their transformer layers
1717
2. Writing these layers and customizations is not hard
@@ -21,7 +21,7 @@
2121
own performant transformer layers following our recommended best practices.
2222
The technologies used will be the following
2323
24-
1. Nested Tensors with the ``torch.jagged`` layout
24+
1. Nested Tensors with the ``torch.jagged`` layout (AKA NJTs)
2525
2. ``scaled_dot_product_attention``
2626
3. ``torch.compile()``
2727
4. ``FlexAttention``
@@ -31,8 +31,13 @@
3131
3232
If you are looking for an out-of-the-box implementation of a popular transformer
3333
architecture, note that there are many open-source libraries that provide them,
34-
with some examples being HuggingFace transformers and torchtune. Please head
35-
there instead!
34+
with some examples being:
35+
36+
* `HuggingFace transformers <https://github.com/huggingface/transformers>`_
37+
* `xformers <https://github.com/facebookresearch/xformers>`_
38+
* `torchtune <https://github.com/pytorch/torchtune>`_
39+
40+
Please head there instead!
3641
3742
If you are only interested in performant attention score modifications, please
3843
head to the `FlexAttention blog <https://flexattention.com/blog/>`_ that
@@ -50,52 +55,56 @@
5055
* `torch.nested <https://pytorch.org/tutorials/prototype/nestedtensor.html>`_
5156
5257
Nested tensors generalize the shape of regular dense tensors, allowing for
53-
representation of ragged-sized data. In the context of transformers,
54-
we can think of nested tensors as a tool for representing variable sequence
55-
lengths. They eliminate the need for the bug-prone practices of explicit
58+
representation of ragged-sized data with the same tensor UX. In the context of
59+
transformers, we can think of nested tensors as a tool for representing variable
60+
sequence lengths. They eliminate the need for the bug-prone practices of explicit
5661
padding and masking (think ``key_padding_mask`` in ``nn.MultiHeadAttention``).
5762
5863
*`scaled_dot_product_attention <https://pytorch.org/tutorials/intermediate/scaled_dot_product_attention_tutorial.html>`_
5964
6065
``scaled_dot_product_attention`` is a primitive for
6166
$\text{softmax}(\frac{QK^T}{\sqrt{E}} + B)V$ that dispatches into either fused
6267
implementations of the operator or a fallback implementation. It works out of
63-
the box in eager mode and also integrates seamlessly with compile.
64-
As of 2.6, it will also offer grouped query attention natively.
68+
the box in eager mode (i.e. the default mode of using PyTorch where operations
69+
are executed on the fly as they are encountered) and also integrates seamlessly
70+
with ``torch.compile()``. As of 2.6, it will also offer grouped query attention
71+
natively.
6572
6673
* `torch.compile() <https://pytorch.org/tutorials/intermediate/torch_compile_tutorial.html>`_
6774
6875
``torch.compile()`` is a compiler introduced in version 2.0 that is able to
69-
fuse together sequences of ops. Nested tensors with the ``torch.jagged`` layout
76+
capture a graph of PyTorch code and perform various optimizations on it, such as
77+
fusing together sequences of ops. Nested tensors with the ``torch.jagged`` layout
7078
and ``scaled_dot_product_attention`` work seamlessly with compile. In the
7179
context of transformers, the value add of using compile with nested tensor
72-
and sdpa is that compile can remove framework overhead ones sees in eager mode
80+
and SDPA is that compile can remove framework overhead ones sees in eager mode
7381
and fuse sequences of ops in transformers together (e.g. projection and
7482
activation).
7583
7684
* `FlexAttention <https://pytorch.org/blog/flexattention/>`_
7785
7886
``FlexAttention`` is a primitive that allows users to modify attention scores
7987
prior to the softmax operation. It generalizes the additive ``B`` term above
80-
for `scaled_dot_product_attention` into allowing you to do any op. It requires
81-
compile to achieve good performance.
88+
for `scaled_dot_product_attention`, allowing for arbitrary calculation. It
89+
requires compile to achieve good performance.
8290
8391
The above building blocks are "All You Need" (as of October 2024)
8492
==================================================================
8593
86-
The main premise in this section is that most transformers these days are
94+
The main premise in this section is that most transformer variations are
8795
GPT-style, consisting of layers like Embedding, Positional Encoding, Attention
8896
Blocks and Feed Forward networks. If we were to try to classify the differences
89-
in this space we might land on something like
97+
in this space, we might land on something like:
9098
91-
1. Layer type (activation functions e.g. SwiGLU, normalization functions
92-
e.g. RMSNorm etc., positional encodings e.g. Sinusoidal, Rotary etc.)
99+
1. Layer type (activation functions e.g. ``SwiGLU``, normalization functions
100+
e.g. ``RMSNorm`` etc., positional encodings e.g. Sinusoidal, Rotary etc.)
93101
2. Layer ordering (where to apply norms, where to apply positional encoding etc.)
94-
3. Modifications to attention score (ALiBi, Relative Positional Bias etc.)
102+
3. Modifications to attention score (``ALiBi``, Relative Positional Bias etc.)
95103
96104
In a pre-compiler world, one might write their custom transformer and observe
97105
that it works but is slow. Then, one might write a custom fused kernel for
98-
series of ops. In a compiler world, one can do the former, compile and profit.
106+
the specific series of ops. In a compiler world, one can do the former, compile
107+
and profit.
99108
100109
"""
101110

@@ -125,7 +134,8 @@
125134
# intermediate activations will use less memory.
126135
#
127136
# * Performance
128-
# Since unnecessary computation on padding is skipped, performance improves.
137+
# Since padding is not materialized and unnecessary computation on padding is
138+
# skipped, performance and memory usage improve.
129139
#
130140
# We'll demonstrate the above by building off the ``MultiheadAttention`` layer in the
131141
# `Nested Tensor tutorial <https://pytorch.org/tutorials/prototype/nestedtensor.html>`_
@@ -403,10 +413,11 @@ def benchmark(func, *args, **kwargs):
403413
##################################################################################
404414
# GPT-style layer
405415
# ---------------
406-
# A basic GPT-style transformer layer consistst of a causal self-attention layer
416+
# A basic GPT-style transformer layer consists of a causal self-attention layer
407417
# followed by a feed-forward network (FFN) with skip connections. Implementing
408418
# this is fairly straightforward using the ``MultiheadAttention`` layer above and
409-
# is actually the same as an ``nn.TransformerEncoderLayer`` with ``is_causal=True``.
419+
# gives equivalent results to an ``nn.TransformerEncoderLayer`` with
420+
# ``is_causal=True``.
410421

411422
# We demonstrate examples of implementing the rest of the nn layers
412423
# `here <https://github.com/mikaylagawarecki/temp>`_ but omit that from this
@@ -418,7 +429,7 @@ def benchmark(func, *args, **kwargs):
418429
# So far, we have demonstrated how to implement a performant ``MultiheadAttention``
419430
# layer that follows the traditional ``nn.MultiheadAttention``. Going back to our
420431
# classification of modifications to the transformer architecture, recall that we
421-
# classified the modifications into layer type, layer ordering and modifications
432+
# classified the modifications into layer type, layer ordering, and modifications
422433
# to the attention score. We trust that changing layer type and layer ordering
423434
# (e.g. swapping LayerNorm for RMSNorm) is fairly straightforward.
424435
#
@@ -570,20 +581,23 @@ def forward(self, x):
570581
print(f"Total sequence length in nested key/value {kv_len.sum().item()}, max sequence length {kv_len.max().item()}")
571582
out = new_mha_layer(query, key, value, is_causal=False)
572583

573-
# TODO: anything else I can add here?
574584

575585
################################################################################
576586
# Fully masked rows no longer cause NaNs
577587
# --------------------------------------
578588
#
579589
# There has been a long standing issue with ``nn.MultiheadAttention`` and
580-
# ``scaled_dot_product_attention`` where if a row was fully masked, the output
590+
# ``scaled_dot_product_attention`` where if a row was fully masked out, the output
581591
# of the attention layer would be NaN. See `issue <https://github.com/pytorch/pytorch/issues/41508>`_.
582-
# This is because the softmax operation would divide by zero.
592+
# This is because the softmax over an empty set is undefined.
583593
#
584594
# Thanks to `this PR <https://github.com/pytorch/pytorch/pull/133882>`_
585-
# this is no longer the case. Instead, fully masked rows will be set to zero. More
586-
# motivation can be found in the PR description.
595+
# this is no longer the case. Instead, fully masked rows in ``scaled_dot_product_attention``.
596+
# For cases where ``nn.MHA`` does not employ the "fast-path", this will also apply.
597+
#
598+
# Using a custom MHA layer with NJTs is strongly recommended over the
599+
# existing "fast-path" in ``nn.MultiheadAttention`` as NJT's ability to model raggedness
600+
# appropriately makes it possible to distinguish when there is an empty sequence.
587601

588602

589603
################################################################################

0 commit comments

Comments
 (0)