Skip to content

Commit f0a0176

Browse files
authored
Merge branch 'master' into patch-1
2 parents 8b9b05a + cba6b85 commit f0a0176

File tree

8 files changed

+686
-2
lines changed

8 files changed

+686
-2
lines changed
14.5 KB
Loading
41.9 KB
Loading

advanced_source/dispatcher.rst

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,8 @@ speaking, the structure of your registrations will look like this:
105105
that provides implementations for all basic operators on the XLA dispatch
106106
key.
107107

108+
.. _autograd-support:
109+
108110
Adding autograd support
109111
-----------------------
110112

@@ -299,6 +301,28 @@ the safest choice for the execution type:
299301
at::autocast::cached_cast(exec_type, t1));
300302
}
301303
304+
If your custom op is :ref:`autograd-enabled<autograd-support>`, you only need to write and register
305+
an autocast wrapper for the same name onto which the autograd wrapper is registered.
306+
For example, if you wanted an autocast wrapper for the ``myadd`` function shown
307+
in the autograd section, all you'd need is
308+
309+
.. code-block:: cpp
310+
311+
Tensor myadd_autocast(const Tensor& self, const Tensor& other) {
312+
c10::impl::ExcludeDispatchKeyGuard no_autocast(c10::DispatchKey::Autocast);
313+
return myadd(at::autocast::cached_cast(<desired dtype>, self),
314+
at::autocast::cached_cast(<desired dtype>, other));
315+
}
316+
317+
TORCH_LIBRARY_IMPL(myops, Autocast, m) {
318+
m.impl("myadd", myadd_autocast);
319+
}
320+
321+
There are no separate gymnastics to make the backward method autocast compatible.
322+
However, the backward method defined in your custom autograd function will run in the same
323+
dtype as autocast sets for the forward method, so you should choose a ``<desired dtype>``
324+
suitable for both your forward and backward methods.
325+
302326
Batched
303327
^^^^^^^
304328

beginner_source/transformer_tutorial.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -77,9 +77,9 @@ def init_weights(self):
7777
self.decoder.weight.data.uniform_(-initrange, initrange)
7878

7979
def forward(self, src):
80-
if self.src_mask is None or self.src_mask.size(0) != len(src):
80+
if self.src_mask is None or self.src_mask.size(0) != src.size(0):
8181
device = src.device
82-
mask = self._generate_square_subsequent_mask(len(src)).to(device)
82+
mask = self._generate_square_subsequent_mask(src.size(0)).to(device)
8383
self.src_mask = mask
8484

