Skip to content

Commit af754cb

Browse files
mrshenlibrianjo
andauthored
Add a recipe for CUDA RPC (#1429)
Co-authored-by: Brian Johnson <brianjo@fb.com>
1 parent 7387edb commit af754cb

File tree

3 files changed

+157
-3
lines changed

3 files changed

+157
-3
lines changed

recipes_source/cuda_rpc.rst

Lines changed: 147 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,147 @@
1+
Direct Device-to-Device Communication with TensorPipe CUDA RPC
2+
==============================================================
3+
4+
.. note:: Direct device-to-device RPC (CUDA RPC) is introduced in PyTorch 1.8
5+
as a prototype feature. This API is subject to change.
6+
7+
In this recipe, you will learn:
8+
9+
- The high-level idea of CUDA RPC.
10+
- How to use CUDA RPC.
11+
12+
13+
Requirements
14+
------------
15+
16+
- PyTorch 1.8+
17+
- `Getting Started With Distributed RPC Framework <https://pytorch.org/tutorials/intermediate/rpc_tutorial.html>`_
18+
19+
20+
What is CUDA RPC?
21+
------------------------------------
22+
23+
CUDA RPC supports directly sending Tensors from local CUDA memory to remote
24+
CUDA memory. Prior to v1.8 release, PyTorch RPC only accepts CPU Tensors. As a
25+
result, when an application needs to send a CUDA Tensor through RPC, it has
26+
to first move the Tensor to CPU on the caller, send it via RPC, and then move
27+
it to the destination device on the callee, which incurs both unnecessary
28+
synchronizations and D2H and H2D copies. Since v1.8, RPC allows users to
29+
configure a per-process global device map using the
30+
`set_device_map <https://pytorch.org/docs/master/rpc.html#torch.distributed.rpc.TensorPipeRpcBackendOptions.set_device_map>`_
31+
API, specifying how to map local devices to remote devices. More specifically,
32+
if ``worker0``'s device map has an entry ``"worker1" : {"cuda:0" : "cuda:1"}``,
33+
all RPC arguments on ``"cuda:0"`` from ``worker0`` will be directly sent to
34+
``"cuda:1"`` on ``worker1``. The response of an RPC will use the inverse of
35+
the caller device map, i.e., if ``worker1`` returns a Tensor on ``"cuda:1"``,
36+
it will be directly sent to ``"cuda:0"`` on ``worker0``. All intended
37+
device-to-device direct communication must be specified in the per-process
38+
device map. Otherwise, only CPU tensors are allowed.
39+
40+
Under the hood, PyTorch RPC relies on `TensorPipe <https://github.com/pytorch/tensorpipe>`_
41+
as the communication backend. PyTorch RPC extracts all Tensors from each
42+
request or response into a list and packs everything else into a binary
43+
payload. Then, TensorPipe will automatically choose a communication channel
44+
for each Tensor based on Tensor device type and channel availability on both
45+
the caller and the callee. Existing TensorPipe channels cover NVLink, InfiniBand,
46+
SHM, CMA, TCP, etc.
47+
48+
How to use CUDA RPC?
49+
---------------------------------------
50+
51+
The code below shows how to use CUDA RPC. The model contains two linear layers
52+
and is split into two shards. The two shards are placed on ``worker0`` and
53+
``worker1`` respectively, and ``worker0`` serves as the master that drives the
54+
forward and backward passes. Note that we intentionally skipped
55+
`DistributedOptimizer <https://pytorch.org/docs/master/rpc.html#module-torch.distributed.optim>`_
56+
to highlight the performance improvements when using CUDA RPC. The experiment
57+
repeats the forward and backward passes 10 times and measures the total
58+
execution time. It compares using CUDA RPC against manually staging to CPU
59+
memory and using CPU RPC.
60+
61+
62+
::
63+
64+
import torch
65+
import torch.distributed.autograd as autograd
66+
import torch.distributed.rpc as rpc
67+
import torch.multiprocessing as mp
68+
import torch.nn as nn
69+
70+
import os
71+
import time
72+
73+
74+
class MyModule(nn.Module):
75+
def __init__(self, device, comm_mode):
76+
super().__init__()
77+
self.device = device
78+
self.linear = nn.Linear(1000, 1000).to(device)
79+
self.comm_mode = comm_mode
80+
81+
def forward(self, x):
82+
# x.to() is a no-op if x is already on self.device
83+
y = self.linear(x.to(self.device))
84+
return y.cpu() if self.comm_mode == "cpu" else y
85+
86+
def parameter_rrefs(self):
87+
return [rpc.RRef(p) for p in self.parameters()]
88+
89+
90+
def measure(comm_mode):
91+
# local module on "worker0/cuda:0"
92+
lm = MyModule("cuda:0", comm_mode)
93+
# remote module on "worker1/cuda:1"
94+
rm = rpc.remote("worker1", MyModule, args=("cuda:1", comm_mode))
95+
# prepare random inputs
96+
x = torch.randn(1000, 1000).cuda(0)
97+
98+
tik = time.time()
99+
for _ in range(10):
100+
with autograd.context() as ctx:
101+
y = rm.rpc_sync().forward(lm(x))
102+
autograd.backward(ctx, [y.sum()])
103+
# synchronize on "cuda:0" to make sure that all pending CUDA ops are
104+
# included in the measurements
105+
torch.cuda.current_stream("cuda:0").synchronize()
106+
tok = time.time()
107+
print(f"{comm_mode} RPC total execution time: {tok - tik}")
108+
109+
110+
def run_worker(rank):
111+
os.environ['MASTER_ADDR'] = 'localhost'
112+
os.environ['MASTER_PORT'] = '29500'
113+
options = rpc.TensorPipeRpcBackendOptions(num_worker_threads=128)
114+
115+
if rank == 0:
116+
options.set_device_map("worker1", {0: 1})
117+
rpc.init_rpc(
118+
f"worker{rank}",
119+
rank=rank,
120+
world_size=2,
121+
rpc_backend_options=options
122+
)
123+
measure(comm_mode="cpu")
124+
measure(comm_mode="cuda")
125+
else:
126+
rpc.init_rpc(
127+
f"worker{rank}",
128+
rank=rank,
129+
world_size=2,
130+
rpc_backend_options=options
131+
)
132+
133+
# block until all rpcs finish
134+
rpc.shutdown()
135+
136+
137+
if __name__=="__main__":
138+
world_size = 2
139+
mp.spawn(run_worker, nprocs=world_size, join=True)
140+
141+
Outputs are displayed below, which shows that CUDA RPC can help to achieve
142+
34X speed up compared to CPU RPC in this experiment.
143+
144+
::
145+
146+
cpu RPC total execution time: 2.3145179748535156 Seconds
147+
cuda RPC total execution time: 0.06867480278015137 Seconds

recipes_source/recipes_index.rst

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -234,6 +234,13 @@ Recipes are bite-sized, actionable examples of how to use specific PyTorch featu
234234
:link: ../recipes/zero_redundancy_optimizer.html
235235
:tags: Distributed-Training
236236

237+
.. customcarditem::
238+
:header: Direct Device-to-Device Communication with TensorPipe RPC
239+
:card_description: How to use RPC with direct GPU-to-GPU communication.
240+
:image: ../_static/img/thumbnails/cropped/profiler.png
241+
:link: ../recipes/cuda_rpc.html
242+
:tags: Distributed-Training
243+
237244
.. End of tutorial card section
238245
239246
.. raw:: html
@@ -271,3 +278,4 @@ Recipes are bite-sized, actionable examples of how to use specific PyTorch featu
271278
/recipes/deployment_with_flask
272279
/recipes/distributed_rpc_profiling
273280
/recipes/zero_redundancy_optimizer
281+
/recipes/cuda_rpc

recipes_source/zero_redundancy_optimizer.rst

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,8 @@
11
Shard Optimizer States with ZeroRedundancyOptimizer
22
===================================================
33

4-
.. note:
5-
`ZeroRedundancyOptimizer` is introduced in PyTorch 1.8 as a prototype
6-
feature. It API is subject to change.
4+
.. note:: `ZeroRedundancyOptimizer` is introduced in PyTorch 1.8 as a prototype
5+
feature. This API is subject to change.
76

87
In this recipe, you will learn:
98

0 commit comments

Comments
 (0)