Skip to content

Commit 44c6e8d

Browse files
committed
[dtensor][debug] CommDebugMode recipe
1 parent f1c0b8a commit 44c6e8d

File tree

2 files changed

+219
-0
lines changed

2 files changed

+219
-0
lines changed
Lines changed: 211 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,211 @@
1+
Using CommDebugMode
2+
=====================================================
3+
4+
**Author**: `Anshul Sinha <https://github.com/sinhaanshul>`__
5+
6+
7+
In this tutorial, we will explore how to use CommDebugMode with PyTorch's
8+
DistributedTensor (DTensor) for debugging by tracking collective operations in distributed training environments.
9+
10+
Prerequisites
11+
---------------------
12+
13+
* Python 3.8 - 3.11
14+
* PyTorch 2.2 or later
15+
16+
17+
What is CommDebugMode and why is it useful
18+
------------------------------------------
19+
As the size of models continues to increase, users are seeking to leverage various combinations
20+
of parallel strategies to scale up distributed training. However, the lack of interoperability
21+
between existing solutions poses a significant challenge, primarily due to the absence of a
22+
unified abstraction that can bridge these different parallelism strategies. To address this
23+
issue, PyTorch has proposed `DistributedTensor(DTensor)
24+
<https://github.com/pytorch/pytorch/blob/main/torch/distributed/_tensor/examples/comm_mode_features_example.py>`_
25+
which abstracts away the complexities of tensor communication in distributed training,
26+
providing a seamless user experience. However, when dealing with existing parallelism solutions and
27+
developing parallelism solutions using the unified abstraction like DTensor, the lack of transparency
28+
about what and when the collective communications happens under the hood could make it challenging
29+
for advanced users to identify and resolve issues. To address this challenge, ``CommDebugMode``, a
30+
Python context manager will serve as one of the primary debugging tools for DTensors, enabling
31+
users to view when and why collective operations are happening when using DTensors, effectively
32+
addressing this issue.
33+
34+
35+
How to use CommDebugMode
36+
------------------------
37+
38+
Here is how you can use ``CommDebugMode``:
39+
40+
.. code-block:: python
41+
42+
# The model used in this example is a MLPModule applying Tensor Parallel
43+
comm_mode = CommDebugMode()
44+
with comm_mode:
45+
output = model(inp)
46+
47+
# print the operation level collective tracing information
48+
print(comm_mode.generate_comm_debug_tracing_table(noise_level=0))
49+
50+
# log the operation level collective tracing information to a file
51+
comm_mode.log_comm_debug_tracing_table_to_file(
52+
noise_level=1, file_name="transformer_operation_log.txt"
53+
)
54+
55+
# dump the operation level collective tracing information to json file,
56+
# used in the visual browser below
57+
comm_mode.generate_json_dump(noise_level=2)
58+
59+
.. code-block:: python
60+
61+
"""
62+
This is what the output looks like for a MLPModule at noise level 0
63+
Expected Output:
64+
Global
65+
FORWARD PASS
66+
*c10d_functional.all_reduce: 1
67+
MLPModule
68+
FORWARD PASS
69+
*c10d_functional.all_reduce: 1
70+
MLPModule.net1
71+
MLPModule.relu
72+
MLPModule.net2
73+
FORWARD PASS
74+
*c10d_functional.all_reduce: 1
75+
"""
76+
77+
To use ``CommDebugMode``, you must wrap the code running the model in ``CommDebugMode`` and call the API that
78+
you want to use to display the data. You can also use a ``noise_level`` argument to control the verbosity
79+
level of displayed information. Here is what each noise level displays:
80+
81+
| 0. Prints module-level collective counts
82+
| 1. Prints DTensor operations (not including trivial operations), module sharding information
83+
| 2. Prints tensor operations (not including trivial operations)
84+
| 3. Prints all operations
85+
86+
In the example above, you can see that the collective operation, all_reduce, occurs once in the forward pass
87+
of the ``MLPModule``. Furthermore, you can use ``CommDebugMode`` to pinpoint that the all-reduce operation happens
88+
in the second linear layer of the ``MLPModule``.
89+
90+
91+
Below is the interactive module tree visualization that you can use to upload your own JSON dump:
92+
93+
.. raw:: html
94+
95+
<!DOCTYPE html>
96+
<html lang ="en">
97+
<head>
98+
<meta charset="UTF-8">
99+
<meta name = "viewport" content="width=device-width, initial-scale=1.0">
100+
<title>CommDebugMode Module Tree</title>
101+
<style>
102+
ul, #tree-container {
103+
list-style-type: none;
104+
margin: 0;
105+
padding: 0;
106+
}
107+
.caret {
108+
cursor: pointer;
109+
user-select: none;
110+
}
111+
.caret::before {
112+
content: "\25B6";
113+
color:black;
114+
display: inline-block;
115+
margin-right: 6px;
116+
}
117+
.caret-down::before {
118+
transform: rotate(90deg);
119+
}
120+
.tree {
121+
padding-left: 20px;
122+
}
123+
.tree ul {
124+
padding-left: 20px;
125+
}
126+
.nested {
127+
display: none;
128+
}
129+
.active {
130+
display: block;
131+
}
132+
.forward-pass,
133+
.backward-pass {
134+
margin-left: 40px;
135+
}
136+
.forward-pass table {
137+
margin-left: 40px;
138+
width: auto;
139+
}
140+
.forward-pass table td, .forward-pass table th {
141+
padding: 8px;
142+
}
143+
.forward-pass ul {
144+
display: none;
145+
}
146+
table {
147+
font-family: arial, sans-serif;
148+
border-collapse: collapse;
149+
width: 100%;
150+
}
151+
td, th {
152+
border: 1px solid #dddddd;
153+
text-align: left;
154+
padding: 8px;
155+
}
156+
tr:nth-child(even) {
157+
background-color: #dddddd;
158+
}
159+
#drop-area {
160+
position: relative;
161+
width: 25%;
162+
height: 100px;
163+
border: 2px dashed #ccc;
164+
border-radius: 5px;
165+
padding: 0px;
166+
text-align: center;
167+
}
168+
.drag-drop-block {
169+
display: inline-block;
170+
width: 200px;
171+
height: 50px;
172+
background-color: #f7f7f7;
173+
border: 1px solid #ccc;
174+
border-radius: 5px;
175+
padding: 10px;
176+
font-size: 14px;
177+
color: #666;
178+
cursor: pointer;
179+
}
180+
#file-input {
181+
position: absolute;
182+
top: 0;
183+
left: 0;
184+
width: 100%;
185+
height: 100%;
186+
opacity: 0;
187+
}
188+
</style>
189+
</head>
190+
<body>
191+
<div id="drop-area">
192+
<div class="drag-drop-block">
193+
<span>Drag file here</span>
194+
</div>
195+
<input type="file" id="file-input" accept=".json">
196+
</div>
197+
<div id="tree-container"></div>
198+
<script src="https://cdn.jsdelivr.net/gh/pytorch/pytorch@main/torch/distributed/_tensor/debug/comm_mode_broswer_visual.js"></script>
199+
</body>
200+
</html>
201+
202+
Conclusion
203+
------------------------------------------
204+
205+
In this recipe, we have learned how to use ``CommDebugMode`` to debug Distributed Tensors and
206+
parallelism solutions that uses communication collectives with PyTorch. You can use your own
207+
JSON outputs in the embedded visual browser.
208+
209+
For more detailed information about ``CommDebugMode``, see
210+
`comm_mode_features_example.py
211+
<https://github.com/pytorch/pytorch/blob/main/torch/distributed/_tensor/examples/comm_mode_features_example.py>`_

recipes_source/recipes_index.rst

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -395,6 +395,13 @@ Recipes are bite-sized, actionable examples of how to use specific PyTorch featu
395395
:link: ../recipes/distributed_async_checkpoint_recipe.html
396396
:tags: Distributed-Training
397397

398+
.. customcarditem::
399+
:header: Getting Started with CommDebugMode
400+
:card_description: Learn how to use CommDebugMode for DTensors
401+
:image: ../_static/img/thumbnails/cropped/generic-pytorch-logo.png
402+
:link: ../recipes/distributed_comm_debug_mode.html
403+
:tags: Distributed-Training
404+
398405
.. TorchServe
399406
400407
.. customcarditem::
@@ -449,3 +456,4 @@ Recipes are bite-sized, actionable examples of how to use specific PyTorch featu
449456
/recipes/cuda_rpc
450457
/recipes/distributed_optim_torchscript
451458
/recipes/mobile_interpreter
459+
/recipes/distributed_comm_debug_mode

0 commit comments

Comments
 (0)