Skip to content

Commit ff05ebc

Browse files
committed
[mobile] Mobile Perf Recipe
1 parent f7d7360 commit ff05ebc

File tree

3 files changed

+209
-0
lines changed

3 files changed

+209
-0
lines changed

index.rst

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -288,6 +288,13 @@ Welcome to PyTorch Tutorials
288288
:link: advanced/static_quantization_tutorial.html
289289
:tags: Image/Video,Quantization,Model-Optimization
290290

291+
.. customcarditem::
292+
:header: PyTorch Mobile Performance Recipes
293+
:card_description: List of recipes for performance optimizations for using PyTorch on Mobile.
294+
:image: _static/img/thumbnails/cropped/experimental-Dynamic-Quantization-on-an-LSTM-Word-Language-Model.png
295+
:link: recipes/recipes/mobile_perf.html
296+
:tags: Mobile,Model-Optimization
297+
291298
.. Parallel-and-Distributed-Training
292299
293300
.. customcarditem::

recipes_source/recipes/README.txt

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,3 +52,7 @@ PyTorch Recipes
5252
13. zeroing_out_gradients.py
5353
Zeroing out gradients
5454
https://pytorch.org/tutorials/recipes/recipes/zeroing_out_gradients.html
55+
56+
14. mobile_perf.py
57+
PyTorch Mobile Performance Recipes
58+
https://pytorch.org/tutorials/recipes/recipes/mobile_perf.html

recipes_source/recipes/mobile_perf.py

