Skip to content

Commit 65cebeb

Browse files
committed
fixes
1 parent eefd7f0 commit 65cebeb

File tree

2 files changed

+92
-75
lines changed

2 files changed

+92
-75
lines changed

index.rst

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -525,10 +525,10 @@ What's new in PyTorch tutorials?
525525
:tags: Model-Optimization
526526

527527
.. customcarditem::
528-
:header: (beta) An overview of torch.nn.functional.scaled_dot_product_attention
528+
:header: (beta) Implementing High-Performance Transformers with SCALED DOT PRODUCT ATTENTION
529529
:card_description: This tutorial explores the new torch.nn.functional.scaled_dot_product_attention and how it can be used to construct Transformer components.
530530
:image: _static/img/thumbnails/cropped/pytorch-logo.png
531-
:link: beginner/scaled_dot_product_attention_tutorial.html
531+
:link: intermediate/scaled_dot_product_attention_tutorial.html
532532
:tags: Model-Optimization,Attention,Transformer
533533

534534
.. Parallel-and-Distributed-Training
@@ -916,7 +916,7 @@ Additional Resources
916916
intermediate/nvfuser_intro_tutorial
917917
intermediate/ax_multiobjective_nas_tutorial
918918
intermediate/torch_compile_tutorial
919-
beginner/scaled_dot_product_attention_tutorial
919+
intermediate/scaled_dot_product_attention_tutorial
920920

921921
.. toctree::
922922
:maxdepth: 2

beginner_source/scaled_dot_product_attention_tutorial.py renamed to intermediate_source/scaled_dot_product_attention_tutorial.py

Lines changed: 89 additions & 72 deletions
Original file line numberDiff line numberDiff line change
@@ -1,45 +1,39 @@
11
"""
2-
An overview of torch.nn.functional.scaled_dot_product_attention
3-
===============================================================
2+
Implementing High-Performance Transformers with SCALED DOT PRODUCT ATTENTION
3+
================================================================================
44
55
"""
66

77

88
######################################################################
99
# Summary
1010
# ~~~~~~~~
11-
#
12-
# In this tutorial we want to highlight a new ``torch.nn.functional`` function
11+
#
12+
# In this tutorial, we want to highlight a new ``torch.nn.functional`` function
1313
# that can be helpful for implementing transformer architectures. The
1414
# function is named ``torch.nn.functional.scaled_dot_product_attention``.
15-
# There is some extensive documentation on the function in the `PyTorch
16-
# documentation <https://pytorch.org/docs/master/generated/torch.nn.functional.scaled_dot_product_attention.html#torch.nn.functional.scaled_dot_product_attention>`__.
17-
# This function has already been incorporated into torch.nn.MHA
18-
# (Multi-Head Attention) and ``torch.nn.TransformerEncoderLayer``.
19-
#
15+
# For detailed description of the function, see the `PyTorch documentation <https://pytorch.org/docs/master/generated/torch.nn.functional.scaled_dot_product_attention.html#torch.nn.functional.scaled_dot_product_attention>`__.
16+
# This function has already been incorporated into ``torch.nn.MultiheadAttention`` and ``torch.nn.TransformerEncoderLayer``.
17+
#
2018
# Overview
21-
# ~~~~~~~
22-
# At a high level this PyTorch function calculates the
23-
# scaled dot product attention between query, key, and value according to
19+
# ~~~~~~~~~
20+
# At a high level, this PyTorch function calculates the
21+
# scaled dot product attention (SDPA) between query, key, and value according to
2422
# the definition found in the paper `Attention is all you
25-
# need <https://arxiv.org/abs/1706.03762>`__. While this function can be
26-
# written in PyTorch using existing functions, for GPU tensors this
27-
# function will implicitly dispatch to an optimized implementation. The
28-
# function is also highly modular and can be used to implement other
29-
# attention mechanisms such as
30-
# `Linformer <https://arxiv.org/abs/2006.04768>`__
31-
#
32-
# Fused implementations:
23+
# need <https://arxiv.org/abs/1706.03762>`__. While this function can
24+
# be written in PyTorch using existing functions, a fused implementation can provide
25+
# large performance benefits over a naive implementation.
26+
#
27+
# Fused implementations
3328
# ~~~~~~~~~~~~~~~~~~~~~~
34-
#
35-
# For CUDA tensor inputs the function will dispatch into one of three
36-
# implementations:
29+
#
30+
# For CUDA tensor inputs, the function will dispatch into one of the following
31+
# implementations:
32+
#
3733
# * `FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness <https://arxiv.org/abs/2205.14135>`__
38-
# Attention with IO-Awareness <https://arxiv.org/abs/2205.14135>`__ \*
39-
# `Memory-Efficient
40-
# Attention <https://github.com/facebookresearch/xformers>`__ \* A PyTorch
41-
# implementation defined in C++
42-
#
34+
# * `Memory-Efficient Attention <https://github.com/facebookresearch/xformers>`__
35+
# * A PyTorch implementation defined in C++
36+
#
4337

