From 5cae1fbe836c3dcf703a7355945225efcf40b813 Mon Sep 17 00:00:00 2001 From: Driss Guessous Date: Tue, 20 Dec 2022 01:59:55 +0000 Subject: [PATCH] update nestedtensor prototype --- prototype_source/nestedtensor.py | 383 ++++++++++++++++--------------- 1 file changed, 193 insertions(+), 190 deletions(-) diff --git a/prototype_source/nestedtensor.py b/prototype_source/nestedtensor.py index dd7c8cd2c76..0d2898cc4ac 100644 --- a/prototype_source/nestedtensor.py +++ b/prototype_source/nestedtensor.py @@ -1,24 +1,24 @@ """ -Nested Tensors + +NestedTensors =============================================================== -Nested tensor is very similar to regular tensor, except for the shape: +NestedTensors are similar to regular tensors, except for their shape: * for a regular tensor, each dimension has a size -* for a nested tensor, not all dimensions have regular sizes; some of them are jagged +* for a nestedtensor, not all dimensions have regular sizes; some of them are jagged -Nested tensors are a natural solution for representing sequential data within various domains: +Nestedtensors are a natural solution for representing sequential data within various domains: -* in NLP, sentences can have variable lengths, so a batch of sentences forms a nested tensor +* in NLP, sentences can have variable lengths, so a batch of sentences forms a nestedtensor -* in CV, images can have variable shapes, so a batch of images forms a nested tensor +* in CV, images can have variable shapes, so a batch of images forms a nestedtensor -In this tutorial, we will demonstrate basic usage of nested tensors and motivate their usefulness +In this tutorial, we will demonstrate basic usage of nestedtensors and motivate their usefulness for operating on sequential data of varying lengths with a real-world example. -The nested tensor operations used here have not been released yet. -You will have to install the latest nightly to run this tutorial. +NestedTensor are currently a prototype feature and are subject to change. """ import torch @@ -27,31 +27,42 @@ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') ###################################################################### -# Nested Tensor Initialization +# NestedTensor Initialization # ---------------- # ###################################################################### -# From the Python frontend, a nested tensor can be created from a list of tensors. -nt = torch.nested_tensor([torch.randn((2, 6)), torch.randn((3, 6))], device=device) -print(nt) +# From the Python frontend, a nestedtensor can be created from a list of tensors. +# We denote nt[i] as the ith tensor component of a nestedtensor. +nt = torch.nested.nested_tensor([torch.arange(12).reshape( + 2, 6), torch.arange(18).reshape(3, 6)], dtype=torch.float, device=device) +print(f"{nt=}") ###################################################################### # By padding every underlying tensor to the same shape, -# a nested tensor can be converted to a regular tensor. -pt = torch.nested.to_padded_tensor(nt, padding=0.0) -print(pt) +# a nestedtensor can be converted to a regular tensor. +padded_out_tensor = torch.nested.to_padded_tensor(nt, padding=0.0) +print(f"{padded_out_tensor=}") + +###################################################################### +# All tensors posses an attribute for determining if they are nested; +print(f"nt is nested: {nt.is_nested}") +print(f"padded_out_tensor is nested: {padded_out_tensor.is_nested}") ###################################################################### -# For practical reasons, conceptually we implement nested tensor -# as a batch of tensors with different shapes, +# It is common to construct nestedtensors from batches of irregularly shaped tensors. # i.e. dimension 0 is assumed to be the batch dimension. -# Indexing dimension 0 gives back the underlying tensor. -print("0th underlying tensor:", nt[0], sep='\n') -print("last column of 1st underlying tensor:", nt[1, :, -1], sep='\n') +# Indexing dimension 0 gives back the first underlying tensor component. +print("First underlying tensor component:", nt[0], sep='\n') +print("last column of 2nd underlying tensor component:", nt[1, :, -1], sep='\n') + +# When indexing a nestedtensor's 0th dimension, the result is a regular tensor. +print(f"First underlying tensor component is nested: {nt[0].is_nested}") ###################################################################### -# Slicing in dimension 0 has not been supported yet. +# An important note is that slicing in dimension 0 has not been supported yet. +# Which means it not currently possible to construct a view that combines the underlying +# tensor components. ###################################################################### # Nested Tensor Operations @@ -59,10 +70,10 @@ # ###################################################################### -# As each operation must be explicitly implemented for nested tensors, -# operation coverage for nested tensors is currently narrower than that of regular tensors. +# As each operation must be explicitly implemented for nestedtensors, +# operation coverage for nestedtensors is currently narrower than that of regular tensors. # For now, only basic operations such as index, dropout, softmax, transpose, reshape, linear, bmm are covered. -# However, coverage is being expanded rapidly. +# However, coverage is being expanded. # If you need certain operations, please file an `issue `__ # to help us prioritize coverage. # @@ -75,11 +86,11 @@ # a single dimension may be -1, in which case it is inferred # from the remaining dimensions and the number of elements. # -# The semantics for nested tensors are similar, except that -1 no longer infers. +# The semantics for nestedtensors are similar, except that -1 no longer infers. # Instead, it inherits the old size (here 2 for ``nt[0]`` and 3 for ``nt[1]``). # -1 is the only legal size to specify for a jagged dimension. -nt1 = nt.reshape(2, -1, 2, 3) -print(nt1) +nt_reshaped = nt.reshape(2, -1, 2, 3) +print(f"{nt_reshaped=}") ###################################################################### # **transpose** @@ -87,28 +98,28 @@ # The transpose op is for swapping two dimensions of a tensor. # Its full semantics can be found # `here `__. -# Note that nested tensor dimension 0 is special; +# Note that for nestedtensors dimension 0 is special; # it is assumed to be the batch dimension, -# so transposes involving nested tensor dimension 0 are forbidden. -nt2 = nt1.transpose(1, 2) -print(nt2) +# so transposes involving nestedtensor dimension 0 are not supported. +nt_transposed = nt_reshaped.transpose(1, 2) +print(f"{nt_transposed=}") ###################################################################### # **others** # # Other operations have the same semantics as for regular tensors. -# Applying the operation on a nested tensor is equivalent to +# Applying the operation on a nestedtensor is equivalent to # applying the operation to the underlying tensor components, -# with the result being a nested tensor as well. -nt_mm = torch.nested_tensor([torch.randn((2, 3, 4)), torch.randn((2, 3, 5))], device=device) -nt3 = torch.matmul(nt2, nt_mm) -print("matmul:", nt3, sep='\n') +# with the result being a nestedtensor as well. +nt_mm = torch.nested.nested_tensor([torch.randn((2, 3, 4)), torch.randn((2, 3, 5))], device=device) +nt3 = torch.matmul(nt_transposed, nt_mm) +print(f"Result of Matmul:\n {nt3}") nt4 = F.dropout(nt3, 0.1) -print("dropout:", nt4, sep='\n') +print(f"Result of Dropout:\n {nt4}") nt5 = F.softmax(nt4, -1) -print("softmax:", nt5, sep='\n') +print(f"Result of Softmax:\n {nt5}") ###################################################################### # Why Nested Tensor @@ -116,24 +127,30 @@ # ###################################################################### -# In the age before nested tensor, one has to manually pad each data tensor -# to the same shape to form a batch as a regular tensor. -# For example, we have 2 sentences and a vocabulary, then pad with 0. +# When data is sequential, it is often the case that each sample has a different length. +# For example, in a batch of sentences, each sentence has a different number of words. +# A common technique for handling varying sequences is to manually pad each data tensor +# to the same shape in order to form a batch. +# For example, we have 2 sentences with different lengths and a vocabulary +# In order to represent his as single tensor we pad with 0 to the max length in the batch. sentences = [["goodbye", "padding"], ["embrace", "nested", "tensor"]] -vocabulary = {"goodbye" : 1.0, "padding" : 2.0, - "embrace" : 3.0, "nested" : 4.0, "tensor" : 5.0} +vocabulary = {"goodbye": 1.0, "padding": 2.0, + "embrace": 3.0, "nested": 4.0, "tensor": 5.0} padded_sentences = torch.tensor([[1.0, 2.0, 0.0], [3.0, 4.0, 5.0]]) -nested_sentences = torch.nested_tensor([torch.tensor([1.0, 2.0]), - torch.tensor([3.0, 4.0, 5.0])]) -print(padded_sentences) -print(nested_sentences) +nested_sentences = torch.nested.nested_tensor([torch.tensor([1.0, 2.0]), + torch.tensor([3.0, 4.0, 5.0])]) +print(f"{padded_sentences=}") +print(f"{nested_sentences=}") ###################################################################### -# Clearly, padding introduces inefficiency. -# Further, padding with zeros does not correctly treat entries as padding for every operation, -# e.g. in softmax one has to pad with -inf rather than 0 to ignore specific entries. +# This techinque of padding a batch of data to its max length is not optimal. +# The padded data is not needed for computation and wastes memory by allocating +# larger tensors than necessary. +# Further, not all operations have the same semnatics when applied to padded data. +# For matrix multiplications in order to ignore the padded entries, one needs to pad +# with 0 while for softmax one has to pad with -inf to ignore specific entries. padded_sentences_for_softmax = torch.tensor([[1.0, 2.0, float("-inf")], [3.0, 4.0, 5.0]]) print(F.softmax(padded_sentences_for_softmax, -1)) @@ -142,40 +159,42 @@ ###################################################################### # Let us take a look at a practical example: the multi-head attention component # utilized in `Transformers `__. -# The nested tensor version is straightforward. +# The nestedtensor version is straightforward. import math -""" -Args: - query: query of shape (N, L_t, E_q) - key: key of shape (N, L_s, E_k) - value: value of shape (N, L_s, E_v) - nheads: number of heads in multi-head attention - W_q: Weight for query input projection of shape (E_total, E_q) - W_k: Weight for key input projection of shape (E_total, E_k) - W_v: Weight for value input projection of shape (E_total, E_v) - W_out: Weight for output projection of shape (E_out, E_total) - b_q (optional): Bias for query input projection of shape E_total. Default: None - b_k (optional): Bias for key input projection of shape E_total. Default: None - b_v (optional): Bias for value input projection of shape E_total. Default: None - b_out (optional): Bias for output projection of shape E_out. Default: None - dropout_p: dropout probability. Default: 0.0 - where: - N is the batch size - L_t is the target sequence length (jagged) - L_s is the source sequence length (jagged) - E_q is the embedding size for query - E_k is the embedding size for key - E_v is the embedding size for value - E_total is the embedding size for all heads combined - E_out is the output embedding size -Returns: - attn_output: Output of shape (N, L_t, E_out) -""" -def mha_nested(query, key, value, nheads, -W_q, W_k, W_v, W_out, -b_q=None, b_k=None, b_v=None, b_out=None, -dropout_p=0.0): +def mha_nested(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, nheads: int, + W_q: torch.Tensor, W_k: torch.Tensor, W_v: torch.Tensor, W_out: torch.Tensor, + b_q: torch.Tensor = None, b_k: torch.Tensor = None, b_v: torch.Tensor = None, b_out: torch.Tensor = None, + dropout_p: float = 0.0) -> torch.Tensor: + """Compute multi-head attention with nested tensors. + Args: + query (torch.Tensor): query of shape (N, L_t, E_q) + key (torch.Tensor): key of shape (N, L_s, E_k) + value (torch.Tensor): value of shape (N, L_s, E_v) + nheads (int): number of heads in multi-head attention + W_q (torch.Tensor): Weight for query input projection of shape (E_total, E_q) + W_k (torch.Tensor): Weight for key input projection of shape (E_total, E_k) + W_v (torch.Tensor): Weight for value input projection of shape (E_total, E_v) + W_out (torch.Tensor): Weight for output projection of shape (E_out, E_total) + b_q (torch.Tensor, optional): Bias for query input projection of shape E_total. Default: None. Defaults to None. + b_k (torch.Tensor, optional): Bias for key input projection of shape E_total. Default: None. Defaults to None. + b_v (torch.Tensor, optional): Bias for value input projection of shape E_total. Default: None. Defaults to None. + b_out (torch.Tensor, optional): Bias for output projection of shape E_out. Default: None. Defaults to None. + dropout_p (float, optional): Dropout probability. Defaults to 0.0. + + Where: + N is the batch size + L_t is the target sequence length (jagged) + L_s is the source sequence length (jagged) + E_q is the embedding size for query + E_k is the embedding size for key + E_v is the embedding size for value + E_total is the embedding size for all heads combined + E_out is the output embedding size + Returns: + torch.Tensor: Output of shape (N, L_t, E_out) + """ + N = query.size(0) E_total = W_q.size(0) assert E_total % nheads == 0, "Embedding dim is not divisible by nheads" @@ -191,11 +210,11 @@ def mha_nested(query, key, value, nheads, # reshape query, key, value to separate by head # (N, L_t, E_total) -> (N, L_t, nheads, E_head) -> (N, nheads, L_t, E_head) - query = query.reshape(-1, -1, nheads, E_head).transpose(1, 2) + query = query.reshape(N, -1, nheads, E_head).transpose(1, 2) # (N, L_s, E_total) -> (N, L_s, nheads, E_head) -> (N, nheads, L_s, E_head) - key = key.reshape(-1, -1, nheads, E_head).transpose(1, 2) + key = key.reshape(N, -1, nheads, E_head).transpose(1, 2) # (N, L_s, E_total) -> (N, L_s, nheads, E_head) -> (N, nheads, L_s, E_head) - value = value.reshape(-1, -1, nheads, E_head).transpose(1, 2) + value = value.reshape(N, -1, nheads, E_head).transpose(1, 2) # query matmul key^T # (N, nheads, L_t, E_head) x (N, nheads, L_s, E_head)^T -> (N, nheads, L_t, L_s) @@ -229,45 +248,48 @@ def mha_nested(query, key, value, nheads, ###################################################################### # The 0-padded tensor version additionally requires masks # for more complicated treatments at padded entries. -""" -Args: - query: query of shape (N, L_t, E_q) - key: key of shape (N, L_s, E_k) - value: value of shape (N, L_s, E_v) - nheads: number of heads in multi-head attention - attn_mask_q: boolean mask indicating locations that should not take part in attention for query, shape (N, L_t) - attn_mask_kv: boolean mask indicating locations that should not take part in attention for key and value, shape (N, L_s) - W_q: Weight for query input projection of shape (E_total, E_q) - W_k: Weight for key input projection of shape (E_total, E_k) - W_v: Weight for value input projection of shape (E_total, E_v) - W_out: Weight for output projection of shape (E_out, E_total) - b_q (optional): Bias for query input projection of shape E_total. Default: None - b_k (optional): Bias for key input projection of shape E_total. Default: None - b_v (optional): Bias for value input projection of shape E_total. Default: None - b_out (optional): Bias for output projection of shape E_out. Default: None - dropout_p: dropout probability. Default: 0.0 - where: - N is the batch size - L_t is the target sequence length (padded) - L_s is the source sequence length (padded) - E_q is the embedding size for query - E_k is the embedding size for key - E_v is the embedding size for value - E_total is the embedding size for all heads combined - E_out is the output embedding size -Returns: - attn_output: Output of shape (N, L_t, E_out) -""" -def mha_padded(query, key, value, nheads, -attn_mask_q, attn_mask_kv, -W_q, W_k, W_v, W_out, -b_q=None, b_k=None, b_v=None, b_out=None, -dropout_p=0.0): +def mha_padded(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, nheads: int, + attn_mask_q: torch.Tensor, attn_mask_kv: torch.Tensor, + W_q: torch.Tensor, W_k: torch.Tensor, W_v: torch.Tensor, W_out: torch.Tensor, + b_q: torch.Tensor = None, b_k: torch.Tensor = None, b_v: torch.Tensor = None, b_out: torch.Tensor = None, + dropout_p: float = 0.0) -> torch.Tensor: + """Compute multi-head attention for padded out dense tensors. + + Args: + query (torch.Tensor): query of shape (N, L_t, E_q) + key (torch.Tensor): key of shape (N, L_s, E_k) + value (torch.Tensor): value of shape (N, L_s, E_v) + nheads (int): number of heads in multi-head attention + attn_mask_q (torch.Tensor): boolean mask indicating locations that should not take part in attention for query, shape (N, L_t) + attn_mask_kv (torch.Tensor): boolean mask indicating locations that should not take part in attention for key and value, shape (N, L_s) + W_q (torch.Tensor): Weight for query input projection of shape (E_total, E_q) + W_k (torch.Tensor): Weight for key input projection of shape (E_total, E_k) + W_v (torch.Tensor): Weight for value input projection of shape (E_total, E_v) + W_out (torch.Tensor): Weight for output projection of shape (E_out, E_total) + b_q (torch.Tensor, optional): Bias for query input projection of shape E_total.. Defaults to None. + b_k (torch.Tensor, optional): Bias for key input projection of shape E_total.. Defaults to None. + b_v (torch.Tensor, optional): Bias for value input projection of shape E_total.. Defaults to None. + b_out (torch.Tensor, optional): Bias for output projection of shape E_out. Defaults to None. + dropout_p (float, optional): Dropout probability. Defaults to 0.0. + + Where: + N is the batch size + L_t is the target sequence length (padded) + L_s is the source sequence length (padded) + E_q is the embedding size for query + E_k is the embedding size for key + E_v is the embedding size for value + E_total is the embedding size for all heads combined + E_out is the output embedding size + Returns: + torch.Tensor: Output of shape (N, L_t, E_out) + """ N = query.size(0) L_t = query.size(1) L_s = key.size(1) E_total = W_q.size(0) assert E_total % nheads == 0, "Embedding dim is not divisible by nheads" + assert L_t == L_s, "This implementation assumes equal query and key sequence lengths" E_head = E_total // nheads # apply input projection @@ -278,19 +300,6 @@ def mha_padded(query, key, value, nheads, # (N, L_s, E_v) -> (N, L_s, E_total) value = F.linear(value, W_v, b_v) - # padding-specific step: remove bias from padded entries - # in the specific multihead-attention formula it is not necessary to remove these bias - # because the -inf padding later on in softmax step can take care of it - # but to be general here we demonstrate the bias removal - for i in range(N): - for j in range(L_t): - if attn_mask_q[i, j]: - query[i, j, :] = 0.0 - for j in range(L_s): - if attn_mask_kv[i, j]: - key[i, j, :] = 0.0 - value[i, j, :] = 0.0 - # reshape query, key, value to separate by head # (N, L_t, E_total) -> (N, L_t, nheads, E_head) -> (N, nheads, L_t, E_head) -> (N * nheads, L_t, E_head) query = query.reshape(N, -1, nheads, E_head).transpose(1, 2).reshape(N * nheads, -1, E_head) @@ -302,21 +311,19 @@ def mha_padded(query, key, value, nheads, # query bmm key^T # (N * nheads, L_t, E_head) x (N * nheads, L_s, E_head)^T -> (N * nheads, L_t, L_s) keyT = key.transpose(-1, -2) - # padding-specific step: add -inf mask for padding in softmax - attn_mask = query.new_zeros((N, nheads, L_t, L_s)) - for i in range(N): - for j in range(L_t): - for k in range(L_s): - if attn_mask_q[i, j] or attn_mask_kv[i, k]: - attn_mask[i, :, j, k] = float("-inf") - attn_mask = attn_mask.reshape((N * nheads, L_t, L_s)) - attn_weights = torch.baddbmm(attn_mask, query, keyT) - # if no padding, it could have been as simple as - # attn_weights = torch.bmm(query, keyT) + attn_weights = torch.bmm(query, keyT) # scale down attn_weights = attn_weights * (1.0 / math.sqrt(E_head)) + # Have to manipulate masks in order to apply them to the attention weights + key_padding_mask = attn_mask_q.view(N, 1, 1, L_t).expand(-1, nheads, -1, -1).reshape(N*nheads, 1, L_t).to(device=device) + attn_mask = torch.zeros(key_padding_mask.shape, device=device, dtype=torch.float32) + attn_mask = attn_mask.masked_fill_(key_padding_mask, float("-inf")) + + # Zero out the attention weights where the mask is True by adding -inf prior to softmax + attn_weights.add_(attn_mask) + # softmax attn_weights = F.softmax(attn_weights, dim=-1).nan_to_num_(0.0) @@ -337,10 +344,7 @@ def mha_padded(query, key, value, nheads, attn_output = F.linear(attn_output, W_out, b_out) # padding-specific step: remove output projection bias from padded entries - for i in range(N): - for j in range(L_t): - if attn_mask_q[i, j]: - attn_output[i, j, :] = 0.0 + attn_output[attn_mask_q, :] = 0.0 return attn_output @@ -387,17 +391,17 @@ def zipf_sentence_lengths(alpha: float, batch_size: int) -> np.ndarray: # create nested input queries = [] -keys = [] -values = [] +keys = [] +values = [] for i in range(N): l = sentence_lengths[i] s = l queries.append(torch.randn((l, E_q), device=device)) keys .append(torch.randn((s, E_k), device=device)) values .append(torch.randn((s, E_v), device=device)) -query = torch.nested_tensor(queries) -key = torch.nested_tensor(keys ) -value = torch.nested_tensor(values ) +query = torch.nested.nested_tensor(queries) +key = torch.nested.nested_tensor(keys) +value = torch.nested.nested_tensor(values) # pad input padded_query = torch.nested.to_padded_tensor(query, 0.0, (N, L_t, E_q)) @@ -407,13 +411,11 @@ def zipf_sentence_lengths(alpha: float, batch_size: int) -> np.ndarray: # create attention masks attn_mask_q = torch.zeros((N, L_t), dtype=torch.bool) attn_mask_kv = torch.zeros((N, L_s), dtype=torch.bool) -for i in range(N): - for j in range(L_t): - if padded_query[i, j, :].abs().max().item() == 0.0: - attn_mask_q[i, j] = True - for j in range(L_s): - if padded_key[i, j, :].abs().max().item() == 0.0: - attn_mask_kv[i, j] = True + +# We need to mask out the padding entries in the attention weights. +for i, entry_length in enumerate(sentence_lengths): + attn_mask_q[i, entry_length:] = True + attn_mask_kv[i, entry_length:] = True ###################################################################### # check correctness and performance @@ -437,15 +439,16 @@ def zipf_sentence_lengths(alpha: float, batch_size: int) -> np.ndarray: t2 = timeit.default_timer() print("nested and padded calculations differ by", (torch.nested.to_padded_tensor(out_nested, 0.0, (N, L_t, E_out)) - out_padded).abs().max().item()) -print("nested tensor multi-head attention takes", t1 - t0, "seconds") +print("nestedtensor multi-head attention takes", t1 - t0, "seconds") print("padded tensor multi-head attention takes", t2 - t1, "seconds") ###################################################################### -# The nested tensor version avoids wasted computation on padding, -# so in sequential CPU execution it is faster than padded tensor version as expected. -# Optimization for multi-threaded environment is underway. +# Although the nestedtensor version avoids wasted computation on padding, it is not faster +# then the equivalent padded tensor version. This is because the nestedtensor version +# has implemented a few of the kernels, like softmax, in a non optimal way. # -# For now, performant kernels are provided for specific use cases, e.g. +# There are plans to implement performance critical operations using the new Pytorch 2.0 stack +# For now, some performant kernels are provided for specific use cases, e.g. # self-attention evaluation by multi-head attention formula. # embeddings are assumed to be the same @@ -465,28 +468,28 @@ def zipf_sentence_lengths(alpha: float, batch_size: int) -> np.ndarray: W_out, b_out = mha_lib.out_proj.weight, mha_lib.out_proj.bias ###################################################################### -# check correctness and performance - -t0 = timeit.default_timer() -out_lib, out_lib_weights = mha_lib(query, query, query) - -t1 = timeit.default_timer() -out_nested = mha_nested( - query, query, query, nheads, - W_q, W_k, W_v, W_out, - b_q=b_q, b_k=b_k, b_v=b_v, b_out=b_out, - dropout_p=dropout_p) - -t2 = timeit.default_timer() -padded_out = mha_padded( - padded_query, padded_query, padded_query, nheads, - attn_mask_q, attn_mask_q, - W_q, W_k, W_v, W_out, - b_q=b_q, b_k=b_k, b_v=b_v, b_out=b_out, - dropout_p=dropout_p) -t3 = timeit.default_timer() - -print("nested general and library calculations differ by", (torch.nested.to_padded_tensor(out_nested, 0.0) - torch.nested.to_padded_tensor(out_lib, 0.0)).abs().max().item()) -print("nested library multi-head attention takes", t1 - t0, "seconds") -print("nested general multi-head attention takes", t2 - t1, "seconds") -print("padded tensor multi-head attention takes", t3 - t2, "seconds") +# If we set need_weights to False this will enable the fast path in the library. +# Under the hood this will call _scaled_dot_product_attention. If your tensors +# are on CUDA, than a fused, efficient attention kernel will be used. For +# more detailed performance characteristics look at the benchmark in +# pytorch/benchmarks/transformer/sdp.py + +with torch.inference_mode(): + t0 = timeit.default_timer() + out_lib, out_lib_weights = mha_lib(query, query, query, need_weights=False) + + t1 = timeit.default_timer() + padded_out = mha_padded( + padded_query, padded_query, padded_query, nheads, + attn_mask_q, attn_mask_q, + W_q, W_k, W_v, W_out, + b_q=b_q, b_k=b_k, b_v=b_v, b_out=b_out, + dropout_p=dropout_p) + t2 = timeit.default_timer() + +nested_time = t1 - t0 +padded_time = t2 - t1 +print("Nested and padded calculations differ by", (torch.nested.to_padded_tensor(out_lib, 0.0) - padded_out).abs().max().item()) +print("Nested library multi-head attention takes", nested_time, "seconds") +print("Padded tensor multi-head attention takes", padded_time, "seconds") +print(f"Nested Speedup: {padded_time / nested_time:.3f}") \ No newline at end of file