Skip to content

Commit df5ec8e

Browse files
committed
more things
1 parent 3cdd0ec commit df5ec8e

File tree

1 file changed

+16
-17
lines changed

1 file changed

+16
-17
lines changed

beginner_source/scaled_dot_product_attention_tutorial.py

Lines changed: 16 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
"""
22
Create High-Performance Transformer Variations with Scaled Dot Product Attention
3-
===============================================================
3+
================================================================================
44
55
"""
66

@@ -9,17 +9,16 @@
99
# Summary
1010
# ~~~~~~~~
1111
#
12-
# In this tutorial we want to highlight a new ``torch.nn.functional`` function
12+
# In this tutorial, we want to highlight a new ``torch.nn.functional`` function
1313
# that can be helpful for implementing transformer architectures. The
1414
# function is named ``torch.nn.functional.scaled_dot_product_attention``.
15-
# There is some extensive documentation on the function in the `PyTorch
16-
# documentation <https://pytorch.org/docs/master/generated/torch.nn.functional.scaled_dot_product_attention.html#torch.nn.functional.scaled_dot_product_attention>`__.
15+
# For detailed description of the function, see the `PyTorch# documentation <https://pytorch.org/docs/master/generated/torch.nn.functional.scaled_dot_product_attention.html#torch.nn.functional.scaled_dot_product_attention>`__.
1716
# This function has already been incorporated into ``torch.nn.MultiheadAttention`` and ``torch.nn.TransformerEncoderLayer``.
1817
#
1918
# Overview
20-
# ~~~~~~~
21-
# At a high level this PyTorch function calculates the
22-
# scaled dot product attention between query, key, and value according to
19+
# ~~~~~~~~~
20+
# At a high level, this PyTorch function calculates the
21+
# scaled dot product attention (SDPA) between query, key, and value according to
2322
# the definition found in the paper `Attention is all you
2423
# need <https://arxiv.org/abs/1706.03762>`__. While this function can be
2524
# written in PyTorch using existing functions, for GPU tensors this
@@ -28,11 +27,11 @@
2827
# attention mechanisms such as
2928
# `Linformer <https://arxiv.org/abs/2006.04768>`__
3029
#
31-
# Fused implementations:
30+
# Fused implementations
3231
# ~~~~~~~~~~~~~~~~~~~~~~
3332
#
34-
# For CUDA tensor inputs the function will dispatch into one of three
35-
# implementations
33+
# For CUDA tensor inputs, the function will dispatch into one of the following
34+
# implementations:
3635
#
3736
# * `FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness <https://arxiv.org/abs/2205.14135>`__
3837
# * `Memory-Efficient Attention <https://github.com/facebookresearch/xformers>`__
@@ -117,10 +116,10 @@ def benchmark_torch_function_in_microseconds(f, *args, **kwargs):
117116
# ~~~~~~~~~~~~~~~~~~~
118117
#
119118
# Depending on what machine you ran the above cell on and what hardware is
120-
# available your results might be different.
121-
# - If you don’t have a GPU and are running on CPU then the context manager will have no effect and all
119+
# available, your results might be different.
120+
# - If you don’t have a GPU and are running on CPU, then the context manager will have no effect and all
122121
# are running on CPU then the context manager will have no effect and all
123-
# three run should return similar timings.
122+
# three runs should return similar timings.
124123
#
125124

126125

@@ -189,7 +188,7 @@ def forward(self, x):
189188
# NestedTensor and Dense tensor support
190189
# -------------------------------------
191190
#
192-
# Scaled Dot Product Attention supports both NestedTensor and Dense tensor inputs. NestedTensors handle the case where the input is a batch of variable length sequences
191+
# SDPA supports both NestedTensor and Dense tensor inputs. NestedTensors handle the case where the input is a batch of variable length sequences
193192
# without needing to pad each sequence to the maximum length in the batch. For more information about NestedTensor's see
194193
# `torch.nested <https://pytorch.org/docs/stable/nested.html>`__ and `NestedTensors Tutorial <https://pytorch.org/tutorials/prototype/nestedtensor.html>`__.
195194
#
@@ -244,8 +243,8 @@ def generate_rand_batch(
244243
# Using SDPA with torch.compile
245244
# ============================
246245
#
247-
# Scaled dot product attention is composable with torch.compile(). Lets
248-
# try compiling the above CausalSelfAttention module
246+
# Scaled dot product attention is composable with ``torch.compile()``. Let's
247+
# try compiling the above CausalSelfAttention module:
249248
#
250249

251250
batch_size = 32
@@ -289,7 +288,7 @@ def generate_rand_batch(
289288
print(prof.key_averages().table(sort_by="cpu_time_total", row_limit=20))
290289

291290
# For even more insights, you can export the trace and use ``chrome://tracing`` to view the results
292-
# prof.export_chrome_trace("compiled_causal_attention_trace.json")
291+
# prof.export_chrome_trace("compiled_causal_attention_trace.json").
293292

294293

295294

0 commit comments

Comments
 (0)