diff --git a/.jenkins/metadata.json b/.jenkins/metadata.json index aa479828d02..6e82d054b4e 100644 --- a/.jenkins/metadata.json +++ b/.jenkins/metadata.json @@ -33,7 +33,7 @@ }, "recipes_source/torch_export_aoti_python.py": { "needs": "linux.g5.4xlarge.nvidia.gpu" - }, + }, "advanced_source/pendulum.py": { "needs": "linux.g5.4xlarge.nvidia.gpu", "_comment": "need to be here for the compiling_optimizer_lr_scheduler.py to run." @@ -58,6 +58,9 @@ "intermediate_source/scaled_dot_product_attention_tutorial.py": { "needs": "linux.g5.4xlarge.nvidia.gpu" }, + "intermediate_source/transformer_building_blocks.py": { + "needs": "linux.g5.4xlarge.nvidia.gpu" + }, "recipes_source/torch_compile_user_defined_triton_kernel_tutorial.py": { "needs": "linux.g5.4xlarge.nvidia.gpu" }, diff --git a/.jenkins/validate_tutorials_built.py b/.jenkins/validate_tutorials_built.py index 2aaa5d6ef71..c3bf4c5534b 100644 --- a/.jenkins/validate_tutorials_built.py +++ b/.jenkins/validate_tutorials_built.py @@ -25,6 +25,7 @@ "intermediate_source/mnist_train_nas", # used by ax_multiobjective_nas_tutorial.py "intermediate_source/fx_conv_bn_fuser", "intermediate_source/_torch_export_nightly_tutorial", # does not work on release + "intermediate_source/transformer_building_blocks", # does not work on release "advanced_source/super_resolution_with_onnxruntime", "advanced_source/usb_semisup_learn", # fails with CUDA OOM error, should try on a different worker "prototype_source/fx_graph_mode_ptq_dynamic", diff --git a/en-wordlist.txt b/en-wordlist.txt index 8a118b941b2..2ccab08b094 100644 --- a/en-wordlist.txt +++ b/en-wordlist.txt @@ -1,5 +1,6 @@ ACL ADI +ALiBi AOT AOTInductor APIs @@ -79,6 +80,7 @@ FX FX's FairSeq Fastpath +FFN FloydHub FloydHub's Frobenius @@ -127,6 +129,7 @@ Kihyuk Kiuk Kubernetes Kuei +KV LRSchedulers LSTM LSTMs @@ -162,6 +165,7 @@ NLP NTK NUMA NaN +NaNs NanoGPT Netron NeurIPS @@ -231,6 +235,7 @@ Sigmoid SoTA Sohn Spacy +SwiGLU TCP THP TIAToolbox @@ -276,6 +281,7 @@ Xcode Xeon Yidong YouTube +Zipf accelerometer accuracies activations @@ -305,6 +311,7 @@ bbAP benchmarked benchmarking bitwise +bool boolean breakpoint broadcasted @@ -333,6 +340,7 @@ csv cuDNN cuda customizable +customizations datafile dataflow dataframe @@ -377,6 +385,7 @@ fbgemm feedforward finetune finetuning +FlexAttention fp frontend functionalized @@ -431,6 +440,7 @@ mAP macos manualSeed matmul +matmuls matplotlib memcpy memset @@ -446,6 +456,7 @@ modularized mpp mucosa multihead +MultiheadAttention multimodal multimodality multinode @@ -456,7 +467,11 @@ multithreading namespace natively ndarrays +nheads nightlies +NJT +NJTs +NJT's num numericalize numpy @@ -532,6 +547,7 @@ runtime runtime runtimes scalable +SDPA sharded softmax sparsified @@ -591,12 +607,14 @@ tradeoff tradeoffs triton uint +UX umap uncomment uncommented underflowing unfused unimodal +unigram unnormalized unoptimized unparametrized @@ -618,6 +636,7 @@ warmstarted warmstarting warmup webp +wikitext wsi wsis Meta's diff --git a/index.rst b/index.rst index 86592b22411..385e589de3b 100644 --- a/index.rst +++ b/index.rst @@ -664,6 +664,14 @@ Welcome to PyTorch Tutorials :link: beginner/knowledge_distillation_tutorial.html :tags: Model-Optimization,Image/Video + +.. customcarditem:: + :header: Accelerating PyTorch Transformers by replacing nn.Transformer with Nested Tensors and torch.compile() + :card_description: This tutorial goes over recommended best practices for implementing Transformers with native PyTorch. + :image: _static/img/thumbnails/cropped/pytorch-logo.png + :link: intermediate/transformer_building_blocks.html + :tags: Transformer + .. Parallel-and-Distributed-Training diff --git a/intermediate_source/transformer_building_blocks.py b/intermediate_source/transformer_building_blocks.py new file mode 100644 index 00000000000..932be472e89 --- /dev/null +++ b/intermediate_source/transformer_building_blocks.py @@ -0,0 +1,781 @@ +""" +Accelerating PyTorch Transformers by replacing ``nn.Transformer`` with Nested Tensors and ``torch.compile()`` +============================================================================================================= +**Author:** `Mikayla Gawarecki `_ + +.. note:: + This tutorial currently requires you to use the PyTorch nightly build. + +.. grid:: 2 + + .. grid-item-card:: :octicon:`mortar-board;1em;` What you will learn + :class-card: card-prerequisites + + * Learn about the low-level building blocks PyTorch provides to build custom transformer layers ( + nested tensors, ``scaled_dot_product_attention``, ``torch.compile()``, and ``FlexAttention``) + * Discover how the above improve memory usage and performance using MultiHeadAttention as an example + * Explore advanced customizations using the aforementioned building blocks + + .. grid-item-card:: :octicon:`list-unordered;1em;` Prerequisites + :class-card: card-prerequisites + + * PyTorch v.2.6.0 or later + + +Over the past few years, the PyTorch team has developed various lower level +features that, when composed, can create a variety of transformer variants. These +include: + +* Nested Tensors with the ``torch.jagged`` layout (AKA NJTs) +* ``scaled_dot_product_attention`` +* ``torch.compile()`` +* ``FlexAttention`` + +This tutorial will give a brief overview of the above technologies and +demonstrate how they can be composed to yield flexible and performant transformer \ +layers with improved user experience. + +One may observe that the ``torch.nn`` module currently provides various ``Transformer``-related layers. +In particular, it includes ``TransformerEncoderLayer``, ``TransformerEncoder``, ``TransformerDecoderLayer``, +``TransformerDecoder``, ``Transformer`` and ``MultiheadAttention``. This family +of layers was initially implemented following the `Attention is All +You Need `_ paper. The components discussed in +this tutorial provide improved user experience, flexibility and performance over +the existing ``nn`` layers. + + +Is this tutorial for me? +======================== + +If you are wondering about what building blocks the ``torch`` library provides +for writing your own transformer layers and best practices, you are in the +right place. Please keep reading! + +If you are looking for an out-of-the-box implementation of a popular transformer +architecture, note that there are many open-source libraries that provide them, +including: + +* `HuggingFace transformers `_ +* `xformers `_ +* `torchtune `_ + +If you are only interested in performant attention score modifications, please +check out the `FlexAttention blog `_ that +contains a `gym of masks `_. + +""" + +################################################################################ +# Introducing the Building Blocks +# =============================== +# First, we will briefly introduce the four technologies mentioned in the introduction +# +# * `torch.nested `_ +# +# Nested tensors generalize the shape of regular dense tensors, allowing for +# representation of ragged-sized data with the same tensor UX. In the context of +# transformers, we can think of nested tensors as a tool for representing variable +# sequence lengths. They eliminate the need for the bug-prone practices of explicit +# padding and masking (think ``key_padding_mask`` in ``nn.MultiHeadAttention``). +# +# * `scaled_dot_product_attention `_ +# +# ``scaled_dot_product_attention`` is a primitive for +# :math:`\text{softmax}(\frac{QK^T}{\sqrt{E}} + B)V` that dispatches into either fused +# implementations of the operator or a fallback implementation. It works out of +# the box in eager mode (i.e. the default mode of using PyTorch where operations +# are executed on the fly as they are encountered) and also integrates seamlessly +# with ``torch.compile()``. As of 2.6, it will also offer grouped query attention +# natively. +# +# * `torch.compile() `_ +# +# ``torch.compile()`` is a compiler introduced in version 2.0 that is able to +# capture a graph of PyTorch code and perform various optimizations on it, such as +# fusing together sequences of ops. Nested tensors with the ``torch.jagged`` layout +# and ``scaled_dot_product_attention`` work seamlessly with compile. In the +# context of transformers, the value add of using compile with nested tensor +# and SDPA is that compile can remove framework overhead ones sees in eager mode +# and fuse sequences of ops in transformers together, such as projection and +# activation. +# +# * `FlexAttention `_ +# +# ``FlexAttention`` is a primitive that allows users to modify attention scores +# prior to the softmax operation. It generalizes the additive ``B`` term above +# for ``scaled_dot_product_attention``, allowing for arbitrary calculation. It +# requires compile to achieve good performance. +# +# The above building blocks are "All You Need" (as of October 2024) +# ================================================================== +# +# The main premise in this section is that most transformer variations are +# GPT-style, consisting of layers like Embedding, Positional Encoding, Attention +# Blocks and Feed Forward networks. If we were to try to classify the differences +# in this space, we might land on something like: +# +# 1. Layer type (activation functions such as ``SwiGLU`` and others, normalization functions +# such as ``RMSNorm`` and others, positional encodings, such as Sinusoidal, Rotary.) +# 2. Layer ordering, such as where to apply norms and positional encoding. +# 3. Modifications to attention score, such as ``ALiBi``, Relative Positional Bias and so on. +# +# +# In a pre-compiler environment, you might write a custom transformer and notice +# that it functions correctly but is slow. To address this, you might develop a +# custom fused kernel for the specific series of operations. In a compiler environment, +# you can simply perform the initial step and then compile and benefit from improved performance. + + +############################################################################### +# MultiheadAttention +# ------------------ +# Remember that MultiheadAttention takes in a query, key, and value, and consists +# of an input projection, a ``scaled_dot_product_attention`` operator and an +# output projection. The main takeaway we want to demonstrate here is the +# improvement yielded when we replaced padded/masked inputs with nested tensors. +# The improvements are threefold: +# +# * **User Experience** +# Remember that ``nn.MultiheadAttention`` requires ``query``, ``key``, and +# ``value`` to be dense ``torch.Tensors``. It also provides a +# ``key_padding_mask`` that is used to mask out padding tokens in the ``key`` +# that arise due to different sequence lengths within a batch. Since there is +# no ``query_padding_mask`` in ``nn.MHA``, users have to take care to mask/slice +# the outputs appropriately to account for query sequence lengths. ``NestedTensor`` +# cleanly removes the need for this sort of error-prone padding masks. +# +# * **Memory** +# Instead of materializing a dense ``[B, S, D]`` tensor with a ``[B, S]`` +# padding mask (where ``B`` is batch size, ``S`` is max sequence length in the +# batch and ``D`` is embedding size), nested tensors allow you to cleanly +# represent the batch of varying sequence lengths. As a result, the input and +# intermediate activations will use less memory. +# +# * **Performance** +# Since padding is not materialized and unnecessary computation on padding is +# skipped, performance and memory usage improve. +# +# We'll demonstrate the above by building upon the ``MultiheadAttention`` layer in the +# `Nested Tensor tutorial `_ +# and comparing it to the ``nn.MultiheadAttention`` layer. + +import torch +import torch.nn as nn +import torch.nn.functional as F + +class MultiHeadAttention(nn.Module): + """ + Computes multi-head attention. Supports nested or padded tensors. + + Args: + E_q (int): Size of embedding dim for query + E_k (int): Size of embedding dim for key + E_v (int): Size of embedding dim for value + E_total (int): Total embedding dim of combined heads post input projection. Each head + has dim E_total // nheads + nheads (int): Number of heads + dropout (float, optional): Dropout probability. Default: 0.0 + bias (bool, optional): Whether to add bias to input projection. Default: True + """ + def __init__( + self, + E_q: int, + E_k: int, + E_v: int, + E_total: int, + nheads: int, + dropout: float = 0.0, + bias=True, + device=None, + dtype=None, + ): + factory_kwargs = {"device": device, "dtype": dtype} + super().__init__() + self.nheads = nheads + self.dropout = dropout + self._qkv_same_embed_dim = E_q == E_k and E_q == E_v + if self._qkv_same_embed_dim: + self.packed_proj = nn.Linear(E_q, E_total * 3, bias=bias, **factory_kwargs) + else: + self.q_proj = nn.Linear(E_q, E_total, bias=bias, **factory_kwargs) + self.k_proj = nn.Linear(E_k, E_total, bias=bias, **factory_kwargs) + self.v_proj = nn.Linear(E_v, E_total, bias=bias, **factory_kwargs) + E_out = E_q + self.out_proj = nn.Linear(E_total, E_out, bias=bias, **factory_kwargs) + assert E_total % nheads == 0, "Embedding dim is not divisible by nheads" + self.E_head = E_total // nheads + self.bias = bias + + def forward(self, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attn_mask=None, + is_causal=False) -> torch.Tensor: + """ + Forward pass; runs the following process: + 1. Apply input projection + 2. Split heads and prepare for SDPA + 3. Run SDPA + 4. Apply output projection + + Args: + query (torch.Tensor): query of shape (``N``, ``L_q``, ``E_qk``) + key (torch.Tensor): key of shape (``N``, ``L_kv``, ``E_qk``) + value (torch.Tensor): value of shape (``N``, ``L_kv``, ``E_v``) + attn_mask (torch.Tensor, optional): attention mask of shape (``N``, ``L_q``, ``L_kv``) to pass to SDPA. Default: None + is_causal (bool, optional): Whether to apply causal mask. Default: False + + Returns: + attn_output (torch.Tensor): output of shape (N, L_t, E_q) + """ + # Step 1. Apply input projection + if self._qkv_same_embed_dim: + if query is key and key is value: + result = self.packed_proj(query) + query, key, value = torch.chunk(result, 3, dim=-1) + else: + q_weight, k_weight, v_weight = torch.chunk(self.packed_proj.weight, 3, dim=0) + if self.bias: + q_bias, k_bias, v_bias = torch.chunk(self.packed_proj.bias, 3, dim=0) + else: + q_bias, k_bias, v_bias = None, None, None + query, key, value = F.linear(query, q_weight, q_bias), F.linear(key, k_weight, k_bias), F.linear(value, v_weight, v_bias) + + else: + query = self.q_proj(query) + key = self.k_proj(key) + value = self.v_proj(value) + + # Step 2. Split heads and prepare for SDPA + # reshape query, key, value to separate by head + # (N, L_t, E_total) -> (N, L_t, nheads, E_head) -> (N, nheads, L_t, E_head) + query = query.unflatten(-1, [self.nheads, self.E_head]).transpose(1, 2) + # (N, L_s, E_total) -> (N, L_s, nheads, E_head) -> (N, nheads, L_s, E_head) + key = key.unflatten(-1, [self.nheads, self.E_head]).transpose(1, 2) + # (N, L_s, E_total) -> (N, L_s, nheads, E_head) -> (N, nheads, L_s, E_head) + value = value.unflatten(-1, [self.nheads, self.E_head]).transpose(1, 2) + + # Step 3. Run SDPA + # (N, nheads, L_t, E_head) + attn_output = F.scaled_dot_product_attention( + query, key, value, dropout_p=self.dropout, is_causal=is_causal) + # (N, nheads, L_t, E_head) -> (N, L_t, nheads, E_head) -> (N, L_t, E_total) + attn_output = attn_output.transpose(1, 2).flatten(-2) + + # Step 4. Apply output projection + # (N, L_t, E_total) -> (N, L_t, E_out) + attn_output = self.out_proj(attn_output) + + return attn_output + + +############################################################################### +# Utilities +# ~~~~~~~~~ +# In this section, we include a utility to generate semi-realistic data using +# ``Zipf`` distribution for sentence lengths. This is used to generate the nested +# query, key, and value tensors. We also include a benchmark utility. + + +import numpy as np + +def zipf_sentence_lengths(alpha: float, batch_size: int) -> torch.Tensor: + # generate fake corpus by unigram Zipf distribution + # from wikitext-2 corpus, we get rank "." = 3, "!" = 386, "?" = 858 + sentence_lengths = np.empty(batch_size, dtype=int) + for ibatch in range(batch_size): + sentence_lengths[ibatch] = 1 + word = np.random.zipf(alpha) + while word != 3 and word != 386 and word != 858: + sentence_lengths[ibatch] += 1 + word = np.random.zipf(alpha) + return torch.tensor(sentence_lengths) + +# Generate a batch of semi-realistic data using Zipf distribution for sentence lengths +# in the form of nested tensors with the jagged layout. +def gen_batch(N, E_q, E_k, E_v, device, dtype=torch.float32, query_seq_len_1=False): + # generate semi-realistic data using Zipf distribution for sentence lengths + sentence_lengths = zipf_sentence_lengths(alpha=1.2, batch_size=N) + + # Note: the torch.jagged layout is a nested tensor layout that supports a single ragged + # dimension and works with torch.compile. The batch items each have shape (B, S*, D) + # where B = batch size, S* = ragged sequence length, and D = embedding dimension. + if query_seq_len_1: + query = torch.nested.nested_tensor([ + torch.randn(1, E_q, dtype=dtype, device=device) + for l in sentence_lengths + ], layout=torch.jagged) + else: + query = torch.nested.nested_tensor([ + torch.randn(l.item(), E_q, dtype=dtype, device=device) + for l in sentence_lengths + ], layout=torch.jagged) + + key = torch.nested.nested_tensor([ + torch.randn(s.item(), E_k, dtype=dtype, device=device) + for s in sentence_lengths + ], layout=torch.jagged) + + value = torch.nested.nested_tensor([ + torch.randn(s.item(), E_v, dtype=dtype, device=device) + for s in sentence_lengths + ], layout=torch.jagged) + + return query, key, value, sentence_lengths + +import timeit +import math + +def benchmark(func, *args, **kwargs): + torch.cuda.synchronize() + torch.cuda.reset_peak_memory_stats() + begin = timeit.default_timer() + output = func(*args, **kwargs) + torch.cuda.synchronize() + end = timeit.default_timer() + return output, (end - begin), torch.cuda.max_memory_allocated() + +############################################################################## +# We will now demonstrate the performance improvements of using nested tensors +# in the ``MultiheadAttention`` layer + compile for self attention. We compare this against +# the traditional ``nn.MultiheadAttention`` + compile with padding and masking. + +N, E_q, E_k, E_v, E_total = 512, 512, 512, 512, 512 +E_out = E_q +d_model = E_q +nheads = 8 +dropout = 0.0 +bias = True +device='cuda' +torch.manual_seed(6) +query, key, value, sentence_lengths = gen_batch(N, E_q, E_k, E_v, device) +S = sentence_lengths.max().item() +print(f"Total sequence length in nested query {sentence_lengths.sum().item()}, max sequence length {S}") +padded_query, padded_key, padded_value = ( + t.to_padded_tensor(0.0) for t in (query, key, value) +) + +torch.manual_seed(6) +mha_layer = MultiHeadAttention(E_q, E_k, E_v, E_total, nheads, dropout=dropout, bias=bias, device='cuda') +torch.manual_seed(6) +vanilla_mha_layer = nn.MultiheadAttention(E_q, nheads, dropout=dropout, batch_first=True, bias=bias, device='cuda') + +# ``nn.MultiheadAttention`` uses a non conventional initialization for layers, so do this for exact parity :( +mha_layer.out_proj.weight = nn.Parameter(vanilla_mha_layer.out_proj.weight.clone().detach()) +mha_layer.packed_proj.weight = nn.Parameter(vanilla_mha_layer.in_proj_weight.clone().detach()) +mha_layer.out_proj.bias = nn.Parameter(vanilla_mha_layer.out_proj.bias.clone().detach()) +mha_layer.packed_proj.bias = nn.Parameter(vanilla_mha_layer.in_proj_bias.clone().detach()) + +new_mha_layer = torch.compile(mha_layer) +# warmup compile +nested_result_warmup = new_mha_layer(query, query, query, is_causal=True) + +# benchmark +nested_result, nested_time, nested_peak_memory = benchmark(new_mha_layer, query, query, query, is_causal=True) +padded_nested_result = nested_result.to_padded_tensor(0.0) + +# For the vanilla ``nn.MultiheadAttention``, we need to construct the ``key_padding_mask`` +# Further, ``nn.MultiheadAttention`` forces one to materialize the ``attn_mask`` even if using ``is_causal`` +src_key_padding_mask = torch.where(padded_query == 0.0, -math.inf, 0)[:, :, 0] +attn_mask = torch.empty((N, S, S), device=device).fill_(float('-inf')) +for i, s in enumerate(sentence_lengths): + attn_mask[i, :s, :s] = nn.Transformer.generate_square_subsequent_mask(s) +attn_mask = attn_mask.unsqueeze(1).expand(N, nheads, S, S).reshape(N*nheads, S, S) + +vanilla_mha_layer = torch.compile(vanilla_mha_layer) +# warmup compile +warmup_vanilla_result = vanilla_mha_layer(padded_query, + padded_query, + padded_query, + attn_mask=attn_mask, + key_padding_mask=src_key_padding_mask, + need_weights=False, + is_causal=True) + +# benchmark +(padded_result, _), padded_time, padded_peak_memory = benchmark(vanilla_mha_layer, + padded_query, + padded_query, + padded_query, + key_padding_mask=src_key_padding_mask, + need_weights=False, + attn_mask=attn_mask, + is_causal=True) + +print(f"{padded_time=:.5f}, padded_peak_memory={padded_peak_memory/1e9:.2f} GB") +print(f"{nested_time=:.5f}, nested_peak_memory={nested_peak_memory/1e9:.2f} GB") +print("Max difference between vanilla and nested result", (padded_result - padded_nested_result).abs().max().item()) +print(f"Nested speedup: {(padded_time/nested_time):.2f}") +print(f"Nested peak memory reduction {((padded_peak_memory - nested_peak_memory)/1e9):.2f} GB") + +###################################################################################### +# For reference, here are some sample outputs on A100: +# +# .. code:: +# +# padded_time=0.03454, padded_peak_memory=4.14 GB +# nested_time=0.00612, nested_peak_memory=0.76 GB +# Max difference between vanilla and nested result 0.0 +# Nested speedup: 5.65 +# Nested peak memory reduction 3.39 GB +# +# We can also see the same for backward pass + +for i, entry_length in enumerate(sentence_lengths): + # padding-specific step: remove output projection bias from padded entries for fair comparison + padded_result[i, entry_length:, :] = 0.0 + +_, padded_bw_time, padded_bw_peak_mem = benchmark(lambda : padded_result.sum().backward()) +_, nested_bw_time, nested_bw_peak_mem = benchmark(lambda : padded_nested_result.sum().backward()) + +print(f"{padded_bw_time=:.5f}, padded_bw_peak_mem={padded_bw_peak_mem/1e9:.2f} GB") +print(f"{nested_bw_time=:.5f}, nested_bw_peak_mem={nested_bw_peak_mem/1e9:.2f} GB") +print(f"Nested backward speedup: {(padded_bw_time/nested_bw_time):.2f}") +print(f"Nested backward peak memory reduction {((padded_bw_peak_mem - nested_bw_peak_mem)/1e9):.2f} GB") + +print("Difference in out_proj.weight.grad", (mha_layer.out_proj.weight.grad - vanilla_mha_layer.out_proj.weight.grad).abs().max().item()) +print("Difference in packed_proj.weight.grad", (mha_layer.packed_proj.weight.grad - vanilla_mha_layer.in_proj_weight.grad).abs().max().item()) +print("Difference in out_proj.bias.grad", (mha_layer.out_proj.bias.grad - vanilla_mha_layer.out_proj.bias.grad).abs().max().item()) +print("Difference in packed_proj.bias.grad", (mha_layer.packed_proj.bias.grad - vanilla_mha_layer.in_proj_bias.grad).abs().max().item()) + +################################################################################## +# Sample outputs on A100: +# +# .. code:: +# +# padded_bw_time=2.09337, padded_bw_peak_mem=5.10 GB +# nested_bw_time=0.01452, nested_bw_peak_mem=3.24 GB +# Nested backward speedup: 144.13 +# Nested backward peak memory reduction 1.86 GB +# Difference in out_proj.weight.grad 0.000244140625 +# Difference in packed_proj.weight.grad 0.001556396484375 +# Difference in out_proj.bias.grad 0.0 +# Difference in packed_proj.bias.grad 0.001953125 +# + +################################################################################## +# GPT-style layer +# --------------- +# A basic GPT-style transformer layer consists of a causal self-attention layer +# followed by a feed-forward network (FFN) with skip connections. Implementing +# this is fairly straightforward using the ``MultiheadAttention`` layer above and +# gives equivalent results to an ``nn.TransformerEncoderLayer`` with +# ``is_causal=True``. +# +# We demonstrate examples of implementing the rest of the ``nn`` layers +# `here `_ +# but omit that from this tutorial for brevity. + + +############################################################################### +# Going one step further +# ---------------------- +# So far, we have demonstrated how to implement a performant ``MultiheadAttention`` +# layer that follows the traditional ``nn.MultiheadAttention``. Going back to our +# classification of modifications to the transformer architecture, remember that we +# classified the modifications into layer type, layer ordering, and modifications +# to the attention score. We trust that changing layer type and layer ordering +# (such as swapping ``LayerNorm`` for ``RMSNorm``) is fairly straightforward. +# +# In this section, we will discuss various functionalities using the +# aforementioned building blocks, including the following: +# +# * Cross Attention +# * Fully masked rows no longer cause NaNs +# * Modifying attention score: ALiBi with FlexAttention and NJT +# * Packed Projection + +############################################################################### +# Cross Attention +# --------------- +# Cross attention is a form of attention where the query and key/value tensors +# are from different sequences. +# +# One example of this is in ``nn.TransformerDecoderLayer`` where the query comes +# from the decoder and the key/value come from the encoder. +# +# The above MultiheadAttention layer nicely generalizes to this case with nested +# tensors for both query and key/value. + +query, _, _, q_len = gen_batch(N, E_q, E_k, E_v, device) +_, key, value, kv_len = gen_batch(N, E_q, E_k, E_v, device) + +print(f"Total sequence length in nested query {q_len.sum().item()}, max sequence length {q_len.max().item()}") +print(f"Total sequence length in nested key/value {kv_len.sum().item()}, max sequence length {kv_len.max().item()}") +out = new_mha_layer(query, key, value, is_causal=False) + +######################################################################################## +# As above, we can compare this against the vanilla compiled ``nn.MultiheadAttention``. + +torch.manual_seed(6) +query, _, _, q_len = gen_batch(N, E_q, E_k, E_v, device) +_, key, value, kv_len = gen_batch(N, E_q, E_k, E_v, device) +padded_query, padded_key, padded_value = ( + t.to_padded_tensor(0.0) for t in (query, key, value) +) + +key_padding_mask = torch.where(padded_key == 0.0, -math.inf, 0)[:, :, 0] + +# warmup compile +warmup_nested_result = new_mha_layer(query, key, value, is_causal=False) +warmup_vanilla_result = vanilla_mha_layer(padded_query, + padded_key, + padded_value, + key_padding_mask=key_padding_mask, + need_weights=False, + is_causal=False) + +nested_result, nested_time, nested_peak_memory = benchmark(new_mha_layer, query, key, value, is_causal=False) +(padded_result, _), padded_time, padded_peak_memory = benchmark(vanilla_mha_layer, + padded_query, + padded_key, + padded_value, + key_padding_mask=key_padding_mask, + need_weights=False, + is_causal=False) +padded_nested_result = nested_result.to_padded_tensor(0.0) +for i, entry_length in enumerate(q_len): + # padding-specific step: remove output projection bias from padded entries for fair comparison + padded_result[i, entry_length:, :] = 0.0 + +print("Max difference between vanilla and nested result", (padded_result - padded_nested_result).abs().max().item()) +print(f"Nested speedup: {(padded_time/nested_time):.2f}") +print(f"Nested peak memory reduction {((padded_peak_memory - nested_peak_memory)/1e9):.2f} GB") + +################################################################################## +# Sample outputs on A100: +# +# .. code:: +# +# Max difference between vanilla and nested result 0.0 +# Nested speedup: 4.01 +# Nested peak memory reduction 1.40 GB +# + +################################################################################ +# Fully masked rows no longer cause NaNs +# -------------------------------------- +# +# There has been a long standing issue with ``nn.MultiheadAttention`` and +# ``scaled_dot_product_attention`` where if a row was fully masked out, the output +# of the attention layer would be NaN. See `issue `_. +# This is because the softmax over an empty set is undefined. +# +# Thanks to `this PR `_ +# this is no longer the case. Instead, fully masked rows in ``scaled_dot_product_attention``. +# For cases where ``nn.MHA`` does not employ the "fast-path", this will also apply. +# +# Using a custom MHA layer with NJTs is strongly recommended over the +# existing "fast-path" in ``nn.MultiheadAttention`` as NJT's ability to model raggedness +# appropriately makes it possible to properly express empty sequences. + + +################################################################################ +# FlexAttention + NJT +# --------------------------------------------------------------------- +# NJT also composes with the ``FlexAttention`` module. This is a generalization +# of the ``MultiheadAttention`` layer that allows for arbitrary modifications +# to the attention score. The example below takes the ``alibi_mod`` +# that implements `ALiBi `_ from +# `attention gym `_ and uses it +# with nested input tensors. + +from torch.nn.attention.flex_attention import flex_attention + +def generate_alibi_bias(H: int): + """Returns an alibi bias score_mod given the number of heads H + Args: + H: number of heads + Returns: + alibi_bias: alibi bias score_mod + """ + def alibi_mod(score, b, h, q_idx, kv_idx): + scale = torch.exp2(-((h + 1) * 8.0 / H)) + bias = (q_idx - kv_idx) * scale + return score + bias + return alibi_mod + +query, key, value, _ = gen_batch(N, E_q, E_k, E_v, device) +n_heads, D = 8, E_q // 8 +alibi_score_mod = generate_alibi_bias(n_heads) +query = ( + query.unflatten(-1, [n_heads, D]).transpose(1, 2).detach().requires_grad_() +) +key = key.unflatten(-1, [n_heads, D]).transpose(1, 2).detach().requires_grad_() +value = ( + value.unflatten(-1, [n_heads, D]).transpose(1, 2).detach().requires_grad_() +) +out_flex2 = flex_attention(query, key, value, score_mod=alibi_score_mod) + +############################################################################### +# In addition, one can also use the ``block_mask`` utility of ``FlexAttention`` +# with NJTs via the ``create_nested_block_mask`` function. This is useful for +# taking advantage of the sparsity of the mask to speed up the attention computation. +# In particular, the function creates a sparse block mask for a "stacked sequence" of all +# the variable length sequences in the NJT combined into one, while properly masking out +# inter-sequence attention. In the following example, we show how to create a +# causal block mask using this utility. + +from torch.nn.attention.flex_attention import create_nested_block_mask + +def causal_mask(b, h, q_idx, kv_idx): + return q_idx >= kv_idx + +query, key, value, _ = gen_batch(N, E_q, E_k, E_v, device) +block_mask = create_nested_block_mask(causal_mask, 1, 1, query, _compile=True) +query = ( + query.unflatten(-1, [n_heads, D]).transpose(1, 2).detach().requires_grad_() +) +key = key.unflatten(-1, [n_heads, D]).transpose(1, 2).detach().requires_grad_() +value = ( + value.unflatten(-1, [n_heads, D]).transpose(1, 2).detach().requires_grad_() +) +out_flex = flex_attention(query, key, value, block_mask=block_mask) + +############################################################################### +# Packed Projection +# ----------------- +# +# Packed projection is a technique that makes use of the fact that when the input +# for projection (matrix multiplications) are the same (self-attention), we can pack the projection +# weights and biases into single tensors. It is especially useful when the individual +# projections are memory bound rather than compute bound. There are +# two examples that we will demonstrate here: +# +# * Input projection for MultiheadAttention +# * SwiGLU activation in feed-forward network of Transformer Layer +# +# Input projection for MultiheadAttention +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +# When doing self-attention, the ``query``, ``key``, and ``value`` +# are the same tensor. Each of these tensors is projected with a +# ``Linear(E_q, E_total)`` layer. Instead, we can pack this into one layer, +# which is what we do in the MultiheadAttention layer above. +# +# Let us compare the performance of the packed projection against the usual method: + +class InputProjection(nn.Module): + def __init__(self, E_q, E_total, bias=False, device=None, dtype=None): + factory_kwargs = {'device': device, 'dtype': dtype} + super().__init__() + self.q_proj = nn.Linear(E_q, E_total, bias=bias, **factory_kwargs) + self.k_proj = nn.Linear(E_q, E_total, bias=bias, **factory_kwargs) + self.v_proj = nn.Linear(E_q, E_total, bias=bias, **factory_kwargs) + + def forward(self, x): + return self.q_proj(x), self.k_proj(x), self.v_proj(x) + +class PackedInputProjection(nn.Module): + def __init__(self, E_q, E_total, bias=False, device=None, dtype=None): + factory_kwargs = {'device': device, 'dtype': dtype} + super().__init__() + self.packed_proj = nn.Linear(E_q, E_total * 3, bias=bias, **factory_kwargs) + + def forward(self, query): + return torch.chunk(self.packed_proj(query), 3, dim=-1) + +B, D, dtype = 256, 8192, torch.bfloat16 + +torch.set_float32_matmul_precision('high') +in_proj = torch.compile(InputProjection(D, D, device='cuda', dtype=torch.bfloat16)) +packed_in_proj = torch.compile(PackedInputProjection(D, D, device='cuda', dtype=torch.bfloat16)) + +q, _, _, sequence_lengths = gen_batch(B, D, D, D, device='cuda', dtype=torch.bfloat16) + +# warmup +in_proj(q) +packed_in_proj(q) + +# benchmark +(q_out, k_out, v_out), time, _ = benchmark(in_proj, q) +(q_out, k_out, v_out), time_packed, _ = benchmark(packed_in_proj, q) +# On my A100 prints 1.05x speedup +print(f"InputProjection: {time:5f} s, PackedInputProjection: {time_packed:5f} s, speedup: {time/time_packed:.2f}x") + +################################################## +# SwiGLU feed forward network of Transformer Layer +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +# Swish-Gated Linear Unit (SwiGLU) is a non-linear activation function that is increasingly popular in the feed-forward +# network of the transformer layer (e.g. Llama). A feed-forward network with SwiGLU activation is defined as: + +class SwiGLUFFN(nn.Module): + def __init__(self, dim, hidden_dim, multiple_of, ffn_dim_multiplier=None, device=None, dtype=None): + factory_kwargs = {'device': device, 'dtype': dtype} + super().__init__() + hidden_dim = int(2 * hidden_dim / 3) + # custom dim factor multiplier + if ffn_dim_multiplier is not None: + hidden_dim = int(ffn_dim_multiplier * hidden_dim) + hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of) + + self.w1 = nn.Linear(dim, hidden_dim, bias=False, **factory_kwargs) + self.w2 = nn.Linear(hidden_dim, dim, bias=False, **factory_kwargs) + self.w3 = nn.Linear(dim, hidden_dim, bias=False, **factory_kwargs) + + def forward(self, x): + return self.w2(F.silu(self.w1(x)) * self.w3(x)) + +######################################################################## +# An alternative way of implementing this that uses packed projection is + +class PackedSwiGLUFFN(nn.Module): + def __init__(self, dim, hidden_dim, multiple_of, ffn_dim_multiplier=None, device=None, dtype=None): + factory_kwargs = {'device': device, 'dtype': dtype} + super().__init__() + hidden_dim = int(2 * hidden_dim / 3) + # custom dim factor multiplier + if ffn_dim_multiplier is not None: + hidden_dim = int(ffn_dim_multiplier * hidden_dim) + hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of) + + self.w13 = nn.Linear(dim, 2 * hidden_dim, bias=False, **factory_kwargs) + self.w2 = nn.Linear(hidden_dim, dim, bias=False, **factory_kwargs) + + def forward(self, x): + x1, x3 = torch.chunk(self.w13(x), 2, dim=-1) + return self.w2(F.silu(x1) * x3) + +################################################################################ +# We can compare the performance of the two implementations as follows +# Depending on your hardware, you might see different results. On an A100 I see +# 1.12x speedup for D=128. +D = 128 + +swigluffn = torch.compile(SwiGLUFFN(D, D * 4, 256, device='cuda', dtype=torch.bfloat16)) +packed_swigluffn = torch.compile(PackedSwiGLUFFN(D, D * 4, 256, device='cuda', dtype=torch.bfloat16)) + +q, _, _, sentence_lengths = gen_batch(D, D, D, D, device="cuda", dtype=torch.bfloat16) + +# warmup +swigluffn(q) +packed_swigluffn(q) + +# benchmark +_, time, _ = benchmark(swigluffn, q) +_, time_packed, _ = benchmark(packed_swigluffn, q) +# On my A100 prints 1.08x speedup +print(f"SwiGLUFFN: {time} s, PackedSwiGLUFFN: {time_packed} s, speedup: {time/time_packed:.2f}x") + +################################################################################ +# Extended examples +# ----------------- +# +# We intend to update this tutorial to demonstrate more examples of how to use +# the various performant building blocks such as KV-Caching, Grouped Query Attention +# etc. Further, there are several good examples of using various performant building blocks to +# implement various transformer architectures. Some examples include +# +# * `gpt-fast `_ +# * `segment-anything-fast `_ +# * `lucidrains implementation of NaViT with nested tensors `_ +# * `torchtune's implementation of VisionTransformer `_ + +################################################################################ +# Conclusion +# ---------- +# +# In this tutorial, we have introduced the low level building blocks PyTorch +# provides for writing transformer layers and demonstrated examples how to compose +# them. It is our hope that this tutorial has educated the reader on the ease with +# which flexible and performant transformer layers can be implemented by users of PyTorch.