Skip to content

Commit 6a90666

Browse files
committed
add device mesh recipe
1 parent e77fc43 commit 6a90666

File tree

3 files changed

+142
-0
lines changed

3 files changed

+142
-0
lines changed
37.6 KB
Loading
Lines changed: 135 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,135 @@
1+
Getting Started with DeviceMesh
2+
=====================================================
3+
4+
**Author**: `Iris Zhang <https://github.com/wz337>`__, `Wanchao Liang <https://github.com/wanchaol>`__
5+
6+
.. note::
7+
|edit| View and edit this tutorial in `github <https://github.com/pytorch/tutorials/blob/main/recipes_source/distributed_device_mesh.rst>`__.
8+
9+
Prerequisites:
10+
11+
- `Distributed Communication Package - torch.distributed <https://pytorch.org/docs/stable/distributed.html>`__
12+
13+
.. Setting up nccl communicators for distributed communication during distributed training could be challenging. For workloads where users need to compose different parallelisms,
14+
.. users would need to manually set up and manage nccl communicators(for example, :class:`ProcessGroup`) for each parallelism solutions. This is fairly complicated and error-proned.
15+
.. :class:`DeviceMesh` can help make this process much easier.
16+
17+
What is DeviceMesh
18+
------------------
19+
.. :class:`DeviceMesh` is a higher level abstraction that manages :class:`ProcessGroup`. It allows users to easily
20+
.. create inter-node and intra-node process groups without worrying about how to set up ranks correctly for different sub process groups.
21+
.. Users can also easily manage the underlying process_groups/devices for multi-dimensional parallelism via :class:`DeviceMesh`.
22+
23+
.. figure:: /_static/img/distributed/device_mesh.png
24+
:width: 100%
25+
:align: center
26+
:alt: PyTorch DeviceMesh
27+
28+
Why DeviceMesh is Useful
29+
------------------------
30+
31+
.. Below is the code snippet for a 2D setup without :class:`DeviceMesh`. First, we need to manually calculate shard group and replicate group. Then, we need to assign the correct shard and
32+
.. replicate group to each rank.
33+
34+
.. code-block:: python
35+
import os
36+
37+
import torch
38+
import torch.distributed as dist
39+
40+
# Understand world topology
41+
rank = int(os.environ["RANK"])
42+
world_size = int(os.environ["WORLD_SIZE"])
43+
print(f"Running example on {rank=} in a world with {world_size=}")
44+
45+
# Create process groups to manage 2-D like parallel pattern
46+
dist.init_process_group("nccl")
47+
48+
# Create shard groups (e.g. (0, 1, 2, 3), (4, 5, 6, 7))
49+
# and assign the correct shard group to each rank
50+
num_node_devices = torch.cuda.device_count()
51+
shard_rank_lists = list(range(0, num_node_devices // 2)), list(range(num_node_devices // 2, num_node_devices))
52+
shard_groups = (
53+
dist.new_group(shard_rank_lists[0]),
54+
dist.new_group(shard_rank_lists[1]),
55+
)
56+
current_shard_group = (
57+
shard_groups[0] if rank in shard_rank_lists[0] else shard_groups[1]
58+
)
59+
60+
# Create replicate groups (e.g. (0, 4), (1, 5), (2, 6), (3, 7))
61+
# and assign the correct replicate group to each rank
62+
current_replicate_group = None
63+
shard_factor = len(shard_rank_lists[0])
64+
for i in range(num_node_devices // 2):
65+
replicate_group_ranks = list(range(i, num_node_devices, shard_factor))
66+
replicate_group = dist.new_group(replicate_group_ranks)
67+
if rank in replicate_group_ranks:
68+
current_replicate_group = replicate_group
69+
70+
.. To run the above code snippet, we can leverage PyTorch Elastic. Let's create a file named ``2d_setup.py``.
71+
.. Then, run the following `torch elastic/torchrun <https://pytorch.org/docs/stable/elastic/quickstart.html>`__ command.
72+
73+
.. code-block:: python
74+
torchrun --nnodes=1 --nproc_per_node=8 --rdzv_id=100 --rdzv_endpoint=localhost:29400 2d_setup.py
75+
76+
77+
.. With the help of :func:`init_device_mesh`, we can accomplish the above 2D setup in just 2 lines.
78+
79+
80+
.. code-block:: python
81+
from torch.distributed.device_mesh import init_device_mesh
82+
device_mesh = init_device_mesh("cuda", (2, 4))
83+
84+
.. Let's create a file named ``2d_setup_with_device_mesh.py``.
85+
.. Then, run the following `torch elastic/torchrun <https://pytorch.org/docs/stable/elastic/quickstart.html>`__ command.
86+
87+
.. code-block:: python
88+
torchrun --nnodes=1 --nproc_per_node=8 --rdzv_id=100 --rdzv_endpoint=localhost:29400 2d_setup_with_device_mesh.py
89+
90+
91+
How to use DeviceMesh with HSDP
92+
-------------------------------
93+
94+
Hybrid Sharding(HSDP)
95+
Let's see an example of how DeviceMesh can assist with applying Hybrid Sharding strategy to your model.
96+
97+
.. code-block:: python
98+
import torch
99+
import torch.nn as nn
100+
101+
from torch.distributed.device_mesh import init_device_mesh
102+
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP, ShardingStrategy
103+
104+
105+
class ToyModel(nn.Module):
106+
def __init__(self):
107+
super(ToyModel, self).__init__()
108+
self.net1 = nn.Linear(10, 10)
109+
self.relu = nn.ReLU()
110+
self.net2 = nn.Linear(10, 5)
111+
112+
def forward(self, x):
113+
return self.net2(self.relu(self.net1(x)))
114+
115+
116+
# HSDP: MeshShape(2, 4)
117+
mesh_2d = init_device_mesh("cuda", (2, 4))
118+
model = FSDP(
119+
ToyModel(), device_mesh=mesh_2d, sharding_strategy=ShardingStrategy.HYBRID_SHARD
120+
)
121+
122+
.. Let's create a file named ``hsdp.py``.
123+
.. Then, run the following `torch elastic/torchrun <https://pytorch.org/docs/stable/elastic/quickstart.html>`__ command.
124+
125+
.. code-block:: python
126+
torchrun --nnodes=1 --nproc_per_node=8 --rdzv_id=100 --rdzv_endpoint=localhost:29400 hsdp.py
127+
128+
Conclusion
129+
----------
130+
.. In conclusion, we have learned about :class:`DeviceMesh` and :func:`init_device_mesh`, as well as how
131+
.. they can be used to describe the layout of devices across the cluster.
132+
133+
.. For more information, please see the following:
134+
135+
- `2D parallel combining Tensor/Sequance Parallel with FSDP <https://github.com/pytorch/examples/blob/main/distributed/tensor_parallelism/fsdp_tp_example.py>`__

recipes_source/recipes_index.rst

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -324,6 +324,13 @@ Recipes are bite-sized, actionable examples of how to use specific PyTorch featu
324324
:link: ../recipes/DCP_tutorial.html
325325
:tags: Distributed-Training
326326

327+
.. customcarditem::
328+
:header: Getting Started with DeviceMesh
329+
:card_description: Learn how to use DeviceMesh
330+
:image: ../_static/img/thumbnails/cropped/profiler.png
331+
:link: ../recipes/distributed_device_mesh.html
332+
:tags: Distributed-Training
333+
327334
.. TorchServe
328335
329336
.. customcarditem::

0 commit comments

Comments
 (0)