Skip to content

Commit 71c1bac

Browse files
some wording fixes
1 parent c881238 commit 71c1bac

File tree

1 file changed

+28
-17
lines changed

1 file changed

+28
-17
lines changed

intermediate_source/transformer_building_blocks.py

Lines changed: 28 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -112,7 +112,10 @@
112112
# Recall that `nn.MultiheadAttention` requires ``query```, ``key`` and
113113
# ``value`` to be dense ``torch.Tensor``s. It also provides a
114114
# ``key_padding_mask`` that is used to mask out padding tokens in the ``key``
115-
# that arise due to different sequence lengths within a batch.
115+
# that arise due to different sequence lengths within a batch. Since there is
116+
# no ``query_padding_mask`` in ``nn.MHA``, users have to take care to mask/slice
117+
# the outputs appropriately to account for query sequence lengths. Nested tensor
118+
# cleanly removes the need for this sort of error-prone padding masks.
116119
#
117120
# * Memory
118121
# Instead of materializing a dense ``[B, S, D]`` tensor with a ``[B, S]``
@@ -123,8 +126,10 @@
123126
#
124127
# * Performance
125128
# Since unnecessary computation on padding is skipped, performance improves.
126-
# We'll demonstrate this by building off the ``MultiheadAttention`` layer in the
129+
#
130+
# We'll demonstrate the above by building off the ``MultiheadAttention`` layer in the
127131
# `Nested Tensor tutorial <https://pytorch.org/tutorials/prototype/nestedtensor.html>`_
132+
# and comparing it to the ``nn.MultiheadAttention`` layer.
128133

129134
import torch
130135
import torch.nn as nn
@@ -142,6 +147,7 @@ class MultiHeadAttention(nn.Module):
142147
has dim E_total // nheads
143148
nheads (int): Number of heads
144149
dropout (float, optional): Dropout probability. Default: 0.0
150+
bias (bool, optional): Whether to add bias to input projection. Default: True
145151
"""
146152
def __init__(
147153
self,
@@ -151,7 +157,7 @@ def __init__(
151157
E_total: int,
152158
nheads: int,
153159
dropout: float = 0.0,
154-
bias=False,
160+
bias=True,
155161
device=None,
156162
dtype=None,
157163
):
@@ -163,15 +169,21 @@ def __init__(
163169
if self._qkv_same_embed_dim:
164170
self.packed_proj = nn.Linear(E_q, E_total * 3, bias=bias, **factory_kwargs)
165171
else:
166-
self.query_proj = nn.Linear(E_q, E_total, bias=bias, **factory_kwargs)
167-
self.key_proj = nn.Linear(E_k, E_total, bias=bias, **factory_kwargs)
168-
self.value_proj = nn.Linear(E_v, E_total, bias=bias, **factory_kwargs)
172+
self.q_proj = nn.Linear(E_q, E_total, bias=bias, **factory_kwargs)
173+
self.k_proj = nn.Linear(E_k, E_total, bias=bias, **factory_kwargs)
174+
self.v_proj = nn.Linear(E_v, E_total, bias=bias, **factory_kwargs)
169175
E_out = E_q
170176
self.out_proj = nn.Linear(E_total, E_out, bias=bias, **factory_kwargs)
171177
assert E_total % nheads == 0, "Embedding dim is not divisible by nheads"
172178
self.E_head = E_total // nheads
173-
174-
def forward(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, attn_mask=None, is_causal=False) -> torch.Tensor:
179+
self.bias = bias
180+
181+
def forward(self,
182+
query: torch.Tensor,
183+
key: torch.Tensor,
184+
value: torch.Tensor,
185+
attn_mask=None,
186+
is_causal=False) -> torch.Tensor:
175187
"""
176188
Forward pass; runs the following process:
177189
1. Apply input projection
@@ -196,16 +208,16 @@ def forward(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, a
196208
query, key, value = torch.chunk(result, 3, dim=-1)
197209
else:
198210
q_weight, k_weight, v_weight = torch.chunk(self.packed_proj.weight, 3, dim=0)
199-
if bias:
211+
if self.bias:
200212
q_bias, k_bias, v_bias = torch.chunk(self.packed_proj.bias, 3, dim=0)
201213
else:
202214
q_bias, k_bias, v_bias = None, None, None
203215
query, key, value = F.linear(query, q_weight, q_bias), F.linear(key, k_weight, k_bias), F.linear(value, v_weight, v_bias)
204216

205217
else:
206-
query = self.query_proj(query)
207-
key = self.key_proj(key)
208-
value = self.value_proj(value)
218+
query = self.q_proj(query)
219+
key = self.k_proj(key)
220+
value = self.v_proj(value)
209221

210222
# Step 2. Split heads and prepare for SDPA
211223
# reshape query, key, value to separate by head
@@ -219,7 +231,7 @@ def forward(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, a
219231
# Step 3. Run SDPA
220232
# (N, nheads, L_t, E_head)
221233
attn_output = F.scaled_dot_product_attention(
222-
query, key, value, attn_mask=attn_mask, dropout=self.dropout, is_causal=is_causal)
234+
query, key, value, dropout_p=self.dropout, is_causal=is_causal)
223235
# (N, nheads, L_t, E_head) -> (N, L_t, nheads, E_head) -> (N, L_t, E_total)
224236
attn_output = attn_output.transpose(1, 2).flatten(-2)
225237

@@ -395,11 +407,10 @@ def benchmark(func, *args, **kwargs):
395407
# followed by a feed-forward network (FFN) with skip connections. Implementing
396408
# this is fairly straightforward using the ``MultiheadAttention`` layer above and
397409
# is actually the same as an ``nn.TransformerEncoderLayer`` with ``is_causal=True``.
398-
#
399410

400-
# We will demonstrate examples of implementing the rest of the nn layers but will
401-
# omit that from this tutorial for brevity. The full code is available
402-
# `here <https://github.com/mikaylagawarecki/temp>`_.
411+
# We demonstrate examples of implementing the rest of the nn layers
412+
# `here <https://github.com/mikaylagawarecki/temp>`_ but omit that from this
413+
# tutorial for brevity.
403414

404415
###############################################################################
405416
# Going one step further

0 commit comments

Comments
 (0)