Skip to content

Update PG backend extension tutorial for 1.13 release #2099

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Oct 26, 2022
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 32 additions & 32 deletions intermediate_source/process_group_cpp_extension_tutorial.rst
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
Customize Process Group Backends Using Cpp Extensions
=====================================================

**Author**: `Feng Tian <https://github.com/ftian1>`__, `Shen Li <https://mrshenli.github.io/>`__
**Author**: `Feng Tian <https://github.com/ftian1>`__, `Shen Li <https://mrshenli.github.io/>`__, `Min Si <https://minsii.github.io/>`__

.. note::
|edit| View and edit this tutorial in `github <https://github.com/pytorch/tutorials/blob/master/intermediate_source/process_group_cpp_extension_tutorial.rst>`__.
Expand Down Expand Up @@ -62,7 +62,7 @@ Step 1: Implement a Subclass of ``ProcessGroup``

This first step is to implement a ``ProcessGroup`` subclass that overrides
target collective communication APIs and runs the custom communication algorithm.
The extension also needs to implement a ``ProcessGroup::Work`` subclass, which
The extension also needs to implement a ``Work`` subclass, which
serves as a future of communication results and allows asynchronous execution in
application code. If the extension uses third-party libraries, it can
include the headers and call into the library APIs from the ``ProcessGroupDummy``
Expand All @@ -75,49 +75,49 @@ repository for the full implementation.
// file name: dummy.hpp
#include <torch/python.h>

#include <c10d/ProcessGroup.hpp>
#include <c10d/Store.hpp>
#include <c10d/Types.hpp>
#include <c10d/Utils.hpp>
#include <torch/csrc/distributed/c10d/ProcessGroup.hpp>
#include <torch/csrc/distributed/c10d/Work.hpp>
#include <torch/csrc/distributed/c10d/Store.hpp>
#include <torch/csrc/distributed/c10d/Types.hpp>
#include <torch/csrc/distributed/c10d/Utils.hpp>

#include <pybind11/chrono.h>

namespace c10d {

class ProcessGroupDummy : public ProcessGroup {
public:

class WorkDummy : public ProcessGroup::Work {
public:
WorkDummy(
OpType opType,
c10::intrusive_ptr<c10::ivalue::Future> future) // future of the output
: ProcessGroup::Work(
-1, // rank, only used by recvAnySource, irrelevant in this demo
opType),
future_(std::move(future)) {}
// There are several additional helper functions that need to be
// implemented. Please refer to https://github.com/mrshenli/dummy_collectives
// for the full implementation.

private:
c10::intrusive_ptr<c10::ivalue::Future> future_;
};

ProcessGroupDummy(int rank, int size);

c10::intrusive_ptr<ProcessGroup::Work> allgather(
c10::intrusive_ptr<Work> allgather(
std::vector<std::vector<at::Tensor>>& outputTensors,
std::vector<at::Tensor>& inputTensors,
const AllgatherOptions& opts = AllgatherOptions()) override;

c10::intrusive_ptr<ProcessGroup::Work> allreduce(
c10::intrusive_ptr<Work> allreduce(
std::vector<at::Tensor>& tensors,
const AllreduceOptions& opts = AllreduceOptions()) override;

// The collective communication APIs without a custom implementation
// will error out if invoked by application code.
};

class WorkDummy : public Work {
public:
WorkDummy(
OpType opType,
c10::intrusive_ptr<c10::ivalue::Future> future) // future of the output
: Work(
-1, // rank, only used by recvAnySource, irrelevant in this demo
opType),
future_(std::move(future)) {}
// There are several additional helper functions that need to be
// implemented. Please refer to https://github.com/mrshenli/dummy_collectives
// for the full implementation.

private:
c10::intrusive_ptr<c10::ivalue::Future> future_;
};
} // namespace c10d


Expand All @@ -130,7 +130,7 @@ repository for the full implementation.

// This is a dummy allgather that sets all output tensors to zero
// Modify the implementation to conduct real communication asynchronously
c10::intrusive_ptr<ProcessGroup::Work> ProcessGroupDummy::allgather(
c10::intrusive_ptr<Work> ProcessGroupDummy::allgather(
std::vector<std::vector<at::Tensor>>& outputTensors,
std::vector<at::Tensor>& inputTensors,
const AllgatherOptions& /* unused */) {
Expand All @@ -148,7 +148,7 @@ repository for the full implementation.

// This is a dummy allreduce that sets all output tensors to zero
// Modify the implementation to conduct real communication asynchronously
c10::intrusive_ptr<ProcessGroup::Work> ProcessGroupDummy::allreduce(
c10::intrusive_ptr<Work> ProcessGroupDummy::allreduce(
std::vector<at::Tensor>& tensors,
const AllreduceOptions& opts) {
for (auto& tensor : tensors) {
Expand Down Expand Up @@ -278,11 +278,11 @@ as if it is an builtin backend.

x = torch.ones(6)
dist.all_reduce(x)
y = x.cuda()
dist.all_reduce(y)

print(f"cpu allreduce: {x}")
print(f"cuda allreduce: {y}")
if torch.cuda.is_available():
y = x.cuda()
dist.all_reduce(y)
print(f"cuda allreduce: {y}")

try:
dist.broadcast(x, 0)
Expand Down