Skip to content

Commit 74b7663

Browse files
authored
Add NNAPI tutorial (#1229)
1 parent 0f29eeb commit 74b7663

File tree

1 file changed

+185
-0
lines changed

1 file changed

+185
-0
lines changed
Lines changed: 185 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,185 @@
1+
(Prototype) Convert MobileNetV2 to NNAPI
2+
========================================
3+
4+
Introduction
5+
------------
6+
7+
This tutorial shows how to prepare a computer vision model to use
8+
`Android's Neural Networks API (NNAPI) <https://developer.android.com/ndk/guides/neuralnetworks>`_.
9+
NNAPI provides access to powerful and efficient computational cores
10+
on many modern Android devices.
11+
12+
PyTorch's NNAPI is currently in the "prototype" phase and only supports
13+
a limited range of operators, but we expect to solidify the integration
14+
and expand our operator support over time.
15+
16+
17+
Environment
18+
-----------
19+
20+
Install PyTorch and torchvision.
21+
This tutorial is currently incompatible with the latest trunk,
22+
so we recommend running
23+
``pip install --upgrade --pre --find-links https://download.pytorch.org/whl/nightly/cpu/torch_nightly.html torch==1.8.0.dev20201106+cpu torchvision==0.9.0.dev20201107+cpu``
24+
until this incompatibility is corrected.
25+
26+
27+
Model Preparation
28+
-----------------
29+
30+
First, we must prepare our model to execute with NNAPI.
31+
This step runs on your training server or laptop.
32+
The key conversion function to call is
33+
``torch.backends._nnapi.prepare.convert_model_to_nnapi``,
34+
but some extra steps are required to ensure that
35+
the model is properly structured.
36+
Most notably, quantizing the model is required
37+
in order to run the model on certain accelerators.
38+
39+
You can copy/paste this entire Python script and run it,
40+
or make your own modifications.
41+
By default, it will save the models to ``~/mobilenetv2-nnapi/``.
42+
Please create that directory first.
43+
44+
.. code:: python
45+
46+
#!/usr/bin/env python
47+
import sys
48+
import os
49+
import torch
50+
import torch.utils.bundled_inputs
51+
import torch.utils.mobile_optimizer
52+
import torch.backends._nnapi.prepare
53+
import torchvision.models.quantization.mobilenet
54+
from pathlib import Path
55+
56+
57+
# This script supports 3 modes of quantization:
58+
# - "none": Fully floating-point model.
59+
# - "core": Quantize the core of the model, but wrap it a
60+
# quantizer/dequantizer pair, so the interface uses floating point.
61+
# - "full": Quantize the model, and use quantized tensors
62+
# for input and output.
63+
#
64+
# "none" maintains maximum accuracy
65+
# "core" sacrifices some accuracy for performance,
66+
# but maintains the same interface.
67+
# "full" maximized performance (with the same accuracy as "core"),
68+
# but requires the application to use quantized tensors.
69+
#
70+
# There is a fourth option, not supported by this script,
71+
# where we include the quant/dequant steps as NNAPI operators.
72+
def make_mobilenetv2_nnapi(output_dir_path, quantize_mode):
73+
quantize_core, quantize_iface = {
74+
"none": (False, False),
75+
"core": (True, False),
76+
"full": (True, True),
77+
}[quantize_mode]
78+
79+
model = torchvision.models.quantization.mobilenet.mobilenet_v2(pretrained=True, quantize=quantize_core)
80+
model.eval()
81+
82+
# Fuse BatchNorm operators in the floating point model.
83+
# (Quantized models already have this done.)
84+
# Remove dropout for this inference-only use case.
85+
if not quantize_core:
86+
model.fuse_model()
87+
assert type(model.classifier[0]) == torch.nn.Dropout
88+
model.classifier[0] = torch.nn.Identity()
89+
90+
input_float = torch.zeros(1, 3, 224, 224)
91+
input_tensor = input_float
92+
93+
# If we're doing a quantized model, we need to trace only the quantized core.
94+
# So capture the quantizer and dequantizer, use them to prepare the input,
95+
# and replace them with identity modules so we can trace without them.
96+
if quantize_core:
97+
quantizer = model.quant
98+
dequantizer = model.dequant
99+
model.quant = torch.nn.Identity()
100+
model.dequant = torch.nn.Identity()
101+
input_tensor = quantizer(input_float)
102+
103+
# Many NNAPI backends prefer NHWC tensors, so convert our input to channels_last,
104+
# and set the "nnapi_nhwc" attribute for the converter.
105+
input_tensor = input_tensor.contiguous(memory_format=torch.channels_last)
106+
input_tensor.nnapi_nhwc = True
107+
108+
# Trace the model. NNAPI conversion only works with TorchScript models,
109+
# and traced models are more likely to convert successfully than scripted.
110+
with torch.no_grad():
111+
traced = torch.jit.trace(model, input_tensor)
112+
nnapi_model = torch.backends._nnapi.prepare.convert_model_to_nnapi(traced, input_tensor)
113+
114+
# If we're not using a quantized interface, wrap a quant/dequant around the core.
115+
if quantize_core and not quantize_iface:
116+
nnapi_model = torch.nn.Sequential(quantizer, nnapi_model, dequantizer)
117+
model.quant = quantizer
118+
model.dequant = dequantizer
119+
# Switch back to float input for benchmarking.
120+
input_tensor = input_float.contiguous(memory_format=torch.channels_last)
121+
122+
# Optimize the CPU model to make CPU-vs-NNAPI benchmarks fair.
123+
model = torch.utils.mobile_optimizer.optimize_for_mobile(torch.jit.script(model))
124+
125+
# Bundle sample inputs with the models for easier benchmarking.
126+
# This step is optional.
127+
class BundleWrapper(torch.nn.Module):
128+
def __init__(self, mod):
129+
super().__init__()
130+
self.mod = mod
131+
def forward(self, arg):
132+
return self.mod(arg)
133+
nnapi_model = torch.jit.script(BundleWrapper(nnapi_model))
134+
torch.utils.bundled_inputs.augment_model_with_bundled_inputs(
135+
model, [(torch.utils.bundled_inputs.bundle_large_tensor(input_tensor),)])
136+
torch.utils.bundled_inputs.augment_model_with_bundled_inputs(
137+
nnapi_model, [(torch.utils.bundled_inputs.bundle_large_tensor(input_tensor),)])
138+
139+
# Save both models.
140+
model.save(output_dir_path / ("mobilenetv2-quant_{}-cpu.pt".format(quantize_mode)))
141+
nnapi_model.save(output_dir_path / ("mobilenetv2-quant_{}-nnapi.pt".format(quantize_mode)))
142+
143+
144+
if __name__ == "__main__":
145+
for quantize_mode in ["none", "core", "full"]:
146+
make_mobilenetv2_nnapi(Path(os.environ["HOME"]) / "mobilenetv2-nnapi", quantize_mode)
147+
148+
149+
Running Benchmarks
150+
------------------
151+
152+
Now that the models are ready, we can benchmark them on our Android devices.
153+
See `our performance recipe <https://pytorch.org/tutorials/recipes/mobile_perf.html#android-benchmarking-setup>`_ for details.
154+
The best-performing models are likely to be the "fully-quantized" models:
155+
``mobilenetv2-quant_full-cpu.pt`` and ``mobilenetv2-quant_full-nnapi.pt``.
156+
157+
Because these models have bundled inputs, we can run the benchmark as follows:
158+
159+
.. code:: shell
160+
161+
./speed_benchmark_torch --pthreadpool_size=1 --model=mobilenetv2-quant_full-nnapi.pt --use_bundled_input=0 --warmup=5 --iter=200
162+
163+
Adjusting increasing the thread pool size can can reduce latency,
164+
at the cost of increased CPU usage.
165+
Omitting that argument will use one thread per big core.
166+
The CPU models can get improved performance (at the cost of memory usage)
167+
by passing ``--use_caching_allocator=true``.
168+
169+
170+
Integration
171+
-----------
172+
173+
The converted models are ordinary TorchScript models.
174+
You can use them in your app just like any other PyTorch model.
175+
See `https://pytorch.org/mobile/android/ <https://pytorch.org/mobile/android/>`_
176+
for an introduction to using PyTorch on Android.
177+
178+
179+
Learn More
180+
----------
181+
182+
- Learn more about optimization in our
183+
`Mobile Performance Recipe <https://pytorch.org/tutorials/recipes/mobile_perf.html>`_
184+
- `MobileNetV2 <https://pytorch.org/hub/pytorch_vision_mobilenet_v2/>`_ from torchvision
185+
- Information about `NNAPI <https://developer.android.com/ndk/guides/neuralnetworks>`_

0 commit comments

Comments
 (0)