Skip to content

Commit bcaa9f6

Browse files
wz337svekarswanchaol
authored
[DeviceMesh] Add device mesh recipe (#2718)
* add DeviceMesh recipe --------- Co-authored-by: Svetlana Karslioglu <svekars@meta.com> Co-authored-by: Wanchao <wanchaol@users.noreply.github.com>
1 parent eec8d56 commit bcaa9f6

File tree

4 files changed

+185
-1
lines changed

4 files changed

+185
-1
lines changed
37.6 KB
Loading

distributed/home.rst

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ PyTorch with each method having their advantages in certain use cases:
1313

1414
* `DistributedDataParallel (DDP) <#learn-ddp>`__
1515
* `Fully Sharded Data Parallel (FSDP) <#learn-fsdp>`__
16+
* `Device Mesh <#device-mesh>`__
1617
* `Remote Procedure Call (RPC) distributed training <#learn-rpc>`__
1718
* `Custom Extensions <#custom-extensions>`__
1819

@@ -51,7 +52,7 @@ Learn DDP
5152
:link: https://pytorch.org/tutorials/advanced/generic_join.html?utm_source=distr_landing&utm_medium=generic_join
5253
:link-type: url
5354

54-
This tutorial describes the Join context manager and
55+
This tutorial describes the Join context manager and
5556
demonstrates it's use with DistributedData Parallel.
5657
+++
5758
:octicon:`code;1em` Code
@@ -83,6 +84,23 @@ Learn FSDP
8384
+++
8485
:octicon:`code;1em` Code
8586

87+
.. _device-mesh:
88+
89+
Learn DeviceMesh
90+
----------------
91+
92+
.. grid:: 3
93+
94+
.. grid-item-card:: :octicon:`file-code;1em`
95+
Getting Started with DeviceMesh
96+
:link: https://pytorch.org/tutorials/recipes/distributed_device_mesh.html?highlight=devicemesh
97+
:link-type: url
98+
99+
In this tutorial you will learn about `DeviceMesh`
100+
and how it can help with distributed training.
101+
+++
102+
:octicon:`code;1em` Code
103+
86104
.. _learn-rpc:
87105

88106
Learn RPC
Lines changed: 159 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,159 @@
1+
Getting Started with DeviceMesh
2+
=====================================================
3+
4+
**Author**: `Iris Zhang <https://github.com/wz337>`__, `Wanchao Liang <https://github.com/wanchaol>`__
5+
6+
.. note::
7+
|edit| View and edit this tutorial in `github <https://github.com/pytorch/tutorials/blob/main/recipes_source/distributed_device_mesh.rst>`__.
8+
9+
Prerequisites:
10+
11+
- `Distributed Communication Package - torch.distributed <https://pytorch.org/docs/stable/distributed.html>`__
12+
- Python 3.8 - 3.11
13+
- PyTorch 2.2
14+
15+
16+
Setting up distributed communicators, i.e. NVIDIA Collective Communication Library (NCCL) communicators, for distributed training can pose a significant challenge. For workloads where users need to compose different parallelisms,
17+
users would need to manually set up and manage NCCL communicators (for example, :class:`ProcessGroup`) for each parallelism solutions. This process could be complicated and susceptible to errors.
18+
:class:`DeviceMesh` can simplify this process, making it more manageable and less prone to errors.
19+
20+
What is DeviceMesh
21+
------------------
22+
:class:`DeviceMesh` is a higher level abstraction that manages :class:`ProcessGroup`. It allows users to effortlessly
23+
create inter-node and intra-node process groups without worrying about how to set up ranks correctly for different sub process groups.
24+
Users can also easily manage the underlying process_groups/devices for multi-dimensional parallelism via :class:`DeviceMesh`.
25+
26+
.. figure:: /_static/img/distributed/device_mesh.png
27+
:width: 100%
28+
:align: center
29+
:alt: PyTorch DeviceMesh
30+
31+
Why DeviceMesh is Useful
32+
------------------------
33+
DeviceMesh is useful when working with multi-dimensional parallelism (i.e. 3-D parallel) where parallelism composability is requried. For example, when your parallelism solutions require both communication across hosts and within each host.
34+
The image above shows that we can create a 2D mesh that connects the devices within each host, and connects each device with its counterpart on the other hosts in a homogenous setup.
35+
36+
Without DeviceMesh, users would need to manually set up NCCL communicators, cuda devices on each process before applying any parallelism, which could be quite complicated.
37+
The following code snippet illustrates a hybrid sharding 2-D Parallel pattern setup without :class:`DeviceMesh`.
38+
First, we need to manually calculate the shard group and replicate group. Then, we need to assign the correct shard and
39+
replicate group to each rank.
40+
41+
.. code-block:: python
42+
43+
import os
44+
45+
import torch
46+
import torch.distributed as dist
47+
48+
# Understand world topology
49+
rank = int(os.environ["RANK"])
50+
world_size = int(os.environ["WORLD_SIZE"])
51+
print(f"Running example on {rank=} in a world with {world_size=}")
52+
53+
# Create process groups to manage 2-D like parallel pattern
54+
dist.init_process_group("nccl")
55+
torch.cuda.set_device(rank)
56+
57+
# Create shard groups (e.g. (0, 1, 2, 3), (4, 5, 6, 7))
58+
# and assign the correct shard group to each rank
59+
num_node_devices = torch.cuda.device_count()
60+
shard_rank_lists = list(range(0, num_node_devices // 2)), list(range(num_node_devices // 2, num_node_devices))
61+
shard_groups = (
62+
dist.new_group(shard_rank_lists[0]),
63+
dist.new_group(shard_rank_lists[1]),
64+
)
65+
current_shard_group = (
66+
shard_groups[0] if rank in shard_rank_lists[0] else shard_groups[1]
67+
)
68+
69+
# Create replicate groups (for example, (0, 4), (1, 5), (2, 6), (3, 7))
70+
# and assign the correct replicate group to each rank
71+
current_replicate_group = None
72+
shard_factor = len(shard_rank_lists[0])
73+
for i in range(num_node_devices // 2):
74+
replicate_group_ranks = list(range(i, num_node_devices, shard_factor))
75+
replicate_group = dist.new_group(replicate_group_ranks)
76+
if rank in replicate_group_ranks:
77+
current_replicate_group = replicate_group
78+
79+
To run the above code snippet, we can leverage PyTorch Elastic. Let's create a file named ``2d_setup.py``.
80+
Then, run the following `torch elastic/torchrun <https://pytorch.org/docs/stable/elastic/quickstart.html>`__ command.
81+
82+
.. code-block:: python
83+
84+
torchrun --nproc_per_node=8 --rdzv_id=100 --rdzv_endpoint=localhost:29400 2d_setup.py
85+
86+
.. note::
87+
For simplicity of demonstration, we are simulating 2D parallel using only one node. Note that this code snippet can also be used when running on multi hosts setup.
88+
89+
With the help of :func:`init_device_mesh`, we can accomplish the above 2D setup in just two lines, and we can still
90+
access the underlying :class:`ProcessGroup` if needed.
91+
92+
93+
.. code-block:: python
94+
95+
from torch.distributed.device_mesh import init_device_mesh
96+
mesh_2d = init_device_mesh("cuda", (2, 4), mesh_dim_names=("replicate", "shard"))
97+
98+
# Users can acess the undelying process group thru `get_group` API.
99+
replicate_group = mesh_2d.get_group(mesh_dim="replicate")
100+
shard_group = mesh_2d.get_group(mesh_dim="shard")
101+
102+
Let's create a file named ``2d_setup_with_device_mesh.py``.
103+
Then, run the following `torch elastic/torchrun <https://pytorch.org/docs/stable/elastic/quickstart.html>`__ command.
104+
105+
.. code-block:: python
106+
107+
torchrun --nproc_per_node=8 2d_setup_with_device_mesh.py
108+
109+
110+
How to use DeviceMesh with HSDP
111+
-------------------------------
112+
113+
Hybrid Sharding Data Parallel(HSDP) is 2D strategy to perform FSDP within a host and DDP across hosts.
114+
115+
Let's see an example of how DeviceMesh can assist with applying HSDP to your model with a simple setup. With DeviceMesh,
116+
users would not need to manually create and manage shard group and replicate group.
117+
118+
.. code-block:: python
119+
120+
import torch
121+
import torch.nn as nn
122+
123+
from torch.distributed.device_mesh import init_device_mesh
124+
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP, ShardingStrategy
125+
126+
127+
class ToyModel(nn.Module):
128+
def __init__(self):
129+
super(ToyModel, self).__init__()
130+
self.net1 = nn.Linear(10, 10)
131+
self.relu = nn.ReLU()
132+
self.net2 = nn.Linear(10, 5)
133+
134+
def forward(self, x):
135+
return self.net2(self.relu(self.net1(x)))
136+
137+
138+
# HSDP: MeshShape(2, 4)
139+
mesh_2d = init_device_mesh("cuda", (2, 4))
140+
model = FSDP(
141+
ToyModel(), device_mesh=mesh_2d, sharding_strategy=ShardingStrategy.HYBRID_SHARD
142+
)
143+
144+
Let's create a file named ``hsdp.py``.
145+
Then, run the following `torch elastic/torchrun <https://pytorch.org/docs/stable/elastic/quickstart.html>`__ command.
146+
147+
.. code-block:: python
148+
149+
torchrun --nproc_per_node=8 hsdp.py
150+
151+
Conclusion
152+
----------
153+
In conclusion, we have learned about :class:`DeviceMesh` and :func:`init_device_mesh`, as well as how
154+
they can be used to describe the layout of devices across the cluster.
155+
156+
For more information, please see the following:
157+
158+
- `2D parallel combining Tensor/Sequance Parallel with FSDP <https://github.com/pytorch/examples/blob/main/distributed/tensor_parallelism/fsdp_tp_example.py>`__
159+
- `Composable PyTorch Distributed with PT2 <chrome-extension://efaidnbmnnnibpcajpcglclefindmkaj/https://static.sched.com/hosted_files/pytorch2023/d1/%5BPTC%2023%5D%20Composable%20PyTorch%20Distributed%20with%20PT2.pdf>`__

recipes_source/recipes_index.rst

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -313,6 +313,13 @@ Recipes are bite-sized, actionable examples of how to use specific PyTorch featu
313313

314314
.. Distributed Training
315315
316+
.. customcarditem::
317+
:header: Getting Started with DeviceMesh
318+
:card_description: Learn how to use DeviceMesh
319+
:image: ../_static/img/thumbnails/cropped/profiler.png
320+
:link: ../recipes/distributed_device_mesh.html
321+
:tags: Distributed-Training
322+
316323
.. customcarditem::
317324
:header: Shard Optimizer States with ZeroRedundancyOptimizer
318325
:card_description: How to use ZeroRedundancyOptimizer to reduce memory consumption.

0 commit comments

Comments
 (0)