Skip to content

Commit f27bc5c

Browse files
Fix messaging
1 parent 32a7e99 commit f27bc5c

File tree

2 files changed

+105
-38
lines changed

2 files changed

+105
-38
lines changed

index.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -666,7 +666,7 @@ Welcome to PyTorch Tutorials
666666

667667

668668
.. customcarditem::
669-
:header: Dismantling the nn.Transformer modules for gains and profits
669+
:header: Accelerating PyTorch Transformers by replacing nn.Transformer with Nested Tensors and torch.compile()
670670
:card_description: This tutorial goes over recommended best practices for implementing Transformers with native PyTorch.
671671
:image: _static/img/thumbnails/cropped/pytorch-logo.png
672672
:link: intermediate/transformer_building_blocks.html

intermediate_source/transformer_building_blocks.py

Lines changed: 104 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -1,38 +1,39 @@
11
"""
2-
Dismantling the ``nn.Transformer`` modules for gains and profits
3-
=================================================================
2+
Accelerating PyTorch Transformers by replacing nn.Transformer with Nested Tensors and torch.compile()
3+
=====================================================================================================
44
**Author:** `Mikayla Gawarecki <https://github.com/mikaylagawarecki>`_
55
66
.. note::
77
This tutorial should be run with the latest nightly, or, when available, 2.6.
88
9-
The ``torch.nn`` module currently provides various ``Transformer``-related layers.
10-
In particular ``TransformerEncoderLayer``, ``TransformerEncoder``, ``TransformerDecoderLayer``,
11-
``TransformerDecoder``, ``Transformer`` and ``MultiheadAttention``. This family
12-
of layers was initially implemented following the `Attention is All
13-
You Need <https://arxiv.org/abs/1706.03762>`_ paper. Since then, various improvements
14-
were made to try to make these layers more flexible.
15-
16-
While historically these layers intended to provide out-of-the-box, performant
17-
solutions, we make the observations that
18-
19-
1. People want to add slight customizations to their transformer layers
20-
2. Writing these layers and customizations is not hard
21-
22-
23-
Supporting all transformer variants via a small number of out of the box layers would
24-
yield too many keyword arguments. This tutorial will describe how to build your
25-
own performant transformer layers following our recommended best practices.
26-
The technologies used will be the following
9+
Over the past few years, the PyTorch team has developed various lower level
10+
features that, when composed, can create a variety of transformer variants. These
11+
include:
2712
2813
1. Nested Tensors with the ``torch.jagged`` layout (AKA NJTs)
2914
2. ``scaled_dot_product_attention``
3015
3. ``torch.compile()``
3116
4. ``FlexAttention``
3217
18+
This tutorial will give a brief overview of the above technologies and
19+
demonstrate how they can be composed to yield flexible and performant transformer \
20+
layers with improved user experience.
21+
22+
One may observe that the ``torch.nn`` module currently provides various ``Transformer``-related layers.
23+
In particular ``TransformerEncoderLayer``, ``TransformerEncoder``, ``TransformerDecoderLayer``,
24+
``TransformerDecoder``, ``Transformer`` and ``MultiheadAttention``. This family
25+
of layers was initially implemented following the `Attention is All
26+
You Need <https://arxiv.org/abs/1706.03762>`_ paper. The components discussed in
27+
this tutorial provide improved user experience, flexibility and performance over
28+
the existing ``nn`` layers.
29+
3330
Is this tutorial for me?
3431
========================
3532
33+
If you are wondering about what building blocks the ``torch`` library provides
34+
for writing your own transformer layers and best practices, you are in the
35+
right place, please keep reading!
36+
3637
If you are looking for an out-of-the-box implementation of a popular transformer
3738
architecture, note that there are many open-source libraries that provide them,
3839
with some examples being:
@@ -41,15 +42,9 @@
4142
* `xformers <https://github.com/facebookresearch/xformers>`_
4243
* `torchtune <https://github.com/pytorch/torchtune>`_
4344
44-
Please head there instead!
45-
4645
If you are only interested in performant attention score modifications, please
4746
head to the `FlexAttention blog <https://pytorch.org/blog/flexattention/>`_ that
4847
contains a `gym of masks <https://github.com/pytorch-labs/attention-gym>`_.
49-
If you are wondering about what building blocks the ``torch`` library provides
50-
for writing your own transformer layers and best practices, you are in the
51-
right place, please keep reading!
52-
5348
5449
"""
5550

