Skip to content

1.7 release #1206

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Oct 27, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
108 changes: 108 additions & 0 deletions prototype_source/ios_gpu_workflow.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
(Prototype) Use iOS GPU in PyTorch
==================================

**Author**: `Tao Xu <https://github.com/xta0>`_

Introduction
------------

This tutorial introduces the steps to run your models on iOS GPU. We'll be using the mobilenetv2 model as an example. Since the mobile GPU features are currently in the prototype stage, you'll need to build a custom pytorch binary from source. For the time being, only a limited number of operators are supported, and certain client side APIs are subject to change in the future versions.

Model Preparation
-------------------

Since GPUs consume weights in a different order, the first step we need to do is to convert our TorchScript model to a GPU compatible model. This step is also known as "prepacking". To do that, we'll build a custom pytorch binary from source that includes the Metal backend. Go ahead checkout the pytorch source code from github and run the command below

.. code:: shell

cd PYTORCH_ROOT
USE_PYTORCH_METAL=ON python setup.py install --cmake

The command above will build a custom pytorch binary from master. The ``install`` argument simply tells ``setup.py`` to override the existing PyTorch on your desktop. Once the build finished, open another terminal to check the PyTorch version to see if the installation was successful. As the time of writing of this recipe, the version is ``1.8.0a0+41237a4``. You might be seeing different numbers depending on when you check out the code from master, but it should be greater than 1.7.0.

.. code:: python

import torch
torch.__version__ #1.8.0a0+41237a4


The next step is going to be converting the mobilenetv2 torchscript model to a Metal compatible model. We'll be leveraging the ``optimize_for_mobile`` API from the ``torch.utils`` module. As shown below

.. code:: python

import torch
import torchvision
from torch.utils.mobile_optimizer import optimize_for_mobile

model = torchvision.models.mobilenet_v2(pretrained=True)
scripted_model = torch.jit.script(model)
optimized_model = optimize_for_mobile(scripted_model, backend='metal')
print(torch.jit.export_opnames(optimized_model))
torch.jit.save(optimized_model, './mobilenetv2_metal.pt')

Note that the ``torch.jit.export_opnames(optimized_model)`` is going to dump all the optimized operators from the ``optimized_mobile``. If everything works well, you should be able to see the following ops being printed out from the console


.. code:: shell

['aten::adaptive_avg_pool2d',
'aten::add.Tensor',
'aten::addmm',
'aten::reshape',
'aten::size.int',
'metal::copy_to_host',
'metal_prepack::conv2d_run']

Those are all the ops we need to run the mobilenetv2 model on iOS GPU. Cool! Now that you have the ``mobilenetv2_metal.pt`` saved on your disk, let's move on to the iOS part.


Use C++ APIs
---------------------

In this section, we'll be using the `HelloWorld example <https://github.com/pytorch/ios-demo-app>`_ to demonstrate how to use the C++ APIs. The first thing we need to do is to build a custom LibTorch from Source. Make sure you have deleted the **build** folder from the previous step in PyTorch root directory. Then run the command below

.. code:: shell

IOS_ARCH=arm64 USE_PYTORCH_METAL=1 ./scripts/build_ios.sh

Note ``IOS_ARCH`` tells the script to build a arm64 version of Libtorch. This is because in PyTorch, Metal is only available for the iOS devices that support the Apple A9 chip or above. Once the build finished, follow the `Build PyTorch iOS libraries from source <https://pytorch.org/mobile/ios/#build-pytorch-ios-libraries-from-source>`_ section from the iOS tutorial to setup the XCode settings properly. Don't forget to copy the `./mobilenetv2_metal.pt` to your XCode project.

Next we need to make some changes in ``TorchModule.mm``

.. code:: objective-c

- (NSArray<NSNumber*>*)predictImage:(void*)imageBuffer {
torch::jit::GraphOptimizerEnabledGuard opguard(false);
at::Tensor tensor = torch::from_blob(imageBuffer, {1, 3, 224, 224}, at::kFloat).metal();
auto outputTensor = _impl.forward({tensor}).toTensor().cpu();
...
return nil;
}

As you can see, we simply just call ``.metal()`` to move our input tensor from CPU to GPU, and then call ``.cpu()`` to move the result back. Internally, ``.metal()`` will copy the input data from the CPU buffer to a GPU buffer with a GPU compatible memory format. When `.cpu()` is invoked, the GPU command buffer will be flushed and synced. After `forward` finished, the final result will then be copied back from the GPU buffer back to a CPU buffer.

The last step we have to do is to add the `Accelerate.framework` and the `MetalShaderPerformance.framework` to your xcode project.

If everything works fine, you should be able to see the inference results on your phone. The result below was captured from an iPhone11 device

.. code:: shell

- timber wolf, grey wolf, gray wolf, Canis lupus
- malamute, malemute, Alaskan malamute
- Eskimo dog, husky

