Skip to content

Commit fd43ef4

Browse files
wz337svekars
andcommitted
Apply suggestions from code review
Co-authored-by: Svetlana Karslioglu <svekars@meta.com>
1 parent 9a07238 commit fd43ef4

File tree

1 file changed

+83
-77
lines changed

1 file changed

+83
-77
lines changed

recipes_source/distributed_device_mesh.rst

Lines changed: 83 additions & 77 deletions
Original file line numberDiff line numberDiff line change
@@ -9,16 +9,19 @@ Getting Started with DeviceMesh
99
Prerequisites:
1010

1111
- `Distributed Communication Package - torch.distributed <https://pytorch.org/docs/stable/distributed.html>`__
12+
- Python
13+
- PyTorch 2.2
1214

13-
.. Setting up the NVIDIA Collective Communication Library (NCCL) communicators for distributed communication during distributed training can pose a significant challenge. For workloads where users need to compose different parallelisms,
14-
.. users would need to manually set up and manage nccl communicators(for example, :class:`ProcessGroup`) for each parallelism solutions. This is fairly complicated and error-proned.
15-
.. :class:`DeviceMesh` can help make this process much easier.
15+
16+
Setting up the NVIDIA Collective Communication Library (NCCL) communicators for distributed communication during distributed training can pose a significant challenge. For workloads where users need to compose different parallelisms,
17+
users would need to manually set up and manage NCCL communicators (for example, :class:`ProcessGroup`) for each parallelism solutions. This process could be complicated and susceptible to errors.
18+
:class:`DeviceMesh` can simplify this process, making it more manageable and less prone to errors.
1619

1720
What is DeviceMesh
1821
------------------
19-
.. :class:`DeviceMesh` is a higher level abstraction that manages :class:`ProcessGroup`. It allows users to easily
20-
.. create inter-node and intra-node process groups without worrying about how to set up ranks correctly for different sub process groups.
21-
.. Users can also easily manage the underlying process_groups/devices for multi-dimensional parallelism via :class:`DeviceMesh`.
22+
:class:`DeviceMesh` is a higher level abstraction that manages :class:`ProcessGroup`. It allows users to effortlessly
23+
create inter-node and intra-node process groups without worrying about how to set up ranks correctly for different sub process groups.
24+
Users can also easily manage the underlying process_groups/devices for multi-dimensional parallelism via :class:`DeviceMesh`.
2225

2326
.. figure:: /_static/img/distributed/device_mesh.png
2427
:width: 100%
@@ -28,108 +31,111 @@ What is DeviceMesh
2831
Why DeviceMesh is Useful
2932
------------------------
3033

31-
.. Below is the code snippet for a 2D setup without :class:`DeviceMesh`. First, we need to manually calculate shard group and replicate group. Then, we need to assign the correct shard and
32-
.. replicate group to each rank.
34+
The following code snippet illustrates a 2D setup without :class:`DeviceMesh`. First, we need to manually calculate the shard group and replicate group. Then, we need to assign the correct shard and
35+
replicate group to each rank.
3336

