Skip to content

Commit a3556fa

Browse files
committed
address review comments
1 parent ff38861 commit a3556fa

File tree

1 file changed

+10
-10
lines changed

1 file changed

+10
-10
lines changed

prototype_source/context_parallel.rst

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -3,15 +3,15 @@ Introduction to Context Parallel
33
**Authors**: `Xilun Wu <https://github.com/XilunWu>`_, `Chien-Chin Huang <https://github.com/fegin>`__
44

55
.. note::
6-
|edit| View and edit this tutorial in `github <https://github.com/pytorch/tutorials/blob/main/prototype_source/context_parallel.rst>`__.
6+
|edit| View and edit this tutorial in `GitHub <https://github.com/pytorch/tutorials/blob/main/prototype_source/context_parallel.rst>`__.
77

88
.. grid:: 2
99

1010
.. grid-item-card:: :octicon:`mortar-board;1em;` What you will learn
1111
:class-card: card-prerequisites
1212

1313
* `Context Parallel APIs <https://pytorch.org/docs/stable/distributed.tensor.html#torch.distributed.tensor.experimental.context_parallel>`__
14-
* `1M sequence training in torchtitan with Context Parallel <https://discuss.pytorch.org/t/distributed-w-torchtitan-breaking-barriers-training-long-context-llms-with-1m-sequence-length-in-pytorch-using-context-parallel/215082>`__
14+
* `1M sequence training in TorchTitan with Context Parallel <https://discuss.pytorch.org/t/distributed-w-torchtitan-breaking-barriers-training-long-context-llms-with-1m-sequence-length-in-pytorch-using-context-parallel/215082>`__
1515

1616

1717
.. grid-item-card:: :octicon:`list-unordered;1em;` Prerequisites
@@ -29,16 +29,16 @@ It breaks the constraint on input sequence length resulting from peak memory usa
2929
The core of Context Parallel is Ring Attention, a novel parallel implementation of the Attention layer.
3030
Ring Attention shuffles the KV shards and calculates the partial attention scores,
3131
repeats until all KV shards have been used on each device.
32-
We implemented two Ring Attention variants: `pass-KV <https://arxiv.org/abs/2411.01783>`__ and `all-to-all <https://openreview.net/forum?id=WsRHpHH4s0>`__.
32+
Two Ring Attention variants have been implemented: `pass-KV <https://arxiv.org/abs/2411.01783>`__ and `all-to-all <https://openreview.net/forum?id=WsRHpHH4s0>`__.
3333
The pass-KV approach all-gathers KV shards while performing the local SDPA (Scaled Dot Product Attention) then performs the rest when the communication completes.
3434
The all-to-all approach uses interleaved all-to-all collectives to ring shuffle KV shards to overlap the SDPA computation and the all-to-all communication
3535
necessary for the next SDPA.
3636

3737
The Context Parallel APIs consist of two parts:
3838

3939
1. ``context_parallel()`` allows users to create a Python context where the SDPA function (``torch.nn.functional.scaled_dot_product_attention``)
40-
will be automatically replaced with Ring Attention. To shard Tensors along a dimension, simply pass the Tensors and their sharding dimensions to
41-
argument ``buffers`` and ``buffer_seq_dims`` respectively.
40+
will be automatically replaced with Ring Attention. To shard Tensors along a dimension, simply pass the Tensors and their sharding dimensions to
41+
argument ``buffers`` and ``buffer_seq_dims`` respectively.
4242
2. ``set_rotate_method()`` allows users to choose between the pass-KV approach and the all-to-all approach.
4343

4444

@@ -157,17 +157,17 @@ shard to input and distribute the computation across ranks:
157157
158158
with sdpa_kernel(backend):
159159
# This `context_parallel()` performs two actions:
160-
# 1. shard the tensor objects in `buffers` in-place along the dimension
160+
# 1. Shard the tensor objects in `buffers` in-place along the dimension
161161
# specified in `buffer_seq_dims`, the tensors in `buffers` and their
162162
# sharding dims in `buffer_seq_dims` are organized in the same order.
163-
# 2. replace the execution of `F.scaled_dot_product_attention` with a
163+
# 2. Replace the execution of `F.scaled_dot_product_attention` with a
164164
# context-paralleled-enabled Ring Attention.
165165
with context_parallel(
166166
device_mesh, buffers=tuple(cp_qkv), buffer_seq_dims=(2, 2, 2)
167167
):
168168
cp_out = F.scaled_dot_product_attention(*cp_qkv, is_causal=True)
169169
170-
# the output `cp_out` is still sharded in the same way as QKV
170+
# The output `cp_out` is still sharded in the same way as QKV
171171
# the `context_parallel_unshard` API allows users to easily
172172
# unshard to gain the full tensor.
173173
(cp_out,) = context_parallel_unshard(device_mesh, [cp_out], [2])
@@ -216,6 +216,6 @@ You can choose the desired shards rotation approach in Ring Attention by using `
216216
Conclusion
217217
----------
218218

219-
In this tutorial, have learned how to parallelize the SDPA computation along the sequence dimension easily with our Context Parallel APIs. For
220-
design and implementation details, performance analysis, and an end-to-end training example in `torchtitan <https://github.com/pytorch/torchtitan>`__,
219+
In this tutorial, we have learned how to parallelize the SDPA computation along the sequence dimension easily with our Context Parallel APIs. For
220+
design and implementation details, performance analysis, and an end-to-end training example in `TorchTitan <https://github.com/pytorch/torchtitan>`__,
221221
see our post on `PyTorch native long-context training <https://discuss.pytorch.org/t/distributed-w-torchtitan-breaking-barriers-training-long-context-llms-with-1m-sequence-length-in-pytorch-using-context-parallel/215082>`__.

0 commit comments

Comments
 (0)