You may notice that the results are slighly different from the `results <https://pytorch.org/mobile/ios/#install-libtorch-via-cocoapods>`_ we got from the CPU model as shown in the iOS tutorial. This is because by default Metal uses fp16 rather than fp32 to compute. The precision loss is expected.


Conclusion
----------

In this tutorial, we demonstrated how to convert a mobilenetv2 model to a GPU compatible model. We walked through a HelloWorld example to show how to use the C++ APIs to run models on iOS GPU. Please be aware of that GPU feature is still under development, new operators will continue to be added. APIs are subject to change in the future versions.

Thanks for reading! As always, we welcome any feedback, so please create an issue `here <https://github.com/pytorch/pytorch/issues>`_ if you have any.

Learn More
----------

- The `Mobilenetv2 <https://pytorch.org/hub/pytorch_vision_mobilenet_v2/>`_ from Torchvision
- To learn more about how to use ``optimize_for_mobile``, please refer to the `Mobile Perf Recipe <https://pytorch.org/tutorials/recipes/mobile_perf.html>`_
157 changes: 157 additions & 0 deletions recipes_source/fuse.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,157 @@
Fuse Modules Recipe
=====================================

This recipe demonstrates how to fuse a list of PyTorch modules into a single module and how to do the performance test to compare the fused model with its non-fused version.

Introduction
------------

Before quantization is applied to a model to reduce its size and memory footprint (see `Quantization Recipe <quantization.html>`_ for details on quantization), the list of modules in the model may be fused first into a single module. Fusion is optional, but it may save on memory access, make the model run faster, and improve its accuracy.


Pre-requisites
--------------

PyTorch 1.6.0 or 1.7.0

Steps
--------------

Follow the steps below to fuse an example model, quantize it, script it, optimize it for mobile, save it and test it with the Android benchmark tool.

1. Define the Example Model
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

Use the same example model defined in the `PyTorch Mobile Performance Recipes <https://pytorch.org/tutorials/recipes/mobile_perf.html>`_:

::

import torch
from torch.utils.mobile_optimizer import optimize_for_mobile

class AnnotatedConvBnReLUModel(torch.nn.Module):
def __init__(self):
super(AnnotatedConvBnReLUModel, self).__init__()
self.conv = torch.nn.Conv2d(3, 5, 3, bias=False).to(dtype=torch.float)
self.bn = torch.nn.BatchNorm2d(5).to(dtype=torch.float)
self.relu = torch.nn.ReLU(inplace=True)
self.quant = torch.quantization.QuantStub()
self.dequant = torch.quantization.DeQuantStub()

def forward(self, x):
x.contiguous(memory_format=torch.channels_last)
x = self.quant(x)
x = self.conv(x)
x = self.bn(x)
x = self.relu(x)
x = self.dequant(x)
return x


2. Generate Two Models with and without `fuse_modules`
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

Add the following code below the model definition above and run the script:

::

model = AnnotatedConvBnReLUModel()

def prepare_save(model, fused):
model.qconfig = torch.quantization.get_default_qconfig('qnnpack')
torch.quantization.prepare(model, inplace=True)
torch.quantization.convert(model, inplace=True)
torchscript_model = torch.jit.script(model)
torchscript_model_optimized = optimize_for_mobile(torchscript_model)
torch.jit.save(torchscript_model_optimized, "model.pt" if not fused else "model_fused.pt")

prepare_save(model, False)
model_fused = torch.quantization.fuse_modules(model, [['bn', 'relu']], inplace=False)
prepare_save(model_fused, True)

print(model)
print(model_fused)




The graphs of the original model and its fused version will be printed as follows:

::