@@ -393,7 +388,7 @@ def benchmark(func, *args, **kwargs):
393388

394389
print(f"{padded_time=:.5f}, padded_peak_memory={padded_peak_memory/1e9:.2f} GB")
395390
print(f"{nested_time=:.5f}, nested_peak_memory={nested_peak_memory/1e9:.2f} GB")
396-
print("Difference between vanilla and nested result", (padded_result - padded_nested_result).abs().max().item())
391+
print("Max difference between vanilla and nested result", (padded_result - padded_nested_result).abs().max().item())
397392
print(f"Nested speedup: {(padded_time/nested_time):.2f}")
398393
print(f"Nested peak memory reduction {((padded_peak_memory - nested_peak_memory)/1e9):.2f} GB")
399394

@@ -404,7 +399,7 @@ def benchmark(func, *args, **kwargs):
404399
#
405400
# padded_time=0.03454, padded_peak_memory=4.14 GB
406401
# nested_time=0.00612, nested_peak_memory=0.76 GB
407-
# Difference between vanilla and nested result 0.0
402+
# Max difference between vanilla and nested result 0.0
408403
# Nested speedup: 5.65
409404
# Nested peak memory reduction 3.39 GB
410405
#
@@ -432,14 +427,14 @@ def benchmark(func, *args, **kwargs):
432427
#
433428
# .. code::
434429
#
435-
# ``padded_bw_time``=2.09337, ``padded_bw_peak_mem``=5.10 GB
436-
# ``nested_bw_time``=0.01452, ``nested_bw_peak_mem``=3.24 GB
430+
# padded_bw_time=2.09337, padded_bw_peak_mem=5.10 GB
431+
# nested_bw_time=0.01452, nested_bw_peak_mem=3.24 GB
437432
# Nested backward speedup: 144.13
438433
# Nested backward peak memory reduction 1.86 GB
439-
# Difference in ``out_proj.weight.grad`` 0.000244140625
440-
# Difference in ``packed_proj.weight.grad`` 0.001556396484375
441-
# Difference in ``out_proj.bias.grad`` 0.0
442-
# Difference in ``packed_proj.bias.grad`` 0.001953125
434+
# Difference in out_proj.weight.grad 0.000244140625
435+
# Difference in packed_proj.weight.grad 0.001556396484375
436+
# Difference in out_proj.bias.grad 0.0
437+
# Difference in packed_proj.bias.grad 0.001953125
443438
#
444439

445440
##################################################################################
@@ -493,6 +488,53 @@ def benchmark(func, *args, **kwargs):
493488
print(f"Total sequence length in nested key/value {kv_len.sum().item()}, max sequence length {kv_len.max().item()}")
494489
out = new_mha_layer(query, key, value, is_causal=False)
495490

491+
########################################################################################
492+
# As above, we can compare this against the vanilla compiled ``nn.MultiheadAttention``.
493+
494+
torch.manual_seed(6)
495+
query, _, _, q_len = gen_batch(N, E_q, E_k, E_v, device)
496+
_, key, value, kv_len = gen_batch(N, E_q, E_k, E_v, device)
497+
padded_query, padded_key, padded_value = (
498+
t.to_padded_tensor(0.0) for t in (query, key, value)
499+
)
500+
501+
key_padding_mask = torch.where(padded_key == 0.0, -math.inf, 0)[:, :, 0]
502+
503+
# warmup compile
504+
warmup_nested_result = new_mha_layer(query, key, value, is_causal=False)
505+
warmup_vanilla_result = vanilla_mha_layer(padded_query,
506+
padded_key,
507+
padded_value,
508+
key_padding_mask=key_padding_mask,
509+
need_weights=False,
510+
is_causal=False)
511+
512+
nested_result, nested_time, nested_peak_memory = benchmark(new_mha_layer, query, key, value, is_causal=False)
513+
(padded_result, _), padded_time, padded_peak_memory = benchmark(vanilla_mha_layer,
514+
padded_query,
515+
padded_key,
516+
padded_value,
517+
key_padding_mask=key_padding_mask,
518+
need_weights=False,
519+
is_causal=False)
520+
padded_nested_result = nested_result.to_padded_tensor(0.0)
521+
for i, entry_length in enumerate(q_len):
522+
# padding-specific step: remove output projection bias from padded entries for fair comparison
523+
padded_result[i, entry_length:, :] = 0.0
524+
525+
print("Max difference between vanilla and nested result", (padded_result - padded_nested_result).abs().max().item())
526+
print(f"Nested speedup: {(padded_time/nested_time):.2f}")
527+
print(f"Nested peak memory reduction {((padded_peak_memory - nested_peak_memory)/1e9):.2f} GB")
528+
529+
##################################################################################
530+
# Sample outputs on A100:
531+
#
532+
# .. code::
533+
#
534+
# Max difference between vanilla and nested result 0.0
535+
# Nested speedup: 4.01
536+
# Nested peak memory reduction 1.40 GB
537+
#
496538

