Skip to content

Commit 1fa0cd0

Browse files
authored
Sdpa tutorial (#2252)
1 parent 9580916 commit 1fa0cd0

File tree

2 files changed

+346
-1
lines changed

2 files changed

+346
-1
lines changed

index.rst

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -502,7 +502,7 @@ What's new in PyTorch tutorials?
502502
:image: _static/img/thumbnails/cropped/generic-pytorch-logo.png
503503
:link: intermediate/torchserve_with_ipex_2
504504
:tags: Model-Optimization,Production
505-
505+
506506
.. customcarditem::
507507
:header: Introduction to nvFuser
508508
:card_description: An introduction to nvFuser
@@ -524,6 +524,13 @@ What's new in PyTorch tutorials?
524524
:link: intermediate/torch_compile_tutorial.html
525525
:tags: Model-Optimization
526526

527+
.. customcarditem::
528+
:header: (beta) Implementing High-Performance Transformers with SCALED DOT PRODUCT ATTENTION
529+
:card_description: This tutorial explores the new torch.nn.functional.scaled_dot_product_attention and how it can be used to construct Transformer components.
530+
:image: _static/img/thumbnails/cropped/pytorch-logo.png
531+
:link: intermediate/scaled_dot_product_attention_tutorial.html
532+
:tags: Model-Optimization,Attention,Transformer
533+
527534
.. Parallel-and-Distributed-Training
528535
529536
@@ -909,6 +916,7 @@ Additional Resources
909916
intermediate/nvfuser_intro_tutorial
910917
intermediate/ax_multiobjective_nas_tutorial
911918
intermediate/torch_compile_tutorial
919+
intermediate/scaled_dot_product_attention_tutorial
912920

913921
.. toctree::
914922
:maxdepth: 2
Lines changed: 337 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,337 @@
1+
"""
2+
Implementing High-Performance Transformers with SCALED DOT PRODUCT ATTENTION
3+
================================================================================
4+
5+
"""
6+
7+
8+
######################################################################
9+
# Summary
10+
# ~~~~~~~~
11+
#
12+
# In this tutorial, we want to highlight a new ``torch.nn.functional`` function
13+
# that can be helpful for implementing transformer architectures. The
14+
# function is named ``torch.nn.functional.scaled_dot_product_attention``.
15+
# For detailed description of the function, see the `PyTorch documentation <https://pytorch.org/docs/master/generated/torch.nn.functional.scaled_dot_product_attention.html#torch.nn.functional.scaled_dot_product_attention>`__.
16+
# This function has already been incorporated into ``torch.nn.MultiheadAttention`` and ``torch.nn.TransformerEncoderLayer``.
17+
#
18+
# Overview
19+
# ~~~~~~~~~
20+
# At a high level, this PyTorch function calculates the
21+
# scaled dot product attention (SDPA) between query, key, and value according to
22+
# the definition found in the paper `Attention is all you
23+
# need <https://arxiv.org/abs/1706.03762>`__. While this function can
24+
# be written in PyTorch using existing functions, a fused implementation can provide
25+
# large performance benefits over a naive implementation.
26+
#
27+
# Fused implementations
28+
# ~~~~~~~~~~~~~~~~~~~~~~
29+
#
30+
# For CUDA tensor inputs, the function will dispatch into one of the following
31+
# implementations:
32+
#
33+
# * `FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness <https://arxiv.org/abs/2205.14135>`__
34+
# * `Memory-Efficient Attention <https://github.com/facebookresearch/xformers>`__
35+
# * A PyTorch implementation defined in C++
36+
#
37+
38+
import torch
39+
import torch.nn as nn
40+
import torch.nn.functional as F
41+
device = "cuda" if torch.cuda.is_available() else "cpu"
42+
43+
# Example Usage:
44+
query, key, value = torch.randn(2, 3, 8, device=device), torch.randn(2, 3, 8, device=device), torch.randn(2, 3, 8, device=device)
45+
F.scaled_dot_product_attention(query, key, value)
46+
47+
48+
######################################################################
49+
# Explicit Dispatcher Control
50+
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~
51+
#
52+
# While the function will implicitly dispatch to one of the three
53+
# implementations, the user can also explicitly control the dispatch via
54+
# the use of a context manager. This context manager allows users to
55+
# explicitly disable certain implementations. If a user wants to ensure
56+
# the function is indeed using the fastest implementation for their
57+
# specific inputs, the context manager can be used to sweep through
58+
# measuring performance.
59+
#
60+
61+
# Lets define a helpful benchmarking function:
62+
import torch.utils.benchmark as benchmark
63+
def benchmark_torch_function_in_microseconds(f, *args, **kwargs):
64+
t0 = benchmark.Timer(
65+
stmt="f(*args, **kwargs)", globals={"args": args, "kwargs": kwargs, "f": f}
66+
)
67+
return t0.blocked_autorange().mean * 1e6
68+
69+
# Lets define the hyper-parameters of our input
70+
batch_size = 32
71+
max_sequence_len = 1024
72+
num_heads = 32
73+
embed_dimension = 32
74+
75+
dtype = torch.float16
76+
77+
query = torch.rand(batch_size, num_heads, max_sequence_len, embed_dimension, device=device, dtype=dtype)
78+
key = torch.rand(batch_size, num_heads, max_sequence_len, embed_dimension, device=device, dtype=dtype)
79+
value = torch.rand(batch_size, num_heads, max_sequence_len, embed_dimension, device=device, dtype=dtype)
80+
81+
print(f"The default implementation runs in {benchmark_torch_function_in_microseconds(F.scaled_dot_product_attention, query, key, value):.3f} microseconds")
82+
83+
# Lets explore the speed of each of the 3 implementations
84+
from torch.backends.cuda import sdp_kernel, SDPBackend
85+
86+
# Helpful arg mapper
87+
backend_map = {
88+
SDPBackend.MATH: {"enable_math": True, "enable_flash": False, "enable_mem_efficient": False},
89+
SDPBackend.FLASH_ATTENTION: {"enable_math": False, "enable_flash": True, "enable_mem_efficient": False},
90+
SDPBackend.EFFICIENT_ATTENTION: {
91+
"enable_math": False, "enable_flash": False, "enable_mem_efficient": True}
92+
}
93+
94+
with sdp_kernel(**backend_map[SDPBackend.MATH]):
95+
print(f"The math implementation runs in {benchmark_torch_function_in_microseconds(F.scaled_dot_product_attention, query, key, value):.3f} microseconds")
96+
97+
98+
with sdp_kernel(**backend_map[SDPBackend.FLASH_ATTENTION]):
99+
try:
100+
print(f"The flash attention implementation runs in {benchmark_torch_function_in_microseconds(F.scaled_dot_product_attention, query, key, value):.3f} microseconds")
101+
except RuntimeError:
102+
print("FlashAttention is not supported. See warnings for reasons.")
103+
104+
with sdp_kernel(**backend_map[SDPBackend.EFFICIENT_ATTENTION]):
105+
try:
106+
print(f"The memory efficient implementation runs in {benchmark_torch_function_in_microseconds(F.scaled_dot_product_attention, query, key, value):.3f} microseconds")
107+
except RuntimeError:
108+
print("EfficientAttention is not supported. See warnings for reasons.")
109+
110+
111+
######################################################################
112+
# Hardware dependence
113+
# ~~~~~~~~~~~~~~~~~~~
114+
#
115+
# Depending on what machine you ran the above cell on and what hardware is
116+
# available, your results might be different.
117+
# - If you don’t have a GPU and are running on CPU then the context manager
118+
# will have no effect and all three runs should return similar timings.
119+
# - Depending on what compute capability your graphics card supports
120+
# flash attention or memory efficient might have failed.
121+
122+
123+
######################################################################
124+
# Causal Self Attention
125+
# ~~~~~~~~~~~~~~~~~~~~~
126+
#
127+
# Below is an example implementation of a multi-headed causal self
128+
# attention block inspired by Andrej Karpathy’s
129+
# `NanoGPT <https://github.com/karpathy/nanoGPT>`__ repository.
130+
#
131+
132+
class CausalSelfAttention(nn.Module):
133+
134+
def __init__(self, num_heads: int, embed_dimension: int, bias: bool=False, is_causal: bool=False, dropout:float=0.0):
135+
super().__init__()
136+
assert embed_dimension % num_heads == 0
137+
# key, query, value projections for all heads, but in a batch
138+
self.c_attn = nn.Linear(embed_dimension, 3 * embed_dimension, bias=bias)
139+
# output projection
140+
self.c_proj = nn.Linear(embed_dimension, embed_dimension, bias=bias)
141+
# regularization
142+
self.dropout = dropout
143+
self.resid_dropout = nn.Dropout(dropout)
144+
self.num_heads = num_heads
145+
self.embed_dimension = embed_dimension
146+
# Perform causal masking
147+
self.is_causal = is_causal
148+
149+
def forward(self, x):
150+
# calculate query, key, values for all heads in batch and move head forward to be the batch dim
151+
query_projected = self.c_attn(x)
152+
153+
batch_size = query_projected.size(0)
154+
embed_dim = query_projected.size(2)
155+
head_dim = embed_dim // (self.num_heads * 3)
156+
157+
query, key, value = query_projected.chunk(3, -1)
158+
query = query.view(batch_size, -1, self.num_heads, head_dim).transpose(1, 2)
159+
key = key.view(batch_size, -1, self.num_heads, head_dim).transpose(1, 2)
160+
value = value.view(batch_size, -1, self.num_heads, head_dim).transpose(1, 2)
161+
162+
if self.training:
163+
dropout = self.dropout
164+
is_causal = self.is_causal
165+
else:
166+
dropout = 0.0
167+
is_causal = False
168+
169+
y = F.scaled_dot_product_attention(query, key, value, attn_mask=None, dropout_p=dropout, is_causal=is_causal)
170+
y = y.transpose(1, 2).view(batch_size, -1, self.num_heads * head_dim)
171+
172+
y = self.resid_dropout(self.c_proj(y))
173+
return y
174+
175+
176+
num_heads = 8
177+
heads_per_dim = 64
178+
embed_dimension = num_heads * heads_per_dim
179+
dtype = torch.float16
180+
model = CausalSelfAttention(num_heads=num_heads, embed_dimension=embed_dimension, bias=False, is_causal=True, dropout=0.1).to("cuda").to(dtype).eval()
181+
print(model)
182+
183+
184+
######################################################################
185+
# NestedTensor and Dense tensor support
186+
# -------------------------------------
187+
#
188+
# SDPA supports both NestedTensor and Dense tensor inputs. NestedTensors handle the case where the input is a batch of variable length sequences
189+
# without needing to pad each sequence to the maximum length in the batch. For more information about NestedTensors see
190+
# `torch.nested <https://pytorch.org/docs/stable/nested.html>`__ and `NestedTensors Tutorial <https://pytorch.org/tutorials/prototype/nestedtensor.html>`__.
191+
#
192+
193+
import random
194+
def generate_rand_batch(
195+
batch_size,
196+
max_sequence_len,
197+
embed_dimension,
198+
pad_percentage=None,
199+
dtype=torch.float16,
200+
device="cuda",
201+
):
202+
if not pad_percentage:
203+
return (
204+
torch.randn(
205+
batch_size,
206+
max_sequence_len,
207+
embed_dimension,
208+
dtype=dtype,
209+
device=device,
210+
),
211+
None,
212+
)
213+
# Random sequence lengths
214+
seq_len_list = [
215+
int(max_sequence_len * (1 - random.gauss(pad_percentage, 0.01)))
216+
for _ in range(batch_size)
217+
]
218+
# Make random entry in the batch have max sequence length
219+
seq_len_list[random.randint(0, batch_size - 1)] = max_sequence_len
220+
return (
221+
torch.nested.nested_tensor(
222+
[
223+
torch.randn(seq_len, embed_dimension,
224+
dtype=dtype, device=device)
225+
for seq_len in seq_len_list
226+
]
227+
),
228+
seq_len_list,
229+
)
230+
231+
random_nt, _ = generate_rand_batch(32, 512, embed_dimension, pad_percentage=0.5, dtype=dtype, device=device)
232+
random_dense, _ = generate_rand_batch(32, 512, embed_dimension, pad_percentage=None, dtype=dtype, device=device)
233+
234+
# Currently the fused implementations don't support NestedTensor for training
235+
model.eval()
236+
237+
with sdp_kernel(**backend_map[SDPBackend.FLASH_ATTENTION]):
238+
try:
239+
print(f"Random NT runs in {benchmark_torch_function_in_microseconds(model, random_nt):.3f} microseconds")
240+
print(f"Random Dense runs in {benchmark_torch_function_in_microseconds(model, random_dense):.3f} microseconds")
241+
except RuntimeError:
242+
print("FlashAttention is not supported. See warnings for reasons.")
243+
244+
245+
######################################################################
246+
# Using SDPA with torch.compile
247+
# ============================
248+
#
249+
# With the release of PyTorch 2.0, a new feature called
250+
# ``torch.compile()`` has been introduced, which can provide
251+
# significant performance improvements over eager mode.
252+
# Scaled dot product attention is fully composable with ``torch.compile()``.
253+
# To demonstrate this, let's compile the CausalSelfAttention module using
254+
# ``torch.compile()`` and observe the resulting performance improvements.
255+
#
256+
257+
batch_size = 32
258+
max_sequence_len = 256
259+
x = torch.rand(batch_size, max_sequence_len,
260+
embed_dimension, device=device, dtype=dtype)
261+
print(
262+
f"The non compiled module runs in {benchmark_torch_function_in_microseconds(model, x):.3f} microseconds")
263+
264+
265+
compiled_model = torch.compile(model)
266+
# Let's compile it
267+
compiled_model(x)
268+
print(
269+
f"The compiled module runs in {benchmark_torch_function_in_microseconds(compiled_model, x):.3f} microseconds")
270+
271+
272+
######################################################################
273+
#
274+
# The exact execution time is dependent on machine, however the results for mine:
275+
# The non compiled module runs in 166.616 microseconds
276+
# The compiled module runs in 166.726 microseconds
277+
# That is not what we were expecting. Let's dig a little deeper.
278+
# PyTorch comes with an amazing built-in profiler that you can use to
279+
# inspect the performance characteristics of your code.
280+
#
281+
282+
from torch.profiler import profile, record_function, ProfilerActivity
283+
activities = [ProfilerActivity.CPU]
284+
if device == 'cuda':
285+
activities.append(ProfilerActivity.CUDA)
286+
287+
with profile(activities=activities, record_shapes=False) as prof:
288+
with record_function(" Non-Compilied Causal Attention"):
289+
for _ in range(25):
290+
model(x)
291+
print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10))
292+
293+
294+
with profile(activities=activities, record_shapes=False) as prof:
295+
with record_function("Compiled Causal Attention"):
296+
for _ in range(25):
297+
compiled_model(x)
298+
print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10))
299+
300+
# For even more insights, you can export the trace and use ``chrome://tracing`` to view the results
301+
# prof.export_chrome_trace("compiled_causal_attention_trace.json").
302+
303+
304+
305+
306+
######################################################################
307+
# The previous code snippet generates a report of the top 10 PyTorch functions
308+
# that consumed the most GPU execution time, for both the compiled and non-compiled module.
309+
# The analysis reveals that the majority of time spent on the GPU is concentrated
310+
# on the same set of functions for both modules.
311+
# The reason for this here is that ``torch.compile`` is very good at removing the
312+
# framework overhead associated with PyTorch. If your model is launching
313+
# large, efficient CUDA kernels, which in this case CausaulSelfAttention
314+
# is, then the overhead of PyTorch can be hidden.
315+
#
316+
# In reality, your module does not normally consist of a singular
317+
# CausalSelfAttention block. When experimenting with Andrej Karpathy’s
318+
# `NanoGPT <https://github.com/karpathy/nanoGPT>`__ repository, compiling
319+
# the module took the time per train step from: ``6090.49ms`` to
320+
# ``3273.17ms``! This was done on commit: ae3a8d5 of NanoGPT training on
321+
# the shakespeare dataset.
322+
#
323+
324+
325+
######################################################################
326+
# Conclusion
327+
# ==========
328+
#
329+
# In this tutorial, we have demonstrated the basic usage of
330+
# ``torch.nn.functional.scaled_dot_product_attention``. We have shown how
331+
# the ``sdp_kernel`` context manager can be used to assert a certain
332+
# implementation is used on GPU. As well, we built a simple
333+
# CausalSelfAttention module that works with NestedTensor and is torch
334+
# compilable. In the process we have shown how to the profiling tools can
335+
# be used to explore the performance characteristics of a user defined
336+
# module.
337+
#

0 commit comments

Comments
 (0)