8585
src = self.encoder(src) * math.sqrt(self.ninp)
Lines changed: 314 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,314 @@
1+
Profiling PyTorch RPC-Based Workloads
2+
======================================
3+
4+
In this recipe, you will learn:
5+
6+
- An overview of the `Distributed RPC Framework`_
7+
- An overview of the `PyTorch Profiler`_
8+
- How to use the profiler to profile RPC-based workloads
9+
10+
Requirements
11+
------------
12+
13+
- PyTorch 1.6
14+
15+
The instructions for installing PyTorch are
16+
available at `pytorch.org`_.
17+
18+
What is the Distributed RPC Framework?
19+
---------------------------------------
20+
21+
The **Distributed RPC Framework** provides mechanisms for multi-machine model
22+
training through a set of primitives to allow for remote communication, and a
23+
higher-level API to automatically differentiate models split across several machines.
24+
For this recipe, it would be helpful to be familiar with the `Distributed RPC Framework`_
25+
as well as the `RPC Tutorials`_.
26+
27+
What is the PyTorch Profiler?
28+
---------------------------------------
29+
The profiler is a context manager based API that allows for on-demand profiling of
30+
operators in a model's workload. The profiler can be used to analyze various aspects
31+
of a model including execution time, operators invoked, and memory consumption. For a
32+
detailed tutorial on using the profiler to profile a single-node model, please see the
33+
`Profiler Recipe`_.
34+
35+
36+
37+
How to use the Profiler for RPC-based workloads
38+
-----------------------------------------------
39+
40+
The profiler supports profiling of calls made of RPC and allows the user to have a
41+
detailed view into the operations that take place on different nodes. To demonstrate an
42+
example of this, let's first set up the RPC framework. The below code snippet will initialize
43+
two RPC workers on the same host, named ``worker0`` and ``worker1`` respectively. The workers will
44+
be spawned as subprocesses, and we set some environment variables required for proper
45+
initialization.
46+
47+
::
48+
49+
import torch
50+
import torch.distributed.rpc as rpc
51+
import torch.autograd.profiler as profiler
52+
import torch.multiprocessing as mp
53+
import os
54+
import logging
55+
import sys
56+
57+
logging.basicConfig(stream=sys.stdout, level=logging.DEBUG)
58+
logger = logging.getLogger()
59+
60+
def random_tensor():
61+
return torch.rand((3, 3), requires_grad=True)
62+
63+
64+
def worker(rank, world_size):
65+
os.environ["MASTER_ADDR"] = "localhost"
66+
os.environ["MASTER_PORT"] = "29500"
67+
worker_name = f"worker{rank}"
68+
69+
# Initialize RPC framework.
70+
rpc.init_rpc(
71+
name=worker_name,
72+
rank=rank,
73+
world_size=world_size
74+
)
75+
logger.debug(f"{worker_name} successfully initialized RPC.")
76+
77+
pass # to be continued below
78+
79+
logger.debug(f"Rank {rank} waiting for workers and shutting down RPC")
80+
rpc.shutdown()
81+
logger.debug(f"Rank {rank} shutdown RPC")
82+
83+
84+
if __name__ == '__main__':
85+
# Run 2 RPC workers.
86+
world_size = 2
87+
mp.spawn(worker, args=(world_size,), nprocs=world_size)
88+
89+
Running the above program should present you with the following output:
90+
91+
::
92+
93+
DEBUG:root:worker1 successfully initialized RPC.
94+
DEBUG:root:worker0 successfully initialized RPC.
95+
DEBUG:root:Rank 0 waiting for workers and shutting down RPC
96+
DEBUG:root:Rank 1 waiting for workers and shutting down RPC
97+
DEBUG:root:Rank 1 shutdown RPC
98+
DEBUG:root:Rank 0 shutdown RPC
99+
100+
Now that we have a skeleton setup of our RPC framework, we can move on to
101+
sending RPCs back and forth and using the profiler to obtain a view of what's
102+
happening under the hood. Let's add to the above ``worker`` function:
103+
104+
::
105+
106+
def worker(rank, world_size):
107+
# Above code omitted...
108+
if rank == 0:
109+
dst_worker_rank = (rank + 1) % world_size
110+
dst_worker_name = f"worker{dst_worker_rank}"
111+
t1, t2 = random_tensor(), random_tensor()
112+
# Send and wait RPC completion under profiling scope.
113+
with profiler.profile() as prof:
114+
fut1 = rpc.rpc_async(dst_worker_name, torch.add, args=(t1, t2))
115+
fut2 = rpc.rpc_async(dst_worker_name, torch.mul, args=(t1, t2))
116+
# RPCs must be awaited within profiling scope.
117+
fut1.wait()
118+
fut2.wait()
119+
120+
print(prof.key_averages().table())
121+
122+
The aformentioned code creates 2 RPCs, specifying ``torch.add`` and ``torch.mul``, respectively,
123+
to be run with two random input tensors on worker 1. Since we use the ``rpc_async`` API,
124+
we are returned a ``torch.futures.Future`` object, which must be awaited for the result
125+
of the computation. Note that this wait must take place within the scope created by
126+
the profiling context manager in order for the RPC to be accurately profiled. Running
127+
the code with this new worker function should result in the following output:
128+
129+
::
130+
131+
# Some columns are omitted for brevity, exact output subject to randomness
132+
---------------------------------------------------------------- --------------- --------------- --------------- --------------- --------------- --------------- ---------------
133+
Name Self CPU total % Self CPU total CPU total % CPU total CPU time avg Number of Calls Node ID
134+
---------------------------------------------------------------- --------------- --------------- --------------- --------------- --------------- --------------- ---------------
135+
rpc_async#aten::add(worker0 -> worker1) 0.00% 0.000us 0 20.462ms 20.462ms 1 0
136+
rpc_async#aten::mul(worker0 -> worker1) 0.00% 0.000us 0 5.712ms 5.712ms 1 0
137+
rpc_async#aten::mul(worker0 -> worker1)#remote_op: mul 1.84% 206.864us 2.69% 302.162us 151.081us 2 1
138+
rpc_async#aten::add(worker0 -> worker1)#remote_op: add 1.41% 158.501us 1.57% 176.924us 176.924us 1 1
139+
rpc_async#aten::mul(worker0 -> worker1)#remote_op: output_nr 0.04% 4.980us 0.04% 4.980us 2.490us 2 1
140+
rpc_async#aten::mul(worker0 -> worker1)#remote_op: is_leaf 0.07% 7.806us 0.07% 7.806us 1.952us 4 1
141+
rpc_async#aten::add(worker0 -> worker1)#remote_op: empty 0.16% 18.423us 0.16% 18.423us 18.423us 1 1
142+
rpc_async#aten::mul(worker0 -> worker1)#remote_op: empty 0.14% 15.712us 0.14% 15.712us 15.712us 1 1
143+
---------------------------------------------------------------- --------------- --------------- --------------- --------------- --------------- --------------- ---------------
144+
Self CPU time total: 11.237ms
145+
146+
Here we can see that the profiler has profiled our ``rpc_async`` calls made to ``worker1``
147+
from ``worker0``. In particular, the first 2 entries in the table show details (such as
148+
the operator name, originating worker, and destination worker) about each RPC call made
149+
and the ``CPU total`` column indicates the end-to-end latency of the RPC call.
150+
151+
We also have visibility into the actual operators invoked remotely on worker 1 due RPC.
152+
We can see operations that took place on ``worker1`` by checking the ``Node ID`` column. For
153+
example, we can interpret the row with name ``rpc_async#aten::mul(worker0 -> worker1)#remote_op: mul``
154+
as a ``mul`` operation taking place on the remote node, as a result of the RPC sent to ``worker1``
155+
from ``worker0``, specifying ``worker1`` to run the builtin ``mul`` operator on the input tensors.
156+
Note that names of remote operations are prefixed with the name of the RPC event that resulted
157+
in them. For example, remote operations corresponding to the ``rpc.rpc_async(dst_worker_name, torch.add, args=(t1, t2))``
158+
call are prefixed with ``rpc_async#aten::mul(worker0 -> worker1)``.
159+
160+
We can also use the profiler to gain insight into user-defined functions that are executed over RPC.
161+
For example, let's add the following to the above ``worker`` function:
162+
163+
::
164+
165+
# Define somewhere outside of worker() func.
166+
def udf_with_ops():
167+
import time
168+
time.sleep(1)
169+
t1, t2 = random_tensor(), random_tensor()
170+
torch.add(t1, t2)
171+
torch.mul(t1, t2)
172+
173+
def worker(rank, world_size):
174+
# Above code omitted
175+
with profiler.profile() as p:
176+
fut = rpc.rpc_async(dst_worker_name, udf_with_ops)
177+
fut.wait()
178+
print(p.key_averages().table())
179+
180+
The above code creates a user-defined function that sleeps for 1 second, and then executes various
181+
operators. Similar to what we've done above, we send an RPC to the remote worker, specifying it to
182+
run our user-defined function. Running this code should result in the following output:
183+
184+
::
185+
186+
# Exact output subject to randomness
187+
-------------------------------------------------------------------- --------------- --------------- --------------- --------------- --------------- --------------- ---------------
188+
Name Self CPU total % Self CPU total CPU total % CPU total CPU time avg Number of Calls Node ID
189+
-------------------------------------------------------------------- --------------- --------------- --------------- --------------- --------------- --------------- ---------------
190+
rpc_async#udf_with_ops(worker0 -> worker1) 0.00% 0.000us 0 1.008s 1.008s 1 0
191+
rpc_async#udf_with_ops(worker0 -> worker1)#remote_op: rand 12.58% 80.037us 47.09% 299.589us 149.795us 2 1
192+
rpc_async#udf_with_ops(worker0 -> worker1)#remote_op: empty 15.40% 98.013us 15.40% 98.013us 24.503us 4 1
193+
rpc_async#udf_with_ops(worker0 -> worker1)#remote_op: uniform_ 22.85% 145.358us 23.87% 151.870us 75.935us 2 1
194+
rpc_async#udf_with_ops(worker0 -> worker1)#remote_op: is_complex 1.02% 6.512us 1.02% 6.512us 3.256us 2 1
195+
rpc_async#udf_with_ops(worker0 -> worker1)#remote_op: add 25.80% 164.179us 28.43% 180.867us 180.867us 1 1
196+
rpc_async#udf_with_ops(worker0 -> worker1)#remote_op: mul 20.48% 130.293us 31.43% 199.949us 99.975us 2 1
197+
rpc_async#udf_with_ops(worker0 -> worker1)#remote_op: output_nr 0.71% 4.506us 0.71% 4.506us 2.253us 2 1
198+
rpc_async#udf_with_ops(worker0 -> worker1)#remote_op: is_leaf 1.16% 7.367us 1.16% 7.367us 1.842us 4 1
199+
-------------------------------------------------------------------- --------------- --------------- --------------- --------------- --------------- --------------- ---------------
200+
201+
Here we can see that the user-defined function has successfully been profiled with its name
202+
``(rpc_async#udf_with_ops(worker0 -> worker1))``, and has the CPU total time we would roughly expect
203+
(slightly greater than 1s given the ``sleep``). Similar to the above profiling output, we can see the
204+
remote operators that have been executed on worker 1 as part of executing this RPC request.
205+
206+
Lastly, we can visualize remote execution using the tracing functionality provided by the profiler.
207+
Let's add the following code to the above ``worker`` function:
208+
209+
::
210+
211+
def worker(rank, world_size):
212+
# Above code omitted
213+
# Will generate trace for above profiling output
214+
trace_file = "/tmp/trace.json"
215+
prof.export_chrome_trace(trace_file)
216+
logger.debug(f"Wrote trace to {trace_file}")
217+
218+
Now, we can load the trace file in Chrome (``chrome://tracing``). We should see output similar to
219+
the following:
220+
221+
.. image:: ../_static/img/rpc_trace_img.png
222+
:scale: 25 %
223+
224+
As we can see, we have traced our RPC requests and can also visualize traces of the remote operations,
225+
in this case, given in the trace row for ``node_id: 1``.
226+
227+
Putting it all together, we have the following code for this recipe:
228+
229+
::
230+
231+
import torch
232+
import torch.distributed.rpc as rpc
233+
import torch.autograd.profiler as profiler
234+
import torch.multiprocessing as mp
235+
import os
236+
import logging
237+
import sys
238+
239+
logging.basicConfig(stream=sys.stdout, level=logging.DEBUG)
240+
logger = logging.getLogger()
241+
242+
def random_tensor():
243+
return torch.rand((3, 3), requires_grad=True)
244+
245+
def udf_with_ops():
246+
import time
247+
time.sleep(1)
248+
t1, t2 = random_tensor(), random_tensor()
249+
torch.add(t1, t2)
250+
torch.mul(t1, t2)
251+
252+
def worker(rank, world_size):
253+
os.environ["MASTER_ADDR"] = "localhost"
254+
os.environ["MASTER_PORT"] = "29500"
255+
worker_name = f"worker{rank}"
256+
257+
# Initialize RPC framework.
258+
rpc.init_rpc(
259+
name=worker_name,
260+
rank=rank,
261+
world_size=world_size
262+
)
263+
logger.debug(f"{worker_name} successfully initialized RPC.")
264+
265+
if rank == 0:
266+
dst_worker_rank = (rank + 1) % world_size
267+
dst_worker_name = f"worker{dst_worker_rank}"
268+
t1, t2 = random_tensor(), random_tensor()
269+
# Send and wait RPC completion under profiling scope.
270+
with profiler.profile() as prof:
271+
fut1 = rpc.rpc_async(dst_worker_name, torch.add, args=(t1, t2))
272+
fut2 = rpc.rpc_async(dst_worker_name, torch.mul, args=(t1, t2))
273+
# RPCs must be awaited within profiling scope.
274+
fut1.wait()
275+
fut2.wait()
276+
print(prof.key_averages().table())
277+
278+
with profiler.profile() as p:
279+
fut = rpc.rpc_async(dst_worker_name, udf_with_ops)
280+
fut.wait()
281+
282+
print(p.key_averages().table())
283+
284+
trace_file = "/tmp/trace.json"
285+
prof.export_chrome_trace(trace_file)
286+
logger.debug(f"Wrote trace to {trace_file}")
287+
288+
289+
logger.debug(f"Rank {rank} waiting for workers and shutting down RPC")
290+
rpc.shutdown()
291+
logger.debug(f"Rank {rank} shutdown RPC")
292+
293+
294+
295+
if __name__ == '__main__':
296+
# Run 2 RPC workers.
297+
world_size = 2
298+
mp.spawn(worker, args=(world_size,), nprocs=world_size)
299+
300+
301+
Learn More
302+
-------------------
303+
304+
- `pytorch.org`_ for installation instructions, and more documentation
305+
and tutorials.
306+
- `Distributed RPC Framework`_ for RPC framework and API reference.
307+
- `Full profiler documentation`_ for profiler documentation.
308+
309+
.. _pytorch.org: https://pytorch.org/
310+
.. _Full profiler documentation: https://pytorch.org/docs/stable/autograd.html#profiler
311+
.. _Pytorch Profiler: https://pytorch.org/docs/stable/autograd.html#profiler
312+
.. _Distributed RPC Framework: https://pytorch.org/docs/stable/rpc.html
313+
.. _RPC Tutorials: https://pytorch.org/tutorials/intermediate/rpc_tutorial.html
314+
.. _Profiler Recipe: https://pytorch.org/tutorials/recipes/recipes/profiler.html

recipes_source/recipes/README.txt

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,3 +56,7 @@ PyTorch Recipes
5656
14. mobile_perf.py
5757
PyTorch Mobile Performance Recipes
5858
https://pytorch.org/tutorials/recipes/mobile_perf.html
59+
60+
15. amp_recipe.py
61+
Automatic Mixed Precision
62+
https://pytorch.org/tutorials/recipes/amp_recipe.html

0 commit comments

Comments
 (0)