1
+ """
2
+ Implementing High-Performance Transformers with SCALED DOT PRODUCT ATTENTION
3
+ ================================================================================
4
+
5
+ """
6
+
7
+
8
+ ######################################################################
9
+ # Summary
10
+ # ~~~~~~~~
11
+ #
12
+ # In this tutorial, we want to highlight a new ``torch.nn.functional`` function
13
+ # that can be helpful for implementing transformer architectures. The
14
+ # function is named ``torch.nn.functional.scaled_dot_product_attention``.
15
+ # For detailed description of the function, see the `PyTorch documentation <https://pytorch.org/docs/master/generated/torch.nn.functional.scaled_dot_product_attention.html#torch.nn.functional.scaled_dot_product_attention>`__.
16
+ # This function has already been incorporated into ``torch.nn.MultiheadAttention`` and ``torch.nn.TransformerEncoderLayer``.
17
+ #
18
+ # Overview
19
+ # ~~~~~~~~~
20
+ # At a high level, this PyTorch function calculates the
21
+ # scaled dot product attention (SDPA) between query, key, and value according to
22
+ # the definition found in the paper `Attention is all you
23
+ # need <https://arxiv.org/abs/1706.03762>`__. While this function can
24
+ # be written in PyTorch using existing functions, a fused implementation can provide
25
+ # large performance benefits over a naive implementation.
26
+ #
27
+ # Fused implementations
28
+ # ~~~~~~~~~~~~~~~~~~~~~~
29
+ #
30
+ # For CUDA tensor inputs, the function will dispatch into one of the following
31
+ # implementations:
32
+ #
33
+ # * `FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness <https://arxiv.org/abs/2205.14135>`__
34
+ # * `Memory-Efficient Attention <https://github.com/facebookresearch/xformers>`__
35
+ # * A PyTorch implementation defined in C++
36
+ #
37
+
38
+ import torch
39
+ import torch .nn as nn
40
+ import torch .nn .functional as F
41
+ device = "cuda" if torch .cuda .is_available () else "cpu"
42
+
43
+ # Example Usage:
44
+ query , key , value = torch .randn (2 , 3 , 8 , device = device ), torch .randn (2 , 3 , 8 , device = device ), torch .randn (2 , 3 , 8 , device = device )
45
+ F .scaled_dot_product_attention (query , key , value )
46
+
47
+
48
+ ######################################################################
49
+ # Explicit Dispatcher Control
50
+ # ~~~~~~~~~~~~~~~~~~~~~~~~~~~
51
+ #
52
+ # While the function will implicitly dispatch to one of the three
53
+ # implementations, the user can also explicitly control the dispatch via
54
+ # the use of a context manager. This context manager allows users to
55
+ # explicitly disable certain implementations. If a user wants to ensure
56
+ # the function is indeed using the fastest implementation for their
57
+ # specific inputs, the context manager can be used to sweep through
58
+ # measuring performance.
59
+ #
60
+
61
+ # Lets define a helpful benchmarking function:
62
+ import torch .utils .benchmark as benchmark
63
+ def benchmark_torch_function_in_microseconds (f , * args , ** kwargs ):
64
+ t0 = benchmark .Timer (
65
+ stmt = "f(*args, **kwargs)" , globals = {"args" : args , "kwargs" : kwargs , "f" : f }
66
+ )
67
+ return t0 .blocked_autorange ().mean * 1e6
68
+
69
+ # Lets define the hyper-parameters of our input
70
+ batch_size = 32
71
+ max_sequence_len = 1024
72
+ num_heads = 32
73
+ embed_dimension = 32
74
+
75
+ dtype = torch .float16
76
+
77
+ query = torch .rand (batch_size , num_heads , max_sequence_len , embed_dimension , device = device , dtype = dtype )
78
+ key = torch .rand (batch_size , num_heads , max_sequence_len , embed_dimension , device = device , dtype = dtype )
79
+ value = torch .rand (batch_size , num_heads , max_sequence_len , embed_dimension , device = device , dtype = dtype )
80
+
81
+ print (f"The default implementation runs in { benchmark_torch_function_in_microseconds (F .scaled_dot_product_attention , query , key , value ):.3f} microseconds" )
82
+
83
+ # Lets explore the speed of each of the 3 implementations
84
+ from torch .backends .cuda import sdp_kernel , SDPBackend
85
+
86
+ # Helpful arg mapper
87
+ backend_map = {
88
+ SDPBackend .MATH : {"enable_math" : True , "enable_flash" : False , "enable_mem_efficient" : False },
89
+ SDPBackend .FLASH_ATTENTION : {"enable_math" : False , "enable_flash" : True , "enable_mem_efficient" : False },
90
+ SDPBackend .EFFICIENT_ATTENTION : {
91
+ "enable_math" : False , "enable_flash" : False , "enable_mem_efficient" : True }
92
+ }
93
+
94
+ with sdp_kernel (** backend_map [SDPBackend .MATH ]):
95
+ print (f"The math implementation runs in { benchmark_torch_function_in_microseconds (F .scaled_dot_product_attention , query , key , value ):.3f} microseconds" )
96
+
97
+
98
+ with sdp_kernel (** backend_map [SDPBackend .FLASH_ATTENTION ]):
99
+ try :
100
+ print (f"The flash attention implementation runs in { benchmark_torch_function_in_microseconds (F .scaled_dot_product_attention , query , key , value ):.3f} microseconds" )
101
+ except RuntimeError :
102
+ print ("FlashAttention is not supported. See warnings for reasons." )
103
+
104
+ with sdp_kernel (** backend_map [SDPBackend .EFFICIENT_ATTENTION ]):
105
+ try :
106
+ print (f"The memory efficient implementation runs in { benchmark_torch_function_in_microseconds (F .scaled_dot_product_attention , query , key , value ):.3f} microseconds" )
107
+ except RuntimeError :
108
+ print ("EfficientAttention is not supported. See warnings for reasons." )
109
+
110
+
111
+ ######################################################################
112
+ # Hardware dependence
113
+ # ~~~~~~~~~~~~~~~~~~~
114
+ #
115
+ # Depending on what machine you ran the above cell on and what hardware is
116
+ # available, your results might be different.
117
+ # - If you don’t have a GPU and are running on CPU then the context manager
118
+ # will have no effect and all three runs should return similar timings.
119
+ # - Depending on what compute capability your graphics card supports
120
+ # flash attention or memory efficient might have failed.
121
+
122
+
123
+ ######################################################################
124
+ # Causal Self Attention
125
+ # ~~~~~~~~~~~~~~~~~~~~~
126
+ #
127
+ # Below is an example implementation of a multi-headed causal self
128
+ # attention block inspired by Andrej Karpathy’s
129
+ # `NanoGPT <https://github.com/karpathy/nanoGPT>`__ repository.
130
+ #
131
+
132
+ class CausalSelfAttention (nn .Module ):
133
+
134
+ def __init__ (self , num_heads : int , embed_dimension : int , bias : bool = False , is_causal : bool = False , dropout :float = 0.0 ):
135
+ super ().__init__ ()
136
+ assert embed_dimension % num_heads == 0
137
+ # key, query, value projections for all heads, but in a batch
138
+ self .c_attn = nn .Linear (embed_dimension , 3 * embed_dimension , bias = bias )
139
+ # output projection
140
+ self .c_proj = nn .Linear (embed_dimension , embed_dimension , bias = bias )
141
+ # regularization
142
+ self .dropout = dropout
143
+ self .resid_dropout = nn .Dropout (dropout )
144
+ self .num_heads = num_heads
145
+ self .embed_dimension = embed_dimension
146
+ # Perform causal masking
147
+ self .is_causal = is_causal
148
+
149
+ def forward (self , x ):
150
+ # calculate query, key, values for all heads in batch and move head forward to be the batch dim
151
+ query_projected = self .c_attn (x )
152
+
153
+ batch_size = query_projected .size (0 )
154
+ embed_dim = query_projected .size (2 )
155
+ head_dim = embed_dim // (self .num_heads * 3 )
156
+
157
+ query , key , value = query_projected .chunk (3 , - 1 )
158
+ query = query .view (batch_size , - 1 , self .num_heads , head_dim ).transpose (1 , 2 )
159
+ key = key .view (batch_size , - 1 , self .num_heads , head_dim ).transpose (1 , 2 )
160
+ value = value .view (batch_size , - 1 , self .num_heads , head_dim ).transpose (1 , 2 )
161
+
162
+ if self .training :
163
+ dropout = self .dropout
164
+ is_causal = self .is_causal
165
+ else :
166
+ dropout = 0.0
167
+ is_causal = False
168
+
169
+ y = F .scaled_dot_product_attention (query , key , value , attn_mask = None , dropout_p = dropout , is_causal = is_causal )
170
+ y = y .transpose (1 , 2 ).view (batch_size , - 1 , self .num_heads * head_dim )
171
+
172
+ y = self .resid_dropout (self .c_proj (y ))
173
+ return y
174
+
175
+
176
+ num_heads = 8
177
+ heads_per_dim = 64
178
+ embed_dimension = num_heads * heads_per_dim
179
+ dtype = torch .float16
180
+ model = CausalSelfAttention (num_heads = num_heads , embed_dimension = embed_dimension , bias = False , is_causal = True , dropout = 0.1 ).to ("cuda" ).to (dtype ).eval ()
181
+ print (model )
182
+
183
+
184
+ ######################################################################
185
+ # NestedTensor and Dense tensor support
186
+ # -------------------------------------
187
+ #
188
+ # SDPA supports both NestedTensor and Dense tensor inputs. NestedTensors handle the case where the input is a batch of variable length sequences
189
+ # without needing to pad each sequence to the maximum length in the batch. For more information about NestedTensors see
190
+ # `torch.nested <https://pytorch.org/docs/stable/nested.html>`__ and `NestedTensors Tutorial <https://pytorch.org/tutorials/prototype/nestedtensor.html>`__.
191
+ #
192
+
193
+ import random
194
+ def generate_rand_batch (
195
+ batch_size ,
196
+ max_sequence_len ,
197
+ embed_dimension ,
198
+ pad_percentage = None ,
199
+ dtype = torch .float16 ,
200
+ device = "cuda" ,
201
+ ):
202
+ if not pad_percentage :
203
+ return (
204
+ torch .randn (
205
+ batch_size ,
206
+ max_sequence_len ,
207
+ embed_dimension ,
208
+ dtype = dtype ,
209
+ device = device ,
210
+ ),
211
+ None ,
212
+ )
213
+ # Random sequence lengths
214
+ seq_len_list = [
215
+ int (max_sequence_len * (1 - random .gauss (pad_percentage , 0.01 )))
216
+ for _ in range (batch_size )
217
+ ]
218
+ # Make random entry in the batch have max sequence length
219
+ seq_len_list [random .randint (0 , batch_size - 1 )] = max_sequence_len
220
+ return (
221
+ torch .nested .nested_tensor (
222
+ [
223
+ torch .randn (seq_len , embed_dimension ,
224
+ dtype = dtype , device = device )
225
+ for seq_len in seq_len_list
226
+ ]
227
+ ),
228
+ seq_len_list ,
229
+ )
230
+
231
+ random_nt , _ = generate_rand_batch (32 , 512 , embed_dimension , pad_percentage = 0.5 , dtype = dtype , device = device )
232
+ random_dense , _ = generate_rand_batch (32 , 512 , embed_dimension , pad_percentage = None , dtype = dtype , device = device )
233
+
234
+ # Currently the fused implementations don't support NestedTensor for training
235
+ model .eval ()
236
+
237
+ with sdp_kernel (** backend_map [SDPBackend .FLASH_ATTENTION ]):
238
+ try :
239
+ print (f"Random NT runs in { benchmark_torch_function_in_microseconds (model , random_nt ):.3f} microseconds" )
240
+ print (f"Random Dense runs in { benchmark_torch_function_in_microseconds (model , random_dense ):.3f} microseconds" )
241
+ except RuntimeError :
242
+ print ("FlashAttention is not supported. See warnings for reasons." )
243
+
244
+
245
+ ######################################################################
246
+ # Using SDPA with torch.compile
247
+ # ============================
248
+ #
249
+ # With the release of PyTorch 2.0, a new feature called
250
+ # ``torch.compile()`` has been introduced, which can provide
251
+ # significant performance improvements over eager mode.
252
+ # Scaled dot product attention is fully composable with ``torch.compile()``.
253
+ # To demonstrate this, let's compile the CausalSelfAttention module using
254
+ # ``torch.compile()`` and observe the resulting performance improvements.
255
+ #
256
+
257
+ batch_size = 32
258
+ max_sequence_len = 256
259
+ x = torch .rand (batch_size , max_sequence_len ,
260
+ embed_dimension , device = device , dtype = dtype )
261
+ print (
262
+ f"The non compiled module runs in { benchmark_torch_function_in_microseconds (model , x ):.3f} microseconds" )
263
+
264
+
265
+ compiled_model = torch .compile (model )
266
+ # Let's compile it
267
+ compiled_model (x )
268
+ print (
269
+ f"The compiled module runs in { benchmark_torch_function_in_microseconds (compiled_model , x ):.3f} microseconds" )
270
+
271
+
272
+ ######################################################################
273
+ #
274
+ # The exact execution time is dependent on machine, however the results for mine:
275
+ # The non compiled module runs in 166.616 microseconds
276
+ # The compiled module runs in 166.726 microseconds
277
+ # That is not what we were expecting. Let's dig a little deeper.
278
+ # PyTorch comes with an amazing built-in profiler that you can use to
279
+ # inspect the performance characteristics of your code.
280
+ #
281
+
282
+ from torch .profiler import profile , record_function , ProfilerActivity
283
+ activities = [ProfilerActivity .CPU ]
284
+ if device == 'cuda' :
285
+ activities .append (ProfilerActivity .CUDA )
286
+
287
+ with profile (activities = activities , record_shapes = False ) as prof :
288
+ with record_function (" Non-Compilied Causal Attention" ):
289
+ for _ in range (25 ):
290
+ model (x )
291
+ print (prof .key_averages ().table (sort_by = "cuda_time_total" , row_limit = 10 ))
292
+
293
+
294
+ with profile (activities = activities , record_shapes = False ) as prof :
295
+ with record_function ("Compiled Causal Attention" ):
296
+ for _ in range (25 ):
297
+ compiled_model (x )
298
+ print (prof .key_averages ().table (sort_by = "cuda_time_total" , row_limit = 10 ))
299
+
300
+ # For even more insights, you can export the trace and use ``chrome://tracing`` to view the results
301
+ # prof.export_chrome_trace("compiled_causal_attention_trace.json").
302
+
303
+
304
+
305
+
306
+ ######################################################################
307
+ # The previous code snippet generates a report of the top 10 PyTorch functions
308
+ # that consumed the most GPU execution time, for both the compiled and non-compiled module.
309
+ # The analysis reveals that the majority of time spent on the GPU is concentrated
310
+ # on the same set of functions for both modules.
311
+ # The reason for this here is that ``torch.compile`` is very good at removing the
312
+ # framework overhead associated with PyTorch. If your model is launching
313
+ # large, efficient CUDA kernels, which in this case CausaulSelfAttention
314
+ # is, then the overhead of PyTorch can be hidden.
315
+ #
316
+ # In reality, your module does not normally consist of a singular
317
+ # CausalSelfAttention block. When experimenting with Andrej Karpathy’s
318
+ # `NanoGPT <https://github.com/karpathy/nanoGPT>`__ repository, compiling
319
+ # the module took the time per train step from: ``6090.49ms`` to
320
+ # ``3273.17ms``! This was done on commit: ae3a8d5 of NanoGPT training on
321
+ # the shakespeare dataset.
322
+ #
323
+
324
+
325
+ ######################################################################
326
+ # Conclusion
327
+ # ==========
328
+ #
329
+ # In this tutorial, we have demonstrated the basic usage of
330
+ # ``torch.nn.functional.scaled_dot_product_attention``. We have shown how
331
+ # the ``sdp_kernel`` context manager can be used to assert a certain
332
+ # implementation is used on GPU. As well, we built a simple
333
+ # CausalSelfAttention module that works with NestedTensor and is torch
334
+ # compilable. In the process we have shown how to the profiling tools can
335
+ # be used to explore the performance characteristics of a user defined
336
+ # module.
337
+ #
0 commit comments