Skip to content

Commit 6c67d6d

Browse files
committed
updates
1 parent ee07c3c commit 6c67d6d

File tree

2 files changed

+19
-14
lines changed

2 files changed

+19
-14
lines changed

index.rst

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -528,7 +528,7 @@ What's new in PyTorch tutorials?
528528
:header: (beta) Implement High-Performance Transformers with SCALED DOT PRODUCT ATTENTION
529529
:card_description: This tutorial explores the new torch.nn.functional.scaled_dot_product_attention and how it can be used to construct Transformer components.
530530
:image: _static/img/thumbnails/cropped/pytorch-logo.png
531-
:link: beginner/scaled_dot_product_attention_tutorial.html
531+
:link: intermediate/scaled_dot_product_attention_tutorial.html
532532
:tags: Model-Optimization,Attention,Transformer
533533

534534
.. Parallel-and-Distributed-Training
@@ -916,7 +916,7 @@ Additional Resources
916916
intermediate/nvfuser_intro_tutorial
917917
intermediate/ax_multiobjective_nas_tutorial
918918
intermediate/torch_compile_tutorial
919-
beginner/scaled_dot_product_attention_tutorial
919+
intermediate/scaled_dot_product_attention_tutorial
920920

921921
.. toctree::
922922
:maxdepth: 2

beginner_source/scaled_dot_product_attention_tutorial.py renamed to intermediate_source/scaled_dot_product_attention_tutorial.py

Lines changed: 17 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
# In this tutorial, we want to highlight a new ``torch.nn.functional`` function
1313
# that can be helpful for implementing transformer architectures. The
1414
# function is named ``torch.nn.functional.scaled_dot_product_attention``.
15-
# For detailed description of the function, see the `PyTorch# documentation <https://pytorch.org/docs/master/generated/torch.nn.functional.scaled_dot_product_attention.html#torch.nn.functional.scaled_dot_product_attention>`__.
15+
# For detailed description of the function, see the `PyTorch documentation <https://pytorch.org/docs/master/generated/torch.nn.functional.scaled_dot_product_attention.html#torch.nn.functional.scaled_dot_product_attention>`__.
1616
# This function has already been incorporated into ``torch.nn.MultiheadAttention`` and ``torch.nn.TransformerEncoderLayer``.
1717
#
1818
# Overview
@@ -22,10 +22,7 @@
2222
# the definition found in the paper `Attention is all you
2323
# need <https://arxiv.org/abs/1706.03762>`__. While this function can be
2424
# written in PyTorch using existing functions, for GPU tensors this
25-
# function will implicitly dispatch to an optimized implementation. The
26-
# function is also highly modular and can be used to implement other
27-
# attention mechanisms such as
28-
# `Linformer <https://arxiv.org/abs/2006.04768>`__
25+
# function will implicitly dispatch to an optimized implementation.
2926
#
3027
# Fused implementations
3128
# ~~~~~~~~~~~~~~~~~~~~~~
@@ -234,7 +231,7 @@ def generate_rand_batch(
234231
# Currently the fastpaths don't support NestedTensor for training
235232
random_nt, _ = generate_rand_batch(32, 512, embed_dimension, pad_percentage=0.5, dtype=dtype, device=device)
236233
random_dense, _ = generate_rand_batch(32, 512, embed_dimension, pad_percentage=None, dtype=dtype, device=device)
237-
model.requires_grad_(False)
234+
model.eval()
238235
print(f"Random NT runs in {benchmark_torch_function_in_microseconds(model, random_nt):.3f} microseconds")
239236
print(f"Random Dense runs in {benchmark_torch_function_in_microseconds(model, random_dense):.3f} microseconds")
240237

@@ -256,14 +253,17 @@ def generate_rand_batch(
256253

257254

258255
compiled_model = torch.compile(model)
259-
# Lets warm it up once
256+
# Let's compile it
260257
compiled_model(x)
261258
print(
262259
f"The compiled module runs in {benchmark_torch_function_in_microseconds(compiled_model, x):.3f} microseconds")
263260

264261

265262
######################################################################
266263
#
264+
# The exact execution time is dependent on machine, however the results for mine:
265+
# The non compiled module runs in 166.616 microseconds
266+
# The compiled module runs in 166.726 microseconds
267267
# That is not what we were expecting. Let's dig a little deeper.
268268
# PyTorch comes with an amazing built-in profiler that you can use to
269269
# inspect the performance characteristics of your code.
@@ -278,14 +278,14 @@ def generate_rand_batch(
278278
with record_function(" Non-Compilied Causal Attention"):
279279
for _ in range(25):
280280
model(x)
281-
print(prof.key_averages().table(sort_by="cpu_time_total", row_limit=20))
281+
print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10))
282282

283283

284284
with profile(activities=activities, record_shapes=False) as prof:
285285
with record_function("Compiled Causal Attention"):
286286
for _ in range(25):
287287
compiled_model(x)
288-
print(prof.key_averages().table(sort_by="cpu_time_total", row_limit=20))
288+
print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10))
289289

290290
# For even more insights, you can export the trace and use ``chrome://tracing`` to view the results
291291
# prof.export_chrome_trace("compiled_causal_attention_trace.json").
@@ -294,16 +294,21 @@ def generate_rand_batch(
294294

295295

296296
######################################################################
297+
# The previous code snippet generates a report of the top 10 PyTorch functions
298+
# that consumed the most GPU execution time, for both the compiled and non-compiled module.
299+
# The analysis reveals that the majority of time spent on the GPU is concentrated
300+
# on the same set of functions for both modules.
297301
# The problem here is that ``torch.compile`` is very good at removing the
298302
# framework overhead associated with PyTorch. If your model is launching
299303
# large, efficient CUDA kernels, which in this case CausaulSelfAttention
300-
# is, then the overhead of ``torch.compile`` can hurt performance.
304+
# is, then the overhead of PyTorch can be hidden.
301305
#
302306
# In reality, your module does not normally consist of a singular
303307
# CausalSelfAttention block. When experimenting with Andrej Karpathy’s
304308
# `NanoGPT <https://github.com/karpathy/nanoGPT>`__ repository, compiling
305-
# the module took the time per train step from: ``902.01ms`` to
306-
# ``552.06ms``!
309+
# the module took the time per train step from: ``6090.49ms`` to
310+
# ``3273.17ms``! This was done on commit: ae3a8d5 of NanoGPT training on
311+
# the shakespeare dataset.
307312
#
308313

309314

0 commit comments

Comments
 (0)