Skip to content

Commit 3549f56

Browse files
author
Svetlana Karslioglu
authored
Merge branch 'master' into tmm
2 parents 1c11444 + 0863302 commit 3549f56

File tree

1 file changed

+32
-32
lines changed

1 file changed

+32
-32
lines changed

intermediate_source/process_group_cpp_extension_tutorial.rst

Lines changed: 32 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
Customize Process Group Backends Using Cpp Extensions
22
=====================================================
33

4-
**Author**: `Feng Tian <https://github.com/ftian1>`__, `Shen Li <https://mrshenli.github.io/>`__
4+
**Author**: `Feng Tian <https://github.com/ftian1>`__, `Shen Li <https://mrshenli.github.io/>`__, `Min Si <https://minsii.github.io/>`__
55

66
.. note::
77
|edit| View and edit this tutorial in `github <https://github.com/pytorch/tutorials/blob/master/intermediate_source/process_group_cpp_extension_tutorial.rst>`__.
@@ -62,7 +62,7 @@ Step 1: Implement a Subclass of ``ProcessGroup``
6262

6363
This first step is to implement a ``ProcessGroup`` subclass that overrides
6464
target collective communication APIs and runs the custom communication algorithm.
65-
The extension also needs to implement a ``ProcessGroup::Work`` subclass, which
65+
The extension also needs to implement a ``Work`` subclass, which
6666
serves as a future of communication results and allows asynchronous execution in
6767
application code. If the extension uses third-party libraries, it can
6868
include the headers and call into the library APIs from the ``ProcessGroupDummy``
@@ -75,49 +75,49 @@ repository for the full implementation.
7575
// file name: dummy.hpp
7676
#include <torch/python.h>
7777
78-
#include <c10d/ProcessGroup.hpp>
79-
#include <c10d/Store.hpp>
80-
#include <c10d/Types.hpp>
81-
#include <c10d/Utils.hpp>
78+
#include <torch/csrc/distributed/c10d/ProcessGroup.hpp>
79+
#include <torch/csrc/distributed/c10d/Work.hpp>
80+
#include <torch/csrc/distributed/c10d/Store.hpp>
81+
#include <torch/csrc/distributed/c10d/Types.hpp>
82+
#include <torch/csrc/distributed/c10d/Utils.hpp>
8283
8384
#include <pybind11/chrono.h>
8485
8586
namespace c10d {
8687
8788
class ProcessGroupDummy : public ProcessGroup {
8889
public:
89-
90-
class WorkDummy : public ProcessGroup::Work {
91-
public:
92-
WorkDummy(
93-
OpType opType,
94-
c10::intrusive_ptr<c10::ivalue::Future> future) // future of the output
95-
: ProcessGroup::Work(
96-
-1, // rank, only used by recvAnySource, irrelevant in this demo
97-
opType),
98-
future_(std::move(future)) {}
99-
// There are several additional helper functions that need to be
100-
// implemented. Please refer to https://github.com/mrshenli/dummy_collectives
101-
// for the full implementation.
102-
103-
private:
104-
c10::intrusive_ptr<c10::ivalue::Future> future_;
105-
};
106-
10790
ProcessGroupDummy(int rank, int size);
10891
109-
c10::intrusive_ptr<ProcessGroup::Work> allgather(
92+
c10::intrusive_ptr<Work> allgather(
11093
std::vector<std::vector<at::Tensor>>& outputTensors,
11194
std::vector<at::Tensor>& inputTensors,
11295
const AllgatherOptions& opts = AllgatherOptions()) override;
11396
114-
c10::intrusive_ptr<ProcessGroup::Work> allreduce(
97+
c10::intrusive_ptr<Work> allreduce(
11598
std::vector<at::Tensor>& tensors,
11699
const AllreduceOptions& opts = AllreduceOptions()) override;
117100
118101
// The collective communication APIs without a custom implementation
119102
// will error out if invoked by application code.
120103
};
104+
105+
class WorkDummy : public Work {
106+
public:
107+
WorkDummy(
108+
OpType opType,
109+
c10::intrusive_ptr<c10::ivalue::Future> future) // future of the output
110+
: Work(
111+
-1, // rank, only used by recvAnySource, irrelevant in this demo
112+
opType),
113+
future_(std::move(future)) {}
114+
// There are several additional helper functions that need to be
115+
// implemented. Please refer to https://github.com/mrshenli/dummy_collectives
116+
// for the full implementation.
117+
118+
private:
119+
c10::intrusive_ptr<c10::ivalue::Future> future_;
120+
};
121121
} // namespace c10d
122122
123123
@@ -130,7 +130,7 @@ repository for the full implementation.
130130
131131
// This is a dummy allgather that sets all output tensors to zero
132132
// Modify the implementation to conduct real communication asynchronously
133-
c10::intrusive_ptr<ProcessGroup::Work> ProcessGroupDummy::allgather(
133+
c10::intrusive_ptr<Work> ProcessGroupDummy::allgather(
134134
std::vector<std::vector<at::Tensor>>& outputTensors,
135135
std::vector<at::Tensor>& inputTensors,
136136
const AllgatherOptions& /* unused */) {
@@ -148,7 +148,7 @@ repository for the full implementation.
148148
149149
// This is a dummy allreduce that sets all output tensors to zero
150150
// Modify the implementation to conduct real communication asynchronously
151-
c10::intrusive_ptr<ProcessGroup::Work> ProcessGroupDummy::allreduce(
151+
c10::intrusive_ptr<Work> ProcessGroupDummy::allreduce(
152152
std::vector<at::Tensor>& tensors,
153153
const AllreduceOptions& opts) {
154154
for (auto& tensor : tensors) {
@@ -278,11 +278,11 @@ as if it is an builtin backend.
278278
279279
x = torch.ones(6)
280280
dist.all_reduce(x)
281-
y = x.cuda()
282-
dist.all_reduce(y)
283-
284281
print(f"cpu allreduce: {x}")
285-
print(f"cuda allreduce: {y}")
282+
if torch.cuda.is_available():
283+
y = x.cuda()
284+
dist.all_reduce(y)
285+
print(f"cuda allreduce: {y}")
286286
287287
try:
288288
dist.broadcast(x, 0)

0 commit comments

Comments
 (0)