From 07b61f592ec1d7effafb19d6abe2b7e3e174df91 Mon Sep 17 00:00:00 2001 From: Anshul Sinha <50644008+sinhaanshul@users.noreply.github.com> Date: Thu, 1 Aug 2024 16:17:34 -0700 Subject: [PATCH 01/14] [dtensor][debug] tutorial showing users how to use commdebugmode and giving access to visual browser --- distributed_comm_debug_mode.rst | 166 ++++++++++++++++++++++++++++++++ 1 file changed, 166 insertions(+) create mode 100644 distributed_comm_debug_mode.rst diff --git a/distributed_comm_debug_mode.rst b/distributed_comm_debug_mode.rst new file mode 100644 index 00000000000..7b36aebbb62 --- /dev/null +++ b/distributed_comm_debug_mode.rst @@ -0,0 +1,166 @@ +Using CommDebugMode +===================================================== + +**Author**: `Anshul Sinha `__ + +Prerequisites: + +- `Distributed Communication Package - torch.distributed `__ +- Python 3.8 - 3.11 +- PyTorch 2.2 + + +What is CommDebugMode and why is it useful +------------------ +As the size of models continues to increase, users are seeking to leverage various combinations of parallel strategies to scale up distributed training. However, the lack of interoperability between existing solutions poses a significant challenge, primarily due to the absence of a unified abstraction that can bridge these different parallelism strategies. To address this issue, PyTorch has proposed DistributedTensor (DTensor)which abstracts away the complexities of tensor communication in distributed training, providing a seamless user experience. However, this abstraction creates a lack of transparency that can make it challenging for users to identify and resolve issues. To address this challenge, my internship project aims to develop and enhance CommDebugMode, a Python context manager that will serve as one of the primary debugging tools for DTensors. CommDebugMode is a python context manager that enables users to view when and why collective operations are happening when using DTensors, addressing this problem. + + +Why DeviceMesh is Useful +------------------------ +DeviceMesh is useful when working with multi-dimensional parallelism (i.e. 3-D parallel) where parallelism composability is required. For example, when your parallelism solutions require both communication across hosts and within each host. +The image above shows that we can create a 2D mesh that connects the devices within each host, and connects each device with its counterpart on the other hosts in a homogenous setup. + +Without DeviceMesh, users would need to manually set up NCCL communicators, cuda devices on each process before applying any parallelism, which could be quite complicated. +The following code snippet illustrates a hybrid sharding 2-D Parallel pattern setup without :class:`DeviceMesh`. +First, we need to manually calculate the shard group and replicate group. Then, we need to assign the correct shard and +replicate group to each rank. + +.. code-block:: python + + import os + + import torch + import torch.distributed as dist + + # Understand world topology + rank = int(os.environ["RANK"]) + world_size = int(os.environ["WORLD_SIZE"]) + print(f"Running example on {rank=} in a world with {world_size=}") + + # Create process groups to manage 2-D like parallel pattern + dist.init_process_group("nccl") + torch.cuda.set_device(rank) + + # Create shard groups (e.g. (0, 1, 2, 3), (4, 5, 6, 7)) + # and assign the correct shard group to each rank + num_node_devices = torch.cuda.device_count() + shard_rank_lists = list(range(0, num_node_devices // 2)), list(range(num_node_devices // 2, num_node_devices)) + shard_groups = ( + dist.new_group(shard_rank_lists[0]), + dist.new_group(shard_rank_lists[1]), + ) + current_shard_group = ( + shard_groups[0] if rank in shard_rank_lists[0] else shard_groups[1] + ) + + # Create replicate groups (for example, (0, 4), (1, 5), (2, 6), (3, 7)) + # and assign the correct replicate group to each rank + current_replicate_group = None + shard_factor = len(shard_rank_lists[0]) + for i in range(num_node_devices // 2): + replicate_group_ranks = list(range(i, num_node_devices, shard_factor)) + replicate_group = dist.new_group(replicate_group_ranks) + if rank in replicate_group_ranks: + current_replicate_group = replicate_group + +To run the above code snippet, we can leverage PyTorch Elastic. Let's create a file named ``2d_setup.py``. +Then, run the following `torch elastic/torchrun `__ command. + +.. code-block:: python + + torchrun --nproc_per_node=8 --rdzv_id=100 --rdzv_endpoint=localhost:29400 2d_setup.py + +.. note:: + For simplicity of demonstration, we are simulating 2D parallel using only one node. Note that this code snippet can also be used when running on multi hosts setup. + +With the help of :func:`init_device_mesh`, we can accomplish the above 2D setup in just two lines, and we can still +access the underlying :class:`ProcessGroup` if needed. + + +.. code-block:: python + + from torch.distributed.device_mesh import init_device_mesh + mesh_2d = init_device_mesh("cuda", (2, 4), mesh_dim_names=("replicate", "shard")) + + # Users can access the underlying process group thru `get_group` API. + replicate_group = mesh_2d.get_group(mesh_dim="replicate") + shard_group = mesh_2d.get_group(mesh_dim="shard") + +Let's create a file named ``2d_setup_with_device_mesh.py``. +Then, run the following `torch elastic/torchrun `__ command. + +.. code-block:: python + + torchrun --nproc_per_node=8 2d_setup_with_device_mesh.py + + +How to use DeviceMesh with HSDP +------------------------------- + +Hybrid Sharding Data Parallel(HSDP) is 2D strategy to perform FSDP within a host and DDP across hosts. + +Let's see an example of how DeviceMesh can assist with applying HSDP to your model with a simple setup. With DeviceMesh, +users would not need to manually create and manage shard group and replicate group. + +.. code-block:: python + + import torch + import torch.nn as nn + + from torch.distributed.device_mesh import init_device_mesh + from torch.distributed.fsdp import FullyShardedDataParallel as FSDP, ShardingStrategy + + + class ToyModel(nn.Module): + def __init__(self): + super(ToyModel, self).__init__() + self.net1 = nn.Linear(10, 10) + self.relu = nn.ReLU() + self.net2 = nn.Linear(10, 5) + + def forward(self, x): + return self.net2(self.relu(self.net1(x))) + + + # HSDP: MeshShape(2, 4) + mesh_2d = init_device_mesh("cuda", (2, 4)) + model = FSDP( + ToyModel(), device_mesh=mesh_2d, sharding_strategy=ShardingStrategy.HYBRID_SHARD + ) + +Let's create a file named ``hsdp.py``. +Then, run the following `torch elastic/torchrun `__ command. + +.. code-block:: python + + torchrun --nproc_per_node=8 hsdp.py + +How to use DeviceMesh for your custom parallel solutions +-------------------------------------------------------- +When working with large scale training, you might have more complex custom parallel training composition. For example, you may need to slice out submeshes for different parallelism solutions. +DeviceMesh allows users to slice child mesh from the parent mesh and re-use the NCCL communicators already created when the parent mesh is initialized. + +.. code-block:: python + + from torch.distributed.device_mesh import init_device_mesh + mesh_3d = init_device_mesh("cuda", (2, 2, 2), mesh_dim_names=("replicate", "shard", "tp")) + + # Users can slice child meshes from the parent mesh. + hsdp_mesh = mesh_3d["replicate", "shard"] + tp_mesh = mesh_3d["tp"] + + # Users can access the underlying process group thru `get_group` API. + replicate_group = hsdp_mesh["replicate"].get_group() + shard_group = hsdp_mesh["Shard"].get_group() + tp_group = tp_mesh.get_group() + + +Conclusion +---------- +In conclusion, we have learned about :class:`DeviceMesh` and :func:`init_device_mesh`, as well as how +they can be used to describe the layout of devices across the cluster. + +For more information, please see the following: + +- `2D parallel combining Tensor/Sequance Parallel with FSDP `__ +- `Composable PyTorch Distributed with PT2 `__ From c758b29a2788382dd53f97b413898d4af04f667d Mon Sep 17 00:00:00 2001 From: Anshul Sinha <50644008+sinhaanshul@users.noreply.github.com> Date: Thu, 1 Aug 2024 16:17:34 -0700 Subject: [PATCH 02/14] [dtensor][debug] tutorial showing users how to use commdebugmode and giving access to visual browser --- distributed_comm_debug_mode.rst | 155 +++----------------------------- 1 file changed, 14 insertions(+), 141 deletions(-) diff --git a/distributed_comm_debug_mode.rst b/distributed_comm_debug_mode.rst index 7b36aebbb62..dcf7f47dd99 100644 --- a/distributed_comm_debug_mode.rst +++ b/distributed_comm_debug_mode.rst @@ -15,152 +15,25 @@ What is CommDebugMode and why is it useful As the size of models continues to increase, users are seeking to leverage various combinations of parallel strategies to scale up distributed training. However, the lack of interoperability between existing solutions poses a significant challenge, primarily due to the absence of a unified abstraction that can bridge these different parallelism strategies. To address this issue, PyTorch has proposed DistributedTensor (DTensor)which abstracts away the complexities of tensor communication in distributed training, providing a seamless user experience. However, this abstraction creates a lack of transparency that can make it challenging for users to identify and resolve issues. To address this challenge, my internship project aims to develop and enhance CommDebugMode, a Python context manager that will serve as one of the primary debugging tools for DTensors. CommDebugMode is a python context manager that enables users to view when and why collective operations are happening when using DTensors, addressing this problem. -Why DeviceMesh is Useful +How to use CommDebugMode ------------------------ -DeviceMesh is useful when working with multi-dimensional parallelism (i.e. 3-D parallel) where parallelism composability is required. For example, when your parallelism solutions require both communication across hosts and within each host. -The image above shows that we can create a 2D mesh that connects the devices within each host, and connects each device with its counterpart on the other hosts in a homogenous setup. - -Without DeviceMesh, users would need to manually set up NCCL communicators, cuda devices on each process before applying any parallelism, which could be quite complicated. -The following code snippet illustrates a hybrid sharding 2-D Parallel pattern setup without :class:`DeviceMesh`. -First, we need to manually calculate the shard group and replicate group. Then, we need to assign the correct shard and -replicate group to each rank. - -.. code-block:: python - - import os - - import torch - import torch.distributed as dist - - # Understand world topology - rank = int(os.environ["RANK"]) - world_size = int(os.environ["WORLD_SIZE"]) - print(f"Running example on {rank=} in a world with {world_size=}") - - # Create process groups to manage 2-D like parallel pattern - dist.init_process_group("nccl") - torch.cuda.set_device(rank) - - # Create shard groups (e.g. (0, 1, 2, 3), (4, 5, 6, 7)) - # and assign the correct shard group to each rank - num_node_devices = torch.cuda.device_count() - shard_rank_lists = list(range(0, num_node_devices // 2)), list(range(num_node_devices // 2, num_node_devices)) - shard_groups = ( - dist.new_group(shard_rank_lists[0]), - dist.new_group(shard_rank_lists[1]), - ) - current_shard_group = ( - shard_groups[0] if rank in shard_rank_lists[0] else shard_groups[1] - ) - - # Create replicate groups (for example, (0, 4), (1, 5), (2, 6), (3, 7)) - # and assign the correct replicate group to each rank - current_replicate_group = None - shard_factor = len(shard_rank_lists[0]) - for i in range(num_node_devices // 2): - replicate_group_ranks = list(range(i, num_node_devices, shard_factor)) - replicate_group = dist.new_group(replicate_group_ranks) - if rank in replicate_group_ranks: - current_replicate_group = replicate_group - -To run the above code snippet, we can leverage PyTorch Elastic. Let's create a file named ``2d_setup.py``. -Then, run the following `torch elastic/torchrun `__ command. - -.. code-block:: python - - torchrun --nproc_per_node=8 --rdzv_id=100 --rdzv_endpoint=localhost:29400 2d_setup.py - -.. note:: - For simplicity of demonstration, we are simulating 2D parallel using only one node. Note that this code snippet can also be used when running on multi hosts setup. - -With the help of :func:`init_device_mesh`, we can accomplish the above 2D setup in just two lines, and we can still -access the underlying :class:`ProcessGroup` if needed. - - -.. code-block:: python - - from torch.distributed.device_mesh import init_device_mesh - mesh_2d = init_device_mesh("cuda", (2, 4), mesh_dim_names=("replicate", "shard")) - - # Users can access the underlying process group thru `get_group` API. - replicate_group = mesh_2d.get_group(mesh_dim="replicate") - shard_group = mesh_2d.get_group(mesh_dim="shard") - -Let's create a file named ``2d_setup_with_device_mesh.py``. -Then, run the following `torch elastic/torchrun `__ command. - -.. code-block:: python - - torchrun --nproc_per_node=8 2d_setup_with_device_mesh.py - - -How to use DeviceMesh with HSDP -------------------------------- - -Hybrid Sharding Data Parallel(HSDP) is 2D strategy to perform FSDP within a host and DDP across hosts. - -Let's see an example of how DeviceMesh can assist with applying HSDP to your model with a simple setup. With DeviceMesh, -users would not need to manually create and manage shard group and replicate group. +Using CommDebugMode and getting its output is very simple. .. code-block:: python - import torch - import torch.nn as nn - - from torch.distributed.device_mesh import init_device_mesh - from torch.distributed.fsdp import FullyShardedDataParallel as FSDP, ShardingStrategy - - - class ToyModel(nn.Module): - def __init__(self): - super(ToyModel, self).__init__() - self.net1 = nn.Linear(10, 10) - self.relu = nn.ReLU() - self.net2 = nn.Linear(10, 5) - - def forward(self, x): - return self.net2(self.relu(self.net1(x))) - + comm_mode = CommDebugMode() + with comm_mode: + output = model(inp) - # HSDP: MeshShape(2, 4) - mesh_2d = init_device_mesh("cuda", (2, 4)) - model = FSDP( - ToyModel(), device_mesh=mesh_2d, sharding_strategy=ShardingStrategy.HYBRID_SHARD - ) + # print the operation level collective tracing information + print(comm_mode.generate_comm_debug_tracing_table(noise_level=2)) -Let's create a file named ``hsdp.py``. -Then, run the following `torch elastic/torchrun `__ command. + # log the operation level collective tracing information to a file + comm_mode.log_comm_debug_tracing_table_to_file( + noise_level=1, file_name="transformer_operation_log.txt" + ) + # dump the operation level collective tracing information to json file, + # used in the visual browser below + comm_mode.generate_json_dump(noise_level=2) .. code-block:: python - - torchrun --nproc_per_node=8 hsdp.py - -How to use DeviceMesh for your custom parallel solutions --------------------------------------------------------- -When working with large scale training, you might have more complex custom parallel training composition. For example, you may need to slice out submeshes for different parallelism solutions. -DeviceMesh allows users to slice child mesh from the parent mesh and re-use the NCCL communicators already created when the parent mesh is initialized. - -.. code-block:: python - - from torch.distributed.device_mesh import init_device_mesh - mesh_3d = init_device_mesh("cuda", (2, 2, 2), mesh_dim_names=("replicate", "shard", "tp")) - - # Users can slice child meshes from the parent mesh. - hsdp_mesh = mesh_3d["replicate", "shard"] - tp_mesh = mesh_3d["tp"] - - # Users can access the underlying process group thru `get_group` API. - replicate_group = hsdp_mesh["replicate"].get_group() - shard_group = hsdp_mesh["Shard"].get_group() - tp_group = tp_mesh.get_group() - - -Conclusion ----------- -In conclusion, we have learned about :class:`DeviceMesh` and :func:`init_device_mesh`, as well as how -they can be used to describe the layout of devices across the cluster. - -For more information, please see the following: - -- `2D parallel combining Tensor/Sequance Parallel with FSDP `__ -- `Composable PyTorch Distributed with PT2 `__ From 6f9b7670e5fa3ea16997b0535dc2e1a6bff219a7 Mon Sep 17 00:00:00 2001 From: Anshul Sinha <50644008+sinhaanshul@users.noreply.github.com> Date: Thu, 1 Aug 2024 16:17:34 -0700 Subject: [PATCH 03/14] [dtensor][debug] tutorial showing users how to use commdebugmode and giving access to visual browser --- distributed_comm_debug_mode.rst | 138 +++++++++++++++++++++++++++++--- 1 file changed, 129 insertions(+), 9 deletions(-) diff --git a/distributed_comm_debug_mode.rst b/distributed_comm_debug_mode.rst index dcf7f47dd99..3d93d3edda1 100644 --- a/distributed_comm_debug_mode.rst +++ b/distributed_comm_debug_mode.rst @@ -25,15 +25,135 @@ Using CommDebugMode and getting its output is very simple. with comm_mode: output = model(inp) - # print the operation level collective tracing information - print(comm_mode.generate_comm_debug_tracing_table(noise_level=2)) + # print the operation level collective tracing information + print(comm_mode.generate_comm_debug_tracing_table(noise_level=2)) - # log the operation level collective tracing information to a file - comm_mode.log_comm_debug_tracing_table_to_file( - noise_level=1, file_name="transformer_operation_log.txt" - ) + # log the operation level collective tracing information to a file + comm_mode.log_comm_debug_tracing_table_to_file( + noise_level=1, file_name="transformer_operation_log.txt" + ) - # dump the operation level collective tracing information to json file, - # used in the visual browser below - comm_mode.generate_json_dump(noise_level=2) + # dump the operation level collective tracing information to json file, + # used in the visual browser below + comm_mode.generate_json_dump(noise_level=2) .. code-block:: python + +All users have to do is wrap the code running the model in CommDebugMode and call the API that they want to use to display the data. +Documentation Title +=================== + +Introduction to the Module +-------------------------- + +Below is the interactive module tree visualization: + +.. raw:: html + + + + + + + CommDebugMode Module Tree + + + +
+
+ Drag file here +
+ +
+
+ + + + +.. raw:: html From 2e3a21a78f8d32901710442c2f1478ffa5129e66 Mon Sep 17 00:00:00 2001 From: Anshul Sinha <50644008+sinhaanshul@users.noreply.github.com> Date: Thu, 1 Aug 2024 16:17:34 -0700 Subject: [PATCH 04/14] [dtensor][debug] tutorial showing users how to use commdebugmode and giving access to visual browser --- distributed_comm_debug_mode.rst | 39 +++++++++++++++++++++++---------- 1 file changed, 28 insertions(+), 11 deletions(-) diff --git a/distributed_comm_debug_mode.rst b/distributed_comm_debug_mode.rst index 3d93d3edda1..df7661152ee 100644 --- a/distributed_comm_debug_mode.rst +++ b/distributed_comm_debug_mode.rst @@ -11,7 +11,7 @@ Prerequisites: What is CommDebugMode and why is it useful ------------------- +------------------------------------------ As the size of models continues to increase, users are seeking to leverage various combinations of parallel strategies to scale up distributed training. However, the lack of interoperability between existing solutions poses a significant challenge, primarily due to the absence of a unified abstraction that can bridge these different parallelism strategies. To address this issue, PyTorch has proposed DistributedTensor (DTensor)which abstracts away the complexities of tensor communication in distributed training, providing a seamless user experience. However, this abstraction creates a lack of transparency that can make it challenging for users to identify and resolve issues. To address this challenge, my internship project aims to develop and enhance CommDebugMode, a Python context manager that will serve as one of the primary debugging tools for DTensors. CommDebugMode is a python context manager that enables users to view when and why collective operations are happening when using DTensors, addressing this problem. @@ -26,7 +26,7 @@ Using CommDebugMode and getting its output is very simple. output = model(inp) # print the operation level collective tracing information - print(comm_mode.generate_comm_debug_tracing_table(noise_level=2)) + print(comm_mode.generate_comm_debug_tracing_table(noise_level=0)) # log the operation level collective tracing information to a file comm_mode.log_comm_debug_tracing_table_to_file( @@ -36,16 +36,35 @@ Using CommDebugMode and getting its output is very simple. # dump the operation level collective tracing information to json file, # used in the visual browser below comm_mode.generate_json_dump(noise_level=2) -.. code-block:: python -All users have to do is wrap the code running the model in CommDebugMode and call the API that they want to use to display the data. -Documentation Title -=================== + """ + This is what the output looks like for a MLPModule at noise level 0 + Expected Output: + Global + FORWARD PASS + *c10d_functional.all_reduce: 1 + MLPModule + FORWARD PASS + *c10d_functional.all_reduce: 1 + MLPModule.net1 + MLPModule.relu + MLPModule.net2 + FORWARD PASS + *c10d_functional.all_reduce: 1 + """ + +All users have to do is wrap the code running the model in CommDebugMode and call the API that they want to use to display the data. One important thing to note +is that the users can use a noise_level arguement to control how much information is displayed to the user. You can see what each noise_level will display to the user. + +| 0. prints module-level collective counts +| 1. prints dTensor operations not included in trivial operations, module information +| 2. prints operations not included in trivial operations +| 3. prints all operations -Introduction to the Module --------------------------- +In the example above, users can see in the first picture that the collective operation, all_reduce, occurs once in the forward pass of the MLPModule. The second picture provides a greater level of detail, allowing users to pinpoint that the all-reduce operation happens in the second linear layer of the MLPModule. -Below is the interactive module tree visualization: + +Below is the interactive module tree visualization that users can upload their JSON dump to: .. raw:: html @@ -155,5 +174,3 @@ Below is the interactive module tree visualization: - -.. raw:: html From 63a0105dd4a347af315a4d006bba6c43e4384574 Mon Sep 17 00:00:00 2001 From: Anshul Sinha <50644008+sinhaanshul@users.noreply.github.com> Date: Thu, 1 Aug 2024 16:17:34 -0700 Subject: [PATCH 05/14] [dtensor][debug] tutorial showing users how to use commdebugmode and giving access to visual browser --- .../distributed_comm_debug_mode.rst | 29 ++++++++++++++++--- recipes_source/recipes_index.rst | 8 +++++ 2 files changed, 33 insertions(+), 4 deletions(-) rename distributed_comm_debug_mode.rst => recipes_source/distributed_comm_debug_mode.rst (74%) diff --git a/distributed_comm_debug_mode.rst b/recipes_source/distributed_comm_debug_mode.rst similarity index 74% rename from distributed_comm_debug_mode.rst rename to recipes_source/distributed_comm_debug_mode.rst index df7661152ee..fe3b13ecd73 100644 --- a/distributed_comm_debug_mode.rst +++ b/recipes_source/distributed_comm_debug_mode.rst @@ -12,7 +12,16 @@ Prerequisites: What is CommDebugMode and why is it useful ------------------------------------------ -As the size of models continues to increase, users are seeking to leverage various combinations of parallel strategies to scale up distributed training. However, the lack of interoperability between existing solutions poses a significant challenge, primarily due to the absence of a unified abstraction that can bridge these different parallelism strategies. To address this issue, PyTorch has proposed DistributedTensor (DTensor)which abstracts away the complexities of tensor communication in distributed training, providing a seamless user experience. However, this abstraction creates a lack of transparency that can make it challenging for users to identify and resolve issues. To address this challenge, my internship project aims to develop and enhance CommDebugMode, a Python context manager that will serve as one of the primary debugging tools for DTensors. CommDebugMode is a python context manager that enables users to view when and why collective operations are happening when using DTensors, addressing this problem. +As the size of models continues to increase, users are seeking to leverage various combinations +of parallel strategies to scale up distributed training. However, the lack of interoperability +between existing solutions poses a significant challenge, primarily due to the absence of a +unified abstraction that can bridge these different parallelism strategies. To address this +issue, PyTorch has proposed DistributedTensor(DTensor) which abstracts away the complexities of +tensor communication in distributed training, providing a seamless user experience. However, +this abstraction creates a lack of transparency that can make it challenging for users to +identify and resolve issues. To address this challenge, CommDebugMode, a Python context manager +will serve as one of the primary debugging tools for DTensors, enabling users to view when and +why collective operations are happening when using DTensors, addressing this problem. How to use CommDebugMode @@ -53,15 +62,19 @@ Using CommDebugMode and getting its output is very simple. *c10d_functional.all_reduce: 1 """ -All users have to do is wrap the code running the model in CommDebugMode and call the API that they want to use to display the data. One important thing to note -is that the users can use a noise_level arguement to control how much information is displayed to the user. You can see what each noise_level will display to the user. +All users have to do is wrap the code running the model in CommDebugMode and call the API that +they want to use to display the data. One important thing to note is that the users can use a noise_level +arguement to control how much information is displayed to the user. The information below shows what each +noise level displays | 0. prints module-level collective counts | 1. prints dTensor operations not included in trivial operations, module information | 2. prints operations not included in trivial operations | 3. prints all operations -In the example above, users can see in the first picture that the collective operation, all_reduce, occurs once in the forward pass of the MLPModule. The second picture provides a greater level of detail, allowing users to pinpoint that the all-reduce operation happens in the second linear layer of the MLPModule. +In the example above, users can see in the first picture that the collective operation, all_reduce, occurs +once in the forward pass of the MLPModule. The second picture provides a greater level of detail, allowing +users to pinpoint that the all-reduce operation happens in the second linear layer of the MLPModule. Below is the interactive module tree visualization that users can upload their JSON dump to: @@ -174,3 +187,11 @@ Below is the interactive module tree visualization that users can upload their J + +Conclusion +------------------------------------------ +In conclusion, we have learned how to use CommDebugMode in order to debug Distributed Tensors +and can use future json dumps in the embedded visual browser. + +For more detailed information about CommDebugMode, please see +https://github.com/pytorch/pytorch/blob/main/torch/distributed/_tensor/examples/comm_mode_features_example.py diff --git a/recipes_source/recipes_index.rst b/recipes_source/recipes_index.rst index c9aa2947a7d..166c84177b2 100644 --- a/recipes_source/recipes_index.rst +++ b/recipes_source/recipes_index.rst @@ -395,6 +395,13 @@ Recipes are bite-sized, actionable examples of how to use specific PyTorch featu :link: ../recipes/distributed_async_checkpoint_recipe.html :tags: Distributed-Training +.. customcarditem:: + :header: Getting Started with CommDebugMode + :card_description: Learn how to use CommDebugMode for DTensors + :image: ../_static/img/thumbnails/cropped/generic-pytorch-logo.png + :link: ../recipes/distributed_comm_debug_mode.html + :tags: Distributed-Training + .. TorchServe .. customcarditem:: @@ -449,3 +456,4 @@ Recipes are bite-sized, actionable examples of how to use specific PyTorch featu /recipes/cuda_rpc /recipes/distributed_optim_torchscript /recipes/mobile_interpreter + /recipes/distributed_comm_debug_mode From 4cc0a2d975903a0b8b6551bc5cb36446a3ea4127 Mon Sep 17 00:00:00 2001 From: Anshul Sinha <50644008+sinhaanshul@users.noreply.github.com> Date: Thu, 1 Aug 2024 16:17:34 -0700 Subject: [PATCH 06/14] [dtensor][debug] tutorial showing users how to use commdebugmode and giving access to visual browser --- .../distributed_comm_debug_mode.rst | 55 ++++++++++--------- 1 file changed, 30 insertions(+), 25 deletions(-) diff --git a/recipes_source/distributed_comm_debug_mode.rst b/recipes_source/distributed_comm_debug_mode.rst index fe3b13ecd73..0ecca645ded 100644 --- a/recipes_source/distributed_comm_debug_mode.rst +++ b/recipes_source/distributed_comm_debug_mode.rst @@ -3,11 +3,10 @@ Using CommDebugMode **Author**: `Anshul Sinha `__ -Prerequisites: +Prerequisites -- `Distributed Communication Package - torch.distributed `__ - Python 3.8 - 3.11 -- PyTorch 2.2 +- PyTorch 2.2 or later What is CommDebugMode and why is it useful @@ -16,17 +15,20 @@ As the size of models continues to increase, users are seeking to leverage vario of parallel strategies to scale up distributed training. However, the lack of interoperability between existing solutions poses a significant challenge, primarily due to the absence of a unified abstraction that can bridge these different parallelism strategies. To address this -issue, PyTorch has proposed DistributedTensor(DTensor) which abstracts away the complexities of -tensor communication in distributed training, providing a seamless user experience. However, -this abstraction creates a lack of transparency that can make it challenging for users to -identify and resolve issues. To address this challenge, CommDebugMode, a Python context manager -will serve as one of the primary debugging tools for DTensors, enabling users to view when and -why collective operations are happening when using DTensors, addressing this problem. +issue, PyTorch has proposed `DistributedTensor(DTensor) +`_ +which abstracts away the complexities of tensor communication in distributed training, +providing a seamless user experience. However, this abstraction creates a lack of transparency +that can make it challenging for users to identify and resolve issues. To address this challenge, +``CommDebugMode``, a Python context manager will serve as one of the primary debugging tools for +DTensors, enabling users to view when and why collective operations are happening when using DTensors, +effectively addressing this issue. How to use CommDebugMode ------------------------ -Using CommDebugMode and getting its output is very simple. + +Here is how you can use ``CommDebugMode``: .. code-block:: python @@ -46,6 +48,8 @@ Using CommDebugMode and getting its output is very simple. # used in the visual browser below comm_mode.generate_json_dump(noise_level=2) +.. code-block:: python + """ This is what the output looks like for a MLPModule at noise level 0 Expected Output: @@ -62,19 +66,18 @@ Using CommDebugMode and getting its output is very simple. *c10d_functional.all_reduce: 1 """ -All users have to do is wrap the code running the model in CommDebugMode and call the API that -they want to use to display the data. One important thing to note is that the users can use a noise_level -arguement to control how much information is displayed to the user. The information below shows what each -noise level displays +To use ``CommDebugMode``, you must wrap the code running the model in ``CommDebugMode`` and call the API that +you want to use to display the data. You can also use a ``noise_level`` argument to control the verbosity +level of displayed information. Here is what each noise level displays: -| 0. prints module-level collective counts -| 1. prints dTensor operations not included in trivial operations, module information -| 2. prints operations not included in trivial operations -| 3. prints all operations +| 0. Prints module-level collective counts +| 1. Prints dTensor operations not included in trivial operations, module information +| 2. Prints operations not included in trivial operations +| 3. Prints all operations -In the example above, users can see in the first picture that the collective operation, all_reduce, occurs -once in the forward pass of the MLPModule. The second picture provides a greater level of detail, allowing -users to pinpoint that the all-reduce operation happens in the second linear layer of the MLPModule. +In the example above, you can see that the collective operation, all_reduce, occurs once in the forward pass +of the ``MLPModule``. Furthermore, you can use ``CommDebugMode`` pinpoint that the all-reduce operation happens +in the second linear layer of the ``MLPModule``. Below is the interactive module tree visualization that users can upload their JSON dump to: @@ -190,8 +193,10 @@ Below is the interactive module tree visualization that users can upload their J Conclusion ------------------------------------------ -In conclusion, we have learned how to use CommDebugMode in order to debug Distributed Tensors -and can use future json dumps in the embedded visual browser. -For more detailed information about CommDebugMode, please see -https://github.com/pytorch/pytorch/blob/main/torch/distributed/_tensor/examples/comm_mode_features_example.py +In this recipe, we have learned how to use ``CommDebugMode`` to debug Distributed Tensors. You can use your +own JSON outputs in the embedded visual browser. + +For more detailed information about ``CommDebugMode``, see +`comm_mode_features_example.py +`_ From 1cccebd1e9cb9ddad17646596a65b6d244801b12 Mon Sep 17 00:00:00 2001 From: Svetlana Karslioglu Date: Thu, 8 Aug 2024 11:51:27 -0700 Subject: [PATCH 07/14] Apply suggestions from code review --- recipes_source/distributed_comm_debug_mode.rst | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/recipes_source/distributed_comm_debug_mode.rst b/recipes_source/distributed_comm_debug_mode.rst index 0ecca645ded..bffbe372d51 100644 --- a/recipes_source/distributed_comm_debug_mode.rst +++ b/recipes_source/distributed_comm_debug_mode.rst @@ -76,11 +76,11 @@ level of displayed information. Here is what each noise level displays: | 3. Prints all operations In the example above, you can see that the collective operation, all_reduce, occurs once in the forward pass -of the ``MLPModule``. Furthermore, you can use ``CommDebugMode`` pinpoint that the all-reduce operation happens +of the ``MLPModule``. Furthermore, you can use ``CommDebugMode`` to pinpoint that the all-reduce operation happens in the second linear layer of the ``MLPModule``. -Below is the interactive module tree visualization that users can upload their JSON dump to: +Below is the interactive module tree visualization that you can use to upload your own JSON dump: .. raw:: html From 05069ca38db78292b2bd6e9713235d2b51f488b9 Mon Sep 17 00:00:00 2001 From: Anshul Sinha <50644008+sinhaanshul@users.noreply.github.com> Date: Thu, 8 Aug 2024 13:42:07 -0700 Subject: [PATCH 08/14] mend --- recipes_source/distributed_comm_debug_mode.rst | 15 +++++++++------ 1 file changed, 9 insertions(+), 6 deletions(-) diff --git a/recipes_source/distributed_comm_debug_mode.rst b/recipes_source/distributed_comm_debug_mode.rst index bffbe372d51..28821b553dc 100644 --- a/recipes_source/distributed_comm_debug_mode.rst +++ b/recipes_source/distributed_comm_debug_mode.rst @@ -18,8 +18,10 @@ unified abstraction that can bridge these different parallelism strategies. To a issue, PyTorch has proposed `DistributedTensor(DTensor) `_ which abstracts away the complexities of tensor communication in distributed training, -providing a seamless user experience. However, this abstraction creates a lack of transparency -that can make it challenging for users to identify and resolve issues. To address this challenge, +providing a seamless user experience. However, when dealing with existing parallelism solutions +and developing parallelism solutions using the unified abstraction like DTensor, the lack of +transparency about what and when the collective communications happens under the hood could +make it challenging for advanced users to identify and resolve issues. To address this challenge, ``CommDebugMode``, a Python context manager will serve as one of the primary debugging tools for DTensors, enabling users to view when and why collective operations are happening when using DTensors, effectively addressing this issue. @@ -31,7 +33,7 @@ How to use CommDebugMode Here is how you can use ``CommDebugMode``: .. code-block:: python - + # The model used in this example is a MLPModule that applies Tensor Parallel comm_mode = CommDebugMode() with comm_mode: output = model(inp) @@ -71,8 +73,8 @@ you want to use to display the data. You can also use a ``noise_level`` argument level of displayed information. Here is what each noise level displays: | 0. Prints module-level collective counts -| 1. Prints dTensor operations not included in trivial operations, module information -| 2. Prints operations not included in trivial operations +| 1. Prints DTensor operations (not including trivial operations), module sharding information +| 2. Prints tensor operations (not including trivial operations) | 3. Prints all operations In the example above, you can see that the collective operation, all_reduce, occurs once in the forward pass @@ -194,7 +196,8 @@ Below is the interactive module tree visualization that you can use to upload yo Conclusion ------------------------------------------ -In this recipe, we have learned how to use ``CommDebugMode`` to debug Distributed Tensors. You can use your +In this recipe, we have learned how to use ``CommDebugMode`` to debug Distributed Tensors and +parallelism solutions that uses communication collectives with PyTorch. You can use your own JSON outputs in the embedded visual browser. For more detailed information about ``CommDebugMode``, see From 4bc065723ca23db9bd954a495eeab48877bc38cd Mon Sep 17 00:00:00 2001 From: Anshul Sinha <50644008+sinhaanshul@users.noreply.github.com> Date: Thu, 8 Aug 2024 13:42:07 -0700 Subject: [PATCH 09/14] mend --- recipes_source/distributed_comm_debug_mode.rst | 1 + 1 file changed, 1 insertion(+) diff --git a/recipes_source/distributed_comm_debug_mode.rst b/recipes_source/distributed_comm_debug_mode.rst index 28821b553dc..ed165cb0eda 100644 --- a/recipes_source/distributed_comm_debug_mode.rst +++ b/recipes_source/distributed_comm_debug_mode.rst @@ -33,6 +33,7 @@ How to use CommDebugMode Here is how you can use ``CommDebugMode``: .. code-block:: python + # The model used in this example is a MLPModule that applies Tensor Parallel comm_mode = CommDebugMode() with comm_mode: From beaab5c6ba6adff59681b837c510750d74b527c3 Mon Sep 17 00:00:00 2001 From: Anshul Sinha <50644008+sinhaanshul@users.noreply.github.com> Date: Thu, 8 Aug 2024 13:42:07 -0700 Subject: [PATCH 10/14] mend --- recipes_source/distributed_comm_debug_mode.rst | 14 +++++--------- 1 file changed, 5 insertions(+), 9 deletions(-) diff --git a/recipes_source/distributed_comm_debug_mode.rst b/recipes_source/distributed_comm_debug_mode.rst index ed165cb0eda..bffbe372d51 100644 --- a/recipes_source/distributed_comm_debug_mode.rst +++ b/recipes_source/distributed_comm_debug_mode.rst @@ -18,10 +18,8 @@ unified abstraction that can bridge these different parallelism strategies. To a issue, PyTorch has proposed `DistributedTensor(DTensor) `_ which abstracts away the complexities of tensor communication in distributed training, -providing a seamless user experience. However, when dealing with existing parallelism solutions -and developing parallelism solutions using the unified abstraction like DTensor, the lack of -transparency about what and when the collective communications happens under the hood could -make it challenging for advanced users to identify and resolve issues. To address this challenge, +providing a seamless user experience. However, this abstraction creates a lack of transparency +that can make it challenging for users to identify and resolve issues. To address this challenge, ``CommDebugMode``, a Python context manager will serve as one of the primary debugging tools for DTensors, enabling users to view when and why collective operations are happening when using DTensors, effectively addressing this issue. @@ -34,7 +32,6 @@ Here is how you can use ``CommDebugMode``: .. code-block:: python - # The model used in this example is a MLPModule that applies Tensor Parallel comm_mode = CommDebugMode() with comm_mode: output = model(inp) @@ -74,8 +71,8 @@ you want to use to display the data. You can also use a ``noise_level`` argument level of displayed information. Here is what each noise level displays: | 0. Prints module-level collective counts -| 1. Prints DTensor operations (not including trivial operations), module sharding information -| 2. Prints tensor operations (not including trivial operations) +| 1. Prints dTensor operations not included in trivial operations, module information +| 2. Prints operations not included in trivial operations | 3. Prints all operations In the example above, you can see that the collective operation, all_reduce, occurs once in the forward pass @@ -197,8 +194,7 @@ Below is the interactive module tree visualization that you can use to upload yo Conclusion ------------------------------------------ -In this recipe, we have learned how to use ``CommDebugMode`` to debug Distributed Tensors and -parallelism solutions that uses communication collectives with PyTorch. You can use your +In this recipe, we have learned how to use ``CommDebugMode`` to debug Distributed Tensors. You can use your own JSON outputs in the embedded visual browser. For more detailed information about ``CommDebugMode``, see From 4c6b4b151ae1888996736ee745bbb885573f0ea2 Mon Sep 17 00:00:00 2001 From: Anshul Sinha <50644008+sinhaanshul@users.noreply.github.com> Date: Thu, 8 Aug 2024 13:42:07 -0700 Subject: [PATCH 11/14] mend --- recipes_source/distributed_comm_debug_mode.rst | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/recipes_source/distributed_comm_debug_mode.rst b/recipes_source/distributed_comm_debug_mode.rst index bffbe372d51..f36fd0b009e 100644 --- a/recipes_source/distributed_comm_debug_mode.rst +++ b/recipes_source/distributed_comm_debug_mode.rst @@ -18,11 +18,13 @@ unified abstraction that can bridge these different parallelism strategies. To a issue, PyTorch has proposed `DistributedTensor(DTensor) `_ which abstracts away the complexities of tensor communication in distributed training, -providing a seamless user experience. However, this abstraction creates a lack of transparency -that can make it challenging for users to identify and resolve issues. To address this challenge, -``CommDebugMode``, a Python context manager will serve as one of the primary debugging tools for -DTensors, enabling users to view when and why collective operations are happening when using DTensors, -effectively addressing this issue. +providing a seamless user experience. However, when dealing with existing parallelism solutions and +developing parallelism solutions using the unified abstraction like DTensor, the lack of transparency +about what and when the collective communications happens under the hood could make it challenging +for advanced users to identify and resolve issues. To address this challenge, ``CommDebugMode``, a +Python context manager will serve as one of the primary debugging tools for DTensors, enabling +users to view when and why collective operations are happening when using DTensors, effectively +addressing this issue. How to use CommDebugMode From ba97b6431934af3cc325f7bf69c05da79aac11e9 Mon Sep 17 00:00:00 2001 From: Svetlana Karslioglu Date: Thu, 8 Aug 2024 11:51:27 -0700 Subject: [PATCH 12/14] Apply suggestions from code review --- recipes_source/distributed_comm_debug_mode.rst | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/recipes_source/distributed_comm_debug_mode.rst b/recipes_source/distributed_comm_debug_mode.rst index f36fd0b009e..7ebeb0b95ab 100644 --- a/recipes_source/distributed_comm_debug_mode.rst +++ b/recipes_source/distributed_comm_debug_mode.rst @@ -34,6 +34,7 @@ Here is how you can use ``CommDebugMode``: .. code-block:: python + # The model used in this example is a MLPModule applying Tensor Parallel comm_mode = CommDebugMode() with comm_mode: output = model(inp) @@ -73,8 +74,8 @@ you want to use to display the data. You can also use a ``noise_level`` argument level of displayed information. Here is what each noise level displays: | 0. Prints module-level collective counts -| 1. Prints dTensor operations not included in trivial operations, module information -| 2. Prints operations not included in trivial operations +| 1. Prints DTensor operations (not including trivial operations), module sharding information +| 2. Prints tensor operations (not including trivial operations) | 3. Prints all operations In the example above, you can see that the collective operation, all_reduce, occurs once in the forward pass @@ -196,8 +197,9 @@ Below is the interactive module tree visualization that you can use to upload yo Conclusion ------------------------------------------ -In this recipe, we have learned how to use ``CommDebugMode`` to debug Distributed Tensors. You can use your -own JSON outputs in the embedded visual browser. +In this recipe, we have learned how to use ``CommDebugMode`` to debug Distributed Tensors and +parallelism solutions that uses communication collectives with PyTorch. You can use your own +JSON outputs in the embedded visual browser. For more detailed information about ``CommDebugMode``, see `comm_mode_features_example.py From c62b4f1d5d562388c11964de66656c4ef33261a6 Mon Sep 17 00:00:00 2001 From: Svetlana Karslioglu Date: Thu, 8 Aug 2024 15:41:18 -0700 Subject: [PATCH 13/14] Update recipes_source/distributed_comm_debug_mode.rst --- recipes_source/distributed_comm_debug_mode.rst | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/recipes_source/distributed_comm_debug_mode.rst b/recipes_source/distributed_comm_debug_mode.rst index 7ebeb0b95ab..431b433aca7 100644 --- a/recipes_source/distributed_comm_debug_mode.rst +++ b/recipes_source/distributed_comm_debug_mode.rst @@ -3,10 +3,14 @@ Using CommDebugMode **Author**: `Anshul Sinha `__ + +In this tutorial, we will explore how to use CommDebugMode with PyTorch's DistributedTensor (DTensor) for debugging by tracking collective operations in distributed training environments. + Prerequisites +--------------------- -- Python 3.8 - 3.11 -- PyTorch 2.2 or later +* Python 3.8 - 3.11 +* PyTorch 2.2 or later What is CommDebugMode and why is it useful From 86fbaabeb3bc231de44c30893deb2addcb8991a2 Mon Sep 17 00:00:00 2001 From: Svetlana Karslioglu Date: Thu, 8 Aug 2024 15:41:41 -0700 Subject: [PATCH 14/14] Update recipes_source/distributed_comm_debug_mode.rst --- recipes_source/distributed_comm_debug_mode.rst | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/recipes_source/distributed_comm_debug_mode.rst b/recipes_source/distributed_comm_debug_mode.rst index 431b433aca7..f7e3265802f 100644 --- a/recipes_source/distributed_comm_debug_mode.rst +++ b/recipes_source/distributed_comm_debug_mode.rst @@ -4,7 +4,8 @@ Using CommDebugMode **Author**: `Anshul Sinha `__ -In this tutorial, we will explore how to use CommDebugMode with PyTorch's DistributedTensor (DTensor) for debugging by tracking collective operations in distributed training environments. +In this tutorial, we will explore how to use CommDebugMode with PyTorch's +DistributedTensor (DTensor) for debugging by tracking collective operations in distributed training environments. Prerequisites ---------------------