Skip to content

Commit 9568cd2

Browse files
authored
Merge branch 'master' into 60mb_setup
2 parents f56befe + 5ea0ff6 commit 9568cd2

File tree

3 files changed

+326
-2
lines changed

3 files changed

+326
-2
lines changed

beginner_source/profiler.py

Lines changed: 316 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,316 @@
1+
"""
2+
Profiling your PyTorch Module
3+
------------
4+
**Author:** `Suraj Subramanian <https://github.com/suraj813>`_
5+
6+
PyTorch includes a profiler API that is useful to identify the time and
7+
memory costs of various PyTorch operations in your code. Profiler can be
8+
easily integrated in your code, and the results can be printed as a table
9+
or retured in a JSON trace file.
10+
11+
.. note::
12+
Profiler supports multithreaded models. Profiler runs in the
13+
same thread as the operation but it will also profile child operators
14+
that might run in another thread. Concurrently-running profilers will be
15+
scoped to their own thread to prevent mixing of results.
16+
17+
Head on over to `this
18+
recipe <https://pytorch.org/tutorials/recipes/recipes/profiler.html>`__
19+
for a quicker walkthrough of Profiler API usage.
20+
21+
22+
--------------
23+
"""
24+
25+
import torch
26+
import numpy as np
27+
from torch import nn
28+
import torch.autograd.profiler as profiler
29+
30+
31+
######################################################################
32+
# Performance debugging using Profiler
33+
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
34+
#
35+
# Profiler can be useful to identify performance bottlenecks in your
36+
# models. In this example, we build a custom module that performs two
37+
# sub-tasks:
38+
#
39+
# - a linear transformation on the input, and
40+
# - use the transformation result to get indices on a mask tensor.
41+
#
42+
# We wrap the code for each sub-task in separate labelled context managers using
43+
# ``profiler.record_function("label")``. In the profiler output, the
44+
# aggregate performance metrics of all operations in the sub-task will
45+
# show up under its corresponding label.
46+
#
47+
#
48+
# Note that using Profiler incurs some overhead, and is best used only for investigating
49+
# code. Remember to remove it if you are benchmarking runtimes.
50+
#
51+
52+
class MyModule(nn.Module):
53+
def __init__(self, in_features: int, out_features: int, bias: bool = True):
54+
super(MyModule, self).__init__()
55+
self.linear = nn.Linear(in_features, out_features, bias)
56+
57+
def forward(self, input, mask):
58+
with profiler.record_function("LINEAR PASS"):
59+
out = self.linear(input)
60+
61+
with profiler.record_function("MASK INDICES"):
62+
threshold = out.sum(axis=1).mean().item()
63+
hi_idx = np.argwhere(mask.cpu().numpy() > threshold)
64+
hi_idx = torch.from_numpy(hi_idx).cuda()
65+
66+
return out, hi_idx
67+
68+
69+
######################################################################
70+
# Profile the forward pass
71+
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
72+
#
73+
# We initialize random input and mask tensors, and the model.
74+
#
75+
# Before we run the profiler, we warm-up CUDA to ensure accurate
76+
# performance benchmarking. We wrap the forward pass of our module in the
77+
# ``profiler.profile`` context manager. The ``with_stack=True`` parameter appends the
78+
# file and line number of the operation in the trace.
79+
#
80+
# .. WARNING::
81+
# ``with_stack=True`` incurs an additional overhead, and is better suited for investigating code.
82+
# Remember to remove it if you are benchmarking performance.
83+
#
84+
85+
model = MyModule(500, 10).cuda()
86+
input = torch.rand(128, 500).cuda()
87+
mask = torch.rand((500, 500, 500), dtype=torch.double).cuda()
88+
89+
# warm-up
90+
model(input, mask)
91+
92+
with profiler.profile(with_stack=True, profile_memory=True) as prof:
93+
out, idx = model(input, mask)
94+
95+
96+
######################################################################
97+
# Print profiler results
98+
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
99+
#
100+
# Finally, we print the profiler results. ``profiler.key_averages``
101+
# aggregates the results by operator name, and optionally by input
102+
# shapes and/or stack trace events.
103+
# Grouping by input shapes is useful to identify which tensor shapes
104+
# are utilized by the model.
105+
#
106+
# Here, we use ``group_by_stack_n=5`` which aggregates runtimes by the
107+
# operation and its traceback (truncated to the most recent 5 events), and
108+
# display the events in the order they are registered. The table can also
109+
# be sorted by passing a ``sort_by`` argument (refer to the
110+
# `docs <https://pytorch.org/docs/stable/autograd.html#profiler>`__ for
111+
# valid sorting keys).
112+
#
113+
# .. Note::
114+
# When running profiler in a notebook, you might see entries like ``<ipython-input-18-193a910735e8>(13): forward``
115+
# instead of filenames in the stacktrace. These correspond to ``<notebook-cell>(line number): calling-function``.
116+
117+
print(prof.key_averages(group_by_stack_n=5).table(sort_by='self_cpu_time_total', row_limit=5))
118+
119+
"""
120+
(Some columns are omitted)
121+
122+
------------- ------------ ------------ ------------ ---------------------------------
123+
Name Self CPU % Self CPU Self CPU Mem Source Location
124+
------------- ------------ ------------ ------------ ---------------------------------
125+
MASK INDICES 87.88% 5.212s -953.67 Mb /mnt/xarfuse/.../torch/au
126+
<ipython-input-...>(10): forward
127+
/mnt/xarfuse/.../torch/nn
128+
<ipython-input-...>(9): <module>
129+
/mnt/xarfuse/.../IPython/
130+
131+
aten::copy_ 12.07% 715.848ms 0 b <ipython-input-...>(12): forward
132+
/mnt/xarfuse/.../torch/nn
133+
<ipython-input-...>(9): <module>
134+
/mnt/xarfuse/.../IPython/
135+
/mnt/xarfuse/.../IPython/
136+
137+
LINEAR PASS 0.01% 350.151us -20 b /mnt/xarfuse/.../torch/au
138+
<ipython-input-...>(7): forward
139+
/mnt/xarfuse/.../torch/nn
140+
<ipython-input-...>(9): <module>
141+
/mnt/xarfuse/.../IPython/
142+
143+
aten::addmm 0.00% 293.342us 0 b /mnt/xarfuse/.../torch/nn
144+
/mnt/xarfuse/.../torch/nn
145+
/mnt/xarfuse/.../torch/nn
146+
<ipython-input-...>(8): forward
147+
/mnt/xarfuse/.../torch/nn
148+
149+
aten::mean 0.00% 235.095us 0 b <ipython-input-...>(11): forward
150+
/mnt/xarfuse/.../torch/nn
151+
<ipython-input-...>(9): <module>
152+
/mnt/xarfuse/.../IPython/
153+
/mnt/xarfuse/.../IPython/
154+
155+
----------------------------- ------------ ---------- ----------------------------------
156+
Self CPU time total: 5.931s
157+
158+
"""
159+
160+
######################################################################
161+
# Improve memory performance
162+
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
163+
# Note that the most expensive operations - in terms of memory and time -
164+
# are at ``forward (10)`` representing the operations within MASK INDICES. Let’s try to
165+
# tackle the memory consumption first. We can see that the ``.to()``
166+
# operation at line 12 consumes 953.67 Mb. This operation copies ``mask`` to the CPU.
167+
# ``mask`` is initialized with a ``torch.double`` datatype. Can we reduce the memory footprint by casting
168+
# it to ``torch.float`` instead?
169+
#
170+
171+
model = MyModule(500, 10).cuda()
172+
input = torch.rand(128, 500).cuda()
173+
mask = torch.rand((500, 500, 500), dtype=torch.float).cuda()
174+
175+
# warm-up
176+
model(input, mask)
177+
178+
with profiler.profile(with_stack=True, profile_memory=True) as prof:
179+
out, idx = model(input, mask)
180+
181+
print(prof.key_averages(group_by_stack_n=5).table(sort_by='self_cpu_time_total', row_limit=5))
182+
183+
"""
184+
(Some columns are omitted)
185+
186+
----------------- ------------ ------------ ------------ --------------------------------
187+
Name Self CPU % Self CPU Self CPU Mem Source Location
188+
----------------- ------------ ------------ ------------ --------------------------------
189+
MASK INDICES 93.61% 5.006s -476.84 Mb /mnt/xarfuse/.../torch/au
190+
<ipython-input-...>(10): forward
191+
/mnt/xarfuse/ /torch/nn
192+
<ipython-input-...>(9): <module>
193+
/mnt/xarfuse/.../IPython/
194+
195+
aten::copy_ 6.34% 338.759ms 0 b <ipython-input-...>(12): forward
196+
/mnt/xarfuse/.../torch/nn
197+
<ipython-input-...>(9): <module>
198+
/mnt/xarfuse/.../IPython/
199+
/mnt/xarfuse/.../IPython/
200+
201+
aten::as_strided 0.01% 281.808us 0 b <ipython-input-...>(11): forward
202+
/mnt/xarfuse/.../torch/nn
203+
<ipython-input-...>(9): <module>
204+
/mnt/xarfuse/.../IPython/
205+
/mnt/xarfuse/.../IPython/
206+
207+
aten::addmm 0.01% 275.721us 0 b /mnt/xarfuse/.../torch/nn
208+
/mnt/xarfuse/.../torch/nn
209+
/mnt/xarfuse/.../torch/nn
210+
<ipython-input-...>(8): forward
211+
/mnt/xarfuse/.../torch/nn
212+
213+
aten::_local 0.01% 268.650us 0 b <ipython-input-...>(11): forward
214+
_scalar_dense /mnt/xarfuse/.../torch/nn
215+
<ipython-input-...>(9): <module>
216+
/mnt/xarfuse/.../IPython/
217+
/mnt/xarfuse/.../IPython/
218+
219+
----------------- ------------ ------------ ------------ --------------------------------
220+
Self CPU time total: 5.347s
221+
222+
"""
223+
224+
######################################################################
225+
#
226+
# The CPU memory footprint for this operation has halved.
227+
#
228+
# Improve time performance
229+
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
230+
# While the time consumed has also reduced a bit, it’s still too high.
231+
# Turns out copying a matrix from CUDA to CPU is pretty expensive!
232+
# The ``aten::copy_`` operator in ``forward (12)`` copies ``mask`` to CPU
233+
# so that it can use the NumPy ``argwhere`` function. ``aten::copy_`` at ``forward(13)``
234+
# copies the array back to CUDA as a tensor. We could eliminate both of these if we use a
235+
# ``torch`` function ``nonzero()`` here instead.
236+
#
237+
238+
class MyModule(nn.Module):
239+
def __init__(self, in_features: int, out_features: int, bias: bool = True):
240+
super(MyModule, self).__init__()
241+
self.linear = nn.Linear(in_features, out_features, bias)
242+
243+
def forward(self, input, mask):
244+
with profiler.record_function("LINEAR PASS"):
245+
out = self.linear(input)
246+
247+
with profiler.record_function("MASK INDICES"):
248+
threshold = out.sum(axis=1).mean()
249+
hi_idx = (mask > threshold).nonzero(as_tuple=True)
250+
251+
return out, hi_idx
252+
253+
254+
model = MyModule(500, 10).cuda()
255+
input = torch.rand(128, 500).cuda()
256+
mask = torch.rand((500, 500, 500), dtype=torch.float).cuda()
257+
258+
# warm-up
259+
model(input, mask)
260+
261+
with profiler.profile(with_stack=True, profile_memory=True) as prof:
262+
out, idx = model(input, mask)
263+
264+
print(prof.key_averages(group_by_stack_n=5).table(sort_by='self_cpu_time_total', row_limit=5))
265+
266+
"""
267+
(Some columns are omitted)
268+
269+
-------------- ------------ ------------ ------------ ---------------------------------
270+
Name Self CPU % Self CPU Self CPU Mem Source Location
271+
-------------- ------------ ------------ ------------ ---------------------------------
272+
aten::gt 57.17% 129.089ms 0 b <ipython-input-...>(12): forward
273+
/mnt/xarfuse/.../torch/nn
274+
<ipython-input-...>(25): <module>
275+
/mnt/xarfuse/.../IPython/
276+
/mnt/xarfuse/.../IPython/
277+
278+
aten::nonzero 37.38% 84.402ms 0 b <ipython-input-...>(12): forward
279+
/mnt/xarfuse/.../torch/nn
280+
<ipython-input-...>(25): <module>
281+
/mnt/xarfuse/.../IPython/
282+
/mnt/xarfuse/.../IPython/
283+
284+
INDEX SCORE 3.32% 7.491ms -119.21 Mb /mnt/xarfuse/.../torch/au
285+
<ipython-input-...>(10): forward
286+
/mnt/xarfuse/.../torch/nn
287+
<ipython-input-...>(25): <module>
288+
/mnt/xarfuse/.../IPython/
289+
290+
aten::as_strided 0.20% 441.587us 0 b <ipython-input-...>(12): forward
291+
/mnt/xarfuse/.../torch/nn
292+
<ipython-input-...>(25): <module>
293+
/mnt/xarfuse/.../IPython/
294+
/mnt/xarfuse/.../IPython/
295+
296+
aten::nonzero
297+
_numpy 0.18% 395.602us 0 b <ipython-input-...>(12): forward
298+
/mnt/xarfuse/.../torch/nn
299+
<ipython-input-...>(25): <module>
300+
/mnt/xarfuse/.../IPython/
301+
/mnt/xarfuse/.../IPython/
302+
-------------- ------------ ------------ ------------ ---------------------------------
303+
Self CPU time total: 225.801ms
304+
305+
"""
306+
307+
308+
######################################################################
309+
# Further Reading
310+
# ~~~~~~~~~~~~~~~~~
311+
# We have seen how Profiler can be used to investigate time and memory bottlenecks in PyTorch models.
312+
# Read more about Profiler here:
313+
#
314+
# - `Profiler Usage Recipe <https://pytorch.org/tutorials/recipes/recipes/profiler.html>`__
315+
# - `Profiling RPC-Based Workloads <https://pytorch.org/tutorials/recipes/distributed_rpc_profiling.html>`__
316+
# - `Profiler API Docs <https://pytorch.org/docs/stable/autograd.html?highlight=profiler#profiler>`__

