Skip to content

Commit 95e40d9

Browse files
authored
Merge branch 'master' into transformer_ts
2 parents e471b0d + 84bf3e3 commit 95e40d9

11 files changed

+1237
-1
lines changed

prototype_source/README.txt

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,3 +19,7 @@ Prototype Tutorials
1919
5. torchscript_freezing.py
2020
Model Freezing in TorchScript
2121
https://github.com/pytorch/tutorials/blob/master/prototype_source/torchscript_freezing.py
22+
23+
6. vulkan_workflow.rst
24+
Vulkan Backend User Workflow
25+
https://pytorch.org/tutorials/intermediate/vulkan_workflow.html

prototype_source/ios_gpu_workflow.rst

Lines changed: 108 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,108 @@
1+
(Prototype) Use iOS GPU in PyTorch
2+
==================================
3+
4+
**Author**: `Tao Xu <https://github.com/xta0>`_
5+
6+
Introduction
7+
------------
8+
9+
This tutorial introduces the steps to run your models on iOS GPU. We'll be using the mobilenetv2 model as an example. Since the mobile GPU features are currently in the prototype stage, you'll need to build a custom pytorch binary from source. For the time being, only a limited number of operators are supported, and certain client side APIs are subject to change in the future versions.
10+
11+
Model Preparation
12+
-------------------
13+
14+
Since GPUs consume weights in a different order, the first step we need to do is to convert our TorchScript model to a GPU compatible model. This step is also known as "prepacking". To do that, we'll build a custom pytorch binary from source that includes the Metal backend. Go ahead checkout the pytorch source code from github and run the command below
15+
16+
.. code:: shell
17+
18+
cd PYTORCH_ROOT
19+
USE_PYTORCH_METAL=ON python setup.py install --cmake
20+
21+
The command above will build a custom pytorch binary from master. The ``install`` argument simply tells ``setup.py`` to override the existing PyTorch on your desktop. Once the build finished, open another terminal to check the PyTorch version to see if the installation was successful. As the time of writing of this recipe, the version is ``1.8.0a0+41237a4``. You might be seeing different numbers depending on when you check out the code from master, but it should be greater than 1.7.0.
22+
23+
.. code:: python
24+
25+
import torch
26+
torch.__version__ #1.8.0a0+41237a4
27+
28+
29+
The next step is going to be converting the mobilenetv2 torchscript model to a Metal compatible model. We'll be leveraging the ``optimize_for_mobile`` API from the ``torch.utils`` module. As shown below
30+
31+
.. code:: python
32+
33+
import torch
34+
import torchvision
35+
from torch.utils.mobile_optimizer import optimize_for_mobile
36+
37+
model = torchvision.models.mobilenet_v2(pretrained=True)
38+
scripted_model = torch.jit.script(model)
39+
optimized_model = optimize_for_mobile(scripted_model, backend='metal')
40+
print(torch.jit.export_opnames(optimized_model))
41+
torch.jit.save(optimized_model, './mobilenetv2_metal.pt')
42+
43+
Note that the ``torch.jit.export_opnames(optimized_model)`` is going to dump all the optimized operators from the ``optimized_mobile``. If everything works well, you should be able to see the following ops being printed out from the console
44+
45+
46+
.. code:: shell
47+
48+
['aten::adaptive_avg_pool2d',
49+
'aten::add.Tensor',
50+
'aten::addmm',
51+
'aten::reshape',
52+
'aten::size.int',
53+
'metal::copy_to_host',
54+
'metal_prepack::conv2d_run']
55+
56+
Those are all the ops we need to run the mobilenetv2 model on iOS GPU. Cool! Now that you have the ``mobilenetv2_metal.pt`` saved on your disk, let's move on to the iOS part.
57+
58+
59+
Use C++ APIs
60+
---------------------
61+
62+
In this section, we'll be using the `HelloWorld example <https://github.com/pytorch/ios-demo-app>`_ to demonstrate how to use the C++ APIs. The first thing we need to do is to build a custom LibTorch from Source. Make sure you have deleted the **build** folder from the previous step in PyTorch root directory. Then run the command below
63+
64+
.. code:: shell
65+
66+
IOS_ARCH=arm64 USE_PYTORCH_METAL=1 ./scripts/build_ios.sh
67+
68+
Note ``IOS_ARCH`` tells the script to build a arm64 version of Libtorch. This is because in PyTorch, Metal is only available for the iOS devices that support the Apple A9 chip or above. Once the build finished, follow the `Build PyTorch iOS libraries from source <https://pytorch.org/mobile/ios/#build-pytorch-ios-libraries-from-source>`_ section from the iOS tutorial to setup the XCode settings properly. Don't forget to copy the `./mobilenetv2_metal.pt` to your XCode project.
69+
70+
Next we need to make some changes in ``TorchModule.mm``
71+
72+
.. code:: objective-c
73+
74+
- (NSArray<NSNumber*>*)predictImage:(void*)imageBuffer {
75+
torch::jit::GraphOptimizerEnabledGuard opguard(false);
76+
at::Tensor tensor = torch::from_blob(imageBuffer, {1, 3, 224, 224}, at::kFloat).metal();
77+
auto outputTensor = _impl.forward({tensor}).toTensor().cpu();
78+
...
79+
return nil;
80+
}
81+
82+
As you can see, we simply just call ``.metal()`` to move our input tensor from CPU to GPU, and then call ``.cpu()`` to move the result back. Internally, ``.metal()`` will copy the input data from the CPU buffer to a GPU buffer with a GPU compatible memory format. When `.cpu()` is invoked, the GPU command buffer will be flushed and synced. After `forward` finished, the final result will then be copied back from the GPU buffer back to a CPU buffer.
83+
84+
The last step we have to do is to add the `Accelerate.framework` and the `MetalShaderPerformance.framework` to your xcode project.
85+
86+
If everything works fine, you should be able to see the inference results on your phone. The result below was captured from an iPhone11 device
87+
88+
.. code:: shell
89+
90+
- timber wolf, grey wolf, gray wolf, Canis lupus
91+
- malamute, malemute, Alaskan malamute
92+
- Eskimo dog, husky
93+
94+
You may notice that the results are slighly different from the `results <https://pytorch.org/mobile/ios/#install-libtorch-via-cocoapods>`_ we got from the CPU model as shown in the iOS tutorial. This is because by default Metal uses fp16 rather than fp32 to compute. The precision loss is expected.
95+
96+
97+
Conclusion
98+
----------
99+
100+
In this tutorial, we demonstrated how to convert a mobilenetv2 model to a GPU compatible model. We walked through a HelloWorld example to show how to use the C++ APIs to run models on iOS GPU. Please be aware of that GPU feature is still under development, new operators will continue to be added. APIs are subject to change in the future versions.
101+
102+
Thanks for reading! As always, we welcome any feedback, so please create an issue `here <https://github.com/pytorch/pytorch/issues>`_ if you have any.
103+
104+
Learn More
105+
----------
106+
107+
- The `Mobilenetv2 <https://pytorch.org/hub/pytorch_vision_mobilenet_v2/>`_ from Torchvision
108+
- To learn more about how to use ``optimize_for_mobile``, please refer to the `Mobile Perf Recipe <https://pytorch.org/tutorials/recipes/mobile_perf.html>`_

