You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Copy file name to clipboardExpand all lines: distributed_comm_debug_mode.rst
+28-11Lines changed: 28 additions & 11 deletions
Original file line number
Diff line number
Diff line change
@@ -11,7 +11,7 @@ Prerequisites:
11
11
12
12
13
13
What is CommDebugMode and why is it useful
14
-
------------------
14
+
------------------------------------------
15
15
As the size of models continues to increase, users are seeking to leverage various combinations of parallel strategies to scale up distributed training. However, the lack of interoperability between existing solutions poses a significant challenge, primarily due to the absence of a unified abstraction that can bridge these different parallelism strategies. To address this issue, PyTorch has proposed DistributedTensor (DTensor)which abstracts away the complexities of tensor communication in distributed training, providing a seamless user experience. However, this abstraction creates a lack of transparency that can make it challenging for users to identify and resolve issues. To address this challenge, my internship project aims to develop and enhance CommDebugMode, a Python context manager that will serve as one of the primary debugging tools for DTensors. CommDebugMode is a python context manager that enables users to view when and why collective operations are happening when using DTensors, addressing this problem.
16
16
17
17
@@ -26,7 +26,7 @@ Using CommDebugMode and getting its output is very simple.
26
26
output = model(inp)
27
27
28
28
# print the operation level collective tracing information
# log the operation level collective tracing information to a file
32
32
comm_mode.log_comm_debug_tracing_table_to_file(
@@ -36,16 +36,35 @@ Using CommDebugMode and getting its output is very simple.
36
36
# dump the operation level collective tracing information to json file,
37
37
# used in the visual browser below
38
38
comm_mode.generate_json_dump(noise_level=2)
39
-
.. code-block:: python
40
39
41
-
All users have to do is wrap the code running the model in CommDebugMode and call the API that they want to use to display the data.
42
-
Documentation Title
43
-
===================
40
+
"""
41
+
This is what the output looks like for a MLPModule at noise level 0
42
+
Expected Output:
43
+
Global
44
+
FORWARD PASS
45
+
*c10d_functional.all_reduce: 1
46
+
MLPModule
47
+
FORWARD PASS
48
+
*c10d_functional.all_reduce: 1
49
+
MLPModule.net1
50
+
MLPModule.relu
51
+
MLPModule.net2
52
+
FORWARD PASS
53
+
*c10d_functional.all_reduce: 1
54
+
"""
55
+
56
+
All users have to do is wrap the code running the model in CommDebugMode and call the API that they want to use to display the data. One important thing to note
57
+
is that the users can use a noise_level arguement to control how much information is displayed to the user. You can see what each noise_level will display to the user.
58
+
59
+
|0. prints module-level collective counts
60
+
|1. prints dTensor operations not included in trivial operations, module information
61
+
|2. prints operations not included in trivial operations
62
+
|3. prints all operations
44
63
45
-
Introduction to the Module
46
-
--------------------------
64
+
In the example above, users can see in the first picture that the collective operation, all_reduce, occurs once in the forward pass of the MLPModule. The second picture provides a greater level of detail, allowing users to pinpoint that the all-reduce operation happens in the second linear layer of the MLPModule.
47
65
48
-
Below is the interactive module tree visualization:
66
+
67
+
Below is the interactive module tree visualization that users can upload their JSON dump to:
49
68
50
69
.. raw:: html
51
70
@@ -155,5 +174,3 @@ Below is the interactive module tree visualization:
0 commit comments