Skip to content

Commit 24946b2

Browse files
gaoteng-gitmaxluk
and
maxluk
authored
Add a tutorial for Tensorboard profiler (#1380)
* add tensorboard_profiler tutorial * Update intermediate_source/tensorboard_profiler_tutorial.py Co-authored-by: maxluk <maxluk@microsoft.com> * update title * remove testing on windows because kineto doesn't support windows now * rename * Update with the help of Ilia Co-authored-by: Teng Gao <tegao@microsoft.com> Co-authored-by: maxluk <maxluk@microsoft.com>
1 parent d763c40 commit 24946b2

File tree

8 files changed

+246
-0
lines changed

8 files changed

+246
-0
lines changed

.circleci/scripts/build_for_windows.sh

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@ if [[ "${CIRCLE_JOB}" == *worker_* ]]; then
4949
python $DIR/remove_runnable_code.py advanced_source/static_quantization_tutorial.py advanced_source/static_quantization_tutorial.py || true
5050
python $DIR/remove_runnable_code.py beginner_source/hyperparameter_tuning_tutorial.py beginner_source/hyperparameter_tuning_tutorial.py || true
5151
python $DIR/remove_runnable_code.py beginner_source/audio_preprocessing_tutorial.py beginner_source/audio_preprocessing_tutorial.py || true
52+
python $DIR/remove_runnable_code.py intermediate_source/tensorboard_profiler_tutorial.py intermediate_source/tensorboard_profiler_tutorial.py || true
5253
# Temp remove for mnist download issue. (Re-enabled for 1.8.1)
5354
# python $DIR/remove_runnable_code.py beginner_source/fgsm_tutorial.py beginner_source/fgsm_tutorial.py || true
5455

_static/img/profiler_overview1.PNG

133 KB
Loading

_static/img/profiler_overview2.PNG

77.3 KB
Loading

_static/img/profiler_trace_view1.PNG

128 KB
Loading

_static/img/profiler_trace_view2.PNG

133 KB
Loading

_static/img/profiler_views_list.PNG

67.8 KB
Loading

index.rst

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -318,6 +318,13 @@ Welcome to PyTorch Tutorials
318318
:link: beginner/profiler.html
319319
:tags: Model-Optimization,Best-Practice,Profiling
320320

321+
.. customcarditem::
322+
:header: Performance Profiling in Tensorboard
323+
:card_description: Learn how to use tensorboard plugin to profile and analyze your model's performance.
324+
:image: _static/img/thumbnails/cropped/profiler.png
325+
:link: intermediate/tensorboard_profiler_tutorial.html
326+
:tags: Model-Optimization,Best-Practice,Profiling
327+
321328
.. customcarditem::
322329
:header: Hyperparameter Tuning Tutorial
323330
:card_description: Learn how to use Ray Tune to find the best performing set of hyperparameters for your model.
@@ -627,6 +634,7 @@ Additional Resources
627634
:caption: Model Optimization
628635

629636
beginner/profiler
637+
intermediate/tensorboard_profiler_tutorial
630638
beginner/hyperparameter_tuning_tutorial
631639
intermediate/parametrizations
632640
intermediate/pruning_tutorial
Lines changed: 237 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,237 @@
1+
"""
2+
PyTorch Profiler With TensorBoard
3+
====================================
4+
This tutorial demonstrates how to use TensorBoard plugin with PyTorch Profiler
5+
to detect performance bottlenecks of the model.
6+
7+
Introduction
8+
------------
9+
PyTorch 1.8 includes an updated profiler API capable of
10+
recording the CPU side operations as well as the CUDA kernel launches on the GPU side.
11+
The profiler can visualize this information
12+
in TensorBoard Plugin and provide analysis of the performance bottlenecks.
13+
14+
In this tutorial, we will use a simple Resnet model to demonstrate how to
15+
use TensorBoard plugin to analyze model performance.
16+
17+
Setup
18+
-----
19+
To install ``torch`` and ``torchvision`` use the following command:
20+
21+
::
22+
23+
pip install torch torchvision
24+
25+
26+
"""
27+
28+
29+
######################################################################
30+
# Steps
31+
# -----
32+
#
33+
# 1. Prepare the data and model
34+
# 2. Use profiler to record execution events
35+
# 3. Run the profiler
36+
# 4. Use TensorBoard to view results and analyze performance
37+
# 5. Improve performance with the help of profiler
38+
#
39+
# 1. Prepare the data and model
40+
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
41+
#
42+
# First, import all necessary libraries:
43+
#
44+
45+
import torch
46+
import torch.nn
47+
import torch.optim
48+
import torch.profiler
49+
import torch.utils.data
50+
import torchvision.datasets
51+
import torchvision.models
52+
import torchvision.transforms as T
53+
54+
######################################################################
55+
# Then prepare the input data. For this tutorial, we use the CIFAR10 dataset.
56+
# Transform it to the desired format and use DataLoader to load each batch.
57+
58+
transform = T.Compose(
59+
[T.Resize(224),
60+
T.ToTensor(),
61+
T.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
62+
train_set = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
63+
train_loader = torch.utils.data.DataLoader(train_set, batch_size=32, shuffle=True)
64+
65+
######################################################################
66+
# Next, create Resnet model, loss function, and optimizer objects.
67+
# To run on GPU, move model and loss to GPU device.
68+
69+
device = torch.device("cuda:0")
70+
model = torchvision.models.resnet18(pretrained=True).cuda(device)
71+
criterion = torch.nn.CrossEntropyLoss().cuda(device)
72+
optimizer = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
73+
model.train()
74+
75+
76+
######################################################################
77+
# Define the training step for each batch of input data.
78+
79+
def train(data):
80+
inputs, labels = data[0].to(device=device), data[1].to(device=device)
81+
outputs = model(inputs)
82+
loss = criterion(outputs, labels)
83+
optimizer.zero_grad()
84+
loss.backward()
85+
optimizer.step()
86+
87+
88+
######################################################################
89+
# 2. Use profiler to record execution events
90+
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
91+
#
92+
# The profiler is enabled through the context manager and accepts several parameters,
93+
# some of the most useful are:
94+
#
95+
# - ``schedule`` - callable that takes step (int) as a single parameter
96+
# and returns the profiler action to perform at each step;
97+
# In this example with ``wait=1, warmup=1, active=5``,
98+
# profiler will skip the first step/iteration,
99+
# start warming up on the second,
100+
# record the following five iterations,
101+
# after which the trace will become available and on_trace_ready (when set) is called;
102+
# The cycle repeats starting with the next step until the loop exits.
103+
# During ``wait`` steps, the profiler does not work.
104+
# During ``warmup`` steps, the profiler starts profiling as warmup but does not record any events.
105+
# This is for reducing the profiling overhead.
106+
# The overhead at the beginning of profiling is high and easy to bring skew to the profiling result.
107+
# During ``active`` steps, the profiler works and records events.
108+
# - ``on_trace_ready`` - callable that is called at the end of each cycle;
109+
# In this example we use ``torch.profiler.tensorboard_trace_handler`` to generate result files for TensorBoard.
110+
# After profiling, result files will be saved into the ``./log/resnet18`` directory.
111+
# Specify this directory as a ``logdir`` parameter to analyze profile in TensorBoard.
112+
# - ``record_shapes`` - whether to record shapes of the operator inputs.
113+
114+
with torch.profiler.profile(
115+
schedule=torch.profiler.schedule(wait=1, warmup=1, active=5),
116+
on_trace_ready=torch.profiler.tensorboard_trace_handler('./log/resnet18'),
117+
record_shapes=True
118+
) as prof:
119+
for step, batch_data in enumerate(train_loader):
120+
if step >= 7:
121+
break
122+
train(batch_data)
123+
prof.step() # Need call this at the end of each step to notify profiler of steps' boundary.
124+
125+
126+
######################################################################
127+
# 3. Run the profiler
128+
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
129+
#
130+
# Run the above code. The profiling result will be saved under ``./log`` directory.
131+
132+
133+
######################################################################
134+
# 4. Use TensorBoard to view results and analyze performance
135+
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
136+
#
137+
# Install PyTorch Profiler TensorBoard Plugin.
138+
#
139+
# ::
140+
#
141+
# pip install torch_tb_profiler
142+
#
143+
144+
######################################################################
145+
# Launch the TensorBoard.
146+
#
147+
# ::
148+
#
149+
# tensorboard --logdir=./log
150+
#
151+
152+
######################################################################
153+
# Open the TensorBoard profile URL in Google Chrome browser or Microsoft Edge browser.
154+
#
155+
# ::
156+
#
157+
# http://localhost:6006/#pytorch_profiler
158+
#
159+
160+
######################################################################
161+
# You could see Profiler plugin page as shown below.
162+
#
163+
# .. image:: ../../_static/img/profiler_overview1.png
164+
# :scale: 25 %
165+
#
166+
# The overview shows a high-level summary of model performance.
167+
#
168+
# The "Step Time Breakdown" shows distribution of time spent in each step over different categories of execution.
169+
# In this example, you can see the ``DataLoader`` overhead is significant.
170+
#
171+
# The bottom "Performance Recommendation" uses the profiling data
172+
# to automatically highlight likely bottlenecks,
173+
# and gives you actionable optimization suggestions.
174+
#
175+
# You can change the view page in left "Views" dropdown list.
176+
#
177+
# .. image:: ../../_static/img/profiler_views_list.png
178+
# :alt:
179+
#
180+
# The operator view displays the performance of every PyTorch operator
181+
# that is executed either on the host or device.
182+
#
183+
# The GPU kernel view shows all kernels’ time spent on GPU.
184+
#
185+
# The trace view shows timeline of profiled operators and GPU kernels.
186+
# You can select it to see details as below.
187+
#
188+
# .. image:: ../../_static/img/profiler_trace_view1.png
189+
# :scale: 25 %
190+
#
191+
# You can move the graph and zoom in/out with the help of right side toolbar.
192+
#
193+
# In this example, we can see the event prefixed with ``enumerate(DataLoader)`` costs a lot of time.
194+
# And during most of this period, the GPU is idle.
195+
# Because this function is loading data and transforming data on host side,
196+
# during which the GPU resource is wasted.
197+
198+
199+
######################################################################
200+
# 5. Improve performance with the help of profiler
201+
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
202+
#
203+
# The PyTorch DataLoader uses single process by default.
204+
# User could enable multi-process data loading by setting the parameter ``num_workers``.
205+
# `Here <https://pytorch.org/docs/stable/data.html#single-and-multi-process-data-loading>`_ is more details.
206+
#
207+
# In this example, we can set ``num_workers`` as below,
208+
# pass a different name such as ``./log/resnet18_4workers`` to tensorboard_trace_handler, and run it again.
209+
#
210+
# ::
211+
#
212+
# train_loader = torch.utils.data.DataLoader(train_set, batch_size=32, shuffle=True, num_workers=4)
213+
#
214+
215+
######################################################################
216+
# Then let’s choose the just profiled run in left "Runs" dropdown list.
217+
#
218+
# .. image:: ../../_static/img/profiler_overview2.png
219+
# :scale: 25 %
220+
#
221+
# From the above view, we can find the step time is reduced,
222+
# and the time reduction of ``DataLoader`` mainly contributes.
223+
#
224+
# .. image:: ../../_static/img/profiler_trace_view2.png
225+
# :scale: 25 %
226+
#
227+
# From the above view, we can see that the runtime of ``enumerate(DataLoader)`` is reduced,
228+
# and the GPU utilization is increased.
229+
230+
######################################################################
231+
# Learn More
232+
# ----------
233+
#
234+
# Take a look at the following documents to continue your learning:
235+
#
236+
# - `Pytorch TensorBoard Profiler github <https://github.com/pytorch/kineto/tree/master/tb_plugin>`_
237+
# - `torch.profiler API <https://pytorch.org/docs/master/profiler.html>`_

0 commit comments

Comments
 (0)