Skip to content

Commit 32a3bba

Browse files
committed
more tweaks
1 parent 8e581e2 commit 32a3bba

File tree

2 files changed

+17
-13
lines changed

2 files changed

+17
-13
lines changed

index.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -525,7 +525,7 @@ What's new in PyTorch tutorials?
525525
:tags: Model-Optimization
526526

527527
.. customcarditem::
528-
:header: (beta) Implement High-Performance Transformers with SCALED DOT PRODUCT ATTENTION
528+
:header: (beta) Implementing High-Performance Transformers with SCALED DOT PRODUCT ATTENTION
529529
:card_description: This tutorial explores the new torch.nn.functional.scaled_dot_product_attention and how it can be used to construct Transformer components.
530530
:image: _static/img/thumbnails/cropped/pytorch-logo.png
531531
:link: intermediate/scaled_dot_product_attention_tutorial.html

intermediate_source/scaled_dot_product_attention_tutorial.py

Lines changed: 16 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
"""
2-
Implement High-Performance Transformers with SCALED DOT PRODUCT ATTENTION
2+
Implementing High-Performance Transformers with SCALED DOT PRODUCT ATTENTION
33
================================================================================
44
55
"""
@@ -20,9 +20,9 @@
2020
# At a high level, this PyTorch function calculates the
2121
# scaled dot product attention (SDPA) between query, key, and value according to
2222
# the definition found in the paper `Attention is all you
23-
# need <https://arxiv.org/abs/1706.03762>`__. While this function can be
24-
# written in PyTorch using existing functions, for GPU tensors this
25-
# function will implicitly dispatch to an optimized implementation.
23+
# need <https://arxiv.org/abs/1706.03762>`__. While this function can
24+
# be written in PyTorch using existing functions, a fused implementation can provide
25+
# large performance benefits over a naive implementation.
2626
#
2727
# Fused implementations
2828
# ~~~~~~~~~~~~~~~~~~~~~~
@@ -114,10 +114,10 @@ def benchmark_torch_function_in_microseconds(f, *args, **kwargs):
114114
#
115115
# Depending on what machine you ran the above cell on and what hardware is
116116
# available, your results might be different.
117-
# - If you don’t have a GPU and are running on CPU, then the context manager will have no effect and all
118-
# are running on CPU then the context manager will have no effect and all
119-
# three runs should return similar timings.
120-
#
117+
# - If you don’t have a GPU and are running on CPU then the context manager
118+
# will have no effect and all three runs should return similar timings.
119+
# - Depending on what compute capability your graphics card supports
120+
# flash attention or memory efficient might have failed.
121121

122122

123123
######################################################################
@@ -186,7 +186,7 @@ def forward(self, x):
186186
# -------------------------------------
187187
#
188188
# SDPA supports both NestedTensor and Dense tensor inputs. NestedTensors handle the case where the input is a batch of variable length sequences
189-
# without needing to pad each sequence to the maximum length in the batch. For more information about NestedTensor's see
189+
# without needing to pad each sequence to the maximum length in the batch. For more information about NestedTensors see
190190
# `torch.nested <https://pytorch.org/docs/stable/nested.html>`__ and `NestedTensors Tutorial <https://pytorch.org/tutorials/prototype/nestedtensor.html>`__.
191191
#
192192

@@ -246,8 +246,12 @@ def generate_rand_batch(
246246
# Using SDPA with torch.compile
247247
# ============================
248248
#
249-
# Scaled dot product attention is composable with ``torch.compile()``. Let's
250-
# try compiling the above CausalSelfAttention module:
249+
# With the release of PyTorch 2.0, a new feature called
250+
# ``torch.compile()`` has been introduced, which can provide
251+
# significant performance improvements over eager mode.
252+
# Scaled dot product attention is fully composable with ``torch.compile()``.
253+
# To demonstrate this, let's compile the CausalSelfAttention module using
254+
# ``torch.compile()`` and observe the resulting performance improvements.
251255
#
252256

253257
batch_size = 32
@@ -304,7 +308,7 @@ def generate_rand_batch(
304308
# that consumed the most GPU execution time, for both the compiled and non-compiled module.
305309
# The analysis reveals that the majority of time spent on the GPU is concentrated
306310
# on the same set of functions for both modules.
307-
# The problem here is that ``torch.compile`` is very good at removing the
311+
# The reason for this here is that ``torch.compile`` is very good at removing the
308312
# framework overhead associated with PyTorch. If your model is launching
309313
# large, efficient CUDA kernels, which in this case CausaulSelfAttention
310314
# is, then the overhead of PyTorch can be hidden.

0 commit comments

Comments
 (0)