1
1
"""
2
- An overview of torch.nn.functional.scaled_dot_product_attention
3
- ===============================================================
2
+ Implementing High-Performance Transformers with SCALED DOT PRODUCT ATTENTION
3
+ ================================================================================
4
4
5
5
"""
6
6
7
7
8
8
######################################################################
9
9
# Summary
10
10
# ~~~~~~~~
11
- #
12
- # In this tutorial we want to highlight a new ``torch.nn.functional`` function
11
+ #
12
+ # In this tutorial, we want to highlight a new ``torch.nn.functional`` function
13
13
# that can be helpful for implementing transformer architectures. The
14
14
# function is named ``torch.nn.functional.scaled_dot_product_attention``.
15
- # There is some extensive documentation on the function in the `PyTorch
16
- # documentation <https://pytorch.org/docs/master/generated/torch.nn.functional.scaled_dot_product_attention.html#torch.nn.functional.scaled_dot_product_attention>`__.
17
- # This function has already been incorporated into torch.nn.MHA
18
- # (Multi-Head Attention) and ``torch.nn.TransformerEncoderLayer``.
19
- #
15
+ # For detailed description of the function, see the `PyTorch documentation <https://pytorch.org/docs/master/generated/torch.nn.functional.scaled_dot_product_attention.html#torch.nn.functional.scaled_dot_product_attention>`__.
16
+ # This function has already been incorporated into ``torch.nn.MultiheadAttention`` and ``torch.nn.TransformerEncoderLayer``.
17
+ #
20
18
# Overview
21
- # ~~~~~~~
22
- # At a high level this PyTorch function calculates the
23
- # scaled dot product attention between query, key, and value according to
19
+ # ~~~~~~~~~
20
+ # At a high level, this PyTorch function calculates the
21
+ # scaled dot product attention (SDPA) between query, key, and value according to
24
22
# the definition found in the paper `Attention is all you
25
- # need <https://arxiv.org/abs/1706.03762>`__. While this function can be
26
- # written in PyTorch using existing functions, for GPU tensors this
27
- # function will implicitly dispatch to an optimized implementation. The
28
- # function is also highly modular and can be used to implement other
29
- # attention mechanisms such as
30
- # `Linformer <https://arxiv.org/abs/2006.04768>`__
31
- #
32
- # Fused implementations:
23
+ # need <https://arxiv.org/abs/1706.03762>`__. While this function can
24
+ # be written in PyTorch using existing functions, a fused implementation can provide
25
+ # large performance benefits over a naive implementation.
26
+ #
27
+ # Fused implementations
33
28
# ~~~~~~~~~~~~~~~~~~~~~~
34
- #
35
- # For CUDA tensor inputs the function will dispatch into one of three
36
- # implementations:
29
+ #
30
+ # For CUDA tensor inputs, the function will dispatch into one of the following
31
+ # implementations:
32
+ #
37
33
# * `FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness <https://arxiv.org/abs/2205.14135>`__
38
- # Attention with IO-Awareness <https://arxiv.org/abs/2205.14135>`__ \*
39
- # `Memory-Efficient
40
- # Attention <https://github.com/facebookresearch/xformers>`__ \* A PyTorch
41
- # implementation defined in C++
42
- #
34
+ # * `Memory-Efficient Attention <https://github.com/facebookresearch/xformers>`__
35
+ # * A PyTorch implementation defined in C++
36
+ #
43
37
44
38
import torch
45
39
import torch .nn as nn
54
48
######################################################################
55
49
# Explicit Dispatcher Control
56
50
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~
57
- #
51
+ #
58
52
# While the function will implicitly dispatch to one of the three
59
53
# implementations, the user can also explicitly control the dispatch via
60
54
# the use of a context manager. This context manager allows users to
61
55
# explicitly disable certain implementations. If a user wants to ensure
62
- # the function is indeed using the fasted implementation for their
63
- # specific inputs the context manager can be used to sweep through
56
+ # the function is indeed using the fastest implementation for their
57
+ # specific inputs, the context manager can be used to sweep through
64
58
# measuring performance.
65
- #
59
+ #
66
60
67
61
# Lets define a helpful benchmarking function:
68
62
import torch .utils .benchmark as benchmark
@@ -102,35 +96,38 @@ def benchmark_torch_function_in_microseconds(f, *args, **kwargs):
102
96
103
97
104
98
with sdp_kernel (** backend_map [SDPBackend .FLASH_ATTENTION ]):
105
- print (f"The flash attention implementation runs in { benchmark_torch_function_in_microseconds (F .scaled_dot_product_attention , query , key , value ):.3f} microseconds" )
106
-
99
+ try :
100
+ print (f"The flash attention implementation runs in { benchmark_torch_function_in_microseconds (F .scaled_dot_product_attention , query , key , value ):.3f} microseconds" )
101
+ except RuntimeError :
102
+ print ("FlashAttention is not supported. See warnings for reasons." )
107
103
108
104
with sdp_kernel (** backend_map [SDPBackend .EFFICIENT_ATTENTION ]):
109
- print (f"The memory efficient implementation runs in { benchmark_torch_function_in_microseconds (F .scaled_dot_product_attention , query , key , value ):.3f} microseconds" )
105
+ try :
106
+ print (f"The memory efficient implementation runs in { benchmark_torch_function_in_microseconds (F .scaled_dot_product_attention , query , key , value ):.3f} microseconds" )
107
+ except RuntimeError :
108
+ print ("EfficientAttention is not supported. See warnings for reasons." )
110
109
111
110
112
111
######################################################################
113
112
# Hardware dependence
114
113
# ~~~~~~~~~~~~~~~~~~~
115
- #
114
+ #
116
115
# Depending on what machine you ran the above cell on and what hardware is
117
- # available your results might be different.
118
- # - If you don’t have a GPU and are running on CPU then the context manager will have no effect and all
119
- # are running on CPU then the context manager will have no effect and all
120
- # three run should return similar timings. - Depending on what Compute
121
- # Capability your graphics card supports FlashAttention or memory
122
- # efficient might have failed.
123
- #
116
+ # available, your results might be different.
117
+ # - If you don’t have a GPU and are running on CPU then the context manager
118
+ # will have no effect and all three runs should return similar timings.
119
+ # - Depending on what compute capability your graphics card supports
120
+ # flash attention or memory efficient might have failed.
124
121
125
122
126
123
######################################################################
127
124
# Causal Self Attention
128
125
# ~~~~~~~~~~~~~~~~~~~~~
129
- #
126
+ #
130
127
# Below is an example implementation of a multi-headed causal self
131
128
# attention block inspired by Andrej Karpathy’s
132
129
# `NanoGPT <https://github.com/karpathy/nanoGPT>`__ repository.
133
- #
130
+ #
134
131
135
132
class CausalSelfAttention (nn .Module ):
136
133
@@ -187,7 +184,11 @@ def forward(self, x):
187
184
######################################################################
188
185
# NestedTensor and Dense tensor support
189
186
# -------------------------------------
190
- #
187
+ #
188
+ # SDPA supports both NestedTensor and Dense tensor inputs. NestedTensors handle the case where the input is a batch of variable length sequences
189
+ # without needing to pad each sequence to the maximum length in the batch. For more information about NestedTensors see
190
+ # `torch.nested <https://pytorch.org/docs/stable/nested.html>`__ and `NestedTensors Tutorial <https://pytorch.org/tutorials/prototype/nestedtensor.html>`__.
191
+ #
191
192
192
193
import random
193
194
def generate_rand_batch (
@@ -227,21 +228,31 @@ def generate_rand_batch(
227
228
seq_len_list ,
228
229
)
229
230
230
- # Currently the fastpaths don't support NestedTensor for training
231
231
random_nt , _ = generate_rand_batch (32 , 512 , embed_dimension , pad_percentage = 0.5 , dtype = dtype , device = device )
232
232
random_dense , _ = generate_rand_batch (32 , 512 , embed_dimension , pad_percentage = None , dtype = dtype , device = device )
233
- model .requires_grad_ (False )
234
- print (f"Random NT runs in { benchmark_torch_function_in_microseconds (model , random_nt ):.3f} microseconds" )
235
- print (f"Random Dense runs in { benchmark_torch_function_in_microseconds (model , random_dense ):.3f} microseconds" )
233
+
234
+ # Currently the fused implementations don't support NestedTensor for training
235
+ model .eval ()
236
+
237
+ with sdp_kernel (** backend_map [SDPBackend .FLASH_ATTENTION ]):
238
+ try :
239
+ print (f"Random NT runs in { benchmark_torch_function_in_microseconds (model , random_nt ):.3f} microseconds" )
240
+ print (f"Random Dense runs in { benchmark_torch_function_in_microseconds (model , random_dense ):.3f} microseconds" )
241
+ except RuntimeError :
242
+ print ("FlashAttention is not supported. See warnings for reasons." )
236
243
237
244
238
245
######################################################################
239
- # Composable with 2.0 Features
246
+ # Using SDPA with torch.compile
240
247
# ============================
241
- #
242
- # Scaled dot product attention is composable with torch.compile(). Lets
243
- # try compiling the above CausalSelfAttention module
244
- #
248
+ #
249
+ # With the release of PyTorch 2.0, a new feature called
250
+ # ``torch.compile()`` has been introduced, which can provide
251
+ # significant performance improvements over eager mode.
252
+ # Scaled dot product attention is fully composable with ``torch.compile()``.
253
+ # To demonstrate this, let's compile the CausalSelfAttention module using
254
+ # ``torch.compile()`` and observe the resulting performance improvements.
255
+ #
245
256
246
257
batch_size = 32
247
258
max_sequence_len = 256
@@ -252,20 +263,21 @@ def generate_rand_batch(
252
263
253
264
254
265
compiled_model = torch .compile (model )
255
- # Lets warm it up once
266
+ # Let's compile it
256
267
compiled_model (x )
257
268
print (
258
269
f"The compiled module runs in { benchmark_torch_function_in_microseconds (compiled_model , x ):.3f} microseconds" )
259
270
260
271
261
272
######################################################################
262
- # HMM..
263
- # ~~~~~
264
- #
265
- # That is not what we were expecting. Let's dig a little deeper.
273
+ #
274
+ # The exact execution time is dependent on machine, however the results for mine:
275
+ # The non compiled module runs in 166.616 microseconds
276
+ # The compiled module runs in 166.726 microseconds
277
+ # That is not what we were expecting. Let's dig a little deeper.
266
278
# PyTorch comes with an amazing built-in profiler that you can use to
267
279
# inspect the performance characteristics of your code.
268
- #
280
+ #
269
281
270
282
from torch .profiler import profile , record_function , ProfilerActivity
271
283
activities = [ProfilerActivity .CPU ]
@@ -276,39 +288,44 @@ def generate_rand_batch(
276
288
with record_function (" Non-Compilied Causal Attention" ):
277
289
for _ in range (25 ):
278
290
model (x )
279
- print (prof .key_averages ().table (sort_by = "cpu_time_total " , row_limit = 20 ))
291
+ print (prof .key_averages ().table (sort_by = "cuda_time_total " , row_limit = 10 ))
280
292
281
293
282
294
with profile (activities = activities , record_shapes = False ) as prof :
283
295
with record_function ("Compiled Causal Attention" ):
284
296
for _ in range (25 ):
285
297
compiled_model (x )
286
- print (prof .key_averages ().table (sort_by = "cpu_time_total " , row_limit = 20 ))
298
+ print (prof .key_averages ().table (sort_by = "cuda_time_total " , row_limit = 10 ))
287
299
288
300
# For even more insights, you can export the trace and use ``chrome://tracing`` to view the results
289
- # prof.export_chrome_trace("compiled_causal_attention_trace.json")
301
+ # prof.export_chrome_trace("compiled_causal_attention_trace.json").
290
302
291
303
292
304
293
305
294
306
######################################################################
295
- # The problem here is that ``torch.compile`` is very good at removing the
307
+ # The previous code snippet generates a report of the top 10 PyTorch functions
308
+ # that consumed the most GPU execution time, for both the compiled and non-compiled module.
309
+ # The analysis reveals that the majority of time spent on the GPU is concentrated
310
+ # on the same set of functions for both modules.
311
+ # The reason for this here is that ``torch.compile`` is very good at removing the
296
312
# framework overhead associated with PyTorch. If your model is launching
297
313
# large, efficient CUDA kernels, which in this case CausaulSelfAttention
298
- # is, then the overhead of ``torch.compile`` can hurt performance .
299
- #
314
+ # is, then the overhead of PyTorch can be hidden .
315
+ #
300
316
# In reality, your module does not normally consist of a singular
301
317
# CausalSelfAttention block. When experimenting with Andrej Karpathy’s
302
318
# `NanoGPT <https://github.com/karpathy/nanoGPT>`__ repository, compiling
303
- # the module took the time per train step from: ``902.01ms`` to
304
- # ``552.06ms``!
305
- #
319
+ # the module took the time per train step from: ``6090.49ms`` to
320
+ # ``3273.17ms``! This was done on commit: ae3a8d5 of NanoGPT training on
321
+ # the shakespeare dataset.
322
+ #
306
323
307
324
308
325
######################################################################
309
326
# Conclusion
310
327
# ==========
311
- #
328
+ #
312
329
# In this tutorial, we have demonstrated the basic usage of
313
330
# ``torch.nn.functional.scaled_dot_product_attention``. We have shown how
314
331
# the ``sdp_kernel`` context manager can be used to assert a certain
@@ -317,4 +334,4 @@ def generate_rand_batch(
317
334
# compilable. In the process we have shown how to the profiling tools can
318
335
# be used to explore the performance characteristics of a user defined
319
336
# module.
320
- #
337
+ #
0 commit comments