Skip to content

Commit 605e750

Browse files
committed
address comments: improve pass-KV description
1 parent a3556fa commit 605e750

File tree

1 file changed

+12
-7
lines changed

1 file changed

+12
-7
lines changed

prototype_source/context_parallel.rst

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -27,19 +27,21 @@ Context Parallel is an approach used in large language model training to reduce
2727
It breaks the constraint on input sequence length resulting from peak memory usage on storing activations in Transformer blocks.
2828

2929
The core of Context Parallel is Ring Attention, a novel parallel implementation of the Attention layer.
30-
Ring Attention shuffles the KV shards and calculates the partial attention scores,
31-
repeats until all KV shards have been used on each device.
32-
Two Ring Attention variants have been implemented: `pass-KV <https://arxiv.org/abs/2411.01783>`__ and `all-to-all <https://openreview.net/forum?id=WsRHpHH4s0>`__.
33-
The pass-KV approach all-gathers KV shards while performing the local SDPA (Scaled Dot Product Attention) then performs the rest when the communication completes.
34-
The all-to-all approach uses interleaved all-to-all collectives to ring shuffle KV shards to overlap the SDPA computation and the all-to-all communication
35-
necessary for the next SDPA.
30+
Ring Attention shuffles the KV shards and calculates the partial attention scores, repeats until all KV shards have been used on each device.
31+
Two Ring Attention variants have been implemented: `the all-gather based pass-KV <https://arxiv.org/abs/2407.21783>`__ and `the all-to-all based pass-KV <https://openreview.net/forum?id=WsRHpHH4s0>`__:
32+
1. The all-gather based pass-KV algorithm is used in Llama3 training, which initially performs an all-gather on the key and value tensors, followed by computing the attention output for the
33+
local query tensor chunk. Our modified all-gather based pass-KV algorithm concurrently all-gathers KV shards and computes attention output for the local query tensor chunk
34+
using local key and value tensor chunks, followed by a final computation of attention output for the local query tensor and remaining KV shards. This allows some degree of
35+
overlap between the attention computation and the all-gather collective.
36+
2. The all-to-all approach uses interleaved all-to-all collectives to ring shuffle KV shards to overlap the SDPA computation and the all-to-all communication
37+
necessary for the next SDPA.
3638

3739
The Context Parallel APIs consist of two parts:
3840

3941
1. ``context_parallel()`` allows users to create a Python context where the SDPA function (``torch.nn.functional.scaled_dot_product_attention``)
4042
will be automatically replaced with Ring Attention. To shard Tensors along a dimension, simply pass the Tensors and their sharding dimensions to
4143
argument ``buffers`` and ``buffer_seq_dims`` respectively.
42-
2. ``set_rotate_method()`` allows users to choose between the pass-KV approach and the all-to-all approach.
44+
2. ``set_rotate_method()`` allows users to choose between the all-gather based pass-KV approach and the all-to-all based pass-KV approach.
4345

4446

4547
Setup
@@ -213,6 +215,9 @@ You can choose the desired shards rotation approach in Ring Attention by using `
213215
cp_out = F.scaled_dot_product_attention(*cp_qkv, is_causal=True)
214216
215217
218+
The default rotation approach is the all-gather based pass-KV.
219+
220+
216221
Conclusion
217222
----------
218223

0 commit comments

Comments
 (0)