4438
import torch
4539
import torch.nn as nn
@@ -54,15 +48,15 @@
5448
######################################################################
5549
# Explicit Dispatcher Control
5650
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~
57-
#
51+
#
5852
# While the function will implicitly dispatch to one of the three
5953
# implementations, the user can also explicitly control the dispatch via
6054
# the use of a context manager. This context manager allows users to
6155
# explicitly disable certain implementations. If a user wants to ensure
62-
# the function is indeed using the fasted implementation for their
63-
# specific inputs the context manager can be used to sweep through
56+
# the function is indeed using the fastest implementation for their
57+
# specific inputs, the context manager can be used to sweep through
6458
# measuring performance.
65-
#
59+
#
6660

6761
# Lets define a helpful benchmarking function:
6862
import torch.utils.benchmark as benchmark
@@ -102,35 +96,38 @@ def benchmark_torch_function_in_microseconds(f, *args, **kwargs):
10296

10397

10498
with sdp_kernel(**backend_map[SDPBackend.FLASH_ATTENTION]):
105-
print(f"The flash attention implementation runs in {benchmark_torch_function_in_microseconds(F.scaled_dot_product_attention, query, key, value):.3f} microseconds")
106-
99+
try:
100+
print(f"The flash attention implementation runs in {benchmark_torch_function_in_microseconds(F.scaled_dot_product_attention, query, key, value):.3f} microseconds")
101+
except RuntimeError:
102+
print("FlashAttention is not supported. See warnings for reasons.")
107103

108104
with sdp_kernel(**backend_map[SDPBackend.EFFICIENT_ATTENTION]):
109-
print(f"The memory efficient implementation runs in {benchmark_torch_function_in_microseconds(F.scaled_dot_product_attention, query, key, value):.3f} microseconds")
105+
try:
106+
print(f"The memory efficient implementation runs in {benchmark_torch_function_in_microseconds(F.scaled_dot_product_attention, query, key, value):.3f} microseconds")
107+
except RuntimeError:
108+
print("EfficientAttention is not supported. See warnings for reasons.")
110109

111110

112111
######################################################################
113112
# Hardware dependence
114113
# ~~~~~~~~~~~~~~~~~~~
115-
#
114+
#
116115
# Depending on what machine you ran the above cell on and what hardware is
117-
# available your results might be different.
118-
# - If you don’t have a GPU and are running on CPU then the context manager will have no effect and all
119-
# are running on CPU then the context manager will have no effect and all
120-
# three run should return similar timings. - Depending on what Compute
121-
# Capability your graphics card supports FlashAttention or memory
122-
# efficient might have failed.
123-
#
116+
# available, your results might be different.
117+
# - If you don’t have a GPU and are running on CPU then the context manager
118+
# will have no effect and all three runs should return similar timings.
119+
# - Depending on what compute capability your graphics card supports
120+
# flash attention or memory efficient might have failed.
124121

125122

126123
######################################################################
127124
# Causal Self Attention
128125
# ~~~~~~~~~~~~~~~~~~~~~
129-
#
126+
#
130127
# Below is an example implementation of a multi-headed causal self
131128
# attention block inspired by Andrej Karpathy’s
132129
# `NanoGPT <https://github.com/karpathy/nanoGPT>`__ repository.
133-
#
130+
#
134131

135132
class CausalSelfAttention(nn.Module):
136133

@@ -187,7 +184,11 @@ def forward(self, x):
187184
######################################################################
188185
# NestedTensor and Dense tensor support
189186
# -------------------------------------
190-
#
187+
#
188+
# SDPA supports both NestedTensor and Dense tensor inputs. NestedTensors handle the case where the input is a batch of variable length sequences
189+
# without needing to pad each sequence to the maximum length in the batch. For more information about NestedTensors see
190+
# `torch.nested <https://pytorch.org/docs/stable/nested.html>`__ and `NestedTensors Tutorial <https://pytorch.org/tutorials/prototype/nestedtensor.html>`__.
191+
#
191192

