Skip to content

Commit c11927a

Browse files
pyspelling + linkcheck + flex
1 parent a83f66c commit c11927a

File tree

2 files changed

+71
-56
lines changed

2 files changed

+71
-56
lines changed

en-wordlist.txt

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -128,6 +128,7 @@ Kihyuk
128128
Kiuk
129129
Kubernetes
130130
Kuei
131+
KV-Caching
131132
LRSchedulers
132133
LSTM
133134
LSTMs
@@ -276,6 +277,7 @@ Xcode
276277
Xeon
277278
Yidong
278279
YouTube
280+
Zipf
279281
accelerometer
280282
accuracies
281283
activations
@@ -305,6 +307,7 @@ bbAP
305307
benchmarked
306308
benchmarking
307309
bitwise
310+
bool
308311
boolean
309312
breakpoint
310313
broadcasted
@@ -333,6 +336,7 @@ csv
333336
cuDNN
334337
cuda
335338
customizable
339+
customizations
336340
datafile
337341
dataflow
338342
dataframe
@@ -377,6 +381,7 @@ fbgemm
377381
feedforward
378382
finetune
379383
finetuning
384+
FlexAttention
380385
fp
381386
frontend
382387
functionalized
@@ -431,6 +436,7 @@ mAP
431436
macos
432437
manualSeed
433438
matmul
439+
matmuls
434440
matplotlib
435441
memcpy
436442
memset
@@ -446,6 +452,7 @@ modularized
446452
mpp
447453
mucosa
448454
multihead
455+
MultiheadAttention
449456
multimodal
450457
multimodality
451458
multinode
@@ -456,7 +463,10 @@ multithreading
456463
namespace
457464
natively
458465
ndarrays
466+
nheads
459467
nightlies
468+
NJT
469+
NJTs
460470
num
461471
numericalize
462472
numpy
@@ -532,6 +542,7 @@ runtime
532542
runtime
533543
runtimes
534544
scalable
545+
SDPA
535546
sharded
536547
softmax
537548
sparsified
@@ -591,12 +602,14 @@ tradeoff
591602
tradeoffs
592603
triton
593604
uint
605+
UX
594606
umap
595607
uncomment
596608
uncommented
597609
underflowing
598610
unfused
599611
unimodal
612+
unigram
600613
unnormalized
601614
unoptimized
602615
unparametrized
@@ -618,6 +631,7 @@ warmstarted
618631
warmstarting
619632
warmup
620633
webp
634+
wikitext
621635
wsi
622636
wsis
623637
Meta's

intermediate_source/transformer_building_blocks.py

