|
| 1 | +Getting Started with DeviceMesh |
| 2 | +===================================================== |
| 3 | + |
| 4 | +**Author**: `Iris Zhang <https://github.com/wz337>`__, `Wanchao Liang <https://github.com/wanchaol>`__ |
| 5 | + |
| 6 | +.. note:: |
| 7 | + |edit| View and edit this tutorial in `github <https://github.com/pytorch/tutorials/blob/main/recipes_source/distributed_device_mesh.rst>`__. |
| 8 | + |
| 9 | +Prerequisites: |
| 10 | + |
| 11 | +- `Distributed Communication Package - torch.distributed <https://pytorch.org/docs/stable/distributed.html>`__ |
| 12 | +- Python 3.8 - 3.11 |
| 13 | +- PyTorch 2.2 |
| 14 | + |
| 15 | + |
| 16 | +Setting up distributed communicators, i.e. NVIDIA Collective Communication Library (NCCL) communicators, for distributed training can pose a significant challenge. For workloads where users need to compose different parallelisms, |
| 17 | +users would need to manually set up and manage NCCL communicators (for example, :class:`ProcessGroup`) for each parallelism solutions. This process could be complicated and susceptible to errors. |
| 18 | +:class:`DeviceMesh` can simplify this process, making it more manageable and less prone to errors. |
| 19 | + |
| 20 | +What is DeviceMesh |
| 21 | +------------------ |
| 22 | +:class:`DeviceMesh` is a higher level abstraction that manages :class:`ProcessGroup`. It allows users to effortlessly |
| 23 | +create inter-node and intra-node process groups without worrying about how to set up ranks correctly for different sub process groups. |
| 24 | +Users can also easily manage the underlying process_groups/devices for multi-dimensional parallelism via :class:`DeviceMesh`. |
| 25 | + |
| 26 | +.. figure:: /_static/img/distributed/device_mesh.png |
| 27 | + :width: 100% |
| 28 | + :align: center |
| 29 | + :alt: PyTorch DeviceMesh |
| 30 | + |
| 31 | +Why DeviceMesh is Useful |
| 32 | +------------------------ |
| 33 | +DeviceMesh is useful when working with multi-dimensional parallelism (i.e. 3-D parallel) where parallelism composability is requried. For example, when your parallelism solutions require both communication across hosts and within each host. |
| 34 | +The image above shows that we can create a 2D mesh that connects the devices within each host, and connects each device with its counterpart on the other hosts in a homogenous setup. |
| 35 | + |
| 36 | +Without DeviceMesh, users would need to manually set up NCCL communicators, cuda devices on each process before applying any parallelism, which could be quite complicated. |
| 37 | +The following code snippet illustrates a hybrid sharding 2-D Parallel pattern setup without :class:`DeviceMesh`. |
| 38 | +First, we need to manually calculate the shard group and replicate group. Then, we need to assign the correct shard and |
| 39 | +replicate group to each rank. |
| 40 | + |
| 41 | +.. code-block:: python |
| 42 | +
|
| 43 | + import os |
| 44 | +
|
| 45 | + import torch |
| 46 | + import torch.distributed as dist |
| 47 | +
|
| 48 | + # Understand world topology |
| 49 | + rank = int(os.environ["RANK"]) |
| 50 | + world_size = int(os.environ["WORLD_SIZE"]) |
| 51 | + print(f"Running example on {rank=} in a world with {world_size=}") |
| 52 | +
|
| 53 | + # Create process groups to manage 2-D like parallel pattern |
| 54 | + dist.init_process_group("nccl") |
| 55 | + torch.cuda.set_device(rank) |
| 56 | +
|
| 57 | + # Create shard groups (e.g. (0, 1, 2, 3), (4, 5, 6, 7)) |
| 58 | + # and assign the correct shard group to each rank |
| 59 | + num_node_devices = torch.cuda.device_count() |
| 60 | + shard_rank_lists = list(range(0, num_node_devices // 2)), list(range(num_node_devices // 2, num_node_devices)) |
| 61 | + shard_groups = ( |
| 62 | + dist.new_group(shard_rank_lists[0]), |
| 63 | + dist.new_group(shard_rank_lists[1]), |
| 64 | + ) |
| 65 | + current_shard_group = ( |
| 66 | + shard_groups[0] if rank in shard_rank_lists[0] else shard_groups[1] |
| 67 | + ) |
| 68 | +
|
| 69 | + # Create replicate groups (for example, (0, 4), (1, 5), (2, 6), (3, 7)) |
| 70 | + # and assign the correct replicate group to each rank |
| 71 | + current_replicate_group = None |
| 72 | + shard_factor = len(shard_rank_lists[0]) |
| 73 | + for i in range(num_node_devices // 2): |
| 74 | + replicate_group_ranks = list(range(i, num_node_devices, shard_factor)) |
| 75 | + replicate_group = dist.new_group(replicate_group_ranks) |
| 76 | + if rank in replicate_group_ranks: |
| 77 | + current_replicate_group = replicate_group |
| 78 | +
|
| 79 | +To run the above code snippet, we can leverage PyTorch Elastic. Let's create a file named ``2d_setup.py``. |
| 80 | +Then, run the following `torch elastic/torchrun <https://pytorch.org/docs/stable/elastic/quickstart.html>`__ command. |
| 81 | + |
| 82 | +.. code-block:: python |
| 83 | +
|
| 84 | + torchrun --nproc_per_node=8 --rdzv_id=100 --rdzv_endpoint=localhost:29400 2d_setup.py |
| 85 | +
|
| 86 | +.. note:: |
| 87 | + For simplicity of demonstration, we are simulating 2D parallel using only one node. Note that this code snippet can also be used when running on multi hosts setup. |
| 88 | + |
| 89 | +With the help of :func:`init_device_mesh`, we can accomplish the above 2D setup in just two lines, and we can still |
| 90 | +access the underlying :class:`ProcessGroup` if needed. |
| 91 | + |
| 92 | + |
| 93 | +.. code-block:: python |
| 94 | +
|
| 95 | + from torch.distributed.device_mesh import init_device_mesh |
| 96 | + mesh_2d = init_device_mesh("cuda", (2, 4), mesh_dim_names=("replicate", "shard")) |
| 97 | +
|
| 98 | + # Users can acess the undelying process group thru `get_group` API. |
| 99 | + replicate_group = mesh_2d.get_group(mesh_dim="replicate") |
| 100 | + shard_group = mesh_2d.get_group(mesh_dim="shard") |
| 101 | +
|
| 102 | +Let's create a file named ``2d_setup_with_device_mesh.py``. |
| 103 | +Then, run the following `torch elastic/torchrun <https://pytorch.org/docs/stable/elastic/quickstart.html>`__ command. |
| 104 | + |
| 105 | +.. code-block:: python |
| 106 | +
|
| 107 | + torchrun --nproc_per_node=8 2d_setup_with_device_mesh.py |
| 108 | +
|
| 109 | +
|
| 110 | +How to use DeviceMesh with HSDP |
| 111 | +------------------------------- |
| 112 | + |
| 113 | +Hybrid Sharding Data Parallel(HSDP) is 2D strategy to perform FSDP within a host and DDP across hosts. |
| 114 | + |
| 115 | +Let's see an example of how DeviceMesh can assist with applying HSDP to your model with a simple setup. With DeviceMesh, |
| 116 | +users would not need to manually create and manage shard group and replicate group. |
| 117 | + |
| 118 | +.. code-block:: python |
| 119 | +
|
| 120 | + import torch |
| 121 | + import torch.nn as nn |
| 122 | +
|
| 123 | + from torch.distributed.device_mesh import init_device_mesh |
| 124 | + from torch.distributed.fsdp import FullyShardedDataParallel as FSDP, ShardingStrategy |
| 125 | +
|
| 126 | +
|
| 127 | + class ToyModel(nn.Module): |
| 128 | + def __init__(self): |
| 129 | + super(ToyModel, self).__init__() |
| 130 | + self.net1 = nn.Linear(10, 10) |
| 131 | + self.relu = nn.ReLU() |
| 132 | + self.net2 = nn.Linear(10, 5) |
| 133 | +
|
| 134 | + def forward(self, x): |
| 135 | + return self.net2(self.relu(self.net1(x))) |
| 136 | +
|
| 137 | +
|
| 138 | + # HSDP: MeshShape(2, 4) |
| 139 | + mesh_2d = init_device_mesh("cuda", (2, 4)) |
| 140 | + model = FSDP( |
| 141 | + ToyModel(), device_mesh=mesh_2d, sharding_strategy=ShardingStrategy.HYBRID_SHARD |
| 142 | + ) |
| 143 | +
|
| 144 | +Let's create a file named ``hsdp.py``. |
| 145 | +Then, run the following `torch elastic/torchrun <https://pytorch.org/docs/stable/elastic/quickstart.html>`__ command. |
| 146 | + |
| 147 | +.. code-block:: python |
| 148 | +
|
| 149 | + torchrun --nproc_per_node=8 hsdp.py |
| 150 | +
|
| 151 | +Conclusion |
| 152 | +---------- |
| 153 | +In conclusion, we have learned about :class:`DeviceMesh` and :func:`init_device_mesh`, as well as how |
| 154 | +they can be used to describe the layout of devices across the cluster. |
| 155 | + |
| 156 | +For more information, please see the following: |
| 157 | + |
| 158 | +- `2D parallel combining Tensor/Sequance Parallel with FSDP <https://github.com/pytorch/examples/blob/main/distributed/tensor_parallelism/fsdp_tp_example.py>`__ |
| 159 | +- `Composable PyTorch Distributed with PT2 <chrome-extension://efaidnbmnnnibpcajpcglclefindmkaj/https://static.sched.com/hosted_files/pytorch2023/d1/%5BPTC%2023%5D%20Composable%20PyTorch%20Distributed%20with%20PT2.pdf>`__ |
0 commit comments