Skip to content

Commit af662e5

Browse files
committed
1. add to why DM is useful, 2. add get PG, 3. add note
1 parent 572fdd9 commit af662e5

File tree

1 file changed

+13
-2
lines changed

1 file changed

+13
-2
lines changed

recipes_source/distributed_device_mesh.rst

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,9 @@ Users can also easily manage the underlying process_groups/devices for multi-dim
3030

3131
Why DeviceMesh is Useful
3232
------------------------
33+
DeviceMesh is useful, when composability is requried. That is when your parallelism solutions require both communication across hosts and within host.
34+
The image above shows that we can create a 2D mesh that connects the devices within each host, and connects each device with its counterpart on another host in a homogenous setup.
35+
3336

3437
The following code snippet illustrates a 2D setup without :class:`DeviceMesh`. First, we need to manually calculate the shard group and replicate group. Then, we need to assign the correct shard and
3538
replicate group to each rank.
@@ -76,13 +79,21 @@ Then, run the following `torch elastic/torchrun <https://pytorch.org/docs/stable
7679
.. code-block:: python
7780
torchrun --nnodes=1 --nproc_per_node=8 --rdzv_id=100 --rdzv_endpoint=localhost:29400 2d_setup.py
7881
82+
Note
83+
84+
For simplicity of demonstration, we are simulating 2D parallel using only one node. Note that this code snippet can also be used when running on multi hosts setup.
7985

80-
With the help of :func:`init_device_mesh`, we can accomplish the above 2D setup in just two lines.
86+
With the help of :func:`init_device_mesh`, we can accomplish the above 2D setup in just two lines, and we can still
87+
access the underlying :class:`ProcessGroup` if needed.
8188

8289

8390
.. code-block:: python
8491
from torch.distributed.device_mesh import init_device_mesh
85-
device_mesh = init_device_mesh("cuda", (2, 4))
92+
mesh_2d = init_device_mesh("cuda", (2, 4), mesh_dim_names=("replicate", "shard"))
93+
94+
# Users can acess the undelying process group thru `get_group` API.
95+
replicate_group = mesh_2d.get_group(mesh_dim="replicate")
96+
shard_group = mesh_2d.get_group(mesh_dim="shard")
8697
8798
Let's create a file named ``2d_setup_with_device_mesh.py``.
8899
Then, run the following `torch elastic/torchrun <https://pytorch.org/docs/stable/elastic/quickstart.html>`__ command.

0 commit comments

Comments
 (0)