Skip to content

[android] Update HelloWorld android tutorial #710

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
May 20, 2021
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 7 additions & 5 deletions _mobile/android.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ This application runs TorchScript serialized TorchVision pretrained resnet18 mod

#### 1. Model Preparation

Let’s start with model preparation. If you are familiar with PyTorch, you probably should already know how to train and save your model. In case you don’t, we are going to use a pre-trained image classification model ([Resnet18](https://pytorch.org/hub/pytorch_vision_resnet/)), which is packaged in [TorchVision](https://pytorch.org/docs/stable/torchvision/index.html).
Let’s start with model preparation. If you are familiar with PyTorch, you probably should already know how to train and save your model. In case you don’t, we are going to use a pre-trained image classification model ([MobileNetV2](https://pytorch.org/hub/pytorch_vision_mobilenet_v2/)).
To install it, run the command below:
```
pip install torchvision
Expand All @@ -27,12 +27,14 @@ To serialize the model you can use python [script](https://github.com/pytorch/an
```
import torch
import torchvision
from torch.utils.mobile_optimizer import optimize_for_mobile

model = torchvision.models.resnet18(pretrained=True)
model = torchvision.models.mobilenet_v2(pretrained=True)
model.eval()
example = torch.rand(1, 3, 224, 224)
traced_script_module = torch.jit.trace(model, example)
traced_script_module.save("app/src/main/assets/model.pt")
traced_script_module_optimized = optimize_for_mobile(traced_script_module)
traced_script_module_optimized.save("app/src/main/assets/model.pt")
```
If everything works well, we should have our model - `model.pt` generated in the assets folder of android application.
That will be packaged inside android application as `asset` and can be used on the device.
Expand Down Expand Up @@ -62,8 +64,8 @@ repositories {
}

dependencies {
implementation 'org.pytorch:pytorch_android:1.4.0'
implementation 'org.pytorch:pytorch_android_torchvision:1.4.0'
implementation 'org.pytorch:pytorch_android:1.8.0'
implementation 'org.pytorch:pytorch_android_torchvision:1.8.0'
}
```
Where `org.pytorch:pytorch_android` is the main dependency with PyTorch Android API, including libtorch native library for all 4 android abis (armeabi-v7a, arm64-v8a, x86, x86_64).
Expand Down