From 44c6e8d69db47f56d495901fee6084a174e257d5 Mon Sep 17 00:00:00 2001 From: Xilun Wu <12968408+XilunWu@users.noreply.github.com> Date: Fri, 16 Aug 2024 13:36:59 -0700 Subject: [PATCH 1/3] [dtensor][debug] CommDebugMode recipe --- .../distributed_comm_debug_mode.rst | 211 ++++++++++++++++++ recipes_source/recipes_index.rst | 8 + 2 files changed, 219 insertions(+) create mode 100644 recipes_source/distributed_comm_debug_mode.rst diff --git a/recipes_source/distributed_comm_debug_mode.rst b/recipes_source/distributed_comm_debug_mode.rst new file mode 100644 index 00000000000..f7e3265802f --- /dev/null +++ b/recipes_source/distributed_comm_debug_mode.rst @@ -0,0 +1,211 @@ +Using CommDebugMode +===================================================== + +**Author**: `Anshul Sinha `__ + + +In this tutorial, we will explore how to use CommDebugMode with PyTorch's +DistributedTensor (DTensor) for debugging by tracking collective operations in distributed training environments. + +Prerequisites +--------------------- + +* Python 3.8 - 3.11 +* PyTorch 2.2 or later + + +What is CommDebugMode and why is it useful +------------------------------------------ +As the size of models continues to increase, users are seeking to leverage various combinations +of parallel strategies to scale up distributed training. However, the lack of interoperability +between existing solutions poses a significant challenge, primarily due to the absence of a +unified abstraction that can bridge these different parallelism strategies. To address this +issue, PyTorch has proposed `DistributedTensor(DTensor) +`_ +which abstracts away the complexities of tensor communication in distributed training, +providing a seamless user experience. However, when dealing with existing parallelism solutions and +developing parallelism solutions using the unified abstraction like DTensor, the lack of transparency +about what and when the collective communications happens under the hood could make it challenging +for advanced users to identify and resolve issues. To address this challenge, ``CommDebugMode``, a +Python context manager will serve as one of the primary debugging tools for DTensors, enabling +users to view when and why collective operations are happening when using DTensors, effectively +addressing this issue. + + +How to use CommDebugMode +------------------------ + +Here is how you can use ``CommDebugMode``: + +.. code-block:: python + + # The model used in this example is a MLPModule applying Tensor Parallel + comm_mode = CommDebugMode() + with comm_mode: + output = model(inp) + + # print the operation level collective tracing information + print(comm_mode.generate_comm_debug_tracing_table(noise_level=0)) + + # log the operation level collective tracing information to a file + comm_mode.log_comm_debug_tracing_table_to_file( + noise_level=1, file_name="transformer_operation_log.txt" + ) + + # dump the operation level collective tracing information to json file, + # used in the visual browser below + comm_mode.generate_json_dump(noise_level=2) + +.. code-block:: python + + """ + This is what the output looks like for a MLPModule at noise level 0 + Expected Output: + Global + FORWARD PASS + *c10d_functional.all_reduce: 1 + MLPModule + FORWARD PASS + *c10d_functional.all_reduce: 1 + MLPModule.net1 + MLPModule.relu + MLPModule.net2 + FORWARD PASS + *c10d_functional.all_reduce: 1 + """ + +To use ``CommDebugMode``, you must wrap the code running the model in ``CommDebugMode`` and call the API that +you want to use to display the data. You can also use a ``noise_level`` argument to control the verbosity +level of displayed information. Here is what each noise level displays: + +| 0. Prints module-level collective counts +| 1. Prints DTensor operations (not including trivial operations), module sharding information +| 2. Prints tensor operations (not including trivial operations) +| 3. Prints all operations + +In the example above, you can see that the collective operation, all_reduce, occurs once in the forward pass +of the ``MLPModule``. Furthermore, you can use ``CommDebugMode`` to pinpoint that the all-reduce operation happens +in the second linear layer of the ``MLPModule``. + + +Below is the interactive module tree visualization that you can use to upload your own JSON dump: + +.. raw:: html + + + + + + + CommDebugMode Module Tree + + + +
+
+ Drag file here +
+ +
+
+ + + + +Conclusion +------------------------------------------ + +In this recipe, we have learned how to use ``CommDebugMode`` to debug Distributed Tensors and +parallelism solutions that uses communication collectives with PyTorch. You can use your own +JSON outputs in the embedded visual browser. + +For more detailed information about ``CommDebugMode``, see +`comm_mode_features_example.py +`_ diff --git a/recipes_source/recipes_index.rst b/recipes_source/recipes_index.rst index 8959ea98a38..d94d7d5c22e 100644 --- a/recipes_source/recipes_index.rst +++ b/recipes_source/recipes_index.rst @@ -395,6 +395,13 @@ Recipes are bite-sized, actionable examples of how to use specific PyTorch featu :link: ../recipes/distributed_async_checkpoint_recipe.html :tags: Distributed-Training +.. customcarditem:: + :header: Getting Started with CommDebugMode + :card_description: Learn how to use CommDebugMode for DTensors + :image: ../_static/img/thumbnails/cropped/generic-pytorch-logo.png + :link: ../recipes/distributed_comm_debug_mode.html + :tags: Distributed-Training + .. TorchServe .. customcarditem:: @@ -449,3 +456,4 @@ Recipes are bite-sized, actionable examples of how to use specific PyTorch featu /recipes/cuda_rpc /recipes/distributed_optim_torchscript /recipes/mobile_interpreter + /recipes/distributed_comm_debug_mode From cb37fe40c663565a9e46817a3a3fa12a0f95b342 Mon Sep 17 00:00:00 2001 From: Svetlana Karslioglu Date: Mon, 19 Aug 2024 08:24:38 -0700 Subject: [PATCH 2/3] Apply suggestions from code review --- recipes_source/distributed_comm_debug_mode.rst | 13 ++++++------- 1 file changed, 6 insertions(+), 7 deletions(-) diff --git a/recipes_source/distributed_comm_debug_mode.rst b/recipes_source/distributed_comm_debug_mode.rst index f7e3265802f..4d715d34fd0 100644 --- a/recipes_source/distributed_comm_debug_mode.rst +++ b/recipes_source/distributed_comm_debug_mode.rst @@ -1,10 +1,10 @@ -Using CommDebugMode +Getting Started with ``CommDebugMode`` ===================================================== **Author**: `Anshul Sinha `__ -In this tutorial, we will explore how to use CommDebugMode with PyTorch's +In this tutorial, we will explore how to use ``CommDebugMode`` with PyTorch's DistributedTensor (DTensor) for debugging by tracking collective operations in distributed training environments. Prerequisites @@ -14,7 +14,7 @@ Prerequisites * PyTorch 2.2 or later -What is CommDebugMode and why is it useful +What is ``CommDebugMode`` and why is it useful ------------------------------------------ As the size of models continues to increase, users are seeking to leverage various combinations of parallel strategies to scale up distributed training. However, the lack of interoperability @@ -32,7 +32,7 @@ users to view when and why collective operations are happening when using DTenso addressing this issue. -How to use CommDebugMode +Using ``CommDebugMode`` ------------------------ Here is how you can use ``CommDebugMode``: @@ -56,10 +56,10 @@ Here is how you can use ``CommDebugMode``: # used in the visual browser below comm_mode.generate_json_dump(noise_level=2) +This is what the output looks like for a MLPModule at noise level 0: + .. code-block:: python - """ - This is what the output looks like for a MLPModule at noise level 0 Expected Output: Global FORWARD PASS @@ -72,7 +72,6 @@ Here is how you can use ``CommDebugMode``: MLPModule.net2 FORWARD PASS *c10d_functional.all_reduce: 1 - """ To use ``CommDebugMode``, you must wrap the code running the model in ``CommDebugMode`` and call the API that you want to use to display the data. You can also use a ``noise_level`` argument to control the verbosity From f674c2f65401ba013c6a863c0f39193843dee172 Mon Sep 17 00:00:00 2001 From: Svetlana Karslioglu Date: Mon, 19 Aug 2024 08:25:14 -0700 Subject: [PATCH 3/3] Update recipes_source/distributed_comm_debug_mode.rst --- recipes_source/distributed_comm_debug_mode.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/recipes_source/distributed_comm_debug_mode.rst b/recipes_source/distributed_comm_debug_mode.rst index 4d715d34fd0..db79cdc8992 100644 --- a/recipes_source/distributed_comm_debug_mode.rst +++ b/recipes_source/distributed_comm_debug_mode.rst @@ -15,7 +15,7 @@ Prerequisites What is ``CommDebugMode`` and why is it useful ------------------------------------------- +---------------------------------------------------- As the size of models continues to increase, users are seeking to leverage various combinations of parallel strategies to scale up distributed training. However, the lack of interoperability between existing solutions poses a significant challenge, primarily due to the absence of a