Skip to content

Commit 41f58c1

Browse files
zou3519brianjo
andauthored
Add recipe for prototype vmap (#1209)
Test Plan: - Run `vmap_recipe.py` locally Co-authored-by: Brian Johnson <brianjo@fb.com>
1 parent a299e79 commit 41f58c1

File tree

1 file changed

+121
-0
lines changed

1 file changed

+121
-0
lines changed

prototype_source/vmap_recipe.py

Lines changed: 121 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,121 @@
1+
"""
2+
torch.vmap
3+
==========
4+
This tutorial introduces torch.vmap, an autovectorizer for PyTorch operations.
5+
torch.vmap is a prototype feature and cannot handle a number of use cases;
6+
however, we would like to gather use cases for it to inform the design. If you
7+
are considering using torch.vmap or think it would be really cool for something,
8+
please contact us at https://github.com/pytorch/pytorch/issues/42368.
9+
10+
So, what is vmap?
11+
-----------------
12+
vmap is a higher-order function. It accepts a function `func` and returns a new
13+
function that maps `func` over some dimension of the inputs. It is highly
14+
inspired by JAX's vmap.
15+
16+
Semantically, vmap pushes the "map" into PyTorch operations called by `func`,
17+
effectively vectorizing those operations.
18+
"""
19+
import torch
20+
# NB: vmap is only available on nightly builds of PyTorch.
21+
# You can download one at pytorch.org if you're interested in testing it out.
22+
from torch import vmap
23+
24+
####################################################################
25+
# The first use case for vmap is making it easier to handle
26+
# batch dimensions in your code. One can write a function `func`
27+
# that runs on examples and then lift it to a function that can
28+
# take batches of examples with `vmap(func)`. `func` however
29+
# is subject to many restrictions:
30+
# - it must be functional (one cannot mutate a Python data structure
31+
# inside of it), with teh exception of in-place PyTorch operations.
32+
# - batches of examples must be provided as Tensors. This means that
33+
# vmap doesn't handle variable-length sequences out of the box.
34+
#
35+
# One example of using `vmap` is to compute batched dot products. PyTorch
36+
# doesn't provide a batched `torch.dot` API; instead of unsuccessfully
37+
# rummaging through docs, use `vmap` to construct a new function:
38+
39+
torch.dot # [D], [D] -> []
40+
batched_dot = torch.vmap(torch.dot) # [N, D], [N, D] -> [N]
41+
x, y = torch.randn(2, 5), torch.randn(2, 5)
42+
batched_dot(x, y)
43+
44+
####################################################################
45+
# `vmap` can be helpful in hiding batch dimensions, leading to a simpler
46+
# model authoring experience.
47+
batch_size, feature_size = 3, 5
48+
weights = torch.randn(feature_size, requires_grad=True)
49+
50+
# Note that model doesn't work with a batch of feature vectors because
51+
# torch.dot must take 1D tensors. It's pretty easy to rewrite this
52+
# to use `torch.matmul` instead, but if we didn't want to do that or if
53+
# the code is more complicated (e.g., does some advanced indexing
54+
# shenanigins), we can simply call `vmap`. `vmap` batches over ALL
55+
# inputs, unless otherwise specified (with the in_dims argument,
56+
# please see the documentation for more details).
57+
def model(feature_vec):
58+
# Very simple linear model with activation
59+
return feature_vec.dot(weights).relu()
60+
61+
examples = torch.randn(batch_size, feature_size)
62+
result = torch.vmap(model)(examples)
63+
expected = torch.stack([model(example) for example in examples.unbind()])
64+
assert torch.allclose(result, expected)
65+
66+
####################################################################
67+
# `vmap` can also help vectorize computations that were previously difficult
68+
# or impossible to batch. This bring us to our second use case: batched
69+
# gradient computation.
70+
# - https://github.com/pytorch/pytorch/issues/8304
71+
# - https://github.com/pytorch/pytorch/issues/23475
72+
#
73+
# The PyTorch autograd engine computes vjps (vector-Jacobian products).
74+
# Using vmap, we can compute (batched vector) - jacobian products.
75+
#
76+
# One example of this is computing a full Jacobian matrix (this can also be
77+
# applied to computing a full Hessian matrix).
78+
# Computing a full Jacobian matrix for some function f: R^N -> R^N usually
79+
# requires N calls to `autograd.grad`, one per Jacobian row.
80+
81+
# Setup
82+
N = 5
83+
def f(x):
84+
return x ** 2
85+
86+
x = torch.randn(N, requires_grad=True)
87+
y = f(x)
88+
basis_vectors = torch.eye(N)
89+
90+
# Sequential approach
91+
jacobian_rows = [torch.autograd.grad(y, x, v, retain_graph=True)[0]
92+
for v in basis_vectors.unbind()]
93+
jacobian = torch.stack(jacobian_rows)
94+
95+
# Using `vmap`, we can vectorize the whole computation, computing the
96+
# Jacobian in a single call to `autograd.grad`.
97+
def get_vjp(v):
98+
return torch.autograd.grad(y, x, v)[0]
99+
100+
jacobian_vmap = vmap(get_vjp)(basis_vectors)
101+
assert torch.allclose(jacobian_vmap, jacobian)
102+
103+
####################################################################
104+
# The third main use case for vmap is computing per-sample-gradients.
105+
# This is something that the vmap prototype cannot handle performantly
106+
# right now. We're not sure what the API for computing per-sample-gradients
107+
# should be, but if you have ideas, please comment in
108+
# https://github.com/pytorch/pytorch/issues/7786.
109+
110+
def model(sample, weight):
111+
# do something...
112+
return torch.dot(sample, weight)
113+
114+
def grad_sample(sample):
115+
return torch.autograd.functional.vjp(lambda weight: model(sample), weight)[1]
116+
117+
# The following doesn't actually work in the vmap prototype. But it
118+
# could be an API for computing per-sample-gradients.
119+
120+
# batch_of_samples = torch.randn(64, 5)
121+
# vmap(grad_sample)(batch_of_samples)

0 commit comments

Comments
 (0)