Skip to content

Commit 2400fed

Browse files
pyspelling + linkcheck + flex
1 parent 09642b4 commit 2400fed

File tree

2 files changed

+72
-57
lines changed

2 files changed

+72
-57
lines changed

en-wordlist.txt

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -127,6 +127,7 @@ Kihyuk
127127
Kiuk
128128
Kubernetes
129129
Kuei
130+
KV-Caching
130131
LRSchedulers
131132
LSTM
132133
LSTMs
@@ -275,6 +276,7 @@ Xcode
275276
Xeon
276277
Yidong
277278
YouTube
279+
Zipf
278280
accelerometer
279281
accuracies
280282
activations
@@ -304,6 +306,7 @@ bbAP
304306
benchmarked
305307
benchmarking
306308
bitwise
309+
bool
307310
boolean
308311
breakpoint
309312
broadcasted
@@ -332,6 +335,7 @@ csv
332335
cuDNN
333336
cuda
334337
customizable
338+
customizations
335339
datafile
336340
dataflow
337341
dataframe
@@ -376,6 +380,7 @@ fbgemm
376380
feedforward
377381
finetune
378382
finetuning
383+
FlexAttention
379384
fp
380385
frontend
381386
functionalized
@@ -430,6 +435,7 @@ mAP
430435
macos
431436
manualSeed
432437
matmul
438+
matmuls
433439
matplotlib
434440
memcpy
435441
memset
@@ -445,6 +451,7 @@ modularized
445451
mpp
446452
mucosa
447453
multihead
454+
MultiheadAttention
448455
multimodal
449456
multimodality
450457
multinode
@@ -455,7 +462,10 @@ multithreading
455462
namespace
456463
natively
457464
ndarrays
465+
nheads
458466
nightlies
467+
NJT
468+
NJTs
459469
num
460470
numericalize
461471
numpy
@@ -531,6 +541,7 @@ runtime
531541
runtime
532542
runtimes
533543
scalable
544+
SDPA
534545
sharded
535546
softmax
536547
sparsified
@@ -590,12 +601,14 @@ tradeoff
590601
tradeoffs
591602
triton
592603
uint
604+
UX
593605
umap
594606
uncomment
595607
uncommented
596608
underflowing
597609
unfused
598610
unimodal
611+
unigram
599612
unnormalized
600613
unoptimized
601614
unparametrized
@@ -617,6 +630,7 @@ warmstarted
617630
warmstarting
618631
warmup
619632
webp
633+
wikitext
620634
wsi
621635
wsis
622636
Meta's
@@ -647,4 +661,4 @@ url
647661
colab
648662
sharders
649663
Criteo
650-
torchrec
664+
torchrec

intermediate_source/transformer_building_blocks.py

