diff --git a/_static/img/distributed/device_mesh.png b/_static/img/distributed/device_mesh.png new file mode 100644 index 00000000000..2ccabcc4824 Binary files /dev/null and b/_static/img/distributed/device_mesh.png differ diff --git a/distributed/home.rst b/distributed/home.rst index aac2a1df494..ff0dbf73e5a 100644 --- a/distributed/home.rst +++ b/distributed/home.rst @@ -13,6 +13,7 @@ PyTorch with each method having their advantages in certain use cases: * `DistributedDataParallel (DDP) <#learn-ddp>`__ * `Fully Sharded Data Parallel (FSDP) <#learn-fsdp>`__ +* `Device Mesh <#device-mesh>`__ * `Remote Procedure Call (RPC) distributed training <#learn-rpc>`__ * `Custom Extensions <#custom-extensions>`__ @@ -51,7 +52,7 @@ Learn DDP :link: https://pytorch.org/tutorials/advanced/generic_join.html?utm_source=distr_landing&utm_medium=generic_join :link-type: url - This tutorial describes the Join context manager and + This tutorial describes the Join context manager and demonstrates it's use with DistributedData Parallel. +++ :octicon:`code;1em` Code @@ -83,6 +84,23 @@ Learn FSDP +++ :octicon:`code;1em` Code +.. _device-mesh: + +Learn DeviceMesh +---------------- + +.. grid:: 3 + + .. grid-item-card:: :octicon:`file-code;1em` + Getting Started with DeviceMesh + :link: https://pytorch.org/tutorials/recipes/distributed_device_mesh.html?highlight=devicemesh + :link-type: url + + In this tutorial you will learn about `DeviceMesh` + and how it can help with distributed training. + +++ + :octicon:`code;1em` Code + .. _learn-rpc: Learn RPC diff --git a/recipes_source/distributed_device_mesh.rst b/recipes_source/distributed_device_mesh.rst new file mode 100644 index 00000000000..ded1ecd4e99 --- /dev/null +++ b/recipes_source/distributed_device_mesh.rst @@ -0,0 +1,159 @@ +Getting Started with DeviceMesh +===================================================== + +**Author**: `Iris Zhang `__, `Wanchao Liang `__ + +.. note:: + |edit| View and edit this tutorial in `github `__. + +Prerequisites: + +- `Distributed Communication Package - torch.distributed `__ +- Python 3.8 - 3.11 +- PyTorch 2.2 + + +Setting up distributed communicators, i.e. NVIDIA Collective Communication Library (NCCL) communicators, for distributed training can pose a significant challenge. For workloads where users need to compose different parallelisms, +users would need to manually set up and manage NCCL communicators (for example, :class:`ProcessGroup`) for each parallelism solutions. This process could be complicated and susceptible to errors. +:class:`DeviceMesh` can simplify this process, making it more manageable and less prone to errors. + +What is DeviceMesh +------------------ +:class:`DeviceMesh` is a higher level abstraction that manages :class:`ProcessGroup`. It allows users to effortlessly +create inter-node and intra-node process groups without worrying about how to set up ranks correctly for different sub process groups. +Users can also easily manage the underlying process_groups/devices for multi-dimensional parallelism via :class:`DeviceMesh`. + +.. figure:: /_static/img/distributed/device_mesh.png + :width: 100% + :align: center + :alt: PyTorch DeviceMesh + +Why DeviceMesh is Useful +------------------------ +DeviceMesh is useful when working with multi-dimensional parallelism (i.e. 3-D parallel) where parallelism composability is requried. For example, when your parallelism solutions require both communication across hosts and within each host. +The image above shows that we can create a 2D mesh that connects the devices within each host, and connects each device with its counterpart on the other hosts in a homogenous setup. + +Without DeviceMesh, users would need to manually set up NCCL communicators, cuda devices on each process before applying any parallelism, which could be quite complicated. +The following code snippet illustrates a hybrid sharding 2-D Parallel pattern setup without :class:`DeviceMesh`. +First, we need to manually calculate the shard group and replicate group. Then, we need to assign the correct shard and +replicate group to each rank. + +.. code-block:: python + + import os + + import torch + import torch.distributed as dist + + # Understand world topology + rank = int(os.environ["RANK"]) + world_size = int(os.environ["WORLD_SIZE"]) + print(f"Running example on {rank=} in a world with {world_size=}") + + # Create process groups to manage 2-D like parallel pattern + dist.init_process_group("nccl") + torch.cuda.set_device(rank) + + # Create shard groups (e.g. (0, 1, 2, 3), (4, 5, 6, 7)) + # and assign the correct shard group to each rank + num_node_devices = torch.cuda.device_count() + shard_rank_lists = list(range(0, num_node_devices // 2)), list(range(num_node_devices // 2, num_node_devices)) + shard_groups = ( + dist.new_group(shard_rank_lists[0]), + dist.new_group(shard_rank_lists[1]), + ) + current_shard_group = ( + shard_groups[0] if rank in shard_rank_lists[0] else shard_groups[1] + ) + + # Create replicate groups (for example, (0, 4), (1, 5), (2, 6), (3, 7)) + # and assign the correct replicate group to each rank + current_replicate_group = None + shard_factor = len(shard_rank_lists[0]) + for i in range(num_node_devices // 2): + replicate_group_ranks = list(range(i, num_node_devices, shard_factor)) + replicate_group = dist.new_group(replicate_group_ranks) + if rank in replicate_group_ranks: + current_replicate_group = replicate_group + +To run the above code snippet, we can leverage PyTorch Elastic. Let's create a file named ``2d_setup.py``. +Then, run the following `torch elastic/torchrun `__ command. + +.. code-block:: python + + torchrun --nproc_per_node=8 --rdzv_id=100 --rdzv_endpoint=localhost:29400 2d_setup.py + +.. note:: + For simplicity of demonstration, we are simulating 2D parallel using only one node. Note that this code snippet can also be used when running on multi hosts setup. + +With the help of :func:`init_device_mesh`, we can accomplish the above 2D setup in just two lines, and we can still +access the underlying :class:`ProcessGroup` if needed. + + +.. code-block:: python + + from torch.distributed.device_mesh import init_device_mesh + mesh_2d = init_device_mesh("cuda", (2, 4), mesh_dim_names=("replicate", "shard")) + + # Users can acess the undelying process group thru `get_group` API. + replicate_group = mesh_2d.get_group(mesh_dim="replicate") + shard_group = mesh_2d.get_group(mesh_dim="shard") + +Let's create a file named ``2d_setup_with_device_mesh.py``. +Then, run the following `torch elastic/torchrun `__ command. + +.. code-block:: python + + torchrun --nproc_per_node=8 2d_setup_with_device_mesh.py + + +How to use DeviceMesh with HSDP +------------------------------- + +Hybrid Sharding Data Parallel(HSDP) is 2D strategy to perform FSDP within a host and DDP across hosts. + +Let's see an example of how DeviceMesh can assist with applying HSDP to your model with a simple setup. With DeviceMesh, +users would not need to manually create and manage shard group and replicate group. + +.. code-block:: python + + import torch + import torch.nn as nn + + from torch.distributed.device_mesh import init_device_mesh + from torch.distributed.fsdp import FullyShardedDataParallel as FSDP, ShardingStrategy + + + class ToyModel(nn.Module): + def __init__(self): + super(ToyModel, self).__init__() + self.net1 = nn.Linear(10, 10) + self.relu = nn.ReLU() + self.net2 = nn.Linear(10, 5) + + def forward(self, x): + return self.net2(self.relu(self.net1(x))) + + + # HSDP: MeshShape(2, 4) + mesh_2d = init_device_mesh("cuda", (2, 4)) + model = FSDP( + ToyModel(), device_mesh=mesh_2d, sharding_strategy=ShardingStrategy.HYBRID_SHARD + ) + +Let's create a file named ``hsdp.py``. +Then, run the following `torch elastic/torchrun `__ command. + +.. code-block:: python + + torchrun --nproc_per_node=8 hsdp.py + +Conclusion +---------- +In conclusion, we have learned about :class:`DeviceMesh` and :func:`init_device_mesh`, as well as how +they can be used to describe the layout of devices across the cluster. + +For more information, please see the following: + +- `2D parallel combining Tensor/Sequance Parallel with FSDP `__ +- `Composable PyTorch Distributed with PT2 `__ diff --git a/recipes_source/recipes_index.rst b/recipes_source/recipes_index.rst index 5ea68f7aa8f..411eca29d2c 100644 --- a/recipes_source/recipes_index.rst +++ b/recipes_source/recipes_index.rst @@ -313,6 +313,13 @@ Recipes are bite-sized, actionable examples of how to use specific PyTorch featu .. Distributed Training +.. customcarditem:: + :header: Getting Started with DeviceMesh + :card_description: Learn how to use DeviceMesh + :image: ../_static/img/thumbnails/cropped/profiler.png + :link: ../recipes/distributed_device_mesh.html + :tags: Distributed-Training + .. customcarditem:: :header: Shard Optimizer States with ZeroRedundancyOptimizer :card_description: How to use ZeroRedundancyOptimizer to reduce memory consumption.