|
| 1 | +""" |
| 2 | +PyTorch Profiler With TensorBoard |
| 3 | +==================================== |
| 4 | +This tutorial demonstrates how to use TensorBoard plugin with PyTorch Profiler |
| 5 | +to detect performance bottlenecks of the model. |
| 6 | +
|
| 7 | +Introduction |
| 8 | +------------ |
| 9 | +PyTorch 1.8 includes an updated profiler API capable of |
| 10 | +recording the CPU side operations as well as the CUDA kernel launches on the GPU side. |
| 11 | +The profiler can visualize this information |
| 12 | +in TensorBoard Plugin and provide analysis of the performance bottlenecks. |
| 13 | +
|
| 14 | +In this tutorial, we will use a simple Resnet model to demonstrate how to |
| 15 | +use TensorBoard plugin to analyze model performance. |
| 16 | +
|
| 17 | +Setup |
| 18 | +----- |
| 19 | +To install ``torch`` and ``torchvision`` use the following command: |
| 20 | +
|
| 21 | +:: |
| 22 | +
|
| 23 | + pip install torch torchvision |
| 24 | +
|
| 25 | +
|
| 26 | +""" |
| 27 | + |
| 28 | + |
| 29 | +###################################################################### |
| 30 | +# Steps |
| 31 | +# ----- |
| 32 | +# |
| 33 | +# 1. Prepare the data and model |
| 34 | +# 2. Use profiler to record execution events |
| 35 | +# 3. Run the profiler |
| 36 | +# 4. Use TensorBoard to view results and analyze performance |
| 37 | +# 5. Improve performance with the help of profiler |
| 38 | +# |
| 39 | +# 1. Prepare the data and model |
| 40 | +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ |
| 41 | +# |
| 42 | +# First, import all necessary libraries: |
| 43 | +# |
| 44 | + |
| 45 | +import torch |
| 46 | +import torch.nn |
| 47 | +import torch.optim |
| 48 | +import torch.profiler |
| 49 | +import torch.utils.data |
| 50 | +import torchvision.datasets |
| 51 | +import torchvision.models |
| 52 | +import torchvision.transforms as T |
| 53 | + |
| 54 | +###################################################################### |
| 55 | +# Then prepare the input data. For this tutorial, we use the CIFAR10 dataset. |
| 56 | +# Transform it to the desired format and use DataLoader to load each batch. |
| 57 | + |
| 58 | +transform = T.Compose( |
| 59 | + [T.Resize(224), |
| 60 | + T.ToTensor(), |
| 61 | + T.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]) |
| 62 | +train_set = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform) |
| 63 | +train_loader = torch.utils.data.DataLoader(train_set, batch_size=32, shuffle=True) |
| 64 | + |
| 65 | +###################################################################### |
| 66 | +# Next, create Resnet model, loss function, and optimizer objects. |
| 67 | +# To run on GPU, move model and loss to GPU device. |
| 68 | + |
| 69 | +device = torch.device("cuda:0") |
| 70 | +model = torchvision.models.resnet18(pretrained=True).cuda(device) |
| 71 | +criterion = torch.nn.CrossEntropyLoss().cuda(device) |
| 72 | +optimizer = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9) |
| 73 | +model.train() |
| 74 | + |
| 75 | + |
| 76 | +###################################################################### |
| 77 | +# Define the training step for each batch of input data. |
| 78 | + |
| 79 | +def train(data): |
| 80 | + inputs, labels = data[0].to(device=device), data[1].to(device=device) |
| 81 | + outputs = model(inputs) |
| 82 | + loss = criterion(outputs, labels) |
| 83 | + optimizer.zero_grad() |
| 84 | + loss.backward() |
| 85 | + optimizer.step() |
| 86 | + |
| 87 | + |
| 88 | +###################################################################### |
| 89 | +# 2. Use profiler to record execution events |
| 90 | +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ |
| 91 | +# |
| 92 | +# The profiler is enabled through the context manager and accepts several parameters, |
| 93 | +# some of the most useful are: |
| 94 | +# |
| 95 | +# - ``schedule`` - callable that takes step (int) as a single parameter |
| 96 | +# and returns the profiler action to perform at each step; |
| 97 | +# In this example with ``wait=1, warmup=1, active=5``, |
| 98 | +# profiler will skip the first step/iteration, |
| 99 | +# start warming up on the second, |
| 100 | +# record the following five iterations, |
| 101 | +# after which the trace will become available and on_trace_ready (when set) is called; |
| 102 | +# The cycle repeats starting with the next step until the loop exits. |
| 103 | +# During ``wait`` steps, the profiler does not work. |
| 104 | +# During ``warmup`` steps, the profiler starts profiling as warmup but does not record any events. |
| 105 | +# This is for reducing the profiling overhead. |
| 106 | +# The overhead at the beginning of profiling is high and easy to bring skew to the profiling result. |
| 107 | +# During ``active`` steps, the profiler works and records events. |
| 108 | +# - ``on_trace_ready`` - callable that is called at the end of each cycle; |
| 109 | +# In this example we use ``torch.profiler.tensorboard_trace_handler`` to generate result files for TensorBoard. |
| 110 | +# After profiling, result files will be saved into the ``./log/resnet18`` directory. |
| 111 | +# Specify this directory as a ``logdir`` parameter to analyze profile in TensorBoard. |
| 112 | +# - ``record_shapes`` - whether to record shapes of the operator inputs. |
| 113 | + |
| 114 | +with torch.profiler.profile( |
| 115 | + schedule=torch.profiler.schedule(wait=1, warmup=1, active=5), |
| 116 | + on_trace_ready=torch.profiler.tensorboard_trace_handler('./log/resnet18'), |
| 117 | + record_shapes=True |
| 118 | +) as prof: |
| 119 | + for step, batch_data in enumerate(train_loader): |
| 120 | + if step >= 7: |
| 121 | + break |
| 122 | + train(batch_data) |
| 123 | + prof.step() # Need call this at the end of each step to notify profiler of steps' boundary. |
| 124 | + |
| 125 | + |
| 126 | +###################################################################### |
| 127 | +# 3. Run the profiler |
| 128 | +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ |
| 129 | +# |
| 130 | +# Run the above code. The profiling result will be saved under ``./log`` directory. |
| 131 | + |
| 132 | + |
| 133 | +###################################################################### |
| 134 | +# 4. Use TensorBoard to view results and analyze performance |
| 135 | +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ |
| 136 | +# |
| 137 | +# Install PyTorch Profiler TensorBoard Plugin. |
| 138 | +# |
| 139 | +# :: |
| 140 | +# |
| 141 | +# pip install torch_tb_profiler |
| 142 | +# |
| 143 | + |
| 144 | +###################################################################### |
| 145 | +# Launch the TensorBoard. |
| 146 | +# |
| 147 | +# :: |
| 148 | +# |
| 149 | +# tensorboard --logdir=./log |
| 150 | +# |
| 151 | + |
| 152 | +###################################################################### |
| 153 | +# Open the TensorBoard profile URL in Google Chrome browser or Microsoft Edge browser. |
| 154 | +# |
| 155 | +# :: |
| 156 | +# |
| 157 | +# http://localhost:6006/#pytorch_profiler |
| 158 | +# |
| 159 | + |
| 160 | +###################################################################### |
| 161 | +# You could see Profiler plugin page as shown below. |
| 162 | +# |
| 163 | +# .. image:: ../../_static/img/profiler_overview1.png |
| 164 | +# :scale: 25 % |
| 165 | +# |
| 166 | +# The overview shows a high-level summary of model performance. |
| 167 | +# |
| 168 | +# The "Step Time Breakdown" shows distribution of time spent in each step over different categories of execution. |
| 169 | +# In this example, you can see the ``DataLoader`` overhead is significant. |
| 170 | +# |
| 171 | +# The bottom "Performance Recommendation" uses the profiling data |
| 172 | +# to automatically highlight likely bottlenecks, |
| 173 | +# and gives you actionable optimization suggestions. |
| 174 | +# |
| 175 | +# You can change the view page in left "Views" dropdown list. |
| 176 | +# |
| 177 | +# .. image:: ../../_static/img/profiler_views_list.png |
| 178 | +# :alt: |
| 179 | +# |
| 180 | +# The operator view displays the performance of every PyTorch operator |
| 181 | +# that is executed either on the host or device. |
| 182 | +# |
| 183 | +# The GPU kernel view shows all kernels’ time spent on GPU. |
| 184 | +# |
| 185 | +# The trace view shows timeline of profiled operators and GPU kernels. |
| 186 | +# You can select it to see details as below. |
| 187 | +# |
| 188 | +# .. image:: ../../_static/img/profiler_trace_view1.png |
| 189 | +# :scale: 25 % |
| 190 | +# |
| 191 | +# You can move the graph and zoom in/out with the help of right side toolbar. |
| 192 | +# |
| 193 | +# In this example, we can see the event prefixed with ``enumerate(DataLoader)`` costs a lot of time. |
| 194 | +# And during most of this period, the GPU is idle. |
| 195 | +# Because this function is loading data and transforming data on host side, |
| 196 | +# during which the GPU resource is wasted. |
| 197 | + |
| 198 | + |
| 199 | +###################################################################### |
| 200 | +# 5. Improve performance with the help of profiler |
| 201 | +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ |
| 202 | +# |
| 203 | +# The PyTorch DataLoader uses single process by default. |
| 204 | +# User could enable multi-process data loading by setting the parameter ``num_workers``. |
| 205 | +# `Here <https://pytorch.org/docs/stable/data.html#single-and-multi-process-data-loading>`_ is more details. |
| 206 | +# |
| 207 | +# In this example, we can set ``num_workers`` as below, |
| 208 | +# pass a different name such as ``./log/resnet18_4workers`` to tensorboard_trace_handler, and run it again. |
| 209 | +# |
| 210 | +# :: |
| 211 | +# |
| 212 | +# train_loader = torch.utils.data.DataLoader(train_set, batch_size=32, shuffle=True, num_workers=4) |
| 213 | +# |
| 214 | + |
| 215 | +###################################################################### |
| 216 | +# Then let’s choose the just profiled run in left "Runs" dropdown list. |
| 217 | +# |
| 218 | +# .. image:: ../../_static/img/profiler_overview2.png |
| 219 | +# :scale: 25 % |
| 220 | +# |
| 221 | +# From the above view, we can find the step time is reduced, |
| 222 | +# and the time reduction of ``DataLoader`` mainly contributes. |
| 223 | +# |
| 224 | +# .. image:: ../../_static/img/profiler_trace_view2.png |
| 225 | +# :scale: 25 % |
| 226 | +# |
| 227 | +# From the above view, we can see that the runtime of ``enumerate(DataLoader)`` is reduced, |
| 228 | +# and the GPU utilization is increased. |
| 229 | + |
| 230 | +###################################################################### |
| 231 | +# Learn More |
| 232 | +# ---------- |
| 233 | +# |
| 234 | +# Take a look at the following documents to continue your learning: |
| 235 | +# |
| 236 | +# - `Pytorch TensorBoard Profiler github <https://github.com/pytorch/kineto/tree/master/tb_plugin>`_ |
| 237 | +# - `torch.profiler API <https://pytorch.org/docs/master/profiler.html>`_ |
0 commit comments