Lines changed: 57 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@
4040
Please head there instead!
4141
4242
If you are only interested in performant attention score modifications, please
43-
head to the `FlexAttention blog <https://flexattention.com/blog/>`_ that
43+
head to the `FlexAttention blog <https://pytorch.org/blog/flexattention/>`_ that
4444
contains a `gym of masks <https://github.com/pytorch-labs/attention-gym>`_ .
4545
4646
If you are wondering about what building blocks the ``torch`` library provides
@@ -63,7 +63,7 @@
6363
*`scaled_dot_product_attention <https://pytorch.org/tutorials/intermediate/scaled_dot_product_attention_tutorial.html>`_
6464
6565
``scaled_dot_product_attention`` is a primitive for
66-
$\text{softmax}(\frac{QK^T}{\sqrt{E}} + B)V$ that dispatches into either fused
66+
:math:`\text{softmax}(\frac{QK^T}{\sqrt{E}} + B)V` that dispatches into either fused
6767
implementations of the operator or a fallback implementation. It works out of
6868
the box in eager mode (i.e. the default mode of using PyTorch where operations
6969
are executed on the fly as they are encountered) and also integrates seamlessly
@@ -118,7 +118,7 @@
118118
# The improvements are threefold:
119119
#
120120
# * User Experience
121-
# Recall that `nn.MultiheadAttention` requires ``query```, ``key`` and
121+
# Recall that ``nn.MultiheadAttention`` requires ``query```, ``key`` and
122122
# ``value`` to be dense ``torch.Tensor``s. It also provides a
123123
# ``key_padding_mask`` that is used to mask out padding tokens in the ``key``
124124
# that arise due to different sequence lengths within a batch. Since there is
@@ -202,10 +202,10 @@ def forward(self,
202202
4. Apply output projection
203203
204204
Args:
205-
query (torch.Tensor): query of shape (N, L_q, E_qk)
206-
key (torch.Tensor): key of shape (N, L_kv, E_qk)
207-
value (torch.Tensor): value of shape (N, L_kv, E_v)
208-
attn_mask (torch.Tensor, optional): attention mask of shape (N, L_q, L_kv) to pass to sdpa. Default: None
205+
query (torch.Tensor): query of shape (``N``, ``L_q``, ``E_qk``)
206+
key (torch.Tensor): key of shape (``N``, ``L_kv``, ``E_qk``)
207+
value (torch.Tensor): value of shape (``N``, ``L_kv``, ``E_v``)
208+
attn_mask (torch.Tensor, optional): attention mask of shape (``N``, ``L_q``, ``L_kv``) to pass to SDPA. Default: None
209209
is_causal (bool, optional): Whether to apply causal mask. Default: False
210210
211211
Returns:
@@ -251,11 +251,10 @@ def forward(self,
251251

252252
return attn_output
253253

254-
# TODO: Check whether there is a way to collapse this section by default
255-
# sphinx.collapse?
254+
256255
###############################################################################
257256
# Utilities
258-
# =========
257+
# ========================
259258
# In this section, we include a utility to generate semi-realistic data using
260259
# Zipf distribution for sentence lengths. This is used to generate the nested
261260
# query, key and value tensors. We also include a benchmark utility.
@@ -343,7 +342,7 @@ def benchmark(func, *args, **kwargs):
343342
torch.manual_seed(6)
344343
vanilla_mha_layer = nn.MultiheadAttention(E_q, nheads, dropout=dropout, batch_first=True, bias=bias, device='cuda')
345344

346-
# nn.MultiheadAttention uses a non conventional init for layers, so do this for exact parity :(
345+
# nn.MultiheadAttention uses a non conventional initialization for layers, so do this for exact parity :(
347346
mha_layer.out_proj.weight = nn.Parameter(vanilla_mha_layer.out_proj.weight.clone().detach())
348347
mha_layer.packed_proj.weight = nn.Parameter(vanilla_mha_layer.in_proj_weight.clone().detach())
349348
mha_layer.out_proj.bias = nn.Parameter(vanilla_mha_layer.out_proj.bias.clone().detach())
@@ -357,8 +356,8 @@ def benchmark(func, *args, **kwargs):
357356
nested_result, nested_time, nested_peak_memory = benchmark(new_mha_layer, query, query, query, is_causal=True)
358357
padded_nested_result = nested_result.to_padded_tensor(0.0)
359358

360-
# For the vanilla nn.MHA, we need to construct the key_padding_mask
361-
# Further, nn.MultiheadAttention forces one to materialize the attn_mask even if using is_causal
359+
# For the vanilla ``nn.MultiheadAttention``, we need to construct the ``key_padding_mask``
360+
# Further, ``nn.MultiheadAttention`` forces one to materialize the ``attn_mask`` even if using ``is_causal``
362361
src_key_padding_mask = torch.where(padded_query == 0.0, -math.inf, 0)[:, :, 0]
363362
attn_mask = torch.empty((N, S, S), device=device).fill_(float('-inf'))
364363
for i, s in enumerate(sentence_lengths):
@@ -431,14 +430,14 @@ def benchmark(func, *args, **kwargs):
431430
# classification of modifications to the transformer architecture, recall that we
432431
# classified the modifications into layer type, layer ordering, and modifications
433432
# to the attention score. We trust that changing layer type and layer ordering
434-
# (e.g. swapping LayerNorm for RMSNorm) is fairly straightforward.
433+
# (e.g. swapping``LayerNorm`` for ``RMSNorm``) is fairly straightforward.
435434
#
436435
# In this section, we will discuss various functionalities using the
437436
# aforementioned building blocks. In particular,
438437
#
439438
# * Packed Projection
440439
# * Cross Attention
441-
# * Fully masked rows no longer cause NaNs
440+
# * Fully masked rows no longer cause ``NaN``s
442441
# * [TODO] Modifying attention score: Relative Positional Embedding with NJT
443442
# * [TODO] KV-Caching with NJT
444443
# * [TODO] Grouped Query Attention with NJT
@@ -448,13 +447,13 @@ def benchmark(func, *args, **kwargs):
448447
# -----------------
449448
#
450449
# Packed projection is a technique that makes use of the fact that when the input
451-
# for projection (matmul) are the same (self-attention), we can pack the projection
450+
# for projection (matrix multiplications) are the same (self-attention), we can pack the projection
452451
# weights and biases into single tensors. It is especially useful when the individual
453-
# projections (matmuls) are memory bound rather than compute bound. There are
452+
# projections are memory bound rather than compute bound. There are
454453
# two examples that we will demonstrate here:
455454
#
456455
# * Input projection for MultiheadAttention
457-
# * SwiGLU activation in FFN of Transformer Layer
456+
# * SwiGLU activation in feed-forward network of Transformer Layer
458457
#
459458
# Input projection for MultiheadAttention
460459
# ----------------------------------------
@@ -505,7 +504,7 @@ def forward(self, query):
505504
# SwiGLU feed forward network of Transformer Layer
506505
# ------------------------------------------------
507506
# SwiGLU is a non-linear activation function that is increasingly popular in the feed-forward
508-
# network of the transformer layer (e.g. Llama). A FFN with SwiGLU activation is defined as
507+
# network of the transformer layer (e.g. Llama). A feed-forward network with SwiGLU activation is defined as
509508

510509
class SwiGLUFFN(nn.Module):
511510
def __init__(self, dim, hidden_dim, multiple_of, ffn_dim_multiplier=None, device=None, dtype=None):
@@ -601,45 +600,47 @@ def forward(self, x):
601600

602601

603602
################################################################################
604-
# [PENDING] KV-Caching with NJT
605-
# ----------------------------
606-
# During decoding in inference, the query comprises of the current token. However,
607-
# the key and value comprises of all the previous keys and values in addition to
608-
# the current token.
609-
#
610-
# When we do batched inference, each batch item will be at a different stage of
611-
# decoding, so we expect the keys and values to have different sequence lengths.
612-
# The query is a dense tensor of shape ``[B, 1, E_qk]`` and the keys and values
613-
# will be of shapes ``[B, *, E_qk]`` and ``[B, *, E_v]`` where ``B`` represents
614-
# batch size, ``*`` represents varying sequence lengths and ``E_qk`` and ``E_v``
615-
# are embedding dimensions for query/key and value respectively.
616-
617-
# Directly related to the above point is the idea of KV-Caching. This is a technique
618-
# that is used in inference to reduce the latency of decoding. The idea is to cache
619-
# the key and value tensors for the previous tokens and use them for the current
620-
# token. This is especially useful when the sequence length is long.
621-
622-
# FIXME: Pending https://github.com/pytorch/pytorch/pull/135722
623-
624-
625-
################################################################################
626-
# [PENDING] Relative Positional Embedding with NJT (FlexAttention + NJT)
603+
# ALiBi with NJT (FlexAttention + NJT)
627604
# ---------------------------------------------------------------------
628-
#
629-
# FIXME: Pending https://github.com/pytorch/pytorch/pull/136792
605+
# NJT also composes with the ``FlexAttention`` module. This is a generalization
606+
# of the ``MultiheadAttention`` layer that allows for arbitrary modifications
607+
# to the attention score. The example below takes the ``alibi_mod`` from
608+
# attention gym and uses it with nested input tensors.
630609

610+
from torch.nn.attention.flex_attention import flex_attention
611+
612+
def generate_alibi_bias(H: int):
613+
"""Returns an alibi bias score_mod given the number of heads H
614+
Args:
615+
H: number of heads
616+
Returns:
617+
alibi_bias: alibi bias score_mod
618+
"""
619+
def alibi_mod(score, b, h, q_idx, kv_idx):
620+
scale = torch.exp2(-((h + 1) * 8.0 / H))
621+
bias = (q_idx - kv_idx) * scale
622+
return score + bias
623+
return alibi_mod
624+
625+
query, key, value, _ = gen_batch(N, E_q, E_k, E_v, device)
626+
n_heads, D = 8, E_q // 8
627+
alibi_score_mod = generate_alibi_bias(n_heads)
628+
query = (
629+
query.unflatten(-1, [n_heads, D]).transpose(1, 2).detach().requires_grad_()
630+
)
631+
key = key.unflatten(-1, [n_heads, D]).transpose(1, 2).detach().requires_grad_()
632+
value = (
633+
value.unflatten(-1, [n_heads, D]).transpose(1, 2).detach().requires_grad_()
634+
)
635+
out_flex2 = flex_attention(query, key, value, score_mod=alibi_score_mod)
631636

632637
################################################################################
633-
# [PENDING] Grouped Query Attention with NJT
634-
# ------------------------------------------
635-
#
636-
# Grouped Query Attention refers to using a number of key/value heads that is
637-
# less than the number of query heads. Compared to MultiheadAttention, this
638-
# decreases the size of the kv-cache during inference.
639-
#
640-
# We can implement this using nested tensors as follows
638+
# And more
639+
# --------
641640
#
642-
# FIXME: Pending FlexAttention/testing for NJT with grouped query attention
641+
# We intend to update this tutorial to demonstrate more examples of how to use
642+
# the various performant building blocks such as KV-Caching, Grouped Query Attention
643+
# etc.
643644

644645

645646
################################################################################
@@ -649,7 +650,7 @@ def forward(self, x):
649650
# There are several good examples of using various performant building blocks to
650651
# implement various transformer architectures. Some examples include
651652
#
652-
# * `gpt_fast <https://github.com/pytorch-labs/gpt-fast>`_
653-
# * `sam_fast <https://github.com/pytorch-labs/sam-fast>`_
654-
# * `lucidrains implementation of ViT with nested tensors <https://github.com/lucidrains/vit-pytorch/blob/main/vit_pytorch/nested_tensor.py>`_
653+
# * `gpt-fast <https://github.com/pytorch-labs/gpt-fast>`_
654+
# * `segment-anything-fast <https://github.com/pytorch-labs/segment-anything-fast>`_
655+
# * `lucidrains implementation of NaViT with nested tensors <https://github.com/lucidrains/vit-pytorch/blob/73199ab486e0fad9eced2e3350a11681db08b61b/vit_pytorch/na_vit_nested_tensor.py>`_
655656
# * `torchtune's implementation of VisionTransformer <https://github.com/pytorch/torchtune/blob/a8a64ec6a99a6ea2be4fdaf0cd5797b03a2567cf/torchtune/modules/vision_transformer.py#L16>`_

0 commit comments

Comments
 (0)