|
1 | 1 | """
|
2 |
| -Implement High-Performance Transformers with SCALED DOT PRODUCT ATTENTION |
| 2 | +Implementing High-Performance Transformers with SCALED DOT PRODUCT ATTENTION |
3 | 3 | ================================================================================
|
4 | 4 |
|
5 | 5 | """
|
|
20 | 20 | # At a high level, this PyTorch function calculates the
|
21 | 21 | # scaled dot product attention (SDPA) between query, key, and value according to
|
22 | 22 | # the definition found in the paper `Attention is all you
|
23 |
| -# need <https://arxiv.org/abs/1706.03762>`__. While this function can be |
24 |
| -# written in PyTorch using existing functions, for GPU tensors this |
25 |
| -# function will implicitly dispatch to an optimized implementation. |
| 23 | +# need <https://arxiv.org/abs/1706.03762>`__. While this function can |
| 24 | +# be written in PyTorch using existing functions, a fused implementation can provide |
| 25 | +# large performance benefits over a naive implementation. |
26 | 26 | #
|
27 | 27 | # Fused implementations
|
28 | 28 | # ~~~~~~~~~~~~~~~~~~~~~~
|
@@ -114,10 +114,10 @@ def benchmark_torch_function_in_microseconds(f, *args, **kwargs):
|
114 | 114 | #
|
115 | 115 | # Depending on what machine you ran the above cell on and what hardware is
|
116 | 116 | # available, your results might be different.
|
117 |
| -# - If you don’t have a GPU and are running on CPU, then the context manager will have no effect and all |
118 |
| -# are running on CPU then the context manager will have no effect and all |
119 |
| -# three runs should return similar timings. |
120 |
| -# |
| 117 | +# - If you don’t have a GPU and are running on CPU then the context manager |
| 118 | +# will have no effect and all three runs should return similar timings. |
| 119 | +# - Depending on what compute capability your graphics card supports |
| 120 | +# flash attention or memory efficient might have failed. |
121 | 121 |
|
122 | 122 |
|
123 | 123 | ######################################################################
|
@@ -186,7 +186,7 @@ def forward(self, x):
|
186 | 186 | # -------------------------------------
|
187 | 187 | #
|
188 | 188 | # SDPA supports both NestedTensor and Dense tensor inputs. NestedTensors handle the case where the input is a batch of variable length sequences
|
189 |
| -# without needing to pad each sequence to the maximum length in the batch. For more information about NestedTensor's see |
| 189 | +# without needing to pad each sequence to the maximum length in the batch. For more information about NestedTensors see |
190 | 190 | # `torch.nested <https://pytorch.org/docs/stable/nested.html>`__ and `NestedTensors Tutorial <https://pytorch.org/tutorials/prototype/nestedtensor.html>`__.
|
191 | 191 | #
|
192 | 192 |
|
@@ -246,8 +246,12 @@ def generate_rand_batch(
|
246 | 246 | # Using SDPA with torch.compile
|
247 | 247 | # ============================
|
248 | 248 | #
|
249 |
| -# Scaled dot product attention is composable with ``torch.compile()``. Let's |
250 |
| -# try compiling the above CausalSelfAttention module: |
| 249 | +# With the release of PyTorch 2.0, a new feature called |
| 250 | +# ``torch.compile()`` has been introduced, which can provide |
| 251 | +# significant performance improvements over eager mode. |
| 252 | +# Scaled dot product attention is fully composable with ``torch.compile()``. |
| 253 | +# To demonstrate this, let's compile the CausalSelfAttention module using |
| 254 | +# ``torch.compile()`` and observe the resulting performance improvements. |
251 | 255 | #
|
252 | 256 |
|
253 | 257 | batch_size = 32
|
@@ -304,7 +308,7 @@ def generate_rand_batch(
|
304 | 308 | # that consumed the most GPU execution time, for both the compiled and non-compiled module.
|
305 | 309 | # The analysis reveals that the majority of time spent on the GPU is concentrated
|
306 | 310 | # on the same set of functions for both modules.
|
307 |
| -# The problem here is that ``torch.compile`` is very good at removing the |
| 311 | +# The reason for this here is that ``torch.compile`` is very good at removing the |
308 | 312 | # framework overhead associated with PyTorch. If your model is launching
|
309 | 313 | # large, efficient CUDA kernels, which in this case CausaulSelfAttention
|
310 | 314 | # is, then the overhead of PyTorch can be hidden.
|
|
0 commit comments