index.rst

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -275,6 +275,13 @@ Welcome to PyTorch Tutorials
275275

276276
.. Model Optimization
277277
278+
.. customcarditem::
279+
:header: Performance Profiling in PyTorch
280+
:card_description: Learn how to use the PyTorch Profiler to benchmark your module's performance.
281+
:image: _static/img/thumbnails/cropped/profiler.png
282+
:link: beginner/profiler.html
283+
:tags: Model-Optimization,Best-Practice,Profiling
284+
278285
.. customcarditem::
279286
:header: Hyperparameter Tuning Tutorial
280287
:card_description: Learn how to use Ray Tune to find the best performing set of hyperparameters for your model.
@@ -534,6 +541,7 @@ Additional Resources
534541
:hidden:
535542
:caption: Model Optimization
536543

544+
beginner/profiler
537545
beginner/hyperparameter_tuning_tutorial
538546
intermediate/pruning_tutorial
539547
advanced/dynamic_quantization_tutorial

recipes_source/recipes/Captum_Recipe.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -136,7 +136,7 @@
136136
attribution_dog = np.transpose(attribution_dog.squeeze().cpu().detach().numpy(), (1,2,0))
137137

138138
vis_types = ["heat_map", "original_image"]
139-
vis_signs = ["all", "all"], # "positive", "negative", or "all" to show both
139+
vis_signs = ["all", "all"] # "positive", "negative", or "all" to show both
140140
# positive attribution indicates that the presence of the area increases the prediction score
141141
# negative attribution indicates distractor areas whose absence increases the score
142142

@@ -186,4 +186,4 @@
186186
#
187187
# Another useful post by Gilbert Tanner:
188188
# https://gilberttanner.com/blog/interpreting-pytorch-models-with-captum
189-
#
189+
#

0 commit comments

Comments
 (0)