Skip to content

Commit ab4fd25

Browse files
authored
Merge branch 'main' into add-programmable-search
2 parents 221eb01 + 6a7b26c commit ab4fd25

File tree

3 files changed

+219
-3
lines changed

3 files changed

+219
-3
lines changed

intermediate_source/TP_tutorial.rst

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -83,8 +83,6 @@ To see how to utilize DeviceMesh to set up multi-dimensional parallelisms, pleas
8383

8484
.. code-block:: python
8585
86-
# run this via torchrun: torchrun --standalone --nproc_per_node=8 ./tp_tutorial.py
87-
8886
from torch.distributed.device_mesh import init_device_mesh
8987
9088
tp_mesh = init_device_mesh("cuda", (8,))
@@ -360,4 +358,4 @@ Conclusion
360358
This tutorial demonstrates how to train a large Transformer-like model across hundreds to thousands of GPUs using Tensor Parallel in combination with Fully Sharded Data Parallel.
361359
It explains how to apply Tensor Parallel to different parts of the model, with **no code changes** to the model itself. Tensor Parallel is a efficient model parallelism technique for large scale training.
362360

363-
To see the complete end to end code example explained in this tutorial, please refer to the `Tensor Parallel examples <https://github.com/pytorch/examples/blob/main/distributed/tensor_parallelism/fsdp_tp_example.py>`__ in the pytorch/examples repository.
361+
To see the complete end-to-end code example explained in this tutorial, please refer to the `Tensor Parallel examples <https://github.com/pytorch/examples/blob/main/distributed/tensor_parallelism/fsdp_tp_example.py>`__ in the pytorch/examples repository.
Lines changed: 210 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,210 @@
1+
Getting Started with ``CommDebugMode``
2+
=====================================================
3+
4+
**Author**: `Anshul Sinha <https://github.com/sinhaanshul>`__
5+
6+
7+
In this tutorial, we will explore how to use ``CommDebugMode`` with PyTorch's
8+
DistributedTensor (DTensor) for debugging by tracking collective operations in distributed training environments.
9+
10+
Prerequisites
11+
---------------------
12+
13+
* Python 3.8 - 3.11
14+
* PyTorch 2.2 or later
15+
16+
17+
What is ``CommDebugMode`` and why is it useful
18+
----------------------------------------------------
19+
As the size of models continues to increase, users are seeking to leverage various combinations
20+
of parallel strategies to scale up distributed training. However, the lack of interoperability
21+
between existing solutions poses a significant challenge, primarily due to the absence of a
22+
unified abstraction that can bridge these different parallelism strategies. To address this
23+
issue, PyTorch has proposed `DistributedTensor(DTensor)
24+
<https://github.com/pytorch/pytorch/blob/main/torch/distributed/_tensor/examples/comm_mode_features_example.py>`_
25+
which abstracts away the complexities of tensor communication in distributed training,
26+
providing a seamless user experience. However, when dealing with existing parallelism solutions and
27+
developing parallelism solutions using the unified abstraction like DTensor, the lack of transparency
28+
about what and when the collective communications happens under the hood could make it challenging
29+
for advanced users to identify and resolve issues. To address this challenge, ``CommDebugMode``, a
30+
Python context manager will serve as one of the primary debugging tools for DTensors, enabling
31+
users to view when and why collective operations are happening when using DTensors, effectively
32+
addressing this issue.
33+
34+
35+
Using ``CommDebugMode``
36+
------------------------
37+
38+
Here is how you can use ``CommDebugMode``:
39+
40+
.. code-block:: python
41+
42+
# The model used in this example is a MLPModule applying Tensor Parallel
43+
comm_mode = CommDebugMode()
44+
with comm_mode:
45+
output = model(inp)
46+
47+
# print the operation level collective tracing information
48+
print(comm_mode.generate_comm_debug_tracing_table(noise_level=0))
49+
50+
# log the operation level collective tracing information to a file
51+
comm_mode.log_comm_debug_tracing_table_to_file(
52+
noise_level=1, file_name="transformer_operation_log.txt"
53+
)
54+
55+
# dump the operation level collective tracing information to json file,
56+
# used in the visual browser below
57+
comm_mode.generate_json_dump(noise_level=2)
58+
59+
This is what the output looks like for a MLPModule at noise level 0:
60+
61+
.. code-block:: python
62+
63+
Expected Output:
64+
Global
65+
FORWARD PASS
66+
*c10d_functional.all_reduce: 1
67+
MLPModule
68+
FORWARD PASS
69+
*c10d_functional.all_reduce: 1
70+
MLPModule.net1
71+
MLPModule.relu
72+
MLPModule.net2
73+
FORWARD PASS
74+
*c10d_functional.all_reduce: 1
75+
76+
To use ``CommDebugMode``, you must wrap the code running the model in ``CommDebugMode`` and call the API that
77+
you want to use to display the data. You can also use a ``noise_level`` argument to control the verbosity
78+
level of displayed information. Here is what each noise level displays:
79+
80+
| 0. Prints module-level collective counts
81+
| 1. Prints DTensor operations (not including trivial operations), module sharding information
82+
| 2. Prints tensor operations (not including trivial operations)
83+
| 3. Prints all operations
84+
85+
In the example above, you can see that the collective operation, all_reduce, occurs once in the forward pass
86+
of the ``MLPModule``. Furthermore, you can use ``CommDebugMode`` to pinpoint that the all-reduce operation happens
87+
in the second linear layer of the ``MLPModule``.
88+
89+
90+
Below is the interactive module tree visualization that you can use to upload your own JSON dump:
91+
92+
.. raw:: html
93+
94+
<!DOCTYPE html>
95+
<html lang ="en">
96+
<head>
97+
<meta charset="UTF-8">
98+
<meta name = "viewport" content="width=device-width, initial-scale=1.0">
99+
<title>CommDebugMode Module Tree</title>
100+
<style>
101+
ul, #tree-container {
102+
list-style-type: none;
103+
margin: 0;
104+
padding: 0;
105+
}
106+
.caret {
107+
cursor: pointer;
108+
user-select: none;
109+
}
110+
.caret::before {
111+
content: "\25B6";
112+
color:black;
113+
display: inline-block;
114+
margin-right: 6px;
115+
}
116+
.caret-down::before {
117+
transform: rotate(90deg);
118+
}
119+
.tree {
120+
padding-left: 20px;
121+
}
122+
.tree ul {
123+
padding-left: 20px;
124+
}
125+
.nested {
126+
display: none;
127+
}
128+
.active {
129+
display: block;
130+
}
131+
.forward-pass,
132+
.backward-pass {
133+
margin-left: 40px;
134+
}
135+
.forward-pass table {
136+
margin-left: 40px;
137+
width: auto;
138+
}
139+
.forward-pass table td, .forward-pass table th {
140+
padding: 8px;
141+
}
142+
.forward-pass ul {
143+
display: none;
144+
}
145+
table {
146+
font-family: arial, sans-serif;
147+
border-collapse: collapse;
148+
width: 100%;
149+
}
150+
td, th {
151+
border: 1px solid #dddddd;
152+
text-align: left;
153+
padding: 8px;
154+
}
155+
tr:nth-child(even) {
156+
background-color: #dddddd;
157+
}
158+
#drop-area {
159+
position: relative;
160+
width: 25%;
161+
height: 100px;
162+
border: 2px dashed #ccc;
163+
border-radius: 5px;
164+
padding: 0px;
165+
text-align: center;
166+
}
167+
.drag-drop-block {
168+
display: inline-block;
169+
width: 200px;
170+
height: 50px;
171+
background-color: #f7f7f7;
172+
border: 1px solid #ccc;
173+
border-radius: 5px;
174+
padding: 10px;
175+
font-size: 14px;
176+
color: #666;
177+
cursor: pointer;
178+
}
179+
#file-input {
180+
position: absolute;
181+
top: 0;
182+
left: 0;
183+
width: 100%;
184+
height: 100%;
185+
opacity: 0;
186+
}
187+
</style>
188+
</head>
189+
<body>
190+
<div id="drop-area">
191+
<div class="drag-drop-block">
192+
<span>Drag file here</span>
193+
</div>
194+
<input type="file" id="file-input" accept=".json">
195+
</div>
196+
<div id="tree-container"></div>
197+
<script src="https://cdn.jsdelivr.net/gh/pytorch/pytorch@main/torch/distributed/_tensor/debug/comm_mode_broswer_visual.js"></script>
198+
</body>
199+
</html>
200+
201+
Conclusion
202+
------------------------------------------
203+
204+
In this recipe, we have learned how to use ``CommDebugMode`` to debug Distributed Tensors and
205+
parallelism solutions that uses communication collectives with PyTorch. You can use your own
206+
JSON outputs in the embedded visual browser.
207+
208+
For more detailed information about ``CommDebugMode``, see
209+
`comm_mode_features_example.py
210+
<https://github.com/pytorch/pytorch/blob/main/torch/distributed/_tensor/examples/comm_mode_features_example.py>`_

recipes_source/recipes_index.rst

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -395,6 +395,13 @@ Recipes are bite-sized, actionable examples of how to use specific PyTorch featu
395395
:link: ../recipes/distributed_async_checkpoint_recipe.html
396396
:tags: Distributed-Training
397397

398+
.. customcarditem::
399+
:header: Getting Started with CommDebugMode
400+
:card_description: Learn how to use CommDebugMode for DTensors
401+
:image: ../_static/img/thumbnails/cropped/generic-pytorch-logo.png
402+
:link: ../recipes/distributed_comm_debug_mode.html
403+
:tags: Distributed-Training
404+
398405
.. TorchServe
399406
400407
.. customcarditem::
@@ -449,3 +456,4 @@ Recipes are bite-sized, actionable examples of how to use specific PyTorch featu
449456
/recipes/cuda_rpc
450457
/recipes/distributed_optim_torchscript
451458
/recipes/mobile_interpreter
459+
/recipes/distributed_comm_debug_mode

0 commit comments

Comments
 (0)