Skip to content

Commit 6a7b26c

Browse files
XilunWusvekars
andauthored
[dtensor][debug] CommDebugMode recipe (#3001)
* Add [dtensor][debug] CommDebugMode recipe --------- Co-authored-by: Svetlana Karslioglu <svekars@meta.com>
1 parent a811de6 commit 6a7b26c

File tree

2 files changed

+218
-0
lines changed

2 files changed

+218
-0
lines changed
Lines changed: 210 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,210 @@
1+
Getting Started with ``CommDebugMode``
2+
=====================================================
3+
4+
**Author**: `Anshul Sinha <https://github.com/sinhaanshul>`__
5+
6+
7+
In this tutorial, we will explore how to use ``CommDebugMode`` with PyTorch's
8+
DistributedTensor (DTensor) for debugging by tracking collective operations in distributed training environments.
9+
10+
Prerequisites
11+
---------------------
12+
13+
* Python 3.8 - 3.11
14+
* PyTorch 2.2 or later
15+
16+
17+
What is ``CommDebugMode`` and why is it useful
18+
----------------------------------------------------
19+
As the size of models continues to increase, users are seeking to leverage various combinations
20+
of parallel strategies to scale up distributed training. However, the lack of interoperability
21+
between existing solutions poses a significant challenge, primarily due to the absence of a
22+
unified abstraction that can bridge these different parallelism strategies. To address this
23+
issue, PyTorch has proposed `DistributedTensor(DTensor)
24+
<https://github.com/pytorch/pytorch/blob/main/torch/distributed/_tensor/examples/comm_mode_features_example.py>`_
25+
which abstracts away the complexities of tensor communication in distributed training,
26+
providing a seamless user experience. However, when dealing with existing parallelism solutions and
27+
developing parallelism solutions using the unified abstraction like DTensor, the lack of transparency
28+
about what and when the collective communications happens under the hood could make it challenging
29+
for advanced users to identify and resolve issues. To address this challenge, ``CommDebugMode``, a
30+
Python context manager will serve as one of the primary debugging tools for DTensors, enabling
31+
users to view when and why collective operations are happening when using DTensors, effectively
32+
addressing this issue.
33+
34+
35+
Using ``CommDebugMode``
36+
------------------------
37+
38+
Here is how you can use ``CommDebugMode``:
39+
40+
.. code-block:: python
41+
42+
# The model used in this example is a MLPModule applying Tensor Parallel
43+
comm_mode = CommDebugMode()
44+
with comm_mode:
45+
output = model(inp)
46+
47+
# print the operation level collective tracing information
48+
print(comm_mode.generate_comm_debug_tracing_table(noise_level=0))
49+
50+
# log the operation level collective tracing information to a file
51+
comm_mode.log_comm_debug_tracing_table_to_file(
52+
noise_level=1, file_name="transformer_operation_log.txt"
53+
)
54+
55+
# dump the operation level collective tracing information to json file,
56+
# used in the visual browser below
57+
comm_mode.generate_json_dump(noise_level=2)
58+
59+
This is what the output looks like for a MLPModule at noise level 0:
60+
61+
.. code-block:: python
62+
63+
Expected Output:
64+
Global
65+
FORWARD PASS
66+
*c10d_functional.all_reduce: 1
67+
MLPModule
68+
FORWARD PASS
69+
*c10d_functional.all_reduce: 1
70+
MLPModule.net1
71+
MLPModule.relu
72+
MLPModule.net2
73+
FORWARD PASS
74+
*c10d_functional.all_reduce: 1
75+
76+
To use ``CommDebugMode``, you must wrap the code running the model in ``CommDebugMode`` and call the API that
77+
you want to use to display the data. You can also use a ``noise_level`` argument to control the verbosity
78+
level of displayed information. Here is what each noise level displays:
79+
80+
| 0. Prints module-level collective counts
81+
| 1. Prints DTensor operations (not including trivial operations), module sharding information
82+
| 2. Prints tensor operations (not including trivial operations)
83+
| 3. Prints all operations
84+
85+
In the example above, you can see that the collective operation, all_reduce, occurs once in the forward pass
86+
of the ``MLPModule``. Furthermore, you can use ``CommDebugMode`` to pinpoint that the all-reduce operation happens
87+
in the second linear layer of the ``MLPModule``.
88+
89+
90+
Below is the interactive module tree visualization that you can use to upload your own JSON dump:
91+
92+
.. raw:: html
93+
94+
<!DOCTYPE html>
95+
<html lang ="en">
96+
<head>
97+
<meta charset="UTF-8">
98+
<meta name = "viewport" content="width=device-width, initial-scale=1.0">
99+
<title>CommDebugMode Module Tree</title>
100+
<style>
101+
ul, #tree-container {
102+
list-style-type: none;
103+
margin: 0;
104+
padding: 0;
105+
}
106+
.caret {
107+
cursor: pointer;
108+
user-select: none;
109+
}
110+
.caret::before {
111+
content: "\25B6";
112+
color:black;
113+
display: inline-block;
114+
margin-right: 6px;
115+
}
116+
.caret-down::before {
117+
transform: rotate(90deg);
118+
}
119+
.tree {
120+
padding-left: 20px;
121+
}
122+
.tree ul {
123+
padding-left: 20px;
124+
}
125+
.nested {
126+
display: none;
127+
}
128+
.active {
129+
display: block;
130+
}
131+
.forward-pass,
132+
.backward-pass {
133+
margin-left: 40px;
134+
}
135+
.forward-pass table {
136+
margin-left: 40px;
137+
width: auto;
138+
}
139+
.forward-pass table td, .forward-pass table th {
140+
padding: 8px;
141+
}
142+
.forward-pass ul {
143+
display: none;
144+
}
145+
table {
146+
font-family: arial, sans-serif;
147+
border-collapse: collapse;
148+
width: 100%;
149+
}
150+
td, th {
151+
border: 1px solid #dddddd;
152+
text-align: left;
153+
padding: 8px;
154+
}
155+
tr:nth-child(even) {
156+
background-color: #dddddd;
157+
}
158+
#drop-area {
159+
position: relative;
160+
width: 25%;
161+
height: 100px;
162+
border: 2px dashed #ccc;
163+
border-radius: 5px;
164+
padding: 0px;
165+
text-align: center;
166+
}
167+
.drag-drop-block {
168+
display: inline-block;
169+
width: 200px;
170+
height: 50px;
171+
background-color: #f7f7f7;
172+
border: 1px solid #ccc;
173+
border-radius: 5px;
174+
padding: 10px;
175+
font-size: 14px;
176+
color: #666;
177+
cursor: pointer;
178+
}
179+
#file-input {
180+
position: absolute;
181+
top: 0;
182+
left: 0;
183+
width: 100%;
184+
height: 100%;
185+
opacity: 0;
186+
}
187+
</style>
188+
</head>
189+
<body>
190+
<div id="drop-area">
191+
<div class="drag-drop-block">
192+
<span>Drag file here</span>
193+
</div>
194+
<input type="file" id="file-input" accept=".json">
195+
</div>
196+
<div id="tree-container"></div>
197+
<script src="https://cdn.jsdelivr.net/gh/pytorch/pytorch@main/torch/distributed/_tensor/debug/comm_mode_broswer_visual.js"></script>
198+
</body>
199+
</html>
200+
201+
Conclusion
202+
------------------------------------------
203+
204+
In this recipe, we have learned how to use ``CommDebugMode`` to debug Distributed Tensors and
205+
parallelism solutions that uses communication collectives with PyTorch. You can use your own
206+
JSON outputs in the embedded visual browser.
207+
208+
For more detailed information about ``CommDebugMode``, see
209+
`comm_mode_features_example.py
210+
<https://github.com/pytorch/pytorch/blob/main/torch/distributed/_tensor/examples/comm_mode_features_example.py>`_

recipes_source/recipes_index.rst

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -395,6 +395,13 @@ Recipes are bite-sized, actionable examples of how to use specific PyTorch featu
395395
:link: ../recipes/distributed_async_checkpoint_recipe.html
396396
:tags: Distributed-Training
397397

398+
.. customcarditem::
399+
:header: Getting Started with CommDebugMode
400+
:card_description: Learn how to use CommDebugMode for DTensors
401+
:image: ../_static/img/thumbnails/cropped/generic-pytorch-logo.png
402+
:link: ../recipes/distributed_comm_debug_mode.html
403+
:tags: Distributed-Training
404+
398405
.. TorchServe
399406
400407
.. customcarditem::
@@ -449,3 +456,4 @@ Recipes are bite-sized, actionable examples of how to use specific PyTorch featu
449456
/recipes/cuda_rpc
450457
/recipes/distributed_optim_torchscript
451458
/recipes/mobile_interpreter
459+
/recipes/distributed_comm_debug_mode

0 commit comments

Comments
 (0)