Skip to content

Commit ee5482c

Browse files
rendering
1 parent 311751d commit ee5482c

File tree

1 file changed

+66
-66
lines changed

1 file changed

+66
-66
lines changed

intermediate_source/transformer_building_blocks.py

Lines changed: 66 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -46,72 +46,73 @@
4646
If you are only interested in performant attention score modifications, please
4747
head to the `FlexAttention blog <https://pytorch.org/blog/flexattention/>`_ that
4848
contains a `gym of masks <https://github.com/pytorch-labs/attention-gym>`_.
49-
5049
If you are wondering about what building blocks the ``torch`` library provides
5150
for writing your own transformer layers and best practices, you are in the
5251
right place, please keep reading!
5352
5453
55-
Introducing the Building Blocks
56-
===============================
57-
First, we will briefly introduce the 4 technologies mentioned in the introduction
54+
"""
55+
56+
################################################################################
57+
# Introducing the Building Blocks
58+
# ===============================
59+
# First, we will briefly introduce the 4 technologies mentioned in the introduction
5860

59-
* `torch.nested <https://pytorch.org/tutorials/prototype/nestedtensor.html>`_
61+
# * `torch.nested <https://pytorch.org/tutorials/prototype/nestedtensor.html>`_
6062

61-
Nested tensors generalize the shape of regular dense tensors, allowing for
62-
representation of ragged-sized data with the same tensor UX. In the context of
63-
transformers, we can think of nested tensors as a tool for representing variable
64-
sequence lengths. They eliminate the need for the bug-prone practices of explicit
65-
padding and masking (think ``key_padding_mask`` in ``nn.MultiHeadAttention``).
63+
# Nested tensors generalize the shape of regular dense tensors, allowing for
64+
# representation of ragged-sized data with the same tensor UX. In the context of
65+
# transformers, we can think of nested tensors as a tool for representing variable
66+
# sequence lengths. They eliminate the need for the bug-prone practices of explicit
67+
# padding and masking (think ``key_padding_mask`` in ``nn.MultiHeadAttention``).
6668

67-
* `scaled_dot_product_attention <https://pytorch.org/tutorials/intermediate/scaled_dot_product_attention_tutorial.html>`_
69+
# * `scaled_dot_product_attention <https://pytorch.org/tutorials/intermediate/scaled_dot_product_attention_tutorial.html>`_
6870

69-
``scaled_dot_product_attention`` is a primitive for
70-
:math:`\text{softmax}(\frac{QK^T}{\sqrt{E}} + B)V` that dispatches into either fused
71-
implementations of the operator or a fallback implementation. It works out of
72-
the box in eager mode (i.e. the default mode of using PyTorch where operations
73-
are executed on the fly as they are encountered) and also integrates seamlessly
74-
with ``torch.compile()``. As of 2.6, it will also offer grouped query attention
75-
natively.
71+
# ``scaled_dot_product_attention`` is a primitive for
72+
# :math:`\text{softmax}(\frac{QK^T}{\sqrt{E}} + B)V` that dispatches into either fused
73+
# implementations of the operator or a fallback implementation. It works out of
74+
# the box in eager mode (i.e. the default mode of using PyTorch where operations
75+
# are executed on the fly as they are encountered) and also integrates seamlessly
76+
# with ``torch.compile()``. As of 2.6, it will also offer grouped query attention
77+
# natively.
7678

77-
* `torch.compile() <https://pytorch.org/tutorials/intermediate/torch_compile_tutorial.html>`_
79+
# * `torch.compile() <https://pytorch.org/tutorials/intermediate/torch_compile_tutorial.html>`_
7880

79-
``torch.compile()`` is a compiler introduced in version 2.0 that is able to
80-
capture a graph of PyTorch code and perform various optimizations on it, such as
81-
fusing together sequences of ops. Nested tensors with the ``torch.jagged`` layout
82-
and ``scaled_dot_product_attention`` work seamlessly with compile. In the
83-
context of transformers, the value add of using compile with nested tensor
84-
and SDPA is that compile can remove framework overhead ones sees in eager mode
85-
and fuse sequences of ops in transformers together (e.g. projection and
86-
activation).
81+
# ``torch.compile()`` is a compiler introduced in version 2.0 that is able to
82+
# capture a graph of PyTorch code and perform various optimizations on it, such as
83+
# fusing together sequences of ops. Nested tensors with the ``torch.jagged`` layout
84+
# and ``scaled_dot_product_attention`` work seamlessly with compile. In the
85+
# context of transformers, the value add of using compile with nested tensor
86+
# and SDPA is that compile can remove framework overhead ones sees in eager mode
87+
# and fuse sequences of ops in transformers together (e.g. projection and
88+
# activation).
8789

88-
* `FlexAttention <https://pytorch.org/blog/flexattention/>`_
90+
# * `FlexAttention <https://pytorch.org/blog/flexattention/>`_
8991

90-
``FlexAttention`` is a primitive that allows users to modify attention scores
91-
prior to the softmax operation. It generalizes the additive ``B`` term above
92-
for `scaled_dot_product_attention`, allowing for arbitrary calculation. It
93-
requires compile to achieve good performance.
92+
# ``FlexAttention`` is a primitive that allows users to modify attention scores
93+
# prior to the softmax operation. It generalizes the additive ``B`` term above
94+
# for ``scaled_dot_product_attention``, allowing for arbitrary calculation. It
95+
# requires compile to achieve good performance.
9496

95-
The above building blocks are "All You Need" (as of October 2024)
96-
==================================================================
97+
# The above building blocks are "All You Need" (as of October 2024)
98+
# ==================================================================
9799

98-
The main premise in this section is that most transformer variations are
99-
GPT-style, consisting of layers like Embedding, Positional Encoding, Attention
100-
Blocks and Feed Forward networks. If we were to try to classify the differences
101-
in this space, we might land on something like:
100+
# The main premise in this section is that most transformer variations are
101+
# GPT-style, consisting of layers like Embedding, Positional Encoding, Attention
102+
# Blocks and Feed Forward networks. If we were to try to classify the differences
103+
# in this space, we might land on something like:
102104

103-
1. Layer type (activation functions e.g. ``SwiGLU``, normalization functions
104-
e.g. ``RMSNorm`` etc., positional encodings e.g. Sinusoidal, Rotary etc.)
105-
2. Layer ordering (where to apply norms, where to apply positional encoding etc.)
106-
3. Modifications to attention score (``ALiBi``, Relative Positional Bias etc.)
105+
# 1. Layer type (activation functions e.g. ``SwiGLU``, normalization functions
106+
# e.g. ``RMSNorm`` etc., positional encodings e.g. Sinusoidal, Rotary etc.)
107+
# 2. Layer ordering (where to apply norms, where to apply positional encoding etc.)
108+
# 3. Modifications to attention score (``ALiBi``, Relative Positional Bias etc.)
107109

108110

109-
In a pre-compiler world, one might write their custom transformer and observe
110-
that it works but is slow. Then, one might write a custom fused kernel for
111-
the specific series of ops. In a compiler world, one can do the former, compile
112-
and profit.
111+
# In a pre-compiler world, one might write their custom transformer and observe
112+
# that it works but is slow. Then, one might write a custom fused kernel for
113+
# the specific series of ops. In a compiler world, one can do the former, compile
114+
# and profit.
113115

114-
"""
115116

