diff --git a/intermediate_source/process_group_cpp_extension_tutorial.rst b/intermediate_source/process_group_cpp_extension_tutorial.rst index 49173e9a4cf..15eb23bf3ee 100644 --- a/intermediate_source/process_group_cpp_extension_tutorial.rst +++ b/intermediate_source/process_group_cpp_extension_tutorial.rst @@ -1,7 +1,7 @@ Customize Process Group Backends Using Cpp Extensions ===================================================== -**Author**: `Feng Tian `__, `Shen Li `__ +**Author**: `Feng Tian `__, `Shen Li `__, `Min Si `__ .. note:: |edit| View and edit this tutorial in `github `__. @@ -62,7 +62,7 @@ Step 1: Implement a Subclass of ``ProcessGroup`` This first step is to implement a ``ProcessGroup`` subclass that overrides target collective communication APIs and runs the custom communication algorithm. -The extension also needs to implement a ``ProcessGroup::Work`` subclass, which +The extension also needs to implement a ``Work`` subclass, which serves as a future of communication results and allows asynchronous execution in application code. If the extension uses third-party libraries, it can include the headers and call into the library APIs from the ``ProcessGroupDummy`` @@ -75,10 +75,11 @@ repository for the full implementation. // file name: dummy.hpp #include - #include - #include - #include - #include + #include + #include + #include + #include + #include #include @@ -86,38 +87,37 @@ repository for the full implementation. class ProcessGroupDummy : public ProcessGroup { public: - - class WorkDummy : public ProcessGroup::Work { - public: - WorkDummy( - OpType opType, - c10::intrusive_ptr future) // future of the output - : ProcessGroup::Work( - -1, // rank, only used by recvAnySource, irrelevant in this demo - opType), - future_(std::move(future)) {} - // There are several additional helper functions that need to be - // implemented. Please refer to https://github.com/mrshenli/dummy_collectives - // for the full implementation. - - private: - c10::intrusive_ptr future_; - }; - ProcessGroupDummy(int rank, int size); - c10::intrusive_ptr allgather( + c10::intrusive_ptr allgather( std::vector>& outputTensors, std::vector& inputTensors, const AllgatherOptions& opts = AllgatherOptions()) override; - c10::intrusive_ptr allreduce( + c10::intrusive_ptr allreduce( std::vector& tensors, const AllreduceOptions& opts = AllreduceOptions()) override; // The collective communication APIs without a custom implementation // will error out if invoked by application code. }; + + class WorkDummy : public Work { + public: + WorkDummy( + OpType opType, + c10::intrusive_ptr future) // future of the output + : Work( + -1, // rank, only used by recvAnySource, irrelevant in this demo + opType), + future_(std::move(future)) {} + // There are several additional helper functions that need to be + // implemented. Please refer to https://github.com/mrshenli/dummy_collectives + // for the full implementation. + + private: + c10::intrusive_ptr future_; + }; } // namespace c10d @@ -130,7 +130,7 @@ repository for the full implementation. // This is a dummy allgather that sets all output tensors to zero // Modify the implementation to conduct real communication asynchronously - c10::intrusive_ptr ProcessGroupDummy::allgather( + c10::intrusive_ptr ProcessGroupDummy::allgather( std::vector>& outputTensors, std::vector& inputTensors, const AllgatherOptions& /* unused */) { @@ -148,7 +148,7 @@ repository for the full implementation. // This is a dummy allreduce that sets all output tensors to zero // Modify the implementation to conduct real communication asynchronously - c10::intrusive_ptr ProcessGroupDummy::allreduce( + c10::intrusive_ptr ProcessGroupDummy::allreduce( std::vector& tensors, const AllreduceOptions& opts) { for (auto& tensor : tensors) { @@ -278,11 +278,11 @@ as if it is an builtin backend. x = torch.ones(6) dist.all_reduce(x) - y = x.cuda() - dist.all_reduce(y) - print(f"cpu allreduce: {x}") - print(f"cuda allreduce: {y}") + if torch.cuda.is_available(): + y = x.cuda() + dist.all_reduce(y) + print(f"cuda allreduce: {y}") try: dist.broadcast(x, 0)