112
112
# Recall that `nn.MultiheadAttention` requires ``query```, ``key`` and
113
113
# ``value`` to be dense ``torch.Tensor``s. It also provides a
114
114
# ``key_padding_mask`` that is used to mask out padding tokens in the ``key``
115
- # that arise due to different sequence lengths within a batch.
115
+ # that arise due to different sequence lengths within a batch. Since there is
116
+ # no ``query_padding_mask`` in ``nn.MHA``, users have to take care to mask/slice
117
+ # the outputs appropriately to account for query sequence lengths. Nested tensor
118
+ # cleanly removes the need for this sort of error-prone padding masks.
116
119
#
117
120
# * Memory
118
121
# Instead of materializing a dense ``[B, S, D]`` tensor with a ``[B, S]``
123
126
#
124
127
# * Performance
125
128
# Since unnecessary computation on padding is skipped, performance improves.
126
- # We'll demonstrate this by building off the ``MultiheadAttention`` layer in the
129
+ #
130
+ # We'll demonstrate the above by building off the ``MultiheadAttention`` layer in the
127
131
# `Nested Tensor tutorial <https://pytorch.org/tutorials/prototype/nestedtensor.html>`_
132
+ # and comparing it to the ``nn.MultiheadAttention`` layer.
128
133
129
134
import torch
130
135
import torch .nn as nn
@@ -142,6 +147,7 @@ class MultiHeadAttention(nn.Module):
142
147
has dim E_total // nheads
143
148
nheads (int): Number of heads
144
149
dropout (float, optional): Dropout probability. Default: 0.0
150
+ bias (bool, optional): Whether to add bias to input projection. Default: True
145
151
"""
146
152
def __init__ (
147
153
self ,
@@ -151,7 +157,7 @@ def __init__(
151
157
E_total : int ,
152
158
nheads : int ,
153
159
dropout : float = 0.0 ,
154
- bias = False ,
160
+ bias = True ,
155
161
device = None ,
156
162
dtype = None ,
157
163
):
@@ -163,15 +169,21 @@ def __init__(
163
169
if self ._qkv_same_embed_dim :
164
170
self .packed_proj = nn .Linear (E_q , E_total * 3 , bias = bias , ** factory_kwargs )
165
171
else :
166
- self .query_proj = nn .Linear (E_q , E_total , bias = bias , ** factory_kwargs )
167
- self .key_proj = nn .Linear (E_k , E_total , bias = bias , ** factory_kwargs )
168
- self .value_proj = nn .Linear (E_v , E_total , bias = bias , ** factory_kwargs )
172
+ self .q_proj = nn .Linear (E_q , E_total , bias = bias , ** factory_kwargs )
173
+ self .k_proj = nn .Linear (E_k , E_total , bias = bias , ** factory_kwargs )
174
+ self .v_proj = nn .Linear (E_v , E_total , bias = bias , ** factory_kwargs )
169
175
E_out = E_q
170
176
self .out_proj = nn .Linear (E_total , E_out , bias = bias , ** factory_kwargs )
171
177
assert E_total % nheads == 0 , "Embedding dim is not divisible by nheads"
172
178
self .E_head = E_total // nheads
173
-
174
- def forward (self , query : torch .Tensor , key : torch .Tensor , value : torch .Tensor , attn_mask = None , is_causal = False ) -> torch .Tensor :
179
+ self .bias = bias
180
+
181
+ def forward (self ,
182
+ query : torch .Tensor ,
183
+ key : torch .Tensor ,
184
+ value : torch .Tensor ,
185
+ attn_mask = None ,
186
+ is_causal = False ) -> torch .Tensor :
175
187
"""
176
188
Forward pass; runs the following process:
177
189
1. Apply input projection
@@ -196,16 +208,16 @@ def forward(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, a
196
208
query , key , value = torch .chunk (result , 3 , dim = - 1 )
197
209
else :
198
210
q_weight , k_weight , v_weight = torch .chunk (self .packed_proj .weight , 3 , dim = 0 )
199
- if bias :
211
+ if self . bias :
200
212
q_bias , k_bias , v_bias = torch .chunk (self .packed_proj .bias , 3 , dim = 0 )
201
213
else :
202
214
q_bias , k_bias , v_bias = None , None , None
203
215
query , key , value = F .linear (query , q_weight , q_bias ), F .linear (key , k_weight , k_bias ), F .linear (value , v_weight , v_bias )
204
216
205
217
else :
206
- query = self .query_proj (query )
207
- key = self .key_proj (key )
208
- value = self .value_proj (value )
218
+ query = self .q_proj (query )
219
+ key = self .k_proj (key )
220
+ value = self .v_proj (value )
209
221
210
222
# Step 2. Split heads and prepare for SDPA
211
223
# reshape query, key, value to separate by head
@@ -219,7 +231,7 @@ def forward(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, a
219
231
# Step 3. Run SDPA
220
232
# (N, nheads, L_t, E_head)
221
233
attn_output = F .scaled_dot_product_attention (
222
- query , key , value , attn_mask = attn_mask , dropout = self .dropout , is_causal = is_causal )
234
+ query , key , value , dropout_p = self .dropout , is_causal = is_causal )
223
235
# (N, nheads, L_t, E_head) -> (N, L_t, nheads, E_head) -> (N, L_t, E_total)
224
236
attn_output = attn_output .transpose (1 , 2 ).flatten (- 2 )
225
237
@@ -395,11 +407,10 @@ def benchmark(func, *args, **kwargs):
395
407
# followed by a feed-forward network (FFN) with skip connections. Implementing
396
408
# this is fairly straightforward using the ``MultiheadAttention`` layer above and
397
409
# is actually the same as an ``nn.TransformerEncoderLayer`` with ``is_causal=True``.
398
- #
399
410
400
- # We will demonstrate examples of implementing the rest of the nn layers but will
401
- # omit that from this tutorial for brevity. The full code is available
402
- # `here <https://github.com/mikaylagawarecki/temp>`_.
411
+ # We demonstrate examples of implementing the rest of the nn layers
412
+ # `here <https://github.com/mikaylagawarecki/temp>`_ but omit that from this
413
+ # tutorial for brevity.
403
414
404
415
###############################################################################
405
416
# Going one step further
0 commit comments