AnnotatedConvBnReLUModel(
(conv): Conv2d(3, 5, kernel_size=(3, 3), stride=(1, 1), bias=False)
(bn): BatchNorm2d(5, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(quant): QuantStub()
(dequant): DeQuantStub()
)

AnnotatedConvBnReLUModel(
(conv): Conv2d(3, 5, kernel_size=(3, 3), stride=(1, 1), bias=False)
(bn): BNReLU2d(
(0): BatchNorm2d(5, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(1): ReLU(inplace=True)
)
(relu): Identity()
(quant): QuantStub()
(dequant): DeQuantStub()
)

In the second fused model output, the first item `bn` in the list is replaced with the fused module, and the rest of the modules (`relu` in this example) is replaced with identity. In addition, the non-fused and fused versions of the model `model.pt` and `model_fused.pt` are generated.

3. Build the Android benchmark Tool
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

Get the PyTorch source and build the Android benchmark tool as follows:

::

git clone --recursive https://github.com/pytorch/pytorch
cd pytorch
git submodule update --init --recursive
BUILD_PYTORCH_MOBILE=1 ANDROID_ABI=arm64-v8a ./scripts/build_android.sh -DBUILD_BINARY=ON


This will generate the Android benchmark binary `speed_benchmark_torch` in the `build_android/bin` folder.

4. Test Compare the Fused and Non-Fused Models
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

Connect your Android device, then copy `speed_benchmark_torch` and the model files and run the benchmark tool on them:

::

adb push build_android/bin/speed_benchmark_torch /data/local/tmp
adb push model.pt /data/local/tmp
adb push model_fused.pt /data/local/tmp
adb shell "/data/local/tmp/speed_benchmark_torch --model=/data/local/tmp/model.pt" --input_dims="1,3,224,224" --input_type="float"
adb shell "/data/local/tmp/speed_benchmark_torch --model=/data/local/tmp/model_fused.pt" --input_dims="1,3,224,224" --input_type="float"


The results from the last two commands should be like:

::

Main run finished. Microseconds per iter: 6189.07. Iters per second: 161.575

and

::

Main run finished. Microseconds per iter: 6216.65. Iters per second: 160.858

For this example model, there is no much performance difference between the fused and non-fused models. But the similar steps can be used to fuse and prepare a real deep model and test to see the performance improvement. Keep in mind that currently `torch.quantization.fuse_modules` only fuses the following sequence of modules:

* conv, bn
* conv, bn, relu
* conv, relu
* linear, relu
* bn, relu

If any other sequence list is provided to the `fuse_modules` call, it will simply be ignored.

Learn More
---------------

See `here <https://pytorch.org/docs/stable/quantization.html#preparing-model-for-quantization>`_ for the official documentation of `torch.quantization.fuse_modules`.
85 changes: 85 additions & 0 deletions recipes_source/model_preparation_android.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
Model Preparation for Android Recipe
=====================================

This recipe demonstrates how to prepare a PyTorch MobileNet v2 image classification model for Android apps, and how to set up Android projects to use the mobile-ready model file.

Introduction
-----------------

After a PyTorch model is trained or a pre-trained model is made available, it is normally not ready to be used in mobile apps yet. It needs to be quantized (see the `Quantization Recipe <quantization.html>`_), converted to TorchScript so Android apps can load it, and optimized for mobile apps. Furthermore, Android apps need to be set up correctly to enable the use of PyTorch Mobile libraries, before they can load and use the model for inference.

Pre-requisites
-----------------

PyTorch 1.6.0 or 1.7.0

torchvision 0.6.0 or 0.7.0

Android Studio 3.5.1 or above with NDK installed

Steps
-----------------

1. Get Pretrained and Quantized MobileNet v2 Model
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

To get the MobileNet v2 quantized model, simply do:

::

import torchvision

model_quantized = torchvision.models.quantization.mobilenet_v2(pretrained=True, quantize=True)

2. Script and Optimize the Model for Mobile Apps
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

Use either the `script` or `trace` method to convert the quantized model to the TorchScript format:

::

import torch

dummy_input = torch.rand(1, 3, 224, 224)
torchscript_model = torch.jit.trace(model_quantized, dummy_input)

or

::

torchscript_model = torch.jit.script(model_quantized)


.. warning::
The `trace` method only scripts the code path executed during the trace, so it will not work properly for models that include decision branches. See the `Script and Optimize for Mobile Recipe <script_optimized.html>`_ for more details.

Then optimize the TorchScript formatted model for mobile and save it:

::

from torch.utils.mobile_optimizer import optimize_for_mobile
torchscript_model_optimized = optimize_for_mobile(torchscript_model)
torch.jit.save(torchscript_model_optimized, "mobilenetv2_quantized.pt")

With the total 7 or 8 (depending on if the `script` or `trace` method is called to get the TorchScript format of the model) lines of code in the two steps above, we have a model ready to be added to mobile apps.

3. Add the Model and PyTorch Library on Android
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

* In your current or a new Android Studio project, open the build.gradle file, and add the following two lines (the second one is required only if you plan to use a TorchVision model):

::

implementation 'org.pytorch:pytorch_android:1.6.0'
implementation 'org.pytorch:pytorch_android_torchvision:1.6.0'

* Drag and drop the model file `mobilenetv2_quantized.pt` to your project's assets folder.

That's it! Now you can build your Android app with the PyTorch library and the model ready to use. To actually write code to use the model, refer to the PyTorch Mobile `Android Quickstart with a HelloWorld Example <https://pytorch.org/mobile/android/#quickstart-with-a-helloworld-example>`_ and `Android Hackathon Example <https://github.com/pytorch/workshops/tree/master/PTMobileWalkthruAndroid>`_.

Learn More
-----------------

1. `PyTorch Mobile site <https://pytorch.org/mobile>`_

2. `Introduction to TorchScript <https://pytorch.org/tutorials/beginner/Intro_to_TorchScript_tutorial.html>`_
Loading