116117
###############################################################################
117118
# MultiheadAttention
@@ -399,13 +400,12 @@ def benchmark(func, *args, **kwargs):
399400
######################################################################################
400401
# For reference some sample outputs on A100:
401402
#
402-
# ```
403-
# padded_time=0.03454, padded_peak_memory=4.14 GB
404-
# nested_time=0.00612, nested_peak_memory=0.76 GB
405-
# Difference between vanilla and nested result 0.0
406-
# Nested speedup: 5.65
407-
# Nested peak memory reduction 3.39 GB
408-
# ````
403+
# ..code::
404+
# padded_time=0.03454, padded_peak_memory=4.14 GB
405+
# nested_time=0.00612, nested_peak_memory=0.76 GB
406+
# Difference between vanilla and nested result 0.0
407+
# Nested speedup: 5.65
408+
# Nested peak memory reduction 3.39 GB
409409
#
410410
# We can also see the same for backward pass
411411

@@ -429,16 +429,16 @@ def benchmark(func, *args, **kwargs):
429429
##################################################################################
430430
# Sample outputs on A100:
431431
#
432-
# ```
433-
# padded_bw_time=2.09337, padded_bw_peak_mem=5.10 GB
434-
# nested_bw_time=0.01452, nested_bw_peak_mem=3.24 GB
435-
# Nested backward speedup: 144.13
436-
# Nested backward peak memory reduction 1.86 GB
437-
# Difference in out_proj.weight.grad 0.000244140625
438-
# Difference in packed_proj.weight.grad 0.001556396484375
439-
# Difference in out_proj.bias.grad 0.0
440-
# Difference in packed_proj.bias.grad 0.001953125
441-
# ```
432+
# ..code::
433+
# padded_bw_time=2.09337, padded_bw_peak_mem=5.10 GB
434+
# nested_bw_time=0.01452, nested_bw_peak_mem=3.24 GB
435+
# Nested backward speedup: 144.13
436+
# Nested backward peak memory reduction 1.86 GB
437+
# Difference in out_proj.weight.grad 0.000244140625
438+
# Difference in packed_proj.weight.grad 0.001556396484375
439+
# Difference in out_proj.bias.grad 0.0
440+
# Difference in packed_proj.bias.grad 0.001953125
441+
#
442442

443443
##################################################################################
444444
# GPT-style layer
@@ -462,13 +462,13 @@ def benchmark(func, *args, **kwargs):
462462
# classification of modifications to the transformer architecture, recall that we
463463
# classified the modifications into layer type, layer ordering, and modifications
464464
# to the attention score. We trust that changing layer type and layer ordering
465-
# (e.g. swapping``LayerNorm`` for ``RMSNorm``) is fairly straightforward.
465+
# (e.g. swapping ``LayerNorm`` for ``RMSNorm``) is fairly straightforward.
466466
#
467467
# In this section, we will discuss various functionalities using the
468468
# aforementioned building blocks. In particular,
469469
#
470470
# * Cross Attention
471-
# * Fully masked rows no longer cause ``NaN``s
471+
# * Fully masked rows no longer cause NaNs
472472
# * Modifying attention score: ALiBi with FlexAttention and NJT
473473
# * Packed Projection
474474

0 commit comments

Comments
 (0)