You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
* `1M sequence training in torchtitan with Context Parallel <https://discuss.pytorch.org/t/distributed-w-torchtitan-breaking-barriers-training-long-context-llms-with-1m-sequence-length-in-pytorch-using-context-parallel/215082>`__
14
+
* `1M sequence training in TorchTitan with Context Parallel <https://discuss.pytorch.org/t/distributed-w-torchtitan-breaking-barriers-training-long-context-llms-with-1m-sequence-length-in-pytorch-using-context-parallel/215082>`__
@@ -29,16 +29,16 @@ It breaks the constraint on input sequence length resulting from peak memory usa
29
29
The core of Context Parallel is Ring Attention, a novel parallel implementation of the Attention layer.
30
30
Ring Attention shuffles the KV shards and calculates the partial attention scores,
31
31
repeats until all KV shards have been used on each device.
32
-
We implemented two Ring Attention variants: `pass-KV <https://arxiv.org/abs/2411.01783>`__ and `all-to-all <https://openreview.net/forum?id=WsRHpHH4s0>`__.
32
+
Two Ring Attention variants have been implemented: `pass-KV <https://arxiv.org/abs/2411.01783>`__ and `all-to-all <https://openreview.net/forum?id=WsRHpHH4s0>`__.
33
33
The pass-KV approach all-gathers KV shards while performing the local SDPA (Scaled Dot Product Attention) then performs the rest when the communication completes.
34
34
The all-to-all approach uses interleaved all-to-all collectives to ring shuffle KV shards to overlap the SDPA computation and the all-to-all communication
35
35
necessary for the next SDPA.
36
36
37
37
The Context Parallel APIs consist of two parts:
38
38
39
39
1. ``context_parallel()`` allows users to create a Python context where the SDPA function (``torch.nn.functional.scaled_dot_product_attention``)
40
-
will be automatically replaced with Ring Attention. To shard Tensors along a dimension, simply pass the Tensors and their sharding dimensions to
41
-
argument ``buffers`` and ``buffer_seq_dims`` respectively.
40
+
will be automatically replaced with Ring Attention. To shard Tensors along a dimension, simply pass the Tensors and their sharding dimensions to
41
+
argument ``buffers`` and ``buffer_seq_dims`` respectively.
42
42
2. ``set_rotate_method()`` allows users to choose between the pass-KV approach and the all-to-all approach.
43
43
44
44
@@ -157,17 +157,17 @@ shard to input and distribute the computation across ranks:
157
157
158
158
with sdpa_kernel(backend):
159
159
# This `context_parallel()` performs two actions:
160
-
# 1. shard the tensor objects in `buffers` in-place along the dimension
160
+
# 1. Shard the tensor objects in `buffers` in-place along the dimension
161
161
# specified in `buffer_seq_dims`, the tensors in `buffers` and their
162
162
# sharding dims in `buffer_seq_dims` are organized in the same order.
163
-
# 2. replace the execution of `F.scaled_dot_product_attention` with a
163
+
# 2. Replace the execution of `F.scaled_dot_product_attention` with a
@@ -216,6 +216,6 @@ You can choose the desired shards rotation approach in Ring Attention by using `
216
216
Conclusion
217
217
----------
218
218
219
-
In this tutorial, have learned how to parallelize the SDPA computation along the sequence dimension easily with our Context Parallel APIs. For
220
-
design and implementation details, performance analysis, and an end-to-end training example in `torchtitan<https://github.com/pytorch/torchtitan>`__,
219
+
In this tutorial, we have learned how to parallelize the SDPA computation along the sequence dimension easily with our Context Parallel APIs. For
220
+
design and implementation details, performance analysis, and an end-to-end training example in `TorchTitan<https://github.com/pytorch/torchtitan>`__,
221
221
see our post on `PyTorch native long-context training <https://discuss.pytorch.org/t/distributed-w-torchtitan-breaking-barriers-training-long-context-llms-with-1m-sequence-length-in-pytorch-using-context-parallel/215082>`__.
0 commit comments