Skip to content

Commit 03dd8d6

Browse files
committed
address comments: improve API description
1 parent 605e750 commit 03dd8d6

File tree

1 file changed

+6
-5
lines changed

1 file changed

+6
-5
lines changed

prototype_source/context_parallel.rst

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -33,15 +33,16 @@ Two Ring Attention variants have been implemented: `the all-gather based pass-KV
3333
local query tensor chunk. Our modified all-gather based pass-KV algorithm concurrently all-gathers KV shards and computes attention output for the local query tensor chunk
3434
using local key and value tensor chunks, followed by a final computation of attention output for the local query tensor and remaining KV shards. This allows some degree of
3535
overlap between the attention computation and the all-gather collective.
36-
2. The all-to-all approach uses interleaved all-to-all collectives to ring shuffle KV shards to overlap the SDPA computation and the all-to-all communication
36+
2. The all-to-all approach uses interleaved all-to-all collectives to ring shuffle KV shards to overlap the SDPA (Scaled Dot Product Attention) computation and the all-to-all communication
3737
necessary for the next SDPA.
3838

3939
The Context Parallel APIs consist of two parts:
4040

41-
1. ``context_parallel()`` allows users to create a Python context where the SDPA function (``torch.nn.functional.scaled_dot_product_attention``)
42-
will be automatically replaced with Ring Attention. To shard Tensors along a dimension, simply pass the Tensors and their sharding dimensions to
43-
argument ``buffers`` and ``buffer_seq_dims`` respectively.
44-
2. ``set_rotate_method()`` allows users to choose between the all-gather based pass-KV approach and the all-to-all based pass-KV approach.
41+
1. ``context_parallel()`` allows users to create a Python context where the SDPA function (``torch.nn.functional.scaled_dot_product_attention``)
42+
will be automatically replaced with Ring Attention. To shard Tensors along a dimension, simply pass the Tensors and their sharding dimensions to
43+
argument ``buffers`` and ``buffer_seq_dims`` respectively. We recommend that users add tensors computing along the sequence dimension to ``buffers``
44+
and shard them along this dimension.
45+
2. ``set_rotate_method()`` allows users to choose between the all-gather based pass-KV approach and the all-to-all based pass-KV approach.
4546

4647

4748
Setup

0 commit comments

Comments
 (0)