3437
.. code-block:: python
35-
import os
36-
37-
import torch
38-
import torch.distributed as dist
39-
40-
# Understand world topology
41-
rank = int(os.environ["RANK"])
42-
world_size = int(os.environ["WORLD_SIZE"])
43-
print(f"Running example on {rank=} in a world with {world_size=}")
44-
45-
# Create process groups to manage 2-D like parallel pattern
46-
dist.init_process_group("nccl")
47-
48-
# Create shard groups (e.g. (0, 1, 2, 3), (4, 5, 6, 7))
49-
# and assign the correct shard group to each rank
50-
num_node_devices = torch.cuda.device_count()
51-
shard_rank_lists = list(range(0, num_node_devices // 2)), list(range(num_node_devices // 2, num_node_devices))
52-
shard_groups = (
53-
dist.new_group(shard_rank_lists[0]),
54-
dist.new_group(shard_rank_lists[1]),
55-
)
56-
current_shard_group = (
57-
shard_groups[0] if rank in shard_rank_lists[0] else shard_groups[1]
58-
)
59-
60-
# Create replicate groups (e.g. (0, 4), (1, 5), (2, 6), (3, 7))
61-
# and assign the correct replicate group to each rank
62-
current_replicate_group = None
63-
shard_factor = len(shard_rank_lists[0])
64-
for i in range(num_node_devices // 2):
65-
replicate_group_ranks = list(range(i, num_node_devices, shard_factor))
66-
replicate_group = dist.new_group(replicate_group_ranks)
67-
if rank in replicate_group_ranks:
68-
current_replicate_group = replicate_group
69-
70-
.. To run the above code snippet, we can leverage PyTorch Elastic. Let's create a file named ``2d_setup.py``.
71-
.. Then, run the following `torch elastic/torchrun <https://pytorch.org/docs/stable/elastic/quickstart.html>`__ command.
38+
import os
39+
40+
import torch
41+
import torch.distributed as dist
42+
43+
# Understand world topology
44+
rank = int(os.environ["RANK"])
45+
world_size = int(os.environ["WORLD_SIZE"])
46+
print(f"Running example on {rank=} in a world with {world_size=}")
47+
48+
# Create process groups to manage 2-D like parallel pattern
49+
dist.init_process_group("nccl")
50+
51+
# Create shard groups (e.g. (0, 1, 2, 3), (4, 5, 6, 7))
52+
# and assign the correct shard group to each rank
53+
num_node_devices = torch.cuda.device_count()
54+
shard_rank_lists = list(range(0, num_node_devices // 2)), list(range(num_node_devices // 2, num_node_devices))
55+
shard_groups = (
56+
dist.new_group(shard_rank_lists[0]),
57+
dist.new_group(shard_rank_lists[1]),
58+
)
59+
current_shard_group = (
60+
shard_groups[0] if rank in shard_rank_lists[0] else shard_groups[1]
61+
)
62+
63+
# Create replicate groups (for example, (0, 4), (1, 5), (2, 6), (3, 7))
64+
# and assign the correct replicate group to each rank
65+
current_replicate_group = None
66+
shard_factor = len(shard_rank_lists[0])
67+
for i in range(num_node_devices // 2):
68+
replicate_group_ranks = list(range(i, num_node_devices, shard_factor))
69+
replicate_group = dist.new_group(replicate_group_ranks)
70+
if rank in replicate_group_ranks:
71+
current_replicate_group = replicate_group
72+
73+
To run the above code snippet, we can leverage PyTorch Elastic. Let's create a file named ``2d_setup.py``.
74+
Then, run the following `torch elastic/torchrun <https://pytorch.org/docs/stable/elastic/quickstart.html>`__ command.
7275

7376
.. code-block:: python
74-
torchrun --nnodes=1 --nproc_per_node=8 --rdzv_id=100 --rdzv_endpoint=localhost:29400 2d_setup.py
77+
torchrun --nnodes=1 --nproc_per_node=8 --rdzv_id=100 --rdzv_endpoint=localhost:29400 2d_setup.py
7578
7679
77-
.. With the help of :func:`init_device_mesh`, we can accomplish the above 2D setup in just 2 lines.
80+
With the help of :func:`init_device_mesh`, we can accomplish the above 2D setup in just two lines.
7881

7982

8083
.. code-block:: python
81-
from torch.distributed.device_mesh import init_device_mesh
82-
device_mesh = init_device_mesh("cuda", (2, 4))
84+
from torch.distributed.device_mesh import init_device_mesh
85+
device_mesh = init_device_mesh("cuda", (2, 4))
8386
84-
.. Let's create a file named ``2d_setup_with_device_mesh.py``.
85-
.. Then, run the following `torch elastic/torchrun <https://pytorch.org/docs/stable/elastic/quickstart.html>`__ command.
87+
Let's create a file named ``2d_setup_with_device_mesh.py``.
88+
Then, run the following `torch elastic/torchrun <https://pytorch.org/docs/stable/elastic/quickstart.html>`__ command.
8689

8790
.. code-block:: python
88-
torchrun --nnodes=1 --nproc_per_node=8 --rdzv_id=100 --rdzv_endpoint=localhost:29400 2d_setup_with_device_mesh.py
91+
torchrun --nnodes=1 --nproc_per_node=8 --rdzv_id=100 --rdzv_endpoint=localhost:29400 2d_setup_with_device_mesh.py
8992
9093
9194
How to use DeviceMesh with HSDP
9295
-------------------------------
9396

94-
Hybrid Sharding(HSDP)
95-
Let's see an example of how DeviceMesh can assist with applying Hybrid Sharding strategy to your model.
97+
Hybrid Sharding Data Parallel(HSDP) is 2D strategy to perform FSDP within a host and DDP across hosts.
98+
99+
Let's see an example of how DeviceMesh can assist with applying HSDP to your model. With DeviceMesh,
100+
users would not need to manually create and manage shard group and replicate group.
96101

97102
.. code-block:: python
98-
import torch
99-
import torch.nn as nn
103+
import torch
104+
import torch.nn as nn
100105
101-
from torch.distributed.device_mesh import init_device_mesh
102-
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP, ShardingStrategy
106+
from torch.distributed.device_mesh import init_device_mesh
107+
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP, ShardingStrategy
103108
104109
105-
class ToyModel(nn.Module):
106-
def __init__(self):
107-
super(ToyModel, self).__init__()
108-
self.net1 = nn.Linear(10, 10)
109-
self.relu = nn.ReLU()
110-
self.net2 = nn.Linear(10, 5)
110+
class ToyModel(nn.Module):
111+
def __init__(self):
112+
super(ToyModel, self).__init__()
113+
self.net1 = nn.Linear(10, 10)
114+
self.relu = nn.ReLU()
115+
self.net2 = nn.Linear(10, 5)
111116
112-
def forward(self, x):
113-
return self.net2(self.relu(self.net1(x)))
117+
def forward(self, x):
118+
return self.net2(self.relu(self.net1(x)))
114119
115120
116-
# HSDP: MeshShape(2, 4)
117-
mesh_2d = init_device_mesh("cuda", (2, 4))
118-
model = FSDP(
119-
ToyModel(), device_mesh=mesh_2d, sharding_strategy=ShardingStrategy.HYBRID_SHARD
120-
)
121+
# HSDP: MeshShape(2, 4)
122+
mesh_2d = init_device_mesh("cuda", (2, 4))
123+
model = FSDP(
124+
ToyModel(), device_mesh=mesh_2d, sharding_strategy=ShardingStrategy.HYBRID_SHARD
125+
)
121126
122-
.. Let's create a file named ``hsdp.py``.
123-
.. Then, run the following `torch elastic/torchrun <https://pytorch.org/docs/stable/elastic/quickstart.html>`__ command.
127+
Let's create a file named ``hsdp.py``.
128+
Then, run the following `torch elastic/torchrun <https://pytorch.org/docs/stable/elastic/quickstart.html>`__ command.
124129

125130
.. code-block:: python
126-
torchrun --nnodes=1 --nproc_per_node=8 --rdzv_id=100 --rdzv_endpoint=localhost:29400 hsdp.py
131+
torchrun --nnodes=1 --nproc_per_node=8 --rdzv_id=100 --rdzv_endpoint=localhost:29400 hsdp.py
127132
128133
Conclusion
129134
----------
130-
.. In conclusion, we have learned about :class:`DeviceMesh` and :func:`init_device_mesh`, as well as how
131-
.. they can be used to describe the layout of devices across the cluster.
135+
In conclusion, we have learned about :class:`DeviceMesh` and :func:`init_device_mesh`, as well as how
136+
they can be used to describe the layout of devices across the cluster.
132137

133-
.. For more information, please see the following:
138+
For more information, please see the following:
134139

135140
- `2D parallel combining Tensor/Sequance Parallel with FSDP <https://github.com/pytorch/examples/blob/main/distributed/tensor_parallelism/fsdp_tp_example.py>`__
141+
- `Composable PyTorch Distributed with PT2 <chrome-extension://efaidnbmnnnibpcajpcglclefindmkaj/https://static.sched.com/hosted_files/pytorch2023/d1/%5BPTC%2023%5D%20Composable%20PyTorch%20Distributed%20with%20PT2.pdf>`__

0 commit comments

Comments
 (0)