You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Copy file name to clipboardExpand all lines: prototype_source/context_parallel.rst
+12-7Lines changed: 12 additions & 7 deletions
Original file line number
Diff line number
Diff line change
@@ -27,19 +27,21 @@ Context Parallel is an approach used in large language model training to reduce
27
27
It breaks the constraint on input sequence length resulting from peak memory usage on storing activations in Transformer blocks.
28
28
29
29
The core of Context Parallel is Ring Attention, a novel parallel implementation of the Attention layer.
30
-
Ring Attention shuffles the KV shards and calculates the partial attention scores,
31
-
repeats until all KV shards have been used on each device.
32
-
Two Ring Attention variants have been implemented: `pass-KV <https://arxiv.org/abs/2411.01783>`__ and `all-to-all <https://openreview.net/forum?id=WsRHpHH4s0>`__.
33
-
The pass-KV approach all-gathers KV shards while performing the local SDPA (Scaled Dot Product Attention) then performs the rest when the communication completes.
34
-
The all-to-all approach uses interleaved all-to-all collectives to ring shuffle KV shards to overlap the SDPA computation and the all-to-all communication
35
-
necessary for the next SDPA.
30
+
Ring Attention shuffles the KV shards and calculates the partial attention scores, repeats until all KV shards have been used on each device.
31
+
Two Ring Attention variants have been implemented: `the all-gather based pass-KV <https://arxiv.org/abs/2407.21783>`__ and `the all-to-all based pass-KV <https://openreview.net/forum?id=WsRHpHH4s0>`__:
32
+
1. The all-gather based pass-KV algorithm is used in Llama3 training, which initially performs an all-gather on the key and value tensors, followed by computing the attention output for the
33
+
local query tensor chunk. Our modified all-gather based pass-KV algorithm concurrently all-gathers KV shards and computes attention output for the local query tensor chunk
34
+
using local key and value tensor chunks, followed by a final computation of attention output for the local query tensor and remaining KV shards. This allows some degree of
35
+
overlap between the attention computation and the all-gather collective.
36
+
2. The all-to-all approach uses interleaved all-to-all collectives to ring shuffle KV shards to overlap the SDPA computation and the all-to-all communication
37
+
necessary for the next SDPA.
36
38
37
39
The Context Parallel APIs consist of two parts:
38
40
39
41
1. ``context_parallel()`` allows users to create a Python context where the SDPA function (``torch.nn.functional.scaled_dot_product_attention``)
40
42
will be automatically replaced with Ring Attention. To shard Tensors along a dimension, simply pass the Tensors and their sharding dimensions to
41
43
argument ``buffers`` and ``buffer_seq_dims`` respectively.
42
-
2. ``set_rotate_method()`` allows users to choose between the pass-KV approach and the all-to-all approach.
44
+
2. ``set_rotate_method()`` allows users to choose between the all-gather based pass-KV approach and the all-to-all based pass-KV approach.
43
45
44
46
45
47
Setup
@@ -213,6 +215,9 @@ You can choose the desired shards rotation approach in Ring Attention by using `
0 commit comments