diff --git a/_mobile/android.md b/_mobile/android.md index c87f352c70c4..4c53dac3d0f7 100644 --- a/_mobile/android.md +++ b/_mobile/android.md @@ -17,7 +17,7 @@ This application runs TorchScript serialized TorchVision pretrained resnet18 mod #### 1. Model Preparation -Let’s start with model preparation. If you are familiar with PyTorch, you probably should already know how to train and save your model. In case you don’t, we are going to use a pre-trained image classification model ([Resnet18](https://pytorch.org/hub/pytorch_vision_resnet/)), which is packaged in [TorchVision](https://pytorch.org/docs/stable/torchvision/index.html). +Let’s start with model preparation. If you are familiar with PyTorch, you probably should already know how to train and save your model. In case you don’t, we are going to use a pre-trained image classification model ([MobileNetV2](https://pytorch.org/hub/pytorch_vision_mobilenet_v2/)). To install it, run the command below: ``` pip install torchvision @@ -27,12 +27,14 @@ To serialize the model you can use python [script](https://github.com/pytorch/an ``` import torch import torchvision +from torch.utils.mobile_optimizer import optimize_for_mobile -model = torchvision.models.resnet18(pretrained=True) +model = torchvision.models.mobilenet_v2(pretrained=True) model.eval() example = torch.rand(1, 3, 224, 224) traced_script_module = torch.jit.trace(model, example) -traced_script_module.save("app/src/main/assets/model.pt") +traced_script_module_optimized = optimize_for_mobile(traced_script_module) +traced_script_module_optimized.save("app/src/main/assets/model.pt") ``` If everything works well, we should have our model - `model.pt` generated in the assets folder of android application. That will be packaged inside android application as `asset` and can be used on the device. @@ -62,8 +64,8 @@ repositories { } dependencies { - implementation 'org.pytorch:pytorch_android:1.4.0' - implementation 'org.pytorch:pytorch_android_torchvision:1.4.0' + implementation 'org.pytorch:pytorch_android:1.8.0' + implementation 'org.pytorch:pytorch_android_torchvision:1.8.0' } ``` Where `org.pytorch:pytorch_android` is the main dependency with PyTorch Android API, including libtorch native library for all 4 android abis (armeabi-v7a, arm64-v8a, x86, x86_64).