Lines changed: 198 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,198 @@
1+
######################################################################
2+
# Pytorch Mobile Performance Recipes
3+
# ==================================
4+
# Introduction
5+
# ------------
6+
# Performance (aka latency) is crucial to most, if not all,
7+
# applications and use-cases of ML model inference on mobile devices.
8+
#
9+
# Today, PyTorch executes the models on the CPU backend pending availability
10+
# of other hardware backends such as GPU, DSP, and NPU.
11+
#
12+
13+
######################################################################
14+
# Model preparation
15+
# -----------------
16+
#
17+
# Next recipes you can take (offline) while preparing the model
18+
# to have an optimized model that will probably have shorter execution time
19+
# (higher performance, lower latency) on the mobile device.
20+
21+
######################################################################
22+
# 1. Use torch.utils.mobile_optimizer
23+
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
24+
#
25+
# Torch mobile_optimizer package does several optimizations with the model,
26+
# which will help to conv2d and linear operations.
27+
# It pre-packs model weights in an optimized format and fuses ops above with relu
28+
# if it is the next operation.
29+
30+
from torch.utils.mobile_optimizer import optimize_for_mobile
31+
traced_model = torch.jit.load("input_model_path")
32+
optimized_model = optimize_for_mobile(traced_model)
33+
torch.jit.save(optimized_model, "output_model_path")
34+
35+
######################################################################
36+
# 2. Fuse operators using ``torch.quantization.fuse_modules``
37+
# Do not be confused that fuse_modules is in the quantization package.
38+
# It works for all types of torch script modules.
39+
# ``torch.quantization.fuse_modules`` fuses a list of modules into a single module.
40+
# It fuses only the following sequence of modules:
41+
#
42+
# - Convolution, Batch normalization
43+
# - Convolution, Batch normalization, Relu
44+
# - Convolution, Relu
45+
# - Linear, Relu
46+
#
47+
# This script shows how to fuse Batch Normalization in torchvision mobileNetV2.
48+
#
49+
50+
import torch
51+
import torchvision
52+
53+
# ConvBNReLU: nn.Sequential[nn.Conv2d, nn.BatchNorm2d, nn.ReLU]
54+
# InvertedResidual: nn.Sequential[ConvBnReLU, ConvBNReLU, nn.Conv2d, nn.BatchNorm2d]
55+
def fuse_model(model):
56+
for idx, m in enumerate(model.modules()):
57+
name = m._get_name();
58+
if name == 'ConvBNReLU':
59+
torch.quantization.fuse_modules(m, ['0', '1'], inplace=True)
60+
if name == 'InvertedResidual':
61+
for idx in range(len(m.conv)):
62+
if type(m.conv[idx]) == torch.nn.Conv2d:
63+
torch.quantization.fuse_modules(m.conv, [str(idx), str(idx + 1)], inplace=True)
64+
65+
m = torchvision.models.mobilenet_v2(pretrained=True)
66+
m.eval()
67+
fuse_model(m)
68+
torch.jit.trace(m, torch.rand(1, 3, 224, 224)).save("mobilenetV2-bnfused.pt")
69+
70+
######################################################################
71+
#
72+
# Quantization
73+
# ------------
74+
#
75+
76+
######################################################################
77+
# 3. Try a quantized version of your model
78+
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
79+
#
80+
# You can find more about PyTorch quantization in
81+
# `the dedicated tutorial <https://pytorch.org/blog/introduction-to-quantization-on-pytorch/>`_.
82+
#
83+
# Quantization of the model not only moves computation to int8,
84+
# but also reduces the size of your model on a disk.
85+
# That size reduction helps to reduce disk read operations during the first load of the model and decreases the amount of RAM.
86+
# Both of those resources can be crucial for the performance of mobile applications.
87+
#
88+
89+
import torch
90+
import torchvision
91+
92+
supported_qengines = torch.backends.quantized.supported_engines
93+
print(supported_qengines)
94+
model = torchvision.models.quantization.__dict__['mobilenet_v2'](pretrained=True, quantize=True)
95+
torch.backends.quantized.engine='qnnpack'
96+
model.eval()
97+
script_model = torch.jit.script(model)
98+
x = torch.rand(1, 3 , 224, 224)
99+
y = script_model(x)
100+
torch.jit.save(script_model, 'mobilenetV2_quantized.pt')
101+
102+
103+
######################################################################
104+
# Android API Recommendations
105+
# ---------------------------
106+
#
107+
# 4. Android. Reusing tensors for forward.
108+
#
109+
# Memory is a critical resource for android performance, especially on old devices.
110+
# Tensors can need a significant amount of memory.
111+
# For example, standard computer vision tensor contains 1*3*224*224 elements,
112+
# assuming that data type is float and will need 588Kb of memory.
113+
114+
FloatBuffer buffer = Tensor.allocateFloatBuffer(1*3*224*224);
115+
Tensor tensor = Tensor.fromBlob(buffer, new long[]{1, 3, 224, 224});
116+
117+
118+
######################################################################
119+
# .. Note::
120+
# Here we allocate native memory as ``java.nio.FloatBuffer`` and creating ``org.pytorch.Tensor`` which storage will be pointing to the memory of the allocated buffer.
121+
#
122+
# For most of the use cases, we do not do model forward only once, repeating it with some frequency or as fast as possible.
123+
#
124+
# If we are doing new memory allocation for every module forward - that will be suboptimal.
125+
# Instead of this, we can reuse the same memory that we allocated on the previous step, fill it with new data, and run module forward again on the same tensor object.
126+
#
127+
# You can check how it looks in code in `pytorch android application example <https://github.com/pytorch/android-demo-app/blob/master/PyTorchDemoApp/app/src/main/java/org/pytorch/demo/vision/ImageClassificationActivity.java#L174>`_.
128+
129+
protected AnalysisResult analyzeImage(ImageProxy image, int rotationDegrees) {
130+
if (mModule == null) {
131+
mModule = Module.load(moduleFileAbsoluteFilePath);
132+
mInputTensorBuffer =
133+
Tensor.allocateFloatBuffer(3 * 224 * 224);
134+
mInputTensor = Tensor.fromBlob(mInputTensorBuffer, new long[]{1, 3, 224, 224});
135+
}
136+
137+
TensorImageUtils.imageYUV420CenterCropToFloatBuffer(
138+
image.getImage(), rotationDegrees,
139+
224, 224,
140+
TensorImageUtils.TORCHVISION_NORM_MEAN_RGB,
141+
TensorImageUtils.TORCHVISION_NORM_STD_RGB,
142+
mInputTensorBuffer, 0);
143+
144+
Tensor outputTensor = mModule.forward(IValue.from(mInputTensor)).toTensor();
145+
}
146+
147+
148+
######################################################################
149+
# .. Note::
150+
# Member fields ``mModule``, ``mInputTensorBuffer`` and ``mInputTensor`` are initialized only once
151+
# and buffer is refilled using ``org.pytorch.torchvision.TensorImageUtils.imageYUV420CenterCropToFloatBuffer``.
152+
#
153+
154+
######################################################################
155+
# Benchmarking
156+
# ------------
157+
#
158+
# The best way to benchmark (to check if optimizations helped your use case) - to measure your particular use case that you want to optimize, as performance behavior can vary in different environments.
159+
#
160+
# PyTorch distribution provides a way to benchmark naked binary that runs the model forward,
161+
# this approach can give more stable measurements rather than testing inside the application.
162+
#
163+
164+
######################################################################
165+
# Android
166+
# -------
167+
#
168+
# For this you first need to build benchmark binary:
169+
#
170+
171+
<from-your-root-pytorch-dir>
172+
rm -rf build_android
173+
BUILD_PYTORCH_MOBILE=1 ANDROID_ABI=arm64-v8a ./scripts/build_android.sh -DBUILD_BINARY=ON
174+
175+
176+
######################################################################
177+
# You should have arm64 binary at: ``build_android/bin/speed_benchmark_torch``.
178+
# This binary takes ``--model=<path-to-model>``, ``--input_dim="1,3,224,224"`` as dimension information for the input and ``--input_type="float"`` as the type of the input as arguments.
179+
#
180+
# Once you have your android device connected,
181+
# push speedbenchark_torch binary and your model to the phone:
182+
183+
adb push <speedbenchmark-torch> /data/local/tmp
184+
adb push <path-to-scripted-model> /data/local/tmp
185+
186+
187+
######################################################################
188+
# Now we are ready to benchmark your model:
189+
#
190+
191+
adb shell "/data/local/tmp/speed_benchmark_torch --model="/data/local/tmp/model.pt" --input_dims="1,3,224,224" --input_type="float"
192+
----- output -----
193+
Starting benchmark.
194+
Running warmup runs.
195+
Main runs.
196+
Main run finished. Microseconds per iter: 121318. Iters per second: 8.24281
197+
198+

0 commit comments

Comments
 (0)