Skip to content

Commit 3cdd0ec

Browse files
committed
more
1 parent 2b7cbb7 commit 3cdd0ec

File tree

1 file changed

+7
-2
lines changed

1 file changed

+7
-2
lines changed

beginner_source/scaled_dot_product_attention_tutorial.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
# function is named ``torch.nn.functional.scaled_dot_product_attention``.
1515
# There is some extensive documentation on the function in the `PyTorch
1616
# documentation <https://pytorch.org/docs/master/generated/torch.nn.functional.scaled_dot_product_attention.html#torch.nn.functional.scaled_dot_product_attention>`__.
17-
# This function has already been incorporated into torch.nn.MultiheadAttention# (Multi-Head Attention) and ``torch.nn.TransformerEncoderLayer``.
17+
# This function has already been incorporated into ``torch.nn.MultiheadAttention`` and ``torch.nn.TransformerEncoderLayer``.
1818
#
1919
# Overview
2020
# ~~~~~~~
@@ -32,7 +32,8 @@
3232
# ~~~~~~~~~~~~~~~~~~~~~~
3333
#
3434
# For CUDA tensor inputs the function will dispatch into one of three
35-
# implementations:
35+
# implementations
36+
#
3637
# * `FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness <https://arxiv.org/abs/2205.14135>`__
3738
# * `Memory-Efficient Attention <https://github.com/facebookresearch/xformers>`__
3839
# * A PyTorch implementation defined in C++
@@ -188,6 +189,10 @@ def forward(self, x):
188189
# NestedTensor and Dense tensor support
189190
# -------------------------------------
190191
#
192+
# Scaled Dot Product Attention supports both NestedTensor and Dense tensor inputs. NestedTensors handle the case where the input is a batch of variable length sequences
193+
# without needing to pad each sequence to the maximum length in the batch. For more information about NestedTensor's see
194+
# `torch.nested <https://pytorch.org/docs/stable/nested.html>`__ and `NestedTensors Tutorial <https://pytorch.org/tutorials/prototype/nestedtensor.html>`__.
195+
#
191196

192197
import random
193198
def generate_rand_batch(

0 commit comments

Comments
 (0)