1
1
"""
2
- Dismantling the `` nn.Transformer`` modules for gains and profits
3
- =================================================================
2
+ Accelerating PyTorch Transformers by replacing nn.Transformer with Nested Tensors and torch.compile()
3
+ =====================================================================================================
4
4
**Author:** `Mikayla Gawarecki <https://github.com/mikaylagawarecki>`_
5
5
6
6
.. note::
7
7
This tutorial should be run with the latest nightly, or, when available, 2.6.
8
8
9
- The ``torch.nn`` module currently provides various ``Transformer``-related layers.
10
- In particular ``TransformerEncoderLayer``, ``TransformerEncoder``, ``TransformerDecoderLayer``,
11
- ``TransformerDecoder``, ``Transformer`` and ``MultiheadAttention``. This family
12
- of layers was initially implemented following the `Attention is All
13
- You Need <https://arxiv.org/abs/1706.03762>`_ paper. Since then, various improvements
14
- were made to try to make these layers more flexible.
15
-
16
- While historically these layers intended to provide out-of-the-box, performant
17
- solutions, we make the observations that
18
-
19
- 1. People want to add slight customizations to their transformer layers
20
- 2. Writing these layers and customizations is not hard
21
-
22
-
23
- Supporting all transformer variants via a small number of out of the box layers would
24
- yield too many keyword arguments. This tutorial will describe how to build your
25
- own performant transformer layers following our recommended best practices.
26
- The technologies used will be the following
9
+ Over the past few years, the PyTorch team has developed various lower level
10
+ features that, when composed, can create a variety of transformer variants. These
11
+ include:
27
12
28
13
1. Nested Tensors with the ``torch.jagged`` layout (AKA NJTs)
29
14
2. ``scaled_dot_product_attention``
30
15
3. ``torch.compile()``
31
16
4. ``FlexAttention``
32
17
18
+ This tutorial will give a brief overview of the above technologies and
19
+ demonstrate how they can be composed to yield flexible and performant transformer \
20
+ layers with improved user experience.
21
+
22
+ One may observe that the ``torch.nn`` module currently provides various ``Transformer``-related layers.
23
+ In particular ``TransformerEncoderLayer``, ``TransformerEncoder``, ``TransformerDecoderLayer``,
24
+ ``TransformerDecoder``, ``Transformer`` and ``MultiheadAttention``. This family
25
+ of layers was initially implemented following the `Attention is All
26
+ You Need <https://arxiv.org/abs/1706.03762>`_ paper. The components discussed in
27
+ this tutorial provide improved user experience, flexibility and performance over
28
+ the existing ``nn`` layers.
29
+
33
30
Is this tutorial for me?
34
31
========================
35
32
33
+ If you are wondering about what building blocks the ``torch`` library provides
34
+ for writing your own transformer layers and best practices, you are in the
35
+ right place, please keep reading!
36
+
36
37
If you are looking for an out-of-the-box implementation of a popular transformer
37
38
architecture, note that there are many open-source libraries that provide them,
38
39
with some examples being:
41
42
* `xformers <https://github.com/facebookresearch/xformers>`_
42
43
* `torchtune <https://github.com/pytorch/torchtune>`_
43
44
44
- Please head there instead!
45
-
46
45
If you are only interested in performant attention score modifications, please
47
46
head to the `FlexAttention blog <https://pytorch.org/blog/flexattention/>`_ that
48
47
contains a `gym of masks <https://github.com/pytorch-labs/attention-gym>`_.
49
- If you are wondering about what building blocks the ``torch`` library provides
50
- for writing your own transformer layers and best practices, you are in the
51
- right place, please keep reading!
52
-
53
48
54
49
"""
55
50
@@ -393,7 +388,7 @@ def benchmark(func, *args, **kwargs):
393
388
394
389
print (f"{ padded_time = :.5f} , padded_peak_memory={ padded_peak_memory / 1e9 :.2f} GB" )
395
390
print (f"{ nested_time = :.5f} , nested_peak_memory={ nested_peak_memory / 1e9 :.2f} GB" )
396
- print ("Difference between vanilla and nested result" , (padded_result - padded_nested_result ).abs ().max ().item ())
391
+ print ("Max difference between vanilla and nested result" , (padded_result - padded_nested_result ).abs ().max ().item ())
397
392
print (f"Nested speedup: { (padded_time / nested_time ):.2f} " )
398
393
print (f"Nested peak memory reduction { ((padded_peak_memory - nested_peak_memory )/ 1e9 ):.2f} GB" )
399
394
@@ -404,7 +399,7 @@ def benchmark(func, *args, **kwargs):
404
399
#
405
400
# padded_time=0.03454, padded_peak_memory=4.14 GB
406
401
# nested_time=0.00612, nested_peak_memory=0.76 GB
407
- # Difference between vanilla and nested result 0.0
402
+ # Max difference between vanilla and nested result 0.0
408
403
# Nested speedup: 5.65
409
404
# Nested peak memory reduction 3.39 GB
410
405
#
@@ -432,14 +427,14 @@ def benchmark(func, *args, **kwargs):
432
427
#
433
428
# .. code::
434
429
#
435
- # `` padded_bw_time`` =2.09337, `` padded_bw_peak_mem`` =5.10 GB
436
- # `` nested_bw_time`` =0.01452, `` nested_bw_peak_mem`` =3.24 GB
430
+ # padded_bw_time=2.09337, padded_bw_peak_mem=5.10 GB
431
+ # nested_bw_time=0.01452, nested_bw_peak_mem=3.24 GB
437
432
# Nested backward speedup: 144.13
438
433
# Nested backward peak memory reduction 1.86 GB
439
- # Difference in `` out_proj.weight.grad`` 0.000244140625
440
- # Difference in `` packed_proj.weight.grad`` 0.001556396484375
441
- # Difference in `` out_proj.bias.grad`` 0.0
442
- # Difference in `` packed_proj.bias.grad`` 0.001953125
434
+ # Difference in out_proj.weight.grad 0.000244140625
435
+ # Difference in packed_proj.weight.grad 0.001556396484375
436
+ # Difference in out_proj.bias.grad 0.0
437
+ # Difference in packed_proj.bias.grad 0.001953125
443
438
#
444
439
445
440
##################################################################################
@@ -493,6 +488,53 @@ def benchmark(func, *args, **kwargs):
493
488
print (f"Total sequence length in nested key/value { kv_len .sum ().item ()} , max sequence length { kv_len .max ().item ()} " )
494
489
out = new_mha_layer (query , key , value , is_causal = False )
495
490
491
+ ########################################################################################
492
+ # As above, we can compare this against the vanilla compiled ``nn.MultiheadAttention``.
493
+
494
+ torch .manual_seed (6 )
495
+ query , _ , _ , q_len = gen_batch (N , E_q , E_k , E_v , device )
496
+ _ , key , value , kv_len = gen_batch (N , E_q , E_k , E_v , device )
497
+ padded_query , padded_key , padded_value = (
498
+ t .to_padded_tensor (0.0 ) for t in (query , key , value )
499
+ )
500
+
501
+ key_padding_mask = torch .where (padded_key == 0.0 , - math .inf , 0 )[:, :, 0 ]
502
+
503
+ # warmup compile
504
+ warmup_nested_result = new_mha_layer (query , key , value , is_causal = False )
505
+ warmup_vanilla_result = vanilla_mha_layer (padded_query ,
506
+ padded_key ,
507
+ padded_value ,
508
+ key_padding_mask = key_padding_mask ,
509
+ need_weights = False ,
510
+ is_causal = False )
511
+
512
+ nested_result , nested_time , nested_peak_memory = benchmark (new_mha_layer , query , key , value , is_causal = False )
513
+ (padded_result , _ ), padded_time , padded_peak_memory = benchmark (vanilla_mha_layer ,
514
+ padded_query ,
515
+ padded_key ,
516
+ padded_value ,
517
+ key_padding_mask = key_padding_mask ,
518
+ need_weights = False ,
519
+ is_causal = False )
520
+ padded_nested_result = nested_result .to_padded_tensor (0.0 )
521
+ for i , entry_length in enumerate (q_len ):
522
+ # padding-specific step: remove output projection bias from padded entries for fair comparison
523
+ padded_result [i , entry_length :, :] = 0.0
524
+
525
+ print ("Max difference between vanilla and nested result" , (padded_result - padded_nested_result ).abs ().max ().item ())
526
+ print (f"Nested speedup: { (padded_time / nested_time ):.2f} " )
527
+ print (f"Nested peak memory reduction { ((padded_peak_memory - nested_peak_memory )/ 1e9 ):.2f} GB" )
528
+
529
+ ##################################################################################
530
+ # Sample outputs on A100:
531
+ #
532
+ # .. code::
533
+ #
534
+ # Max difference between vanilla and nested result 0.0
535
+ # Nested speedup: 4.01
536
+ # Nested peak memory reduction 1.40 GB
537
+ #
496
538
497
539
################################################################################
498
540
# Fully masked rows no longer cause NaNs
@@ -549,6 +591,29 @@ def alibi_mod(score, b, h, q_idx, kv_idx):
549
591
)
550
592
out_flex2 = flex_attention (query , key , value , score_mod = alibi_score_mod )
551
593
594
+ ###############################################################################
595
+ # In addition, one can also use the ``block_mask`` utility of ``FlexAttention``
596
+ # with NJTs via the ``create_nested_block_mask`` function. This is useful for
597
+ # taking advantage of the sparsity of the mask to speed up the attention computation.
598
+ # In the following example, we show how to create a causal block mask using this
599
+ # utility.
600
+
601
+ from torch .nn .attention .flex_attention import create_nested_block_mask
602
+
603
+ def causal_mask (b , h , q_idx , kv_idx ):
604
+ return q_idx >= kv_idx
605
+
606
+ query , key , value , _ = gen_batch (N , E_q , E_k , E_v , device )
607
+ block_mask = create_nested_block_mask (causal_mask , 1 , 1 , query , _compile = True )
608
+ query = (
609
+ query .unflatten (- 1 , [n_heads , D ]).transpose (1 , 2 ).detach ().requires_grad_ ()
610
+ )
611
+ key = key .unflatten (- 1 , [n_heads , D ]).transpose (1 , 2 ).detach ().requires_grad_ ()
612
+ value = (
613
+ value .unflatten (- 1 , [n_heads , D ]).transpose (1 , 2 ).detach ().requires_grad_ ()
614
+ )
615
+ out_flex = flex_attention (query , key , value , block_mask = block_mask )
616
+
552
617
###############################################################################
553
618
# Packed Projection
554
619
# -----------------
@@ -579,8 +644,8 @@ def __init__(self, E_q, E_total, bias=False, device=None, dtype=None):
579
644
self .k_proj = nn .Linear (E_q , E_total , bias = bias , ** factory_kwargs )
580
645
self .v_proj = nn .Linear (E_q , E_total , bias = bias , ** factory_kwargs )
581
646
582
- def forward (self , query ):
583
- return self .q_proj (query ), self .k_proj (query ), self .v_proj (query )
647
+ def forward (self , x ):
648
+ return self .q_proj (x ), self .k_proj (x ), self .v_proj (x )
584
649
585
650
class PackedInputProjection (nn .Module ):
586
651
def __init__ (self , E_q , E_total , bias = False , device = None , dtype = None ):
@@ -591,7 +656,7 @@ def __init__(self, E_q, E_total, bias=False, device=None, dtype=None):
591
656
def forward (self , query ):
592
657
return torch .chunk (self .packed_proj (query ), 3 , dim = - 1 )
593
658
594
- B , D , dtype = 256 , 4096 , torch .bfloat16
659
+ B , D , dtype = 256 , 8192 , torch .bfloat16
595
660
596
661
torch .set_float32_matmul_precision ('high' )
597
662
in_proj = torch .compile (InputProjection (D , D , device = 'cuda' , dtype = torch .bfloat16 ))
@@ -606,6 +671,7 @@ def forward(self, query):
606
671
# benchmark
607
672
(q_out , k_out , v_out ), time , _ = benchmark (in_proj , q )
608
673
(q_out , k_out , v_out ), time_packed , _ = benchmark (packed_in_proj , q )
674
+ # On my A100 prints 1.05x speedup
609
675
print (f"InputProjection: { time :5f} s, PackedInputProjection: { time_packed :5f} s, speedup: { time / time_packed :.2f} x" )
610
676
611
677
##################################################
@@ -669,6 +735,7 @@ def forward(self, x):
669
735
# benchmark
670
736
_ , time , _ = benchmark (swigluffn , q )
671
737
_ , time_packed , _ = benchmark (packed_swigluffn , q )
738
+ # On my A100 prints 1.08x speedup
672
739
print (f"SwiGLUFFN: { time } s, PackedSwiGLUFFN: { time_packed } s, speedup: { time / time_packed :.2f} x" )
673
740
674
741
################################################################################
0 commit comments