|
| 1 | +Getting Started with DeviceMesh |
| 2 | +===================================================== |
| 3 | + |
| 4 | +**Author**: `Iris Zhang <https://github.com/wz337>`__, `Wanchao Liang <https://github.com/wanchaol>`__ |
| 5 | + |
| 6 | +.. note:: |
| 7 | + |edit| View and edit this tutorial in `github <https://github.com/pytorch/tutorials/blob/main/recipes_source/distributed_device_mesh.rst>`__. |
| 8 | + |
| 9 | +Prerequisites: |
| 10 | + |
| 11 | +- `Distributed Communication Package - torch.distributed <https://pytorch.org/docs/stable/distributed.html>`__ |
| 12 | + |
| 13 | +.. Setting up nccl communicators for distributed communication during distributed training could be challenging. For workloads where users need to compose different parallelisms, |
| 14 | +.. users would need to manually set up and manage nccl communicators(for example, :class:`ProcessGroup`) for each parallelism solutions. This is fairly complicated and error-proned. |
| 15 | +.. :class:`DeviceMesh` can help make this process much easier. |
| 16 | +
|
| 17 | +What is DeviceMesh |
| 18 | +------------------ |
| 19 | +.. :class:`DeviceMesh` is a higher level abstraction that manages :class:`ProcessGroup`. It allows users to easily |
| 20 | +.. create inter-node and intra-node process groups without worrying about how to set up ranks correctly for different sub process groups. |
| 21 | +.. Users can also easily manage the underlying process_groups/devices for multi-dimensional parallelism via :class:`DeviceMesh`. |
| 22 | +
|
| 23 | +.. figure:: /_static/img/distributed/device_mesh.png |
| 24 | + :width: 100% |
| 25 | + :align: center |
| 26 | + :alt: PyTorch DeviceMesh |
| 27 | + |
| 28 | +Why DeviceMesh is Useful |
| 29 | +------------------------ |
| 30 | + |
| 31 | +.. Below is the code snippet for a 2D setup without :class:`DeviceMesh`. First, we need to manually calculate shard group and replicate group. Then, we need to assign the correct shard and |
| 32 | +.. replicate group to each rank. |
| 33 | +
|
| 34 | +.. code-block:: python |
| 35 | +import os |
| 36 | + |
| 37 | +import torch |
| 38 | +import torch.distributed as dist |
| 39 | + |
| 40 | +# Understand world topology |
| 41 | +rank = int(os.environ["RANK"]) |
| 42 | +world_size = int(os.environ["WORLD_SIZE"]) |
| 43 | +print(f"Running example on {rank=} in a world with {world_size=}") |
| 44 | + |
| 45 | +# Create process groups to manage 2-D like parallel pattern |
| 46 | +dist.init_process_group("nccl") |
| 47 | + |
| 48 | +# Create shard groups (e.g. (0, 1, 2, 3), (4, 5, 6, 7)) |
| 49 | +# and assign the correct shard group to each rank |
| 50 | +num_node_devices = torch.cuda.device_count() |
| 51 | +shard_rank_lists = list(range(0, num_node_devices // 2)), list(range(num_node_devices // 2, num_node_devices)) |
| 52 | +shard_groups = ( |
| 53 | + dist.new_group(shard_rank_lists[0]), |
| 54 | + dist.new_group(shard_rank_lists[1]), |
| 55 | +) |
| 56 | +current_shard_group = ( |
| 57 | + shard_groups[0] if rank in shard_rank_lists[0] else shard_groups[1] |
| 58 | +) |
| 59 | + |
| 60 | +# Create replicate groups (e.g. (0, 4), (1, 5), (2, 6), (3, 7)) |
| 61 | +# and assign the correct replicate group to each rank |
| 62 | +current_replicate_group = None |
| 63 | +shard_factor = len(shard_rank_lists[0]) |
| 64 | +for i in range(num_node_devices // 2): |
| 65 | + replicate_group_ranks = list(range(i, num_node_devices, shard_factor)) |
| 66 | + replicate_group = dist.new_group(replicate_group_ranks) |
| 67 | + if rank in replicate_group_ranks: |
| 68 | + current_replicate_group = replicate_group |
| 69 | + |
| 70 | +.. To run the above code snippet, we can leverage PyTorch Elastic. Let's create a file named ``2d_setup.py``. |
| 71 | +.. Then, run the following `torch elastic/torchrun <https://pytorch.org/docs/stable/elastic/quickstart.html>`__ command. |
| 72 | +
|
| 73 | +.. code-block:: python |
| 74 | +torchrun --nnodes=1 --nproc_per_node=8 --rdzv_id=100 --rdzv_endpoint=localhost:29400 2d_setup.py |
| 75 | + |
| 76 | + |
| 77 | +.. With the help of :func:`init_device_mesh`, we can accomplish the above 2D setup in just 2 lines. |
| 78 | +
|
| 79 | +
|
| 80 | +.. code-block:: python |
| 81 | +from torch.distributed.device_mesh import init_device_mesh |
| 82 | +device_mesh = init_device_mesh("cuda", (2, 4)) |
| 83 | + |
| 84 | +.. Let's create a file named ``2d_setup_with_device_mesh.py``. |
| 85 | +.. Then, run the following `torch elastic/torchrun <https://pytorch.org/docs/stable/elastic/quickstart.html>`__ command. |
| 86 | +
|
| 87 | +.. code-block:: python |
| 88 | +torchrun --nnodes=1 --nproc_per_node=8 --rdzv_id=100 --rdzv_endpoint=localhost:29400 2d_setup_with_device_mesh.py |
| 89 | + |
| 90 | + |
| 91 | +How to use DeviceMesh with HSDP |
| 92 | +------------------------------- |
| 93 | + |
| 94 | +Hybrid Sharding(HSDP) |
| 95 | +Let's see an example of how DeviceMesh can assist with applying Hybrid Sharding strategy to your model. |
| 96 | + |
| 97 | +.. code-block:: python |
| 98 | +import torch |
| 99 | +import torch.nn as nn |
| 100 | + |
| 101 | +from torch.distributed.device_mesh import init_device_mesh |
| 102 | +from torch.distributed.fsdp import FullyShardedDataParallel as FSDP, ShardingStrategy |
| 103 | + |
| 104 | + |
| 105 | +class ToyModel(nn.Module): |
| 106 | + def __init__(self): |
| 107 | + super(ToyModel, self).__init__() |
| 108 | + self.net1 = nn.Linear(10, 10) |
| 109 | + self.relu = nn.ReLU() |
| 110 | + self.net2 = nn.Linear(10, 5) |
| 111 | + |
| 112 | + def forward(self, x): |
| 113 | + return self.net2(self.relu(self.net1(x))) |
| 114 | + |
| 115 | + |
| 116 | +# HSDP: MeshShape(2, 4) |
| 117 | +mesh_2d = init_device_mesh("cuda", (2, 4)) |
| 118 | +model = FSDP( |
| 119 | + ToyModel(), device_mesh=mesh_2d, sharding_strategy=ShardingStrategy.HYBRID_SHARD |
| 120 | +) |
| 121 | + |
| 122 | +.. Let's create a file named ``hsdp.py``. |
| 123 | +.. Then, run the following `torch elastic/torchrun <https://pytorch.org/docs/stable/elastic/quickstart.html>`__ command. |
| 124 | +
|
| 125 | +.. code-block:: python |
| 126 | +torchrun --nnodes=1 --nproc_per_node=8 --rdzv_id=100 --rdzv_endpoint=localhost:29400 hsdp.py |
| 127 | + |
| 128 | +Conclusion |
| 129 | +---------- |
| 130 | +.. In conclusion, we have learned about :class:`DeviceMesh` and :func:`init_device_mesh`, as well as how |
| 131 | +.. they can be used to describe the layout of devices across the cluster. |
| 132 | +
|
| 133 | +.. For more information, please see the following: |
| 134 | +
|
| 135 | +- `2D parallel combining Tensor/Sequance Parallel with FSDP <https://github.com/pytorch/examples/blob/main/distributed/tensor_parallelism/fsdp_tp_example.py>`__ |
0 commit comments