1
+ """
2
+ An overview of torch.nn.functional.scaled_dot_product_attention
3
+ ===============================================================
4
+
5
+ """
6
+
7
+
8
+ ######################################################################
9
+ # Summary
10
+ # ~~~~~~~~
11
+ #
12
+ # In this tutorial we want to highlight a new ``torch.nn.functional`` function
13
+ # that can be helpful for implementing transformer architectures. The
14
+ # function is named ``torch.nn.functional.scaled_dot_product_attention``.
15
+ # There is some extensive documentation on the function in the `PyTorch
16
+ # documentation <https://pytorch.org/docs/master/generated/torch.nn.functional.scaled_dot_product_attention.html#torch.nn.functional.scaled_dot_product_attention>`__.
17
+ # This function has already been incorporated into torch.nn.MHA
18
+ # (Multi-Head Attention) and ``torch.nn.TransformerEncoderLayer``.
19
+ #
20
+ # Overview
21
+ # ~~~~~~~
22
+ # At a high level this PyTorch function calculates the
23
+ # scaled dot product attention between query, key, and value according to
24
+ # the definition found in the paper `Attention is all you
25
+ # need <https://arxiv.org/abs/1706.03762>`__. While this function can be
26
+ # written in PyTorch using existing functions, for GPU tensors this
27
+ # function will implicitly dispatch to an optimized implementation. The
28
+ # function is also highly modular and can be used to implement other
29
+ # attention mechanisms such as
30
+ # `Linformer <https://arxiv.org/abs/2006.04768>`__
31
+ #
32
+ # Fused implementations:
33
+ # ~~~~~~~~~~~~~~~~~~~~~~
34
+ #
35
+ # For CUDA tensor inputs the function will dispatch into one of three
36
+ # implementations:
37
+ # * `FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness <https://arxiv.org/abs/2205.14135>`__
38
+ # Attention with IO-Awareness <https://arxiv.org/abs/2205.14135>`__ \*
39
+ # `Memory-Efficient
40
+ # Attention <https://github.com/facebookresearch/xformers>`__ \* A PyTorch
41
+ # implementation defined in C++
42
+ #
43
+
44
+ import torch
45
+ import torch .nn as nn
46
+ import torch .nn .functional as F
47
+ device = "cuda" if torch .cuda .is_available () else "cpu"
48
+
49
+ # Example Usage:
50
+ query , key , value = torch .randn (2 , 3 , 8 , device = device ), torch .randn (2 , 3 , 8 , device = device ), torch .randn (2 , 3 , 8 , device = device )
51
+ F .scaled_dot_product_attention (query , key , value )
52
+
53
+
54
+ ######################################################################
55
+ # Explicit Dispatcher Control
56
+ # ~~~~~~~~~~~~~~~~~~~~~~~~~~~
57
+ #
58
+ # While the function will implicitly dispatch to one of the three
59
+ # implementations, the user can also explicitly control the dispatch via
60
+ # the use of a context manager. This context manager allows users to
61
+ # explicitly disable certain implementations. If a user wants to ensure
62
+ # the function is indeed using the fasted implementation for their
63
+ # specific inputs the context manager can be used to sweep through
64
+ # measuring performance.
65
+ #
66
+
67
+ # Lets define a helpful benchmarking function:
68
+ import torch .utils .benchmark as benchmark
69
+ def benchmark_torch_function_in_microseconds (f , * args , ** kwargs ):
70
+ t0 = benchmark .Timer (
71
+ stmt = "f(*args, **kwargs)" , globals = {"args" : args , "kwargs" : kwargs , "f" : f }
72
+ )
73
+ return t0 .blocked_autorange ().mean * 1e6
74
+
75
+ # Lets define the hyper-parameters of our input
76
+ batch_size = 32
77
+ max_sequence_len = 1024
78
+ num_heads = 32
79
+ embed_dimension = 32
80
+
81
+ dtype = torch .float16
82
+
83
+ query = torch .rand (batch_size , num_heads , max_sequence_len , embed_dimension , device = device , dtype = dtype )
84
+ key = torch .rand (batch_size , num_heads , max_sequence_len , embed_dimension , device = device , dtype = dtype )
85
+ value = torch .rand (batch_size , num_heads , max_sequence_len , embed_dimension , device = device , dtype = dtype )
86
+
87
+ print (f"The default implementation runs in { benchmark_torch_function_in_microseconds (F .scaled_dot_product_attention , query , key , value ):.3f} microseconds" )
88
+
89
+ # Lets explore the speed of each of the 3 implementations
90
+ from torch .backends .cuda import sdp_kernel , SDPBackend
91
+
92
+ # Helpful arg mapper
93
+ backend_map = {
94
+ SDPBackend .MATH : {"enable_math" : True , "enable_flash" : False , "enable_mem_efficient" : False },
95
+ SDPBackend .FLASH_ATTENTION : {"enable_math" : False , "enable_flash" : True , "enable_mem_efficient" : False },
96
+ SDPBackend .EFFICIENT_ATTENTION : {
97
+ "enable_math" : False , "enable_flash" : False , "enable_mem_efficient" : True }
98
+ }
99
+
100
+ with sdp_kernel (** backend_map [SDPBackend .MATH ]):
101
+ print (f"The math implementation runs in { benchmark_torch_function_in_microseconds (F .scaled_dot_product_attention , query , key , value ):.3f} microseconds" )
102
+
103
+
104
+ with sdp_kernel (** backend_map [SDPBackend .FLASH_ATTENTION ]):
105
+ print (f"The flash attention implementation runs in { benchmark_torch_function_in_microseconds (F .scaled_dot_product_attention , query , key , value ):.3f} microseconds" )
106
+
107
+
108
+ with sdp_kernel (** backend_map [SDPBackend .EFFICIENT_ATTENTION ]):
109
+ print (f"The memory efficient implementation runs in { benchmark_torch_function_in_microseconds (F .scaled_dot_product_attention , query , key , value ):.3f} microseconds" )
110
+
111
+
112
+ ######################################################################
113
+ # Hardware dependence
114
+ # ~~~~~~~~~~~~~~~~~~~
115
+ #
116
+ # Depending on what machine you ran the above cell on and what hardware is
117
+ # available your results might be different.
118
+ # - If you don’t have a GPU and are running on CPU then the context manager will have no effect and all
119
+ # are running on CPU then the context manager will have no effect and all
120
+ # three run should return similar timings. - Depending on what Compute
121
+ # Capability your graphics card supports FlashAttention or memory
122
+ # efficient might have failed.
123
+ #
124
+
125
+
126
+ ######################################################################
127
+ # Causal Self Attention
128
+ # ~~~~~~~~~~~~~~~~~~~~~
129
+ #
130
+ # Below is an example implementation of a multi-headed causal self
131
+ # attention block inspired by Andrej Karpathy’s
132
+ # `NanoGPT <https://github.com/karpathy/nanoGPT>`__ repository.
133
+ #
134
+
135
+ class CausalSelfAttention (nn .Module ):
136
+
137
+ def __init__ (self , num_heads : int , embed_dimension : int , bias : bool = False , is_causal : bool = False , dropout :float = 0.0 ):
138
+ super ().__init__ ()
139
+ assert embed_dimension % num_heads == 0
140
+ # key, query, value projections for all heads, but in a batch
141
+ self .c_attn = nn .Linear (embed_dimension , 3 * embed_dimension , bias = bias )
142
+ # output projection
143
+ self .c_proj = nn .Linear (embed_dimension , embed_dimension , bias = bias )
144
+ # regularization
145
+ self .dropout = dropout
146
+ self .resid_dropout = nn .Dropout (dropout )
147
+ self .num_heads = num_heads
148
+ self .embed_dimension = embed_dimension
149
+ # Perform causal masking
150
+ self .is_causal = is_causal
151
+
152
+ def forward (self , x ):
153
+ # calculate query, key, values for all heads in batch and move head forward to be the batch dim
154
+ query_projected = self .c_attn (x )
155
+
156
+ batch_size = query_projected .size (0 )
157
+ embed_dim = query_projected .size (2 )
158
+ head_dim = embed_dim // (self .num_heads * 3 )
159
+
160
+ query , key , value = query_projected .chunk (3 , - 1 )
161
+ query = query .view (batch_size , - 1 , self .num_heads , head_dim ).transpose (1 , 2 )
162
+ key = key .view (batch_size , - 1 , self .num_heads , head_dim ).transpose (1 , 2 )
163
+ value = value .view (batch_size , - 1 , self .num_heads , head_dim ).transpose (1 , 2 )
164
+
165
+ if self .training :
166
+ dropout = self .dropout
167
+ is_causal = self .is_causal
168
+ else :
169
+ dropout = 0.0
170
+ is_causal = False
171
+
172
+ y = F .scaled_dot_product_attention (query , key , value , attn_mask = None , dropout_p = dropout , is_causal = is_causal )
173
+ y = y .transpose (1 , 2 ).view (batch_size , - 1 , self .num_heads * head_dim )
174
+
175
+ y = self .resid_dropout (self .c_proj (y ))
176
+ return y
177
+
178
+
179
+ num_heads = 8
180
+ heads_per_dim = 64
181
+ embed_dimension = num_heads * heads_per_dim
182
+ dtype = torch .float16
183
+ model = CausalSelfAttention (num_heads = num_heads , embed_dimension = embed_dimension , bias = False , is_causal = True , dropout = 0.1 ).to ("cuda" ).to (dtype ).eval ()
184
+ print (model )
185
+
186
+
187
+ ######################################################################
188
+ # NestedTensor and Dense tensor support
189
+ # -------------------------------------
190
+ #
191
+
192
+ import random
193
+ def generate_rand_batch (
194
+ batch_size ,
195
+ max_sequence_len ,
196
+ embed_dimension ,
197
+ pad_percentage = None ,
198
+ dtype = torch .float16 ,
199
+ device = "cuda" ,
200
+ ):
201
+ if not pad_percentage :
202
+ return (
203
+ torch .randn (
204
+ batch_size ,
205
+ max_sequence_len ,
206
+ embed_dimension ,
207
+ dtype = dtype ,
208
+ device = device ,
209
+ ),
210
+ None ,
211
+ )
212
+ # Random sequence lengths
213
+ seq_len_list = [
214
+ int (max_sequence_len * (1 - random .gauss (pad_percentage , 0.01 )))
215
+ for _ in range (batch_size )
216
+ ]
217
+ # Make random entry in the batch have max sequence length
218
+ seq_len_list [random .randint (0 , batch_size - 1 )] = max_sequence_len
219
+ return (
220
+ torch .nested .nested_tensor (
221
+ [
222
+ torch .randn (seq_len , embed_dimension ,
223
+ dtype = dtype , device = device )
224
+ for seq_len in seq_len_list
225
+ ]
226
+ ),
227
+ seq_len_list ,
228
+ )
229
+
230
+ # Currently the fastpaths don't support NestedTensor for training
231
+ random_nt , _ = generate_rand_batch (32 , 512 , embed_dimension , pad_percentage = 0.5 , dtype = dtype , device = device )
232
+ random_dense , _ = generate_rand_batch (32 , 512 , embed_dimension , pad_percentage = None , dtype = dtype , device = device )
233
+ model .requires_grad_ (False )
234
+ print (f"Random NT runs in { benchmark_torch_function_in_microseconds (model , random_nt ):.3f} microseconds" )
235
+ print (f"Random Dense runs in { benchmark_torch_function_in_microseconds (model , random_dense ):.3f} microseconds" )
236
+
237
+
238
+ ######################################################################
239
+ # Composable with 2.0 Features
240
+ # ============================
241
+ #
242
+ # Scaled dot product attention is composable with torch.compile(). Lets
243
+ # try compiling the above CausalSelfAttention module
244
+ #
245
+
246
+ batch_size = 32
247
+ max_sequence_len = 256
248
+ x = torch .rand (batch_size , max_sequence_len ,
249
+ embed_dimension , device = device , dtype = dtype )
250
+ print (
251
+ f"The non compiled module runs in { benchmark_torch_function_in_microseconds (model , x ):.3f} microseconds" )
252
+
253
+
254
+ compiled_model = torch .compile (model )
255
+ # Lets warm it up once
256
+ compiled_model (x )
257
+ print (
258
+ f"The compiled module runs in { benchmark_torch_function_in_microseconds (compiled_model , x ):.3f} microseconds" )
259
+
260
+
261
+ ######################################################################
262
+ # HMM..
263
+ # ~~~~~
264
+ #
265
+ # That is not what we were expecting. Let's dig a little deeper.
266
+ # PyTorch comes with an amazing built-in profiler that you can use to
267
+ # inspect the performance characteristics of your code.
268
+ #
269
+
270
+ from torch .profiler import profile , record_function , ProfilerActivity
271
+ activities = [ProfilerActivity .CPU ]
272
+ if device == 'cuda' :
273
+ activities .append (ProfilerActivity .CUDA )
274
+
275
+ with profile (activities = activities , record_shapes = False ) as prof :
276
+ with record_function (" Non-Compilied Causal Attention" ):
277
+ for _ in range (25 ):
278
+ model (x )
279
+ print (prof .key_averages ().table (sort_by = "cpu_time_total" , row_limit = 20 ))
280
+
281
+
282
+ with profile (activities = activities , record_shapes = False ) as prof :
283
+ with record_function ("Compiled Causal Attention" ):
284
+ for _ in range (25 ):
285
+ compiled_model (x )
286
+ print (prof .key_averages ().table (sort_by = "cpu_time_total" , row_limit = 20 ))
287
+
288
+ # For even more insights, you can export the trace and use ``chrome://tracing`` to view the results
289
+ # prof.export_chrome_trace("compiled_causal_attention_trace.json")
290
+
291
+
292
+
293
+
294
+ ######################################################################
295
+ # The problem here is that ``torch.compile`` is very good at removing the
296
+ # framework overhead associated with PyTorch. If your model is launching
297
+ # large, efficient CUDA kernels, which in this case CausaulSelfAttention
298
+ # is, then the overhead of ``torch.compile`` can hurt performance.
299
+ #
300
+ # In reality, your module does not normally consist of a singular
301
+ # CausalSelfAttention block. When experimenting with Andrej Karpathy’s
302
+ # `NanoGPT <https://github.com/karpathy/nanoGPT>`__ repository, compiling
303
+ # the module took the time per train step from: ``902.01ms`` to
304
+ # ``552.06ms``!
305
+ #
306
+
307
+
308
+ ######################################################################
309
+ # Conclusion
310
+ # ==========
311
+ #
312
+ # In this tutorial, we have demonstrated the basic usage of
313
+ # ``torch.nn.functional.scaled_dot_product_attention``. We have shown how
314
+ # the ``sdp_kernel`` context manager can be used to assert a certain
315
+ # implementation is used on GPU. As well, we built a simple
316
+ # CausalSelfAttention module that works with NestedTensor and is torch
317
+ # compilable. In the process we have shown how to the profiling tools can
318
+ # be used to explore the performance characteristics of a user defined
319
+ # module.
320
+ #
0 commit comments