57
57
# Introducing the Building Blocks
58
58
# ===============================
59
59
# First, we will briefly introduce the 4 technologies mentioned in the introduction
60
-
60
+ #
61
61
# * `torch.nested <https://pytorch.org/tutorials/prototype/nestedtensor.html>`_
62
-
62
+ #
63
63
# Nested tensors generalize the shape of regular dense tensors, allowing for
64
64
# representation of ragged-sized data with the same tensor UX. In the context of
65
65
# transformers, we can think of nested tensors as a tool for representing variable
66
66
# sequence lengths. They eliminate the need for the bug-prone practices of explicit
67
67
# padding and masking (think ``key_padding_mask`` in ``nn.MultiHeadAttention``).
68
-
68
+ #
69
69
# * `scaled_dot_product_attention <https://pytorch.org/tutorials/intermediate/scaled_dot_product_attention_tutorial.html>`_
70
-
70
+ #
71
71
# ``scaled_dot_product_attention`` is a primitive for
72
72
# :math:`\text{softmax}(\frac{QK^T}{\sqrt{E}} + B)V` that dispatches into either fused
73
73
# implementations of the operator or a fallback implementation. It works out of
74
74
# the box in eager mode (i.e. the default mode of using PyTorch where operations
75
75
# are executed on the fly as they are encountered) and also integrates seamlessly
76
76
# with ``torch.compile()``. As of 2.6, it will also offer grouped query attention
77
77
# natively.
78
-
78
+ #
79
79
# * `torch.compile() <https://pytorch.org/tutorials/intermediate/torch_compile_tutorial.html>`_
80
-
80
+ #
81
81
# ``torch.compile()`` is a compiler introduced in version 2.0 that is able to
82
82
# capture a graph of PyTorch code and perform various optimizations on it, such as
83
83
# fusing together sequences of ops. Nested tensors with the ``torch.jagged`` layout
86
86
# and SDPA is that compile can remove framework overhead ones sees in eager mode
87
87
# and fuse sequences of ops in transformers together (e.g. projection and
88
88
# activation).
89
-
89
+ #
90
90
# * `FlexAttention <https://pytorch.org/blog/flexattention/>`_
91
-
91
+ #
92
92
# ``FlexAttention`` is a primitive that allows users to modify attention scores
93
93
# prior to the softmax operation. It generalizes the additive ``B`` term above
94
94
# for ``scaled_dot_product_attention``, allowing for arbitrary calculation. It
95
95
# requires compile to achieve good performance.
96
-
96
+ #
97
97
# The above building blocks are "All You Need" (as of October 2024)
98
98
# ==================================================================
99
-
99
+ #
100
100
# The main premise in this section is that most transformer variations are
101
101
# GPT-style, consisting of layers like Embedding, Positional Encoding, Attention
102
102
# Blocks and Feed Forward networks. If we were to try to classify the differences
103
103
# in this space, we might land on something like:
104
-
104
+ #
105
105
# 1. Layer type (activation functions e.g. ``SwiGLU``, normalization functions
106
106
# e.g. ``RMSNorm`` etc., positional encodings e.g. Sinusoidal, Rotary etc.)
107
107
# 2. Layer ordering (where to apply norms, where to apply positional encoding etc.)
108
108
# 3. Modifications to attention score (``ALiBi``, Relative Positional Bias etc.)
109
-
110
-
109
+ #
110
+ #
111
111
# In a pre-compiler world, one might write their custom transformer and observe
112
112
# that it works but is slow. Then, one might write a custom fused kernel for
113
113
# the specific series of ops. In a compiler world, one can do the former, compile
@@ -400,12 +400,11 @@ def benchmark(func, *args, **kwargs):
400
400
######################################################################################
401
401
# For reference some sample outputs on A100:
402
402
#
403
- # ..code::
404
- # padded_time=0.03454, padded_peak_memory=4.14 GB
405
- # nested_time=0.00612, nested_peak_memory=0.76 GB
406
- # Difference between vanilla and nested result 0.0
407
- # Nested speedup: 5.65
408
- # Nested peak memory reduction 3.39 GB
403
+ # padded_time=0.03454, padded_peak_memory=4.14 GB
404
+ # nested_time=0.00612, nested_peak_memory=0.76 GB
405
+ # Difference between vanilla and nested result 0.0
406
+ # Nested speedup: 5.65
407
+ # Nested peak memory reduction 3.39 GB
409
408
#
410
409
# We can also see the same for backward pass
411
410
@@ -429,15 +428,14 @@ def benchmark(func, *args, **kwargs):
429
428
##################################################################################
430
429
# Sample outputs on A100:
431
430
#
432
- # ..code::
433
- # ``padded_bw_time``=2.09337, ``padded_bw_peak_mem``=5.10 GB
434
- # ``nested_bw_time``=0.01452, ``nested_bw_peak_mem``=3.24 GB
435
- # Nested backward speedup: 144.13
436
- # Nested backward peak memory reduction 1.86 GB
437
- # Difference in ``out_proj.weight.grad`` 0.000244140625
438
- # Difference in ``packed_proj.weight.grad`` 0.001556396484375
439
- # Difference in ``out_proj.bias.grad`` 0.0
440
- # Difference in ``packed_proj.bias.grad`` 0.001953125
431
+ # ``padded_bw_time``=2.09337, ``padded_bw_peak_mem``=5.10 GB
432
+ # ``nested_bw_time``=0.01452, ``nested_bw_peak_mem``=3.24 GB
433
+ # Nested backward speedup: 144.13
434
+ # Nested backward peak memory reduction 1.86 GB
435
+ # Difference in ``out_proj.weight.grad`` 0.000244140625
436
+ # Difference in ``packed_proj.weight.grad`` 0.001556396484375
437
+ # Difference in ``out_proj.bias.grad`` 0.0
438
+ # Difference in ``packed_proj.bias.grad`` 0.001953125
441
439
#
442
440
443
441
##################################################################################
0 commit comments