3
3
=================================================================
4
4
**Author:** `Mikayla Gawarecki <https://github.com/mikaylagawarecki>`_
5
5
6
+ .. note::
7
+ This tutorial should be run with the latest nightly, or, when available, 2.6.
8
+
6
9
The ``torch.nn`` module currently provides various ``Transformer``-related layers.
7
10
In particular ``TransformerEncoderLayer``, ``TransformerEncoder``, ``TransformerDecoderLayer``,
8
11
``TransformerDecoder``, ``Transformer`` and ``MultiheadAttention``. This family
@@ -253,73 +256,72 @@ def forward(self,
253
256
254
257
return attn_output
255
258
256
- # .. dropdown::
257
-
258
- ###############################################################################
259
- # Utilities
260
- # =========
261
- # In this section, we include a utility to generate semi-realistic data using
262
- # Zipf distribution for sentence lengths. This is used to generate the nested
263
- # query, key and value tensors. We also include a benchmark utility.
264
-
265
259
266
- import numpy as np
267
-
268
- def zipf_sentence_lengths (alpha : float , batch_size : int ) -> torch .Tensor :
269
- # generate fake corpus by unigram Zipf distribution
270
- # from wikitext-2 corpus, we get rank "." = 3, "!" = 386, "?" = 858
271
- sentence_lengths = np .empty (batch_size , dtype = int )
272
- for ibatch in range (batch_size ):
273
- sentence_lengths [ibatch ] = 1
260
+ ###############################################################################
261
+ # Utilities
262
+ # =========
263
+ # In this section, we include a utility to generate semi-realistic data using
264
+ # Zipf distribution for sentence lengths. This is used to generate the nested
265
+ # query, key and value tensors. We also include a benchmark utility.
266
+
267
+
268
+ import numpy as np
269
+
270
+ def zipf_sentence_lengths (alpha : float , batch_size : int ) -> torch .Tensor :
271
+ # generate fake corpus by unigram Zipf distribution
272
+ # from wikitext-2 corpus, we get rank "." = 3, "!" = 386, "?" = 858
273
+ sentence_lengths = np .empty (batch_size , dtype = int )
274
+ for ibatch in range (batch_size ):
275
+ sentence_lengths [ibatch ] = 1
276
+ word = np .random .zipf (alpha )
277
+ while word != 3 and word != 386 and word != 858 :
278
+ sentence_lengths [ibatch ] += 1
274
279
word = np .random .zipf (alpha )
275
- while word != 3 and word != 386 and word != 858 :
276
- sentence_lengths [ibatch ] += 1
277
- word = np .random .zipf (alpha )
278
- return torch .tensor (sentence_lengths )
279
-
280
- # Generate a batch of semi-realistic data using Zipf distribution for sentence lengths
281
- # in the form of nested tensors with the jagged layout.
282
- def gen_batch (N , E_q , E_k , E_v , device , dtype = torch .float32 , query_seq_len_1 = False ):
283
- # generate semi-realistic data using Zipf distribution for sentence lengths
284
- sentence_lengths = zipf_sentence_lengths (alpha = 1.2 , batch_size = N )
285
-
286
- # Note: the torch.jagged layout is a nested tensor layout that supports a single ragged
287
- # dimension and works with torch.compile. The batch items each have shape (B, S*, D)
288
- # where B = batch size, S* = ragged sequence length, and D = embedding dimension.
289
- if query_seq_len_1 :
280
+ return torch .tensor (sentence_lengths )
281
+
282
+ # Generate a batch of semi-realistic data using Zipf distribution for sentence lengths
283
+ # in the form of nested tensors with the jagged layout.
284
+ def gen_batch (N , E_q , E_k , E_v , device , dtype = torch .float32 , query_seq_len_1 = False ):
285
+ # generate semi-realistic data using Zipf distribution for sentence lengths
286
+ sentence_lengths = zipf_sentence_lengths (alpha = 1.2 , batch_size = N )
287
+
288
+ # Note: the torch.jagged layout is a nested tensor layout that supports a single ragged
289
+ # dimension and works with torch.compile. The batch items each have shape (B, S*, D)
290
+ # where B = batch size, S* = ragged sequence length, and D = embedding dimension.
291
+ if query_seq_len_1 :
290
292
query = torch .nested .nested_tensor ([
291
293
torch .randn (1 , E_q , dtype = dtype , device = device )
292
294
for l in sentence_lengths
293
295
], layout = torch .jagged )
294
- else :
296
+ else :
295
297
query = torch .nested .nested_tensor ([
296
298
torch .randn (l .item (), E_q , dtype = dtype , device = device )
297
299
for l in sentence_lengths
298
300
], layout = torch .jagged )
299
301
300
- key = torch .nested .nested_tensor ([
301
- torch .randn (s .item (), E_k , dtype = dtype , device = device )
302
- for s in sentence_lengths
303
- ], layout = torch .jagged )
302
+ key = torch .nested .nested_tensor ([
303
+ torch .randn (s .item (), E_k , dtype = dtype , device = device )
304
+ for s in sentence_lengths
305
+ ], layout = torch .jagged )
304
306
305
- value = torch .nested .nested_tensor ([
306
- torch .randn (s .item (), E_v , dtype = dtype , device = device )
307
- for s in sentence_lengths
308
- ], layout = torch .jagged )
307
+ value = torch .nested .nested_tensor ([
308
+ torch .randn (s .item (), E_v , dtype = dtype , device = device )
309
+ for s in sentence_lengths
310
+ ], layout = torch .jagged )
309
311
310
- return query , key , value , sentence_lengths
312
+ return query , key , value , sentence_lengths
311
313
312
- import timeit
313
- import math
314
+ import timeit
315
+ import math
314
316
315
- def benchmark (func , * args , ** kwargs ):
316
- torch .cuda .synchronize ()
317
- torch .cuda .reset_peak_memory_stats ()
318
- begin = timeit .default_timer ()
319
- output = func (* args , ** kwargs )
320
- torch .cuda .synchronize ()
321
- end = timeit .default_timer ()
322
- return output , (end - begin ), torch .cuda .max_memory_allocated ()
317
+ def benchmark (func , * args , ** kwargs ):
318
+ torch .cuda .synchronize ()
319
+ torch .cuda .reset_peak_memory_stats ()
320
+ begin = timeit .default_timer ()
321
+ output = func (* args , ** kwargs )
322
+ torch .cuda .synchronize ()
323
+ end = timeit .default_timer ()
324
+ return output , (end - begin ), torch .cuda .max_memory_allocated ()
323
325
324
326
##############################################################################
325
327
# We will now demonstrate the performance improvements of using nested tensors
@@ -395,6 +397,16 @@ def benchmark(func, *args, **kwargs):
395
397
print (f"Nested peak memory reduction { ((padded_peak_memory - nested_peak_memory )/ 1e9 ):.2f} GB" )
396
398
397
399
######################################################################################
400
+ # For reference some sample outputs on A100:
401
+ #
402
+ # ```
403
+ # padded_time=0.03454, padded_peak_memory=4.14 GB
404
+ # nested_time=0.00612, nested_peak_memory=0.76 GB
405
+ # Difference between vanilla and nested result 0.0
406
+ # Nested speedup: 5.65
407
+ # Nested peak memory reduction 3.39 GB
408
+ # ````
409
+ #
398
410
# We can also see the same for backward pass
399
411
400
412
for i , entry_length in enumerate (sentence_lengths ):
@@ -414,6 +426,20 @@ def benchmark(func, *args, **kwargs):
414
426
print ("Difference in out_proj.bias.grad" , (mha_layer .out_proj .bias .grad - vanilla_mha_layer .out_proj .bias .grad ).abs ().max ().item ())
415
427
print ("Difference in packed_proj.bias.grad" , (mha_layer .packed_proj .bias .grad - vanilla_mha_layer .in_proj_bias .grad ).abs ().max ().item ())
416
428
429
+ ##################################################################################
430
+ # Sample outputs on A100:
431
+ #
432
+ # ```
433
+ # padded_bw_time=2.09337, padded_bw_peak_mem=5.10 GB
434
+ # nested_bw_time=0.01452, nested_bw_peak_mem=3.24 GB
435
+ # Nested backward speedup: 144.13
436
+ # Nested backward peak memory reduction 1.86 GB
437
+ # Difference in out_proj.weight.grad 0.000244140625
438
+ # Difference in packed_proj.weight.grad 0.001556396484375
439
+ # Difference in out_proj.bias.grad 0.0
440
+ # Difference in packed_proj.bias.grad 0.001953125
441
+ # ```
442
+
417
443
##################################################################################
418
444
# GPT-style layer
419
445
# ---------------
@@ -424,8 +450,9 @@ def benchmark(func, *args, **kwargs):
424
450
# ``is_causal=True``.
425
451
#
426
452
# We demonstrate examples of implementing the rest of the ``nn`` layers
427
- # `here <https://github.com/mikaylagawarecki/temp>`_ but omit that from this
428
- # tutorial for brevity.
453
+ # `here <https://github.com/mikaylagawarecki/transformer_tutorial_accompaniment>`_
454
+ # but omit that from this tutorial for brevity.
455
+
429
456
430
457
###############################################################################
431
458
# Going one step further
@@ -440,10 +467,85 @@ def benchmark(func, *args, **kwargs):
440
467
# In this section, we will discuss various functionalities using the
441
468
# aforementioned building blocks. In particular,
442
469
#
443
- # * Packed Projection
444
470
# * Cross Attention
445
471
# * Fully masked rows no longer cause ``NaN``s
446
472
# * Modifying attention score: ALiBi with FlexAttention and NJT
473
+ # * Packed Projection
474
+
475
+ ###############################################################################
476
+ # Cross Attention
477
+ # ---------------
478
+ # Cross attention is a form of attention where the query and key/value tensors
479
+ # are from different sequences.
480
+ #
481
+ # One example of this is in ``nn.TransformerDecoderLayer`` where the query comes
482
+ # from the decoder and the key/value come from the encoder.
483
+ #
484
+ # The above MultiheadAttention layer nicely generalizes to this case with nested
485
+ # tensors for both query and key/value.
486
+
487
+ query , _ , _ , q_len = gen_batch (N , E_q , E_k , E_v , device )
488
+ _ , key , value , kv_len = gen_batch (N , E_q , E_k , E_v , device )
489
+
490
+ print (f"Total sequence length in nested query { q_len .sum ().item ()} , max sequence length { q_len .max ().item ()} " )
491
+ print (f"Total sequence length in nested key/value { kv_len .sum ().item ()} , max sequence length { kv_len .max ().item ()} " )
492
+ out = new_mha_layer (query , key , value , is_causal = False )
493
+
494
+
495
+ ################################################################################
496
+ # Fully masked rows no longer cause NaNs
497
+ # --------------------------------------
498
+ #
499
+ # There has been a long standing issue with ``nn.MultiheadAttention`` and
500
+ # ``scaled_dot_product_attention`` where if a row was fully masked out, the output
501
+ # of the attention layer would be NaN. See `issue <https://github.com/pytorch/pytorch/issues/41508>`_.
502
+ # This is because the softmax over an empty set is undefined.
503
+ #
504
+ # Thanks to `this PR <https://github.com/pytorch/pytorch/pull/133882>`_
505
+ # this is no longer the case. Instead, fully masked rows in ``scaled_dot_product_attention``.
506
+ # For cases where ``nn.MHA`` does not employ the "fast-path", this will also apply.
507
+ #
508
+ # Using a custom MHA layer with NJTs is strongly recommended over the
509
+ # existing "fast-path" in ``nn.MultiheadAttention`` as NJT's ability to model raggedness
510
+ # appropriately makes it possible to properly express empty sequences.
511
+
512
+
513
+ ################################################################################
514
+ # FlexAttention + NJT
515
+ # ---------------------------------------------------------------------
516
+ # NJT also composes with the ``FlexAttention`` module. This is a generalization
517
+ # of the ``MultiheadAttention`` layer that allows for arbitrary modifications
518
+ # to the attention score. The example below takes the ``alibi_mod``
519
+ # that implements `ALiBi <https://arxiv.org/abs/2108.12409>`_ from
520
+ # `attention gym <https://github.com/pytorch-labs/attention-gym>`_ and uses it
521
+ # with nested input tensors.
522
+
523
+ from torch .nn .attention .flex_attention import flex_attention
524
+
525
+ def generate_alibi_bias (H : int ):
526
+ """Returns an alibi bias score_mod given the number of heads H
527
+ Args:
528
+ H: number of heads
529
+ Returns:
530
+ alibi_bias: alibi bias score_mod
531
+ """
532
+ def alibi_mod (score , b , h , q_idx , kv_idx ):
533
+ scale = torch .exp2 (- ((h + 1 ) * 8.0 / H ))
534
+ bias = (q_idx - kv_idx ) * scale
535
+ return score + bias
536
+ return alibi_mod
537
+
538
+ query , key , value , _ = gen_batch (N , E_q , E_k , E_v , device )
539
+ n_heads , D = 8 , E_q // 8
540
+ alibi_score_mod = generate_alibi_bias (n_heads )
541
+ query = (
542
+ query .unflatten (- 1 , [n_heads , D ]).transpose (1 , 2 ).detach ().requires_grad_ ()
543
+ )
544
+ key = key .unflatten (- 1 , [n_heads , D ]).transpose (1 , 2 ).detach ().requires_grad_ ()
545
+ value = (
546
+ value .unflatten (- 1 , [n_heads , D ]).transpose (1 , 2 ).detach ().requires_grad_ ()
547
+ )
548
+ out_flex2 = flex_attention (query , key , value , score_mod = alibi_score_mod )
447
549
448
550
###############################################################################
449
551
# Packed Projection
@@ -567,80 +669,6 @@ def forward(self, x):
567
669
_ , time_packed , _ = benchmark (packed_swigluffn , q )
568
670
print (f"SwiGLUFFN: { time } s, PackedSwiGLUFFN: { time_packed } s, speedup: { time / time_packed :.2f} x" )
569
671
570
- ###############################################################################
571
- # Cross Attention
572
- # ---------------
573
- # Cross attention is a form of attention where the query and key/value tensors
574
- # are from different sequences.
575
- #
576
- # One example of this is in ``nn.TransformerDecoderLayer`` where the query comes
577
- # from the decoder and the key/value come from the encoder.
578
- #
579
- # The above MultiheadAttention layer nicely generalizes to this case with nested
580
- # tensors for both query and key/value.
581
-
582
- query , _ , _ , q_len = gen_batch (N , E_q , E_k , E_v , device )
583
- _ , key , value , kv_len = gen_batch (N , E_q , E_k , E_v , device )
584
-
585
- print (f"Total sequence length in nested query { q_len .sum ().item ()} , max sequence length { q_len .max ().item ()} " )
586
- print (f"Total sequence length in nested key/value { kv_len .sum ().item ()} , max sequence length { kv_len .max ().item ()} " )
587
- out = new_mha_layer (query , key , value , is_causal = False )
588
-
589
-
590
- ################################################################################
591
- # Fully masked rows no longer cause NaNs
592
- # --------------------------------------
593
- #
594
- # There has been a long standing issue with ``nn.MultiheadAttention`` and
595
- # ``scaled_dot_product_attention`` where if a row was fully masked out, the output
596
- # of the attention layer would be NaN. See `issue <https://github.com/pytorch/pytorch/issues/41508>`_.
597
- # This is because the softmax over an empty set is undefined.
598
- #
599
- # Thanks to `this PR <https://github.com/pytorch/pytorch/pull/133882>`_
600
- # this is no longer the case. Instead, fully masked rows in ``scaled_dot_product_attention``.
601
- # For cases where ``nn.MHA`` does not employ the "fast-path", this will also apply.
602
- #
603
- # Using a custom MHA layer with NJTs is strongly recommended over the
604
- # existing "fast-path" in ``nn.MultiheadAttention`` as NJT's ability to model raggedness
605
- # appropriately makes it possible to distinguish when there is an empty sequence.
606
-
607
-
608
- ################################################################################
609
- # ALiBi with NJT (FlexAttention + NJT)
610
- # ---------------------------------------------------------------------
611
- # NJT also composes with the ``FlexAttention`` module. This is a generalization
612
- # of the ``MultiheadAttention`` layer that allows for arbitrary modifications
613
- # to the attention score. The example below takes the ``alibi_mod`` from
614
- # attention gym and uses it with nested input tensors.
615
-
616
- from torch .nn .attention .flex_attention import flex_attention
617
-
618
- def generate_alibi_bias (H : int ):
619
- """Returns an alibi bias score_mod given the number of heads H
620
- Args:
621
- H: number of heads
622
- Returns:
623
- alibi_bias: alibi bias score_mod
624
- """
625
- def alibi_mod (score , b , h , q_idx , kv_idx ):
626
- scale = torch .exp2 (- ((h + 1 ) * 8.0 / H ))
627
- bias = (q_idx - kv_idx ) * scale
628
- return score + bias
629
- return alibi_mod
630
-
631
- query , key , value , _ = gen_batch (N , E_q , E_k , E_v , device )
632
- n_heads , D = 8 , E_q // 8
633
- alibi_score_mod = generate_alibi_bias (n_heads )
634
- query = (
635
- query .unflatten (- 1 , [n_heads , D ]).transpose (1 , 2 ).detach ().requires_grad_ ()
636
- )
637
- key = key .unflatten (- 1 , [n_heads , D ]).transpose (1 , 2 ).detach ().requires_grad_ ()
638
- value = (
639
- value .unflatten (- 1 , [n_heads , D ]).transpose (1 , 2 ).detach ().requires_grad_ ()
640
- )
641
- out_flex2 = flex_attention (query , key , value , score_mod = alibi_score_mod )
642
-
643
-
644
672
################################################################################
645
673
# Extended examples
646
674
# -----------------
0 commit comments