Lines changed: 57 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@
4040
Please head there instead!
4141
4242
If you are only interested in performant attention score modifications, please
43-
head to the `FlexAttention blog <https://flexattention.com/blog/>`_ that
43+
head to the `FlexAttention blog <https://pytorch.org/blog/flexattention/>`_ that
4444
contains a `gym of masks <https://github.com/pytorch-labs/attention-gym>`_ .
4545
4646
If you are wondering about what building blocks the ``torch`` library provides
@@ -63,7 +63,7 @@
6363
*`scaled_dot_product_attention <https://pytorch.org/tutorials/intermediate/scaled_dot_product_attention_tutorial.html>`_
6464
6565
``scaled_dot_product_attention`` is a primitive for
66-
$\text{softmax}(\frac{QK^T}{\sqrt{E}} + B)V$ that dispatches into either fused
66+
:math:`\text{softmax}(\frac{QK^T}{\sqrt{E}} + B)V` that dispatches into either fused
6767
implementations of the operator or a fallback implementation. It works out of
6868
the box in eager mode (i.e. the default mode of using PyTorch where operations
6969
are executed on the fly as they are encountered) and also integrates seamlessly
@@ -118,7 +118,7 @@
118118
# The improvements are threefold:
119119
#
120120
# * User Experience
121-
# Recall that `nn.MultiheadAttention` requires ``query```, ``key`` and
121+
# Recall that ``nn.MultiheadAttention`` requires ``query```, ``key`` and
122122
# ``value`` to be dense ``torch.Tensor``s. It also provides a
123123
# ``key_padding_mask`` that is used to mask out padding tokens in the ``key``
124124
# that arise due to different sequence lengths within a batch. Since there is
@@ -202,10 +202,10 @@ def forward(self,
202202
4. Apply output projection
203203
204204
Args:
205-
query (torch.Tensor): query of shape (N, L_q, E_qk)
206-
key (torch.Tensor): key of shape (N, L_kv, E_qk)
207-
value (torch.Tensor): value of shape (N, L_kv, E_v)
208-
attn_mask (torch.Tensor, optional): attention mask of shape (N, L_q, L_kv) to pass to sdpa. Default: None
205+
query (torch.Tensor): query of shape (``N``, ``L_q``, ``E_qk``)
206+
key (torch.Tensor): key of shape (``N``, ``L_kv``, ``E_qk``)
207+
value (torch.Tensor): value of shape (``N``, ``L_kv``, ``E_v``)
208+
attn_mask (torch.Tensor, optional): attention mask of shape (``N``, ``L_q``, ``L_kv``) to pass to SDPA. Default: None
209209
is_causal (bool, optional): Whether to apply causal mask. Default: False
210210
211211
Returns:
@@ -251,11 +251,10 @@ def forward(self,
251251

252252
return attn_output
253253

254-
# TODO: Check whether there is a way to collapse this section by default
255-
# sphinx.collapse?
254+
256255
###############################################################################
257256
# Utilities
258-
# =========
257+
# ========================
259258
# In this section, we include a utility to generate semi-realistic data using
260259
# Zipf distribution for sentence lengths. This is used to generate the nested
261260
# query, key and value tensors. We also include a benchmark utility.
@@ -343,7 +342,7 @@ def benchmark(func, *args, **kwargs):
343342
torch.manual_seed(6)
344343
vanilla_mha_layer = nn.MultiheadAttention(E_q, nheads, dropout=dropout, batch_first=True, bias=bias, device='cuda')
345344

346-
# nn.MultiheadAttention uses a non conventional init for layers, so do this for exact parity :(
345+
# nn.MultiheadAttention uses a non conventional initialization for layers, so do this for exact parity :(
347346
mha_layer.out_proj.weight = nn.Parameter(vanilla_mha_layer.out_proj.weight.clone().detach())
348347
mha_layer.packed_proj.weight = nn.Parameter(vanilla_mha_layer.in_proj_weight.clone().detach())
349348
mha_layer.out_proj.bias = nn.Parameter(vanilla_mha_layer.out_proj.bias.clone().detach())
@@ -357,8 +356,8 @@ def benchmark(func, *args, **kwargs):
357356
nested_result, nested_time, nested_peak_memory = benchmark(new_mha_layer, query, query, query, is_causal=True)
358357
padded_nested_result = nested_result.to_padded_tensor(0.0)
359358

360-
# For the vanilla nn.MHA, we need to construct the key_padding_mask
361-
# Further, nn.MultiheadAttention forces one to materialize the attn_mask even if using is_causal
359+
# For the vanilla ``nn.MultiheadAttention``, we need to construct the ``key_padding_mask``
360+
# Further, ``nn.MultiheadAttention`` forces one to materialize the ``attn_mask`` even if using ``is_causal``
362361
src_key_padding_mask = torch.where(padded_query == 0.0, -math.inf, 0)[:, :, 0]
363362
attn_mask = torch.empty((N, S, S), device=device).fill_(float('-inf'))
364363
for i, s in enumerate(sentence_lengths):
@@ -431,14 +430,14 @@ def benchmark(func, *args, **kwargs):
431430
# classification of modifications to the transformer architecture, recall that we
432431
# classified the modifications into layer type, layer ordering, and modifications
433432
# to the attention score. We trust that changing layer type and layer ordering
434-
# (e.g. swapping LayerNorm for RMSNorm) is fairly straightforward.
433+
# (e.g. swapping``LayerNorm`` for ``RMSNorm``) is fairly straightforward.
435434
#
436435
# In this section, we will discuss various functionalities using the
437436
# aforementioned building blocks. In particular,
438437
#
439438
# * Packed Projection
440439
# * Cross Attention
441-
# * Fully masked rows no longer cause NaNs
440+
# * Fully masked rows no longer cause ``NaN``s
442441
# * [TODO] Modifying attention score: Relative Positional Embedding with NJT
443442
# * [TODO] KV-Caching with NJT
444443
# * [TODO] Grouped Query Attention with NJT
@@ -448,13 +447,13 @@ def benchmark(func, *args, **kwargs):
448447
# -----------------
449448
#
450449
# Packed projection is a technique that makes use of the fact that when the input
451-
# for projection (matmul) are the same (self-attention), we can pack the projection
450+
# for projection (matrix multiplications) are the same (self-attention), we can pack the projection
452451
# weights and biases into single tensors. It is especially useful when the individual
453-
# projections (matmuls) are memory bound rather than compute bound. There are
452+
# projections are memory bound rather than compute bound. There are
454453
# two examples that we will demonstrate here:
455454
#
456455
# * Input projection for MultiheadAttention
457-
# * SwiGLU activation in FFN of Transformer Layer
456+
# * SwiGLU activation in feed-forward network of Transformer Layer
458457
#
459458
# Input projection for MultiheadAttention
460459
# ----------------------------------------
@@ -505,7 +504,7 @@ def forward(self, query):
505504
# SwiGLU feed forward network of Transformer Layer
506505
# ------------------------------------------------
507506
# SwiGLU is a non-linear activation function that is increasingly popular in the feed-forward
508-
# network of the transformer layer (e.g. Llama). A FFN with SwiGLU activation is defined as
507+
# network of the transformer layer (e.g. Llama). A feed-forward network with SwiGLU activation is defined as
509508

510509
class SwiGLUFFN(nn.Module):
511510
def __init__(self, dim, hidden_dim, multiple_of, ffn_dim_multiplier=None, device=None, dtype=None):
@@ -601,45 +600,47 @@ def forward(self, x):
601600

602601

603602
################################################################################
604-
# [PENDING] KV-Caching with NJT
605-
# ----------------------------
606-
# During decoding in inference, the query comprises of the current token. However,
607-
# the key and value comprises of all the previous keys and values in addition to
608-
# the current token.
609-
#
610-
# When we do batched inference, each batch item will be at a different stage of
611-
# decoding, so we expect the keys and values to have different sequence lengths.
612-
# The query is a dense tensor of shape ``[B, 1, E_qk]`` and the keys and values
613-
# will be of shapes ``[B, *, E_qk]`` and ``[B, *, E_v]`` where ``B`` represents
614-
# batch size, ``*`` represents varying sequence lengths and ``E_qk`` and ``E_v``
615-
# are embedding dimensions for query/key and value respectively.
616-
617-
# Directly related to the above point is the idea of KV-Caching. This is a technique
618-
# that is used in inference to reduce the latency of decoding. The idea is to cache
619-
# the key and value tensors for the previous tokens and use them for the current
620-
# token. This is especially useful when the sequence length is long.
621-
622-
# FIXME: Pending https://github.com/pytorch/pytorch/pull/135722
623-
624-
625-
################################################################################
626-
# [PENDING] Relative Positional Embedding with NJT (FlexAttention + NJT)
603+
# ALiBi with NJT (FlexAttention + NJT)
627604
# ---------------------------------------------------------------------
628-
#
629-
# FIXME: Pending https://github.com/pytorch/pytorch/pull/136792
605+
# NJT also composes with the ``FlexAttention`` module. This is a generalization
606+
# of the ``MultiheadAttention`` layer that allows for arbitrary modifications
607+
# to the attention score. The example below takes the ``alibi_mod`` from
608+
# attention gym and uses it with nested input tensors.
630609

610+
from torch.nn.attention.flex_attention import flex_attention
611+
612+
def generate_alibi_bias(H: int):
613+
"""Returns an alibi bias score_mod given the number of heads H
614+
Args:
615+
H: number of heads
616+
Returns:
617+
alibi_bias: alibi bias score_mod
618+
"""
619+
def alibi_mod(score, b, h, q_idx, kv_idx):
620+
scale = torch.exp2(-((h + 1) * 8.0 / H))
621+
bias = (q_idx - kv_idx) * scale
622+
return score + bias
623+
return alibi_mod
624+
625+
query, key, value, _ = gen_batch(N, E_q, E_k, E_v, device)
626+
n_heads, D = 8, E_q // 8
627+
alibi_score_mod = generate_alibi_bias(n_heads)
628+
query = (
629+
query.unflatten(-1, [n_heads, D]).transpose(1, 2).detach().requires_grad_()
630+
)
631+
key = key.unflatten(-1, [n_heads, D]).transpose(1, 2).detach().requires_grad_()
632+
value = (
633+
value.unflatten(-1, [n_heads, D]).transpose(1, 2).detach().requires_grad_()
634+
)
635+
out_flex2 = flex_attention(query, key, value, score_mod=alibi_score_mod)
631636

632637
################################################################################
633-
# [PENDING] Grouped Query Attention with NJT
634-
# ------------------------------------------
635-
#
636-
# Grouped Query Attention refers to using a number of key/value heads that is
637-
# less than the number of query heads. Compared to MultiheadAttention, this
638-
# decreases the size of the kv-cache during inference.
639-
#
640-
# We can implement this using nested tensors as follows
638+
# And more
639+
# --------
641640
#
642-
# FIXME: Pending FlexAttention/testing for NJT with grouped query attention
641+
# We intend to update this tutorial to demonstrate more examples of how to use
642+
# the various performant building blocks such as KV-Caching, Grouped Query Attention
643+
# etc.
643644

644645

645646
################################################################################
@@ -649,7 +650,7 @@ def forward(self, x):
649650
# There are several good examples of using various performant building blocks to
650651
# implement various transformer architectures. Some examples include
651652
#
652-
# * `gpt_fast <https://github.com/pytorch-labs/gpt-fast>`_
653-
# * `sam_fast <https://github.com/pytorch-labs/sam-fast>`_
654-
# * `lucidrains implementation of ViT with nested tensors <https://github.com/lucidrains/vit-pytorch/blob/main/vit_pytorch/nested_tensor.py>`_
653+
# * `gpt-fast <https://github.com/pytorch-labs/gpt-fast>`_
654+
# * `segment-anything-fast <https://github.com/pytorch-labs/segment-anything-fast>`_
655+
# * `lucidrains implementation of NaViT with nested tensors <https://github.com/lucidrains/vit-pytorch/blob/73199ab486e0fad9eced2e3350a11681db08b61b/vit_pytorch/na_vit_nested_tensor.py>`_
655656
# * `torchtune's implementation of VisionTransformer <https://github.com/pytorch/torchtune/blob/a8a64ec6a99a6ea2be4fdaf0cd5797b03a2567cf/torchtune/modules/vision_transformer.py#L16>`_

0 commit comments

Comments
 (0)