46
46
If you are only interested in performant attention score modifications, please
47
47
head to the `FlexAttention blog <https://pytorch.org/blog/flexattention/>`_ that
48
48
contains a `gym of masks <https://github.com/pytorch-labs/attention-gym>`_.
49
-
50
49
If you are wondering about what building blocks the ``torch`` library provides
51
50
for writing your own transformer layers and best practices, you are in the
52
51
right place, please keep reading!
53
52
54
53
55
- Introducing the Building Blocks
56
- ===============================
57
- First, we will briefly introduce the 4 technologies mentioned in the introduction
54
+ """
55
+
56
+ ################################################################################
57
+ # Introducing the Building Blocks
58
+ # ===============================
59
+ # First, we will briefly introduce the 4 technologies mentioned in the introduction
58
60
59
- * `torch.nested <https://pytorch.org/tutorials/prototype/nestedtensor.html>`_
61
+ # * `torch.nested <https://pytorch.org/tutorials/prototype/nestedtensor.html>`_
60
62
61
- Nested tensors generalize the shape of regular dense tensors, allowing for
62
- representation of ragged-sized data with the same tensor UX. In the context of
63
- transformers, we can think of nested tensors as a tool for representing variable
64
- sequence lengths. They eliminate the need for the bug-prone practices of explicit
65
- padding and masking (think ``key_padding_mask`` in ``nn.MultiHeadAttention``).
63
+ # Nested tensors generalize the shape of regular dense tensors, allowing for
64
+ # representation of ragged-sized data with the same tensor UX. In the context of
65
+ # transformers, we can think of nested tensors as a tool for representing variable
66
+ # sequence lengths. They eliminate the need for the bug-prone practices of explicit
67
+ # padding and masking (think ``key_padding_mask`` in ``nn.MultiHeadAttention``).
66
68
67
- * `scaled_dot_product_attention <https://pytorch.org/tutorials/intermediate/scaled_dot_product_attention_tutorial.html>`_
69
+ # * `scaled_dot_product_attention <https://pytorch.org/tutorials/intermediate/scaled_dot_product_attention_tutorial.html>`_
68
70
69
- ``scaled_dot_product_attention`` is a primitive for
70
- :math:`\t ext{softmax}(\f rac{QK^T}{\sqrt{E}} + B)V` that dispatches into either fused
71
- implementations of the operator or a fallback implementation. It works out of
72
- the box in eager mode (i.e. the default mode of using PyTorch where operations
73
- are executed on the fly as they are encountered) and also integrates seamlessly
74
- with ``torch.compile()``. As of 2.6, it will also offer grouped query attention
75
- natively.
71
+ # ``scaled_dot_product_attention`` is a primitive for
72
+ # :math:`\text{softmax}(\frac{QK^T}{\sqrt{E}} + B)V` that dispatches into either fused
73
+ # implementations of the operator or a fallback implementation. It works out of
74
+ # the box in eager mode (i.e. the default mode of using PyTorch where operations
75
+ # are executed on the fly as they are encountered) and also integrates seamlessly
76
+ # with ``torch.compile()``. As of 2.6, it will also offer grouped query attention
77
+ # natively.
76
78
77
- * `torch.compile() <https://pytorch.org/tutorials/intermediate/torch_compile_tutorial.html>`_
79
+ # * `torch.compile() <https://pytorch.org/tutorials/intermediate/torch_compile_tutorial.html>`_
78
80
79
- ``torch.compile()`` is a compiler introduced in version 2.0 that is able to
80
- capture a graph of PyTorch code and perform various optimizations on it, such as
81
- fusing together sequences of ops. Nested tensors with the ``torch.jagged`` layout
82
- and ``scaled_dot_product_attention`` work seamlessly with compile. In the
83
- context of transformers, the value add of using compile with nested tensor
84
- and SDPA is that compile can remove framework overhead ones sees in eager mode
85
- and fuse sequences of ops in transformers together (e.g. projection and
86
- activation).
81
+ # ``torch.compile()`` is a compiler introduced in version 2.0 that is able to
82
+ # capture a graph of PyTorch code and perform various optimizations on it, such as
83
+ # fusing together sequences of ops. Nested tensors with the ``torch.jagged`` layout
84
+ # and ``scaled_dot_product_attention`` work seamlessly with compile. In the
85
+ # context of transformers, the value add of using compile with nested tensor
86
+ # and SDPA is that compile can remove framework overhead ones sees in eager mode
87
+ # and fuse sequences of ops in transformers together (e.g. projection and
88
+ # activation).
87
89
88
- * `FlexAttention <https://pytorch.org/blog/flexattention/>`_
90
+ # * `FlexAttention <https://pytorch.org/blog/flexattention/>`_
89
91
90
- ``FlexAttention`` is a primitive that allows users to modify attention scores
91
- prior to the softmax operation. It generalizes the additive ``B`` term above
92
- for `scaled_dot_product_attention`, allowing for arbitrary calculation. It
93
- requires compile to achieve good performance.
92
+ # ``FlexAttention`` is a primitive that allows users to modify attention scores
93
+ # prior to the softmax operation. It generalizes the additive ``B`` term above
94
+ # for `` scaled_dot_product_attention` `, allowing for arbitrary calculation. It
95
+ # requires compile to achieve good performance.
94
96
95
- The above building blocks are "All You Need" (as of October 2024)
96
- ==================================================================
97
+ # The above building blocks are "All You Need" (as of October 2024)
98
+ # ==================================================================
97
99
98
- The main premise in this section is that most transformer variations are
99
- GPT-style, consisting of layers like Embedding, Positional Encoding, Attention
100
- Blocks and Feed Forward networks. If we were to try to classify the differences
101
- in this space, we might land on something like:
100
+ # The main premise in this section is that most transformer variations are
101
+ # GPT-style, consisting of layers like Embedding, Positional Encoding, Attention
102
+ # Blocks and Feed Forward networks. If we were to try to classify the differences
103
+ # in this space, we might land on something like:
102
104
103
- 1. Layer type (activation functions e.g. ``SwiGLU``, normalization functions
104
- e.g. ``RMSNorm`` etc., positional encodings e.g. Sinusoidal, Rotary etc.)
105
- 2. Layer ordering (where to apply norms, where to apply positional encoding etc.)
106
- 3. Modifications to attention score (``ALiBi``, Relative Positional Bias etc.)
105
+ # 1. Layer type (activation functions e.g. ``SwiGLU``, normalization functions
106
+ # e.g. ``RMSNorm`` etc., positional encodings e.g. Sinusoidal, Rotary etc.)
107
+ # 2. Layer ordering (where to apply norms, where to apply positional encoding etc.)
108
+ # 3. Modifications to attention score (``ALiBi``, Relative Positional Bias etc.)
107
109
108
110
109
- In a pre-compiler world, one might write their custom transformer and observe
110
- that it works but is slow. Then, one might write a custom fused kernel for
111
- the specific series of ops. In a compiler world, one can do the former, compile
112
- and profit.
111
+ # In a pre-compiler world, one might write their custom transformer and observe
112
+ # that it works but is slow. Then, one might write a custom fused kernel for
113
+ # the specific series of ops. In a compiler world, one can do the former, compile
114
+ # and profit.
113
115
114
- """
115
116
116
117
###############################################################################
117
118
# MultiheadAttention
@@ -399,13 +400,12 @@ def benchmark(func, *args, **kwargs):
399
400
######################################################################################
400
401
# For reference some sample outputs on A100:
401
402
#
402
- # ```
403
- # padded_time=0.03454, padded_peak_memory=4.14 GB
404
- # nested_time=0.00612, nested_peak_memory=0.76 GB
405
- # Difference between vanilla and nested result 0.0
406
- # Nested speedup: 5.65
407
- # Nested peak memory reduction 3.39 GB
408
- # ````
403
+ # ..code::
404
+ # padded_time=0.03454, padded_peak_memory=4.14 GB
405
+ # nested_time=0.00612, nested_peak_memory=0.76 GB
406
+ # Difference between vanilla and nested result 0.0
407
+ # Nested speedup: 5.65
408
+ # Nested peak memory reduction 3.39 GB
409
409
#
410
410
# We can also see the same for backward pass
411
411
@@ -429,16 +429,16 @@ def benchmark(func, *args, **kwargs):
429
429
##################################################################################
430
430
# Sample outputs on A100:
431
431
#
432
- # ```
433
- # padded_bw_time=2.09337, padded_bw_peak_mem=5.10 GB
434
- # nested_bw_time=0.01452, nested_bw_peak_mem=3.24 GB
435
- # Nested backward speedup: 144.13
436
- # Nested backward peak memory reduction 1.86 GB
437
- # Difference in out_proj.weight.grad 0.000244140625
438
- # Difference in packed_proj.weight.grad 0.001556396484375
439
- # Difference in out_proj.bias.grad 0.0
440
- # Difference in packed_proj.bias.grad 0.001953125
441
- # ```
432
+ # ..code::
433
+ # padded_bw_time=2.09337, padded_bw_peak_mem=5.10 GB
434
+ # nested_bw_time=0.01452, nested_bw_peak_mem=3.24 GB
435
+ # Nested backward speedup: 144.13
436
+ # Nested backward peak memory reduction 1.86 GB
437
+ # Difference in out_proj.weight.grad 0.000244140625
438
+ # Difference in packed_proj.weight.grad 0.001556396484375
439
+ # Difference in out_proj.bias.grad 0.0
440
+ # Difference in packed_proj.bias.grad 0.001953125
441
+ #
442
442
443
443
##################################################################################
444
444
# GPT-style layer
@@ -462,13 +462,13 @@ def benchmark(func, *args, **kwargs):
462
462
# classification of modifications to the transformer architecture, recall that we
463
463
# classified the modifications into layer type, layer ordering, and modifications
464
464
# to the attention score. We trust that changing layer type and layer ordering
465
- # (e.g. swapping``LayerNorm`` for ``RMSNorm``) is fairly straightforward.
465
+ # (e.g. swapping ``LayerNorm`` for ``RMSNorm``) is fairly straightforward.
466
466
#
467
467
# In this section, we will discuss various functionalities using the
468
468
# aforementioned building blocks. In particular,
469
469
#
470
470
# * Cross Attention
471
- # * Fully masked rows no longer cause ``NaN``s
471
+ # * Fully masked rows no longer cause NaNs
472
472
# * Modifying attention score: ALiBi with FlexAttention and NJT
473
473
# * Packed Projection
474
474
0 commit comments