prototype_source/vmap_recipe.py

Lines changed: 121 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,121 @@
1+
"""
2+
torch.vmap
3+
==========
4+
This tutorial introduces torch.vmap, an autovectorizer for PyTorch operations.
5+
torch.vmap is a prototype feature and cannot handle a number of use cases;
6+
however, we would like to gather use cases for it to inform the design. If you
7+
are considering using torch.vmap or think it would be really cool for something,
8+
please contact us at https://github.com/pytorch/pytorch/issues/42368.
9+
10+
So, what is vmap?
11+
-----------------
12+
vmap is a higher-order function. It accepts a function `func` and returns a new
13+
function that maps `func` over some dimension of the inputs. It is highly
14+
inspired by JAX's vmap.
15+
16+
Semantically, vmap pushes the "map" into PyTorch operations called by `func`,
17+
effectively vectorizing those operations.
18+
"""
19+
import torch
20+
# NB: vmap is only available on nightly builds of PyTorch.
21+
# You can download one at pytorch.org if you're interested in testing it out.
22+
from torch import vmap
23+
24+
####################################################################
25+
# The first use case for vmap is making it easier to handle
26+
# batch dimensions in your code. One can write a function `func`
27+
# that runs on examples and then lift it to a function that can
28+
# take batches of examples with `vmap(func)`. `func` however
29+
# is subject to many restrictions:
30+
# - it must be functional (one cannot mutate a Python data structure
31+
# inside of it), with teh exception of in-place PyTorch operations.
32+
# - batches of examples must be provided as Tensors. This means that
33+
# vmap doesn't handle variable-length sequences out of the box.
34+
#
35+
# One example of using `vmap` is to compute batched dot products. PyTorch
36+
# doesn't provide a batched `torch.dot` API; instead of unsuccessfully
37+
# rummaging through docs, use `vmap` to construct a new function:
38+
39+
torch.dot # [D], [D] -> []
40+
batched_dot = torch.vmap(torch.dot) # [N, D], [N, D] -> [N]
41+
x, y = torch.randn(2, 5), torch.randn(2, 5)
42+
batched_dot(x, y)
43+
44+
####################################################################
45+
# `vmap` can be helpful in hiding batch dimensions, leading to a simpler
46+
# model authoring experience.
47+
batch_size, feature_size = 3, 5
48+
weights = torch.randn(feature_size, requires_grad=True)
49+
50+
# Note that model doesn't work with a batch of feature vectors because
51+
# torch.dot must take 1D tensors. It's pretty easy to rewrite this
52+
# to use `torch.matmul` instead, but if we didn't want to do that or if
53+
# the code is more complicated (e.g., does some advanced indexing
54+
# shenanigins), we can simply call `vmap`. `vmap` batches over ALL
55+
# inputs, unless otherwise specified (with the in_dims argument,
56+
# please see the documentation for more details).
57+
def model(feature_vec):
58+
# Very simple linear model with activation
59+
return feature_vec.dot(weights).relu()
60+
61+
examples = torch.randn(batch_size, feature_size)
62+
result = torch.vmap(model)(examples)
63+
expected = torch.stack([model(example) for example in examples.unbind()])
64+
assert torch.allclose(result, expected)
65+
66+
####################################################################
67+
# `vmap` can also help vectorize computations that were previously difficult
68+
# or impossible to batch. This bring us to our second use case: batched
69+
# gradient computation.
70+
# - https://github.com/pytorch/pytorch/issues/8304
71+
# - https://github.com/pytorch/pytorch/issues/23475
72+
#
73+
# The PyTorch autograd engine computes vjps (vector-Jacobian products).
74+
# Using vmap, we can compute (batched vector) - jacobian products.
75+
#
76+
# One example of this is computing a full Jacobian matrix (this can also be
77+
# applied to computing a full Hessian matrix).
78+
# Computing a full Jacobian matrix for some function f: R^N -> R^N usually
79+
# requires N calls to `autograd.grad`, one per Jacobian row.
80+
81+
# Setup
82+
N = 5
83+
def f(x):
84+
return x ** 2
85+
86+
x = torch.randn(N, requires_grad=True)
87+
y = f(x)
88+
basis_vectors = torch.eye(N)
89+
90+
# Sequential approach
91+
jacobian_rows = [torch.autograd.grad(y, x, v, retain_graph=True)[0]
92+
for v in basis_vectors.unbind()]
93+
jacobian = torch.stack(jacobian_rows)
94+
95+
# Using `vmap`, we can vectorize the whole computation, computing the
96+
# Jacobian in a single call to `autograd.grad`.
97+
def get_vjp(v):
98+
return torch.autograd.grad(y, x, v)[0]
99+
100+
jacobian_vmap = vmap(get_vjp)(basis_vectors)
101+
assert torch.allclose(jacobian_vmap, jacobian)
102+
103+
####################################################################
104+
# The third main use case for vmap is computing per-sample-gradients.
105+
# This is something that the vmap prototype cannot handle performantly
106+
# right now. We're not sure what the API for computing per-sample-gradients
107+
# should be, but if you have ideas, please comment in
108+
# https://github.com/pytorch/pytorch/issues/7786.
109+
110+
def model(sample, weight):
111+
# do something...
112+
return torch.dot(sample, weight)
113+
114+
def grad_sample(sample):
115+
return torch.autograd.functional.vjp(lambda weight: model(sample), weight)[1]
116+
117+
# The following doesn't actually work in the vmap prototype. But it
118+
# could be an API for computing per-sample-gradients.
119+
120+
# batch_of_samples = torch.randn(64, 5)
121+
# vmap(grad_sample)(batch_of_samples)

0 commit comments

Comments
 (0)