497539
################################################################################
498540
# Fully masked rows no longer cause NaNs
@@ -549,6 +591,29 @@ def alibi_mod(score, b, h, q_idx, kv_idx):
549591
)
550592
out_flex2 = flex_attention(query, key, value, score_mod=alibi_score_mod)
551593

594+
###############################################################################
595+
# In addition, one can also use the ``block_mask`` utility of ``FlexAttention``
596+
# with NJTs via the ``create_nested_block_mask`` function. This is useful for
597+
# taking advantage of the sparsity of the mask to speed up the attention computation.
598+
# In the following example, we show how to create a causal block mask using this
599+
# utility.
600+
601+
from torch.nn.attention.flex_attention import create_nested_block_mask
602+
603+
def causal_mask(b, h, q_idx, kv_idx):
604+
return q_idx >= kv_idx
605+
606+
query, key, value, _ = gen_batch(N, E_q, E_k, E_v, device)
607+
block_mask = create_nested_block_mask(causal_mask, 1, 1, query, _compile=True)
608+
query = (
609+
query.unflatten(-1, [n_heads, D]).transpose(1, 2).detach().requires_grad_()
610+
)
611+
key = key.unflatten(-1, [n_heads, D]).transpose(1, 2).detach().requires_grad_()
612+
value = (
613+
value.unflatten(-1, [n_heads, D]).transpose(1, 2).detach().requires_grad_()
614+
)
615+
out_flex = flex_attention(query, key, value, block_mask=block_mask)
616+
552617
###############################################################################
553618
# Packed Projection
554619
# -----------------
@@ -579,8 +644,8 @@ def __init__(self, E_q, E_total, bias=False, device=None, dtype=None):
579644
self.k_proj = nn.Linear(E_q, E_total, bias=bias, **factory_kwargs)
580645
self.v_proj = nn.Linear(E_q, E_total, bias=bias, **factory_kwargs)
581646

582-
def forward(self, query):
583-
return self.q_proj(query), self.k_proj(query), self.v_proj(query)
647+
def forward(self, x):
648+
return self.q_proj(x), self.k_proj(x), self.v_proj(x)
584649

585650
class PackedInputProjection(nn.Module):
586651
def __init__(self, E_q, E_total, bias=False, device=None, dtype=None):
@@ -591,7 +656,7 @@ def __init__(self, E_q, E_total, bias=False, device=None, dtype=None):
591656
def forward(self, query):
592657
return torch.chunk(self.packed_proj(query), 3, dim=-1)
593658

594-
B, D, dtype = 256, 4096, torch.bfloat16
659+
B, D, dtype = 256, 8192, torch.bfloat16
595660

596661
torch.set_float32_matmul_precision('high')
597662
in_proj = torch.compile(InputProjection(D, D, device='cuda', dtype=torch.bfloat16))
@@ -606,6 +671,7 @@ def forward(self, query):
606671
# benchmark
607672
(q_out, k_out, v_out), time, _ = benchmark(in_proj, q)
608673
(q_out, k_out, v_out), time_packed, _ = benchmark(packed_in_proj, q)
674+
# On my A100 prints 1.05x speedup
609675
print(f"InputProjection: {time:5f} s, PackedInputProjection: {time_packed:5f} s, speedup: {time/time_packed:.2f}x")
610676

611677
##################################################
@@ -669,6 +735,7 @@ def forward(self, x):
669735
# benchmark
670736
_, time, _ = benchmark(swigluffn, q)
671737
_, time_packed, _ = benchmark(packed_swigluffn, q)
738+
# On my A100 prints 1.08x speedup
672739
print(f"SwiGLUFFN: {time} s, PackedSwiGLUFFN: {time_packed} s, speedup: {time/time_packed:.2f}x")
673740

674741
################################################################################

0 commit comments

Comments
 (0)