From eefd7f0d42d189b352796fac5f53314eddf1ba2f Mon Sep 17 00:00:00 2001 From: Driss Guessous Date: Thu, 23 Feb 2023 23:01:54 +0000 Subject: [PATCH 1/2] add tutorial --- .../scaled_dot_product_attention_tutorial.py | 320 ++++++++++++++++++ index.rst | 10 +- 2 files changed, 329 insertions(+), 1 deletion(-) create mode 100644 beginner_source/scaled_dot_product_attention_tutorial.py diff --git a/beginner_source/scaled_dot_product_attention_tutorial.py b/beginner_source/scaled_dot_product_attention_tutorial.py new file mode 100644 index 00000000000..a54b5ef9363 --- /dev/null +++ b/beginner_source/scaled_dot_product_attention_tutorial.py @@ -0,0 +1,320 @@ +""" +An overview of torch.nn.functional.scaled_dot_product_attention +=============================================================== + +""" + + +###################################################################### +# Summary +# ~~~~~~~~ +# +# In this tutorial we want to highlight a new ``torch.nn.functional`` function +# that can be helpful for implementing transformer architectures. The +# function is named ``torch.nn.functional.scaled_dot_product_attention``. +# There is some extensive documentation on the function in the `PyTorch +# documentation `__. +# This function has already been incorporated into torch.nn.MHA +# (Multi-Head Attention) and ``torch.nn.TransformerEncoderLayer``. +# +# Overview +# ~~~~~~~ +# At a high level this PyTorch function calculates the +# scaled dot product attention between query, key, and value according to +# the definition found in the paper `Attention is all you +# need `__. While this function can be +# written in PyTorch using existing functions, for GPU tensors this +# function will implicitly dispatch to an optimized implementation. The +# function is also highly modular and can be used to implement other +# attention mechanisms such as +# `Linformer `__ +# +# Fused implementations: +# ~~~~~~~~~~~~~~~~~~~~~~ +# +# For CUDA tensor inputs the function will dispatch into one of three +# implementations: +# * `FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness `__ +# Attention with IO-Awareness `__ \* +# `Memory-Efficient +# Attention `__ \* A PyTorch +# implementation defined in C++ +# + +import torch +import torch.nn as nn +import torch.nn.functional as F +device = "cuda" if torch.cuda.is_available() else "cpu" + +# Example Usage: +query, key, value = torch.randn(2, 3, 8, device=device), torch.randn(2, 3, 8, device=device), torch.randn(2, 3, 8, device=device) +F.scaled_dot_product_attention(query, key, value) + + +###################################################################### +# Explicit Dispatcher Control +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~ +# +# While the function will implicitly dispatch to one of the three +# implementations, the user can also explicitly control the dispatch via +# the use of a context manager. This context manager allows users to +# explicitly disable certain implementations. If a user wants to ensure +# the function is indeed using the fasted implementation for their +# specific inputs the context manager can be used to sweep through +# measuring performance. +# + +# Lets define a helpful benchmarking function: +import torch.utils.benchmark as benchmark +def benchmark_torch_function_in_microseconds(f, *args, **kwargs): + t0 = benchmark.Timer( + stmt="f(*args, **kwargs)", globals={"args": args, "kwargs": kwargs, "f": f} + ) + return t0.blocked_autorange().mean * 1e6 + +# Lets define the hyper-parameters of our input +batch_size = 32 +max_sequence_len = 1024 +num_heads = 32 +embed_dimension = 32 + +dtype = torch.float16 + +query = torch.rand(batch_size, num_heads, max_sequence_len, embed_dimension, device=device, dtype=dtype) +key = torch.rand(batch_size, num_heads, max_sequence_len, embed_dimension, device=device, dtype=dtype) +value = torch.rand(batch_size, num_heads, max_sequence_len, embed_dimension, device=device, dtype=dtype) + +print(f"The default implementation runs in {benchmark_torch_function_in_microseconds(F.scaled_dot_product_attention, query, key, value):.3f} microseconds") + +# Lets explore the speed of each of the 3 implementations +from torch.backends.cuda import sdp_kernel, SDPBackend + +# Helpful arg mapper +backend_map = { + SDPBackend.MATH: {"enable_math": True, "enable_flash": False, "enable_mem_efficient": False}, + SDPBackend.FLASH_ATTENTION: {"enable_math": False, "enable_flash": True, "enable_mem_efficient": False}, + SDPBackend.EFFICIENT_ATTENTION: { + "enable_math": False, "enable_flash": False, "enable_mem_efficient": True} +} + +with sdp_kernel(**backend_map[SDPBackend.MATH]): + print(f"The math implementation runs in {benchmark_torch_function_in_microseconds(F.scaled_dot_product_attention, query, key, value):.3f} microseconds") + + +with sdp_kernel(**backend_map[SDPBackend.FLASH_ATTENTION]): + print(f"The flash attention implementation runs in {benchmark_torch_function_in_microseconds(F.scaled_dot_product_attention, query, key, value):.3f} microseconds") + + +with sdp_kernel(**backend_map[SDPBackend.EFFICIENT_ATTENTION]): + print(f"The memory efficient implementation runs in {benchmark_torch_function_in_microseconds(F.scaled_dot_product_attention, query, key, value):.3f} microseconds") + + +###################################################################### +# Hardware dependence +# ~~~~~~~~~~~~~~~~~~~ +# +# Depending on what machine you ran the above cell on and what hardware is +# available your results might be different. +# - If you don’t have a GPU and are running on CPU then the context manager will have no effect and all +# are running on CPU then the context manager will have no effect and all +# three run should return similar timings. - Depending on what Compute +# Capability your graphics card supports FlashAttention or memory +# efficient might have failed. +# + + +###################################################################### +# Causal Self Attention +# ~~~~~~~~~~~~~~~~~~~~~ +# +# Below is an example implementation of a multi-headed causal self +# attention block inspired by Andrej Karpathy’s +# `NanoGPT `__ repository. +# + +class CausalSelfAttention(nn.Module): + + def __init__(self, num_heads: int, embed_dimension: int, bias: bool=False, is_causal: bool=False, dropout:float=0.0): + super().__init__() + assert embed_dimension % num_heads == 0 + # key, query, value projections for all heads, but in a batch + self.c_attn = nn.Linear(embed_dimension, 3 * embed_dimension, bias=bias) + # output projection + self.c_proj = nn.Linear(embed_dimension, embed_dimension, bias=bias) + # regularization + self.dropout = dropout + self.resid_dropout = nn.Dropout(dropout) + self.num_heads = num_heads + self.embed_dimension = embed_dimension + # Perform causal masking + self.is_causal = is_causal + + def forward(self, x): + # calculate query, key, values for all heads in batch and move head forward to be the batch dim + query_projected = self.c_attn(x) + + batch_size = query_projected.size(0) + embed_dim = query_projected.size(2) + head_dim = embed_dim // (self.num_heads * 3) + + query, key, value = query_projected.chunk(3, -1) + query = query.view(batch_size, -1, self.num_heads, head_dim).transpose(1, 2) + key = key.view(batch_size, -1, self.num_heads, head_dim).transpose(1, 2) + value = value.view(batch_size, -1, self.num_heads, head_dim).transpose(1, 2) + + if self.training: + dropout = self.dropout + is_causal = self.is_causal + else: + dropout = 0.0 + is_causal = False + + y = F.scaled_dot_product_attention(query, key, value, attn_mask=None, dropout_p=dropout, is_causal=is_causal) + y = y.transpose(1, 2).view(batch_size, -1, self.num_heads * head_dim) + + y = self.resid_dropout(self.c_proj(y)) + return y + + +num_heads = 8 +heads_per_dim = 64 +embed_dimension = num_heads * heads_per_dim +dtype = torch.float16 +model = CausalSelfAttention(num_heads=num_heads, embed_dimension=embed_dimension, bias=False, is_causal=True, dropout=0.1).to("cuda").to(dtype).eval() +print(model) + + +###################################################################### +# NestedTensor and Dense tensor support +# ------------------------------------- +# + +import random +def generate_rand_batch( + batch_size, + max_sequence_len, + embed_dimension, + pad_percentage=None, + dtype=torch.float16, + device="cuda", +): + if not pad_percentage: + return ( + torch.randn( + batch_size, + max_sequence_len, + embed_dimension, + dtype=dtype, + device=device, + ), + None, + ) + # Random sequence lengths + seq_len_list = [ + int(max_sequence_len * (1 - random.gauss(pad_percentage, 0.01))) + for _ in range(batch_size) + ] + # Make random entry in the batch have max sequence length + seq_len_list[random.randint(0, batch_size - 1)] = max_sequence_len + return ( + torch.nested.nested_tensor( + [ + torch.randn(seq_len, embed_dimension, + dtype=dtype, device=device) + for seq_len in seq_len_list + ] + ), + seq_len_list, + ) + +# Currently the fastpaths don't support NestedTensor for training +random_nt, _ = generate_rand_batch(32, 512, embed_dimension, pad_percentage=0.5, dtype=dtype, device=device) +random_dense, _ = generate_rand_batch(32, 512, embed_dimension, pad_percentage=None, dtype=dtype, device=device) +model.requires_grad_(False) +print(f"Random NT runs in {benchmark_torch_function_in_microseconds(model, random_nt):.3f} microseconds") +print(f"Random Dense runs in {benchmark_torch_function_in_microseconds(model, random_dense):.3f} microseconds") + + +###################################################################### +# Composable with 2.0 Features +# ============================ +# +# Scaled dot product attention is composable with torch.compile(). Lets +# try compiling the above CausalSelfAttention module +# + +batch_size = 32 +max_sequence_len = 256 +x = torch.rand(batch_size, max_sequence_len, + embed_dimension, device=device, dtype=dtype) +print( + f"The non compiled module runs in {benchmark_torch_function_in_microseconds(model, x):.3f} microseconds") + + +compiled_model = torch.compile(model) +# Lets warm it up once +compiled_model(x) +print( + f"The compiled module runs in {benchmark_torch_function_in_microseconds(compiled_model, x):.3f} microseconds") + + +###################################################################### +# HMM.. +# ~~~~~ +# +# That is not what we were expecting. Let's dig a little deeper. +# PyTorch comes with an amazing built-in profiler that you can use to +# inspect the performance characteristics of your code. +# + +from torch.profiler import profile, record_function, ProfilerActivity +activities = [ProfilerActivity.CPU] +if device == 'cuda': + activities.append(ProfilerActivity.CUDA) + +with profile(activities=activities, record_shapes=False) as prof: + with record_function(" Non-Compilied Causal Attention"): + for _ in range(25): + model(x) +print(prof.key_averages().table(sort_by="cpu_time_total", row_limit=20)) + + +with profile(activities=activities, record_shapes=False) as prof: + with record_function("Compiled Causal Attention"): + for _ in range(25): + compiled_model(x) +print(prof.key_averages().table(sort_by="cpu_time_total", row_limit=20)) + +# For even more insights, you can export the trace and use ``chrome://tracing`` to view the results +# prof.export_chrome_trace("compiled_causal_attention_trace.json") + + + + +###################################################################### +# The problem here is that ``torch.compile`` is very good at removing the +# framework overhead associated with PyTorch. If your model is launching +# large, efficient CUDA kernels, which in this case CausaulSelfAttention +# is, then the overhead of ``torch.compile`` can hurt performance. +# +# In reality, your module does not normally consist of a singular +# CausalSelfAttention block. When experimenting with Andrej Karpathy’s +# `NanoGPT `__ repository, compiling +# the module took the time per train step from: ``902.01ms`` to +# ``552.06ms``! +# + + +###################################################################### +# Conclusion +# ========== +# +# In this tutorial, we have demonstrated the basic usage of +# ``torch.nn.functional.scaled_dot_product_attention``. We have shown how +# the ``sdp_kernel`` context manager can be used to assert a certain +# implementation is used on GPU. As well, we built a simple +# CausalSelfAttention module that works with NestedTensor and is torch +# compilable. In the process we have shown how to the profiling tools can +# be used to explore the performance characteristics of a user defined +# module. +# \ No newline at end of file diff --git a/index.rst b/index.rst index 0e2f1eaeaa6..f8b03e6ce1c 100644 --- a/index.rst +++ b/index.rst @@ -502,7 +502,7 @@ What's new in PyTorch tutorials? :image: _static/img/thumbnails/cropped/generic-pytorch-logo.png :link: intermediate/torchserve_with_ipex_2 :tags: Model-Optimization,Production - + .. customcarditem:: :header: Introduction to nvFuser :card_description: An introduction to nvFuser @@ -524,6 +524,13 @@ What's new in PyTorch tutorials? :link: intermediate/torch_compile_tutorial.html :tags: Model-Optimization +.. customcarditem:: + :header: (beta) An overview of torch.nn.functional.scaled_dot_product_attention + :card_description: This tutorial explores the new torch.nn.functional.scaled_dot_product_attention and how it can be used to construct Transformer components. + :image: _static/img/thumbnails/cropped/pytorch-logo.png + :link: beginner/scaled_dot_product_attention_tutorial.html + :tags: Model-Optimization,Attention,Transformer + .. Parallel-and-Distributed-Training @@ -909,6 +916,7 @@ Additional Resources intermediate/nvfuser_intro_tutorial intermediate/ax_multiobjective_nas_tutorial intermediate/torch_compile_tutorial + beginner/scaled_dot_product_attention_tutorial .. toctree:: :maxdepth: 2 From 65cebeb382cad81c449d27764ee984a31be5b166 Mon Sep 17 00:00:00 2001 From: Driss Guessous Date: Mon, 27 Feb 2023 17:50:02 +0000 Subject: [PATCH 2/2] fixes --- index.rst | 6 +- .../scaled_dot_product_attention_tutorial.py | 161 ++++++++++-------- 2 files changed, 92 insertions(+), 75 deletions(-) rename {beginner_source => intermediate_source}/scaled_dot_product_attention_tutorial.py (68%) diff --git a/index.rst b/index.rst index f8b03e6ce1c..a6e29db0b75 100644 --- a/index.rst +++ b/index.rst @@ -525,10 +525,10 @@ What's new in PyTorch tutorials? :tags: Model-Optimization .. customcarditem:: - :header: (beta) An overview of torch.nn.functional.scaled_dot_product_attention + :header: (beta) Implementing High-Performance Transformers with SCALED DOT PRODUCT ATTENTION :card_description: This tutorial explores the new torch.nn.functional.scaled_dot_product_attention and how it can be used to construct Transformer components. :image: _static/img/thumbnails/cropped/pytorch-logo.png - :link: beginner/scaled_dot_product_attention_tutorial.html + :link: intermediate/scaled_dot_product_attention_tutorial.html :tags: Model-Optimization,Attention,Transformer .. Parallel-and-Distributed-Training @@ -916,7 +916,7 @@ Additional Resources intermediate/nvfuser_intro_tutorial intermediate/ax_multiobjective_nas_tutorial intermediate/torch_compile_tutorial - beginner/scaled_dot_product_attention_tutorial + intermediate/scaled_dot_product_attention_tutorial .. toctree:: :maxdepth: 2 diff --git a/beginner_source/scaled_dot_product_attention_tutorial.py b/intermediate_source/scaled_dot_product_attention_tutorial.py similarity index 68% rename from beginner_source/scaled_dot_product_attention_tutorial.py rename to intermediate_source/scaled_dot_product_attention_tutorial.py index a54b5ef9363..c5f84308f42 100644 --- a/beginner_source/scaled_dot_product_attention_tutorial.py +++ b/intermediate_source/scaled_dot_product_attention_tutorial.py @@ -1,6 +1,6 @@ """ -An overview of torch.nn.functional.scaled_dot_product_attention -=============================================================== +Implementing High-Performance Transformers with SCALED DOT PRODUCT ATTENTION +================================================================================ """ @@ -8,38 +8,32 @@ ###################################################################### # Summary # ~~~~~~~~ -# -# In this tutorial we want to highlight a new ``torch.nn.functional`` function +# +# In this tutorial, we want to highlight a new ``torch.nn.functional`` function # that can be helpful for implementing transformer architectures. The # function is named ``torch.nn.functional.scaled_dot_product_attention``. -# There is some extensive documentation on the function in the `PyTorch -# documentation `__. -# This function has already been incorporated into torch.nn.MHA -# (Multi-Head Attention) and ``torch.nn.TransformerEncoderLayer``. -# +# For detailed description of the function, see the `PyTorch documentation `__. +# This function has already been incorporated into ``torch.nn.MultiheadAttention`` and ``torch.nn.TransformerEncoderLayer``. +# # Overview -# ~~~~~~~ -# At a high level this PyTorch function calculates the -# scaled dot product attention between query, key, and value according to +# ~~~~~~~~~ +# At a high level, this PyTorch function calculates the +# scaled dot product attention (SDPA) between query, key, and value according to # the definition found in the paper `Attention is all you -# need `__. While this function can be -# written in PyTorch using existing functions, for GPU tensors this -# function will implicitly dispatch to an optimized implementation. The -# function is also highly modular and can be used to implement other -# attention mechanisms such as -# `Linformer `__ -# -# Fused implementations: +# need `__. While this function can +# be written in PyTorch using existing functions, a fused implementation can provide +# large performance benefits over a naive implementation. +# +# Fused implementations # ~~~~~~~~~~~~~~~~~~~~~~ -# -# For CUDA tensor inputs the function will dispatch into one of three -# implementations: +# +# For CUDA tensor inputs, the function will dispatch into one of the following +# implementations: +# # * `FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness `__ -# Attention with IO-Awareness `__ \* -# `Memory-Efficient -# Attention `__ \* A PyTorch -# implementation defined in C++ -# +# * `Memory-Efficient Attention `__ +# * A PyTorch implementation defined in C++ +# import torch import torch.nn as nn @@ -54,15 +48,15 @@ ###################################################################### # Explicit Dispatcher Control # ~~~~~~~~~~~~~~~~~~~~~~~~~~~ -# +# # While the function will implicitly dispatch to one of the three # implementations, the user can also explicitly control the dispatch via # the use of a context manager. This context manager allows users to # explicitly disable certain implementations. If a user wants to ensure -# the function is indeed using the fasted implementation for their -# specific inputs the context manager can be used to sweep through +# the function is indeed using the fastest implementation for their +# specific inputs, the context manager can be used to sweep through # measuring performance. -# +# # Lets define a helpful benchmarking function: import torch.utils.benchmark as benchmark @@ -102,35 +96,38 @@ def benchmark_torch_function_in_microseconds(f, *args, **kwargs): with sdp_kernel(**backend_map[SDPBackend.FLASH_ATTENTION]): - print(f"The flash attention implementation runs in {benchmark_torch_function_in_microseconds(F.scaled_dot_product_attention, query, key, value):.3f} microseconds") - + try: + print(f"The flash attention implementation runs in {benchmark_torch_function_in_microseconds(F.scaled_dot_product_attention, query, key, value):.3f} microseconds") + except RuntimeError: + print("FlashAttention is not supported. See warnings for reasons.") with sdp_kernel(**backend_map[SDPBackend.EFFICIENT_ATTENTION]): - print(f"The memory efficient implementation runs in {benchmark_torch_function_in_microseconds(F.scaled_dot_product_attention, query, key, value):.3f} microseconds") + try: + print(f"The memory efficient implementation runs in {benchmark_torch_function_in_microseconds(F.scaled_dot_product_attention, query, key, value):.3f} microseconds") + except RuntimeError: + print("EfficientAttention is not supported. See warnings for reasons.") ###################################################################### # Hardware dependence # ~~~~~~~~~~~~~~~~~~~ -# +# # Depending on what machine you ran the above cell on and what hardware is -# available your results might be different. -# - If you don’t have a GPU and are running on CPU then the context manager will have no effect and all -# are running on CPU then the context manager will have no effect and all -# three run should return similar timings. - Depending on what Compute -# Capability your graphics card supports FlashAttention or memory -# efficient might have failed. -# +# available, your results might be different. +# - If you don’t have a GPU and are running on CPU then the context manager +# will have no effect and all three runs should return similar timings. +# - Depending on what compute capability your graphics card supports +# flash attention or memory efficient might have failed. ###################################################################### # Causal Self Attention # ~~~~~~~~~~~~~~~~~~~~~ -# +# # Below is an example implementation of a multi-headed causal self # attention block inspired by Andrej Karpathy’s # `NanoGPT `__ repository. -# +# class CausalSelfAttention(nn.Module): @@ -187,7 +184,11 @@ def forward(self, x): ###################################################################### # NestedTensor and Dense tensor support # ------------------------------------- -# +# +# SDPA supports both NestedTensor and Dense tensor inputs. NestedTensors handle the case where the input is a batch of variable length sequences +# without needing to pad each sequence to the maximum length in the batch. For more information about NestedTensors see +# `torch.nested `__ and `NestedTensors Tutorial `__. +# import random def generate_rand_batch( @@ -227,21 +228,31 @@ def generate_rand_batch( seq_len_list, ) -# Currently the fastpaths don't support NestedTensor for training random_nt, _ = generate_rand_batch(32, 512, embed_dimension, pad_percentage=0.5, dtype=dtype, device=device) random_dense, _ = generate_rand_batch(32, 512, embed_dimension, pad_percentage=None, dtype=dtype, device=device) -model.requires_grad_(False) -print(f"Random NT runs in {benchmark_torch_function_in_microseconds(model, random_nt):.3f} microseconds") -print(f"Random Dense runs in {benchmark_torch_function_in_microseconds(model, random_dense):.3f} microseconds") + +# Currently the fused implementations don't support NestedTensor for training +model.eval() + +with sdp_kernel(**backend_map[SDPBackend.FLASH_ATTENTION]): + try: + print(f"Random NT runs in {benchmark_torch_function_in_microseconds(model, random_nt):.3f} microseconds") + print(f"Random Dense runs in {benchmark_torch_function_in_microseconds(model, random_dense):.3f} microseconds") + except RuntimeError: + print("FlashAttention is not supported. See warnings for reasons.") ###################################################################### -# Composable with 2.0 Features +# Using SDPA with torch.compile # ============================ -# -# Scaled dot product attention is composable with torch.compile(). Lets -# try compiling the above CausalSelfAttention module -# +# +# With the release of PyTorch 2.0, a new feature called +# ``torch.compile()`` has been introduced, which can provide +# significant performance improvements over eager mode. +# Scaled dot product attention is fully composable with ``torch.compile()``. +# To demonstrate this, let's compile the CausalSelfAttention module using +# ``torch.compile()`` and observe the resulting performance improvements. +# batch_size = 32 max_sequence_len = 256 @@ -252,20 +263,21 @@ def generate_rand_batch( compiled_model = torch.compile(model) -# Lets warm it up once +# Let's compile it compiled_model(x) print( f"The compiled module runs in {benchmark_torch_function_in_microseconds(compiled_model, x):.3f} microseconds") ###################################################################### -# HMM.. -# ~~~~~ -# -# That is not what we were expecting. Let's dig a little deeper. +# +# The exact execution time is dependent on machine, however the results for mine: +# The non compiled module runs in 166.616 microseconds +# The compiled module runs in 166.726 microseconds +# That is not what we were expecting. Let's dig a little deeper. # PyTorch comes with an amazing built-in profiler that you can use to # inspect the performance characteristics of your code. -# +# from torch.profiler import profile, record_function, ProfilerActivity activities = [ProfilerActivity.CPU] @@ -276,39 +288,44 @@ def generate_rand_batch( with record_function(" Non-Compilied Causal Attention"): for _ in range(25): model(x) -print(prof.key_averages().table(sort_by="cpu_time_total", row_limit=20)) +print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10)) with profile(activities=activities, record_shapes=False) as prof: with record_function("Compiled Causal Attention"): for _ in range(25): compiled_model(x) -print(prof.key_averages().table(sort_by="cpu_time_total", row_limit=20)) +print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10)) # For even more insights, you can export the trace and use ``chrome://tracing`` to view the results -# prof.export_chrome_trace("compiled_causal_attention_trace.json") +# prof.export_chrome_trace("compiled_causal_attention_trace.json"). ###################################################################### -# The problem here is that ``torch.compile`` is very good at removing the +# The previous code snippet generates a report of the top 10 PyTorch functions +# that consumed the most GPU execution time, for both the compiled and non-compiled module. +# The analysis reveals that the majority of time spent on the GPU is concentrated +# on the same set of functions for both modules. +# The reason for this here is that ``torch.compile`` is very good at removing the # framework overhead associated with PyTorch. If your model is launching # large, efficient CUDA kernels, which in this case CausaulSelfAttention -# is, then the overhead of ``torch.compile`` can hurt performance. -# +# is, then the overhead of PyTorch can be hidden. +# # In reality, your module does not normally consist of a singular # CausalSelfAttention block. When experimenting with Andrej Karpathy’s # `NanoGPT `__ repository, compiling -# the module took the time per train step from: ``902.01ms`` to -# ``552.06ms``! -# +# the module took the time per train step from: ``6090.49ms`` to +# ``3273.17ms``! This was done on commit: ae3a8d5 of NanoGPT training on +# the shakespeare dataset. +# ###################################################################### # Conclusion # ========== -# +# # In this tutorial, we have demonstrated the basic usage of # ``torch.nn.functional.scaled_dot_product_attention``. We have shown how # the ``sdp_kernel`` context manager can be used to assert a certain @@ -317,4 +334,4 @@ def generate_rand_batch( # compilable. In the process we have shown how to the profiling tools can # be used to explore the performance characteristics of a user defined # module. -# \ No newline at end of file +# \ No newline at end of file