Skip to content

Commit 09d6583

Browse files
wz337svekars
andauthored
update device mesh to include slicing (#2958)
* update device mesh to include slicing --------- Co-authored-by: Svetlana Karslioglu <svekars@meta.com>
1 parent 4857131 commit 09d6583

File tree

1 file changed

+20
-0
lines changed

1 file changed

+20
-0
lines changed

recipes_source/distributed_device_mesh.rst

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -148,6 +148,26 @@ Then, run the following `torch elastic/torchrun <https://pytorch.org/docs/stable
148148
149149
torchrun --nproc_per_node=8 hsdp.py
150150
151+
How to use DeviceMesh for your custom parallel solutions
152+
--------------------------------------------------------
153+
When working with large scale training, you might have more complex custom parallel training composition. For example, you may need to slice out submeshes for different parallelism solutions.
154+
DeviceMesh allows users to slice child mesh from the parent mesh and re-use the NCCL communicators already created when the parent mesh is initialized.
155+
156+
.. code-block:: python
157+
158+
from torch.distributed.device_mesh import init_device_mesh
159+
mesh_3d = init_device_mesh("cuda", (2, 2, 2), mesh_dim_names=("replicate", "shard", "tp"))
160+
161+
# Users can slice child meshes from the parent mesh.
162+
hsdp_mesh = mesh_3d["replicate", "shard"]
163+
tp_mesh = mesh_3d["tp"]
164+
165+
# Users can access the underlying process group thru `get_group` API.
166+
replicate_group = hsdp_mesh["replicate"].get_group()
167+
shard_group = hsdp_mesh["Shard"].get_group()
168+
tp_group = tp_mesh.get_group()
169+
170+
151171
Conclusion
152172
----------
153173
In conclusion, we have learned about :class:`DeviceMesh` and :func:`init_device_mesh`, as well as how

0 commit comments

Comments
 (0)