1
1
"""
2
- [Title TBD] Unbundling nn.Transformer modules for gains and profits
3
- ===================================================================
2
+ Dismantling down the `` nn.Transformer`` modules for gains and profits
3
+ ======================================================================
4
4
**Author:** `Mikayla Gawarecki <https://github.com/mikaylagawarecki>`_
5
5
6
6
The ``torch.nn`` module currently provides various ``Transformer``-related layers.
11
11
were made to try to make these layers more flexible.
12
12
13
13
While historically these layers intended to provide out-of-the-box, performant
14
- solutions. We make the observations that
14
+ solutions, we make the observations that
15
15
16
16
1. People want to add slight customizations to their transformer layers
17
17
2. Writing these layers and customizations is not hard
21
21
own performant transformer layers following our recommended best practices.
22
22
The technologies used will be the following
23
23
24
- 1. Nested Tensors with the ``torch.jagged`` layout
24
+ 1. Nested Tensors with the ``torch.jagged`` layout (AKA NJTs)
25
25
2. ``scaled_dot_product_attention``
26
26
3. ``torch.compile()``
27
27
4. ``FlexAttention``
31
31
32
32
If you are looking for an out-of-the-box implementation of a popular transformer
33
33
architecture, note that there are many open-source libraries that provide them,
34
- with some examples being HuggingFace transformers and torchtune. Please head
35
- there instead!
34
+ with some examples being:
35
+
36
+ * `HuggingFace transformers <https://github.com/huggingface/transformers>`_
37
+ * `xformers <https://github.com/facebookresearch/xformers>`_
38
+ * `torchtune <https://github.com/pytorch/torchtune>`_
39
+
40
+ Please head there instead!
36
41
37
42
If you are only interested in performant attention score modifications, please
38
43
head to the `FlexAttention blog <https://flexattention.com/blog/>`_ that
50
55
* `torch.nested <https://pytorch.org/tutorials/prototype/nestedtensor.html>`_
51
56
52
57
Nested tensors generalize the shape of regular dense tensors, allowing for
53
- representation of ragged-sized data. In the context of transformers,
54
- we can think of nested tensors as a tool for representing variable sequence
55
- lengths. They eliminate the need for the bug-prone practices of explicit
58
+ representation of ragged-sized data with the same tensor UX . In the context of
59
+ transformers, we can think of nested tensors as a tool for representing variable
60
+ sequence lengths. They eliminate the need for the bug-prone practices of explicit
56
61
padding and masking (think ``key_padding_mask`` in ``nn.MultiHeadAttention``).
57
62
58
63
*`scaled_dot_product_attention <https://pytorch.org/tutorials/intermediate/scaled_dot_product_attention_tutorial.html>`_
59
64
60
65
``scaled_dot_product_attention`` is a primitive for
61
66
$\t ext{softmax}(\f rac{QK^T}{\sqrt{E}} + B)V$ that dispatches into either fused
62
67
implementations of the operator or a fallback implementation. It works out of
63
- the box in eager mode and also integrates seamlessly with compile.
64
- As of 2.6, it will also offer grouped query attention natively.
68
+ the box in eager mode (i.e. the default mode of using PyTorch where operations
69
+ are executed on the fly as they are encountered) and also integrates seamlessly
70
+ with ``torch.compile()``. As of 2.6, it will also offer grouped query attention
71
+ natively.
65
72
66
73
* `torch.compile() <https://pytorch.org/tutorials/intermediate/torch_compile_tutorial.html>`_
67
74
68
75
``torch.compile()`` is a compiler introduced in version 2.0 that is able to
69
- fuse together sequences of ops. Nested tensors with the ``torch.jagged`` layout
76
+ capture a graph of PyTorch code and perform various optimizations on it, such as
77
+ fusing together sequences of ops. Nested tensors with the ``torch.jagged`` layout
70
78
and ``scaled_dot_product_attention`` work seamlessly with compile. In the
71
79
context of transformers, the value add of using compile with nested tensor
72
- and sdpa is that compile can remove framework overhead ones sees in eager mode
80
+ and SDPA is that compile can remove framework overhead ones sees in eager mode
73
81
and fuse sequences of ops in transformers together (e.g. projection and
74
82
activation).
75
83
76
84
* `FlexAttention <https://pytorch.org/blog/flexattention/>`_
77
85
78
86
``FlexAttention`` is a primitive that allows users to modify attention scores
79
87
prior to the softmax operation. It generalizes the additive ``B`` term above
80
- for `scaled_dot_product_attention` into allowing you to do any op . It requires
81
- compile to achieve good performance.
88
+ for `scaled_dot_product_attention`, allowing for arbitrary calculation . It
89
+ requires compile to achieve good performance.
82
90
83
91
The above building blocks are "All You Need" (as of October 2024)
84
92
==================================================================
85
93
86
- The main premise in this section is that most transformers these days are
94
+ The main premise in this section is that most transformer variations are
87
95
GPT-style, consisting of layers like Embedding, Positional Encoding, Attention
88
96
Blocks and Feed Forward networks. If we were to try to classify the differences
89
- in this space we might land on something like
97
+ in this space, we might land on something like:
90
98
91
- 1. Layer type (activation functions e.g. SwiGLU, normalization functions
92
- e.g. RMSNorm etc., positional encodings e.g. Sinusoidal, Rotary etc.)
99
+ 1. Layer type (activation functions e.g. `` SwiGLU`` , normalization functions
100
+ e.g. `` RMSNorm`` etc., positional encodings e.g. Sinusoidal, Rotary etc.)
93
101
2. Layer ordering (where to apply norms, where to apply positional encoding etc.)
94
- 3. Modifications to attention score (ALiBi, Relative Positional Bias etc.)
102
+ 3. Modifications to attention score (`` ALiBi`` , Relative Positional Bias etc.)
95
103
96
104
In a pre-compiler world, one might write their custom transformer and observe
97
105
that it works but is slow. Then, one might write a custom fused kernel for
98
- series of ops. In a compiler world, one can do the former, compile and profit.
106
+ the specific series of ops. In a compiler world, one can do the former, compile
107
+ and profit.
99
108
100
109
"""
101
110
125
134
# intermediate activations will use less memory.
126
135
#
127
136
# * Performance
128
- # Since unnecessary computation on padding is skipped, performance improves.
137
+ # Since padding is not materialized and unnecessary computation on padding is
138
+ # skipped, performance and memory usage improve.
129
139
#
130
140
# We'll demonstrate the above by building off the ``MultiheadAttention`` layer in the
131
141
# `Nested Tensor tutorial <https://pytorch.org/tutorials/prototype/nestedtensor.html>`_
@@ -403,10 +413,11 @@ def benchmark(func, *args, **kwargs):
403
413
##################################################################################
404
414
# GPT-style layer
405
415
# ---------------
406
- # A basic GPT-style transformer layer consistst of a causal self-attention layer
416
+ # A basic GPT-style transformer layer consists of a causal self-attention layer
407
417
# followed by a feed-forward network (FFN) with skip connections. Implementing
408
418
# this is fairly straightforward using the ``MultiheadAttention`` layer above and
409
- # is actually the same as an ``nn.TransformerEncoderLayer`` with ``is_causal=True``.
419
+ # gives equivalent results to an ``nn.TransformerEncoderLayer`` with
420
+ # ``is_causal=True``.
410
421
411
422
# We demonstrate examples of implementing the rest of the nn layers
412
423
# `here <https://github.com/mikaylagawarecki/temp>`_ but omit that from this
@@ -418,7 +429,7 @@ def benchmark(func, *args, **kwargs):
418
429
# So far, we have demonstrated how to implement a performant ``MultiheadAttention``
419
430
# layer that follows the traditional ``nn.MultiheadAttention``. Going back to our
420
431
# classification of modifications to the transformer architecture, recall that we
421
- # classified the modifications into layer type, layer ordering and modifications
432
+ # classified the modifications into layer type, layer ordering, and modifications
422
433
# to the attention score. We trust that changing layer type and layer ordering
423
434
# (e.g. swapping LayerNorm for RMSNorm) is fairly straightforward.
424
435
#
@@ -570,20 +581,23 @@ def forward(self, x):
570
581
print (f"Total sequence length in nested key/value { kv_len .sum ().item ()} , max sequence length { kv_len .max ().item ()} " )
571
582
out = new_mha_layer (query , key , value , is_causal = False )
572
583
573
- # TODO: anything else I can add here?
574
584
575
585
################################################################################
576
586
# Fully masked rows no longer cause NaNs
577
587
# --------------------------------------
578
588
#
579
589
# There has been a long standing issue with ``nn.MultiheadAttention`` and
580
- # ``scaled_dot_product_attention`` where if a row was fully masked, the output
590
+ # ``scaled_dot_product_attention`` where if a row was fully masked out , the output
581
591
# of the attention layer would be NaN. See `issue <https://github.com/pytorch/pytorch/issues/41508>`_.
582
- # This is because the softmax operation would divide by zero .
592
+ # This is because the softmax over an empty set is undefined .
583
593
#
584
594
# Thanks to `this PR <https://github.com/pytorch/pytorch/pull/133882>`_
585
- # this is no longer the case. Instead, fully masked rows will be set to zero. More
586
- # motivation can be found in the PR description.
595
+ # this is no longer the case. Instead, fully masked rows in ``scaled_dot_product_attention``.
596
+ # For cases where ``nn.MHA`` does not employ the "fast-path", this will also apply.
597
+ #
598
+ # Using a custom MHA layer with NJTs is strongly recommended over the
599
+ # existing "fast-path" in ``nn.MultiheadAttention`` as NJT's ability to model raggedness
600
+ # appropriately makes it possible to distinguish when there is an empty sequence.
587
601
588
602
589
603
################################################################################
0 commit comments