|
1 | 1 | """
|
2 | 2 | Create High-Performance Transformer Variations with Scaled Dot Product Attention
|
3 |
| -=============================================================== |
| 3 | +================================================================================ |
4 | 4 |
|
5 | 5 | """
|
6 | 6 |
|
|
9 | 9 | # Summary
|
10 | 10 | # ~~~~~~~~
|
11 | 11 | #
|
12 |
| -# In this tutorial we want to highlight a new ``torch.nn.functional`` function |
| 12 | +# In this tutorial, we want to highlight a new ``torch.nn.functional`` function |
13 | 13 | # that can be helpful for implementing transformer architectures. The
|
14 | 14 | # function is named ``torch.nn.functional.scaled_dot_product_attention``.
|
15 |
| -# There is some extensive documentation on the function in the `PyTorch |
16 |
| -# documentation <https://pytorch.org/docs/master/generated/torch.nn.functional.scaled_dot_product_attention.html#torch.nn.functional.scaled_dot_product_attention>`__. |
| 15 | +# For detailed description of the function, see the `PyTorch# documentation <https://pytorch.org/docs/master/generated/torch.nn.functional.scaled_dot_product_attention.html#torch.nn.functional.scaled_dot_product_attention>`__. |
17 | 16 | # This function has already been incorporated into ``torch.nn.MultiheadAttention`` and ``torch.nn.TransformerEncoderLayer``.
|
18 | 17 | #
|
19 | 18 | # Overview
|
20 |
| -# ~~~~~~~ |
21 |
| -# At a high level this PyTorch function calculates the |
22 |
| -# scaled dot product attention between query, key, and value according to |
| 19 | +# ~~~~~~~~~ |
| 20 | +# At a high level, this PyTorch function calculates the |
| 21 | +# scaled dot product attention (SDPA) between query, key, and value according to |
23 | 22 | # the definition found in the paper `Attention is all you
|
24 | 23 | # need <https://arxiv.org/abs/1706.03762>`__. While this function can be
|
25 | 24 | # written in PyTorch using existing functions, for GPU tensors this
|
|
28 | 27 | # attention mechanisms such as
|
29 | 28 | # `Linformer <https://arxiv.org/abs/2006.04768>`__
|
30 | 29 | #
|
31 |
| -# Fused implementations: |
| 30 | +# Fused implementations |
32 | 31 | # ~~~~~~~~~~~~~~~~~~~~~~
|
33 | 32 | #
|
34 |
| -# For CUDA tensor inputs the function will dispatch into one of three |
35 |
| -# implementations |
| 33 | +# For CUDA tensor inputs, the function will dispatch into one of the following |
| 34 | +# implementations: |
36 | 35 | #
|
37 | 36 | # * `FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness <https://arxiv.org/abs/2205.14135>`__
|
38 | 37 | # * `Memory-Efficient Attention <https://github.com/facebookresearch/xformers>`__
|
@@ -117,10 +116,10 @@ def benchmark_torch_function_in_microseconds(f, *args, **kwargs):
|
117 | 116 | # ~~~~~~~~~~~~~~~~~~~
|
118 | 117 | #
|
119 | 118 | # Depending on what machine you ran the above cell on and what hardware is
|
120 |
| -# available your results might be different. |
121 |
| -# - If you don’t have a GPU and are running on CPU then the context manager will have no effect and all |
| 119 | +# available, your results might be different. |
| 120 | +# - If you don’t have a GPU and are running on CPU, then the context manager will have no effect and all |
122 | 121 | # are running on CPU then the context manager will have no effect and all
|
123 |
| -# three run should return similar timings. |
| 122 | +# three runs should return similar timings. |
124 | 123 | #
|
125 | 124 |
|
126 | 125 |
|
@@ -189,7 +188,7 @@ def forward(self, x):
|
189 | 188 | # NestedTensor and Dense tensor support
|
190 | 189 | # -------------------------------------
|
191 | 190 | #
|
192 |
| -# Scaled Dot Product Attention supports both NestedTensor and Dense tensor inputs. NestedTensors handle the case where the input is a batch of variable length sequences |
| 191 | +# SDPA supports both NestedTensor and Dense tensor inputs. NestedTensors handle the case where the input is a batch of variable length sequences |
193 | 192 | # without needing to pad each sequence to the maximum length in the batch. For more information about NestedTensor's see
|
194 | 193 | # `torch.nested <https://pytorch.org/docs/stable/nested.html>`__ and `NestedTensors Tutorial <https://pytorch.org/tutorials/prototype/nestedtensor.html>`__.
|
195 | 194 | #
|
@@ -244,8 +243,8 @@ def generate_rand_batch(
|
244 | 243 | # Using SDPA with torch.compile
|
245 | 244 | # ============================
|
246 | 245 | #
|
247 |
| -# Scaled dot product attention is composable with torch.compile(). Lets |
248 |
| -# try compiling the above CausalSelfAttention module |
| 246 | +# Scaled dot product attention is composable with ``torch.compile()``. Let's |
| 247 | +# try compiling the above CausalSelfAttention module: |
249 | 248 | #
|
250 | 249 |
|
251 | 250 | batch_size = 32
|
@@ -289,7 +288,7 @@ def generate_rand_batch(
|
289 | 288 | print(prof.key_averages().table(sort_by="cpu_time_total", row_limit=20))
|
290 | 289 |
|
291 | 290 | # For even more insights, you can export the trace and use ``chrome://tracing`` to view the results
|
292 |
| -# prof.export_chrome_trace("compiled_causal_attention_trace.json") |
| 291 | +# prof.export_chrome_trace("compiled_causal_attention_trace.json"). |
293 | 292 |
|
294 | 293 |
|
295 | 294 |
|
|
0 commit comments