You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Copy file name to clipboardExpand all lines: prototype_source/context_parallel.rst
+6-5Lines changed: 6 additions & 5 deletions
Original file line number
Diff line number
Diff line change
@@ -33,15 +33,16 @@ Two Ring Attention variants have been implemented: `the all-gather based pass-KV
33
33
local query tensor chunk. Our modified all-gather based pass-KV algorithm concurrently all-gathers KV shards and computes attention output for the local query tensor chunk
34
34
using local key and value tensor chunks, followed by a final computation of attention output for the local query tensor and remaining KV shards. This allows some degree of
35
35
overlap between the attention computation and the all-gather collective.
36
-
2. The all-to-all approach uses interleaved all-to-all collectives to ring shuffle KV shards to overlap the SDPA computation and the all-to-all communication
36
+
2. The all-to-all approach uses interleaved all-to-all collectives to ring shuffle KV shards to overlap the SDPA (Scaled Dot Product Attention) computation and the all-to-all communication
37
37
necessary for the next SDPA.
38
38
39
39
The Context Parallel APIs consist of two parts:
40
40
41
-
1. ``context_parallel()`` allows users to create a Python context where the SDPA function (``torch.nn.functional.scaled_dot_product_attention``)
42
-
will be automatically replaced with Ring Attention. To shard Tensors along a dimension, simply pass the Tensors and their sharding dimensions to
43
-
argument ``buffers`` and ``buffer_seq_dims`` respectively.
44
-
2. ``set_rotate_method()`` allows users to choose between the all-gather based pass-KV approach and the all-to-all based pass-KV approach.
41
+
1. ``context_parallel()`` allows users to create a Python context where the SDPA function (``torch.nn.functional.scaled_dot_product_attention``)
42
+
will be automatically replaced with Ring Attention. To shard Tensors along a dimension, simply pass the Tensors and their sharding dimensions to
43
+
argument ``buffers`` and ``buffer_seq_dims`` respectively. We recommend that users add tensors computing along the sequence dimension to ``buffers``
44
+
and shard them along this dimension.
45
+
2. ``set_rotate_method()`` allows users to choose between the all-gather based pass-KV approach and the all-to-all based pass-KV approach.
0 commit comments