|
| 1 | +(Prototype) Convert MobileNetV2 to NNAPI |
| 2 | +======================================== |
| 3 | + |
| 4 | +Introduction |
| 5 | +------------ |
| 6 | + |
| 7 | +This tutorial shows how to prepare a computer vision model to use |
| 8 | +`Android's Neural Networks API (NNAPI) <https://developer.android.com/ndk/guides/neuralnetworks>`_. |
| 9 | +NNAPI provides access to powerful and efficient computational cores |
| 10 | +on many modern Android devices. |
| 11 | + |
| 12 | +PyTorch's NNAPI is currently in the "prototype" phase and only supports |
| 13 | +a limited range of operators, but we expect to solidify the integration |
| 14 | +and expand our operator support over time. |
| 15 | + |
| 16 | + |
| 17 | +Model Preparation |
| 18 | +----------------- |
| 19 | + |
| 20 | +First, we must prepare our model to execute with NNAPI. |
| 21 | +This step runs on your training server or laptop. |
| 22 | +The key conversion function to call is |
| 23 | +``torch.backends._nnapi.prepare.convert_model_to_nnapi``, |
| 24 | +but some extra steps are required to ensure that |
| 25 | +the model is properly structured. |
| 26 | +Most notably, quantizing the model is required |
| 27 | +in order to run the model on certain hardware backends. |
| 28 | + |
| 29 | +You can copy/paste this entire Python script and run it, |
| 30 | +or make your own modifications. |
| 31 | +By default, it will save the models to ``~/mobilenetv2-nnapi/``. |
| 32 | + |
| 33 | +.. code:: python |
| 34 | +
|
| 35 | + #!/usr/bin/env python |
| 36 | + import sys |
| 37 | + import os |
| 38 | + import torch |
| 39 | + import torch.utils.bundled_inputs |
| 40 | + import torch.utils.mobile_optimizer |
| 41 | + import torch.backends._nnapi.prepare |
| 42 | + import torchvision.models.quantization.mobilenet |
| 43 | + from pathlib import Path |
| 44 | +
|
| 45 | +
|
| 46 | + # This script supports 3 modes of quantization: |
| 47 | + # - "none": Fully floating-point model. |
| 48 | + # - "core": Quantize the core of the model, but wrap it a |
| 49 | + # quantizer/dequantizer pair, so the interface uses floating point. |
| 50 | + # - "full": Quantize the model, and use quantize tensors |
| 51 | + # for input and output. |
| 52 | + # |
| 53 | + # "none" maintains maximum accuracy |
| 54 | + # "core" sacrifices some accuracy for performance, |
| 55 | + # but maintains the same interface. |
| 56 | + # "full" maximized performance (with the same accuracy as "core"), |
| 57 | + # but requires the application to use quantized tensors. |
| 58 | + # |
| 59 | + # There is a fourth option, not supported by this script, |
| 60 | + # where we include the quant/dequant steps as NNAPI operators. |
| 61 | + def make_mobilenetv2_nnapi(output_dir_path, quantize_mode): |
| 62 | + quantize_core, quantize_iface = { |
| 63 | + "none": (False, False), |
| 64 | + "core": (True, False), |
| 65 | + "full": (True, True), |
| 66 | + }[quantize_mode] |
| 67 | +
|
| 68 | + model = torchvision.models.quantization.mobilenet.mobilenet_v2(pretrained=True, quantize=quantize_core) |
| 69 | + model.eval() |
| 70 | +
|
| 71 | + # Fuse BatchNorm operators in the floating point model. |
| 72 | + # (Quantized models already have this done.) |
| 73 | + # Remove dropout for this inference-only use case. |
| 74 | + if not quantize_core: |
| 75 | + model.fuse_model() |
| 76 | + assert type(model.classifier[0]) == torch.nn.Dropout |
| 77 | + model.classifier[0] = torch.nn.Identity() |
| 78 | +
|
| 79 | + input_float = torch.zeros(1, 3, 224, 224) |
| 80 | + input_tensor = input_float |
| 81 | +
|
| 82 | + # If we're doing a quantized model, we need to trace only the quantized core. |
| 83 | + # So capture the quantizer and dequantizer, use them to prepare the input, |
| 84 | + # and replace them with identity modules so we can trace without them. |
| 85 | + if quantize_core: |
| 86 | + quantizer = model.quant |
| 87 | + dequantizer = model.dequant |
| 88 | + model.quant = torch.nn.Identity() |
| 89 | + model.dequant = torch.nn.Identity() |
| 90 | + input_tensor = quantizer(input_float) |
| 91 | +
|
| 92 | + # Many NNAPI backends prefer NHWC tensors, so convert our input to channels_last, |
| 93 | + # and set the "nnapi_nhwc" attribute for the converter. |
| 94 | + input_tensor = input_tensor.contiguous(memory_format=torch.channels_last) |
| 95 | + input_tensor.nnapi_nhwc = True |
| 96 | +
|
| 97 | + # Trace the model. NNAPI conversion only works with TorchScript models, |
| 98 | + # and traced models are more likely to convert successfully than scripted. |
| 99 | + with torch.no_grad(): |
| 100 | + traced = torch.jit.trace(model, input_tensor) |
| 101 | + nnapi_model = torch.backends._nnapi.prepare.convert_model_to_nnapi(traced, input_tensor) |
| 102 | +
|
| 103 | + # If we're not using a quantized interface, wrap a quant/dequant around the core. |
| 104 | + if quantize_core and not quantize_iface: |
| 105 | + nnapi_model = torch.nn.Sequential(quantizer, nnapi_model, dequantizer) |
| 106 | + model.quant = quantizer |
| 107 | + model.dequant = dequantizer |
| 108 | + # Switch back to float input for benchmarking. |
| 109 | + input_tensor = input_float.contiguous(memory_format=torch.channels_last) |
| 110 | +
|
| 111 | + # Optimize the CPU model to make CPU-vs-NNAPI benchmarks fair. |
| 112 | + model = torch.utils.mobile_optimizer.optimize_for_mobile(torch.jit.script(model)) |
| 113 | +
|
| 114 | + # Bundle sample inputs with the models for easier benchmarking. |
| 115 | + # This step is optional. |
| 116 | + class BundleWrapper(torch.nn.Module): |
| 117 | + def __init__(self, mod): |
| 118 | + super().__init__() |
| 119 | + self.mod = mod |
| 120 | + def forward(self, arg): |
| 121 | + return self.mod(arg) |
| 122 | + nnapi_model = torch.jit.script(BundleWrapper(nnapi_model)) |
| 123 | + torch.utils.bundled_inputs.augment_model_with_bundled_inputs( |
| 124 | + model, [(torch.utils.bundled_inputs.bundle_large_tensor(input_tensor),)]) |
| 125 | + torch.utils.bundled_inputs.augment_model_with_bundled_inputs( |
| 126 | + nnapi_model, [(torch.utils.bundled_inputs.bundle_large_tensor(input_tensor),)]) |
| 127 | +
|
| 128 | + # Save both models. |
| 129 | + model.save(output_dir_path / ("mobilenetv2-quant_{}-cpu.pt".format(quantize_mode))) |
| 130 | + nnapi_model.save(output_dir_path / ("mobilenetv2-quant_{}-nnapi.pt".format(quantize_mode))) |
| 131 | +
|
| 132 | +
|
| 133 | + if __name__ == "__main__": |
| 134 | + for quantize_mode in ["none", "core", "full"]: |
| 135 | + make_mobilenetv2_nnapi(Path(os.environ["HOME"]) / "mobilenetv2-nnapi", quantize_mode) |
| 136 | +
|
| 137 | +
|
| 138 | +Running Benchmarks |
| 139 | +------------------ |
| 140 | + |
| 141 | +Now that the models are ready, we can benchmark them on our Android devices. |
| 142 | +See `our performance recipe <https://pytorch.org/tutorials/recipes/mobile_perf.html#android-benchmarking-setup>`_ for details. |
| 143 | +The best-performing models are likely to be the "fully-quantized" models: |
| 144 | +``mobilenetv2-quant_full-cpu.pt`` and ``mobilenetv2-quant_full-nnapi.pt``. |
| 145 | + |
| 146 | +Because these models have bundled inputs, we can run the benchmark as follows: |
| 147 | + |
| 148 | +.. code:: shell |
| 149 | + ./speed_benchmark_torch --pthreadpool_size=1 --model=mobilenetv2-quant_full-nnapi.pt --use_bundled_input=0 --warmup=5 --iter=200 |
| 150 | +
|
| 151 | +Adjusting increasing the thread pool size can can reduce latency, |
| 152 | +at the cost of increased CPU usage. |
| 153 | +Omitting that argument will use one thread per big core. |
| 154 | +The CPU models can get improved performance (at the cost of memory usage) |
| 155 | +by passing ``--use_caching_allocator=true``. |
| 156 | + |
| 157 | + |
| 158 | +Integration |
| 159 | +----------- |
| 160 | + |
| 161 | +The converted models are ordinary TorchScript models. |
| 162 | +You can use them in your app just like any other PyTorch model. |
| 163 | +See `https://pytorch.org/mobile/android/ <https://pytorch.org/mobile/android/>`_ |
| 164 | +for an introduction to using PyTorch on Android. |
| 165 | + |
| 166 | + |
| 167 | +Learn More |
| 168 | +---------- |
| 169 | + |
| 170 | +- Learn more about optimization in our |
| 171 | + `Mobile Performance Recipe <https://pytorch.org/tutorials/recipes/mobile_perf.html>`_ |
| 172 | +- `MobileNetV2 <https://pytorch.org/hub/pytorch_vision_mobilenet_v2/>`_ from torchvision |
| 173 | +- Information about `NNAPI <https://developer.android.com/ndk/guides/neuralnetworks>`_ |
0 commit comments