Skip to content

Commit bb3106a

Browse files
rendering again
1 parent e856813 commit bb3106a

File tree

1 file changed

+26
-28
lines changed

1 file changed

+26
-28
lines changed

intermediate_source/transformer_building_blocks.py

Lines changed: 26 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -57,27 +57,27 @@
5757
# Introducing the Building Blocks
5858
# ===============================
5959
# First, we will briefly introduce the 4 technologies mentioned in the introduction
60-
60+
#
6161
# * `torch.nested <https://pytorch.org/tutorials/prototype/nestedtensor.html>`_
62-
62+
#
6363
# Nested tensors generalize the shape of regular dense tensors, allowing for
6464
# representation of ragged-sized data with the same tensor UX. In the context of
6565
# transformers, we can think of nested tensors as a tool for representing variable
6666
# sequence lengths. They eliminate the need for the bug-prone practices of explicit
6767
# padding and masking (think ``key_padding_mask`` in ``nn.MultiHeadAttention``).
68-
68+
#
6969
# * `scaled_dot_product_attention <https://pytorch.org/tutorials/intermediate/scaled_dot_product_attention_tutorial.html>`_
70-
70+
#
7171
# ``scaled_dot_product_attention`` is a primitive for
7272
# :math:`\text{softmax}(\frac{QK^T}{\sqrt{E}} + B)V` that dispatches into either fused
7373
# implementations of the operator or a fallback implementation. It works out of
7474
# the box in eager mode (i.e. the default mode of using PyTorch where operations
7575
# are executed on the fly as they are encountered) and also integrates seamlessly
7676
# with ``torch.compile()``. As of 2.6, it will also offer grouped query attention
7777
# natively.
78-
78+
#
7979
# * `torch.compile() <https://pytorch.org/tutorials/intermediate/torch_compile_tutorial.html>`_
80-
80+
#
8181
# ``torch.compile()`` is a compiler introduced in version 2.0 that is able to
8282
# capture a graph of PyTorch code and perform various optimizations on it, such as
8383
# fusing together sequences of ops. Nested tensors with the ``torch.jagged`` layout
@@ -86,28 +86,28 @@
8686
# and SDPA is that compile can remove framework overhead ones sees in eager mode
8787
# and fuse sequences of ops in transformers together (e.g. projection and
8888
# activation).
89-
89+
#
9090
# * `FlexAttention <https://pytorch.org/blog/flexattention/>`_
91-
91+
#
9292
# ``FlexAttention`` is a primitive that allows users to modify attention scores
9393
# prior to the softmax operation. It generalizes the additive ``B`` term above
9494
# for ``scaled_dot_product_attention``, allowing for arbitrary calculation. It
9595
# requires compile to achieve good performance.
96-
96+
#
9797
# The above building blocks are "All You Need" (as of October 2024)
9898
# ==================================================================
99-
99+
#
100100
# The main premise in this section is that most transformer variations are
101101
# GPT-style, consisting of layers like Embedding, Positional Encoding, Attention
102102
# Blocks and Feed Forward networks. If we were to try to classify the differences
103103
# in this space, we might land on something like:
104-
104+
#
105105
# 1. Layer type (activation functions e.g. ``SwiGLU``, normalization functions
106106
# e.g. ``RMSNorm`` etc., positional encodings e.g. Sinusoidal, Rotary etc.)
107107
# 2. Layer ordering (where to apply norms, where to apply positional encoding etc.)
108108
# 3. Modifications to attention score (``ALiBi``, Relative Positional Bias etc.)
109-
110-
109+
#
110+
#
111111
# In a pre-compiler world, one might write their custom transformer and observe
112112
# that it works but is slow. Then, one might write a custom fused kernel for
113113
# the specific series of ops. In a compiler world, one can do the former, compile
@@ -400,12 +400,11 @@ def benchmark(func, *args, **kwargs):
400400
######################################################################################
401401
# For reference some sample outputs on A100:
402402
#
403-
# ..code::
404-
# padded_time=0.03454, padded_peak_memory=4.14 GB
405-
# nested_time=0.00612, nested_peak_memory=0.76 GB
406-
# Difference between vanilla and nested result 0.0
407-
# Nested speedup: 5.65
408-
# Nested peak memory reduction 3.39 GB
403+
# padded_time=0.03454, padded_peak_memory=4.14 GB
404+
# nested_time=0.00612, nested_peak_memory=0.76 GB
405+
# Difference between vanilla and nested result 0.0
406+
# Nested speedup: 5.65
407+
# Nested peak memory reduction 3.39 GB
409408
#
410409
# We can also see the same for backward pass
411410

@@ -429,15 +428,14 @@ def benchmark(func, *args, **kwargs):
429428
##################################################################################
430429
# Sample outputs on A100:
431430
#
432-
# ..code::
433-
# ``padded_bw_time``=2.09337, ``padded_bw_peak_mem``=5.10 GB
434-
# ``nested_bw_time``=0.01452, ``nested_bw_peak_mem``=3.24 GB
435-
# Nested backward speedup: 144.13
436-
# Nested backward peak memory reduction 1.86 GB
437-
# Difference in ``out_proj.weight.grad`` 0.000244140625
438-
# Difference in ``packed_proj.weight.grad`` 0.001556396484375
439-
# Difference in ``out_proj.bias.grad`` 0.0
440-
# Difference in ``packed_proj.bias.grad`` 0.001953125
431+
# ``padded_bw_time``=2.09337, ``padded_bw_peak_mem``=5.10 GB
432+
# ``nested_bw_time``=0.01452, ``nested_bw_peak_mem``=3.24 GB
433+
# Nested backward speedup: 144.13
434+
# Nested backward peak memory reduction 1.86 GB
435+
# Difference in ``out_proj.weight.grad`` 0.000244140625
436+
# Difference in ``packed_proj.weight.grad`` 0.001556396484375
437+
# Difference in ``out_proj.bias.grad`` 0.0
438+
# Difference in ``packed_proj.bias.grad`` 0.001953125
441439
#
442440

443441
##################################################################################

0 commit comments

Comments
 (0)