192193
import random
193194
def generate_rand_batch(
@@ -227,21 +228,31 @@ def generate_rand_batch(
227228
seq_len_list,
228229
)
229230

230-
# Currently the fastpaths don't support NestedTensor for training
231231
random_nt, _ = generate_rand_batch(32, 512, embed_dimension, pad_percentage=0.5, dtype=dtype, device=device)
232232
random_dense, _ = generate_rand_batch(32, 512, embed_dimension, pad_percentage=None, dtype=dtype, device=device)
233-
model.requires_grad_(False)
234-
print(f"Random NT runs in {benchmark_torch_function_in_microseconds(model, random_nt):.3f} microseconds")
235-
print(f"Random Dense runs in {benchmark_torch_function_in_microseconds(model, random_dense):.3f} microseconds")
233+
234+
# Currently the fused implementations don't support NestedTensor for training
235+
model.eval()
236+
237+
with sdp_kernel(**backend_map[SDPBackend.FLASH_ATTENTION]):
238+
try:
239+
print(f"Random NT runs in {benchmark_torch_function_in_microseconds(model, random_nt):.3f} microseconds")
240+
print(f"Random Dense runs in {benchmark_torch_function_in_microseconds(model, random_dense):.3f} microseconds")
241+
except RuntimeError:
242+
print("FlashAttention is not supported. See warnings for reasons.")
236243

237244

238245
######################################################################
239-
# Composable with 2.0 Features
246+
# Using SDPA with torch.compile
240247
# ============================
241-
#
242-
# Scaled dot product attention is composable with torch.compile(). Lets
243-
# try compiling the above CausalSelfAttention module
244-
#
248+
#
249+
# With the release of PyTorch 2.0, a new feature called
250+
# ``torch.compile()`` has been introduced, which can provide
251+
# significant performance improvements over eager mode.
252+
# Scaled dot product attention is fully composable with ``torch.compile()``.
253+
# To demonstrate this, let's compile the CausalSelfAttention module using
254+
# ``torch.compile()`` and observe the resulting performance improvements.
255+
#
245256

246257
batch_size = 32
247258
max_sequence_len = 256
@@ -252,20 +263,21 @@ def generate_rand_batch(
252263

253264

254265
compiled_model = torch.compile(model)
255-
# Lets warm it up once
266+
# Let's compile it
256267
compiled_model(x)
257268
print(
258269
f"The compiled module runs in {benchmark_torch_function_in_microseconds(compiled_model, x):.3f} microseconds")
259270

260271

261272
######################################################################
262-
# HMM..
263-
# ~~~~~
264-
#
265-
# That is not what we were expecting. Let's dig a little deeper.
273+
#
274+
# The exact execution time is dependent on machine, however the results for mine:
275+
# The non compiled module runs in 166.616 microseconds
276+
# The compiled module runs in 166.726 microseconds
277+
# That is not what we were expecting. Let's dig a little deeper.
266278
# PyTorch comes with an amazing built-in profiler that you can use to
267279
# inspect the performance characteristics of your code.
268-
#
280+
#
269281

270282
from torch.profiler import profile, record_function, ProfilerActivity
271283
activities = [ProfilerActivity.CPU]
@@ -276,39 +288,44 @@ def generate_rand_batch(
276288
with record_function(" Non-Compilied Causal Attention"):
277289
for _ in range(25):
278290
model(x)
279-
print(prof.key_averages().table(sort_by="cpu_time_total", row_limit=20))
291+
print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10))
280292

281293

282294
with profile(activities=activities, record_shapes=False) as prof:
283295
with record_function("Compiled Causal Attention"):
284296
for _ in range(25):
285297
compiled_model(x)
286-
print(prof.key_averages().table(sort_by="cpu_time_total", row_limit=20))
298+
print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10))
287299

288300
# For even more insights, you can export the trace and use ``chrome://tracing`` to view the results
289-
# prof.export_chrome_trace("compiled_causal_attention_trace.json")
301+
# prof.export_chrome_trace("compiled_causal_attention_trace.json").
290302

291303

292304

293305

294306
######################################################################
295-
# The problem here is that ``torch.compile`` is very good at removing the
307+
# The previous code snippet generates a report of the top 10 PyTorch functions
308+
# that consumed the most GPU execution time, for both the compiled and non-compiled module.
309+
# The analysis reveals that the majority of time spent on the GPU is concentrated
310+
# on the same set of functions for both modules.
311+
# The reason for this here is that ``torch.compile`` is very good at removing the
296312
# framework overhead associated with PyTorch. If your model is launching
297313
# large, efficient CUDA kernels, which in this case CausaulSelfAttention
298-
# is, then the overhead of ``torch.compile`` can hurt performance.
299-
#
314+
# is, then the overhead of PyTorch can be hidden.
315+
#
300316
# In reality, your module does not normally consist of a singular
301317
# CausalSelfAttention block. When experimenting with Andrej Karpathy’s
302318
# `NanoGPT <https://github.com/karpathy/nanoGPT>`__ repository, compiling
303-
# the module took the time per train step from: ``902.01ms`` to
304-
# ``552.06ms``!
305-
#
319+
# the module took the time per train step from: ``6090.49ms`` to
320+
# ``3273.17ms``! This was done on commit: ae3a8d5 of NanoGPT training on
321+
# the shakespeare dataset.
322+
#
306323

307324

308325
######################################################################
309326
# Conclusion
310327
# ==========
311-
#
328+
#
312329
# In this tutorial, we have demonstrated the basic usage of
313330
# ``torch.nn.functional.scaled_dot_product_attention``. We have shown how
314331
# the ``sdp_kernel`` context manager can be used to assert a certain
@@ -317,4 +334,4 @@ def generate_rand_batch(
317334
# compilable. In the process we have shown how to the profiling tools can
318335
# be used to explore the performance characteristics of a user defined
319336
# module.
320-
#
337+
#

0 commit comments

Comments
 (0)