|
| 1 | +""" |
| 2 | +torch.vmap |
| 3 | +========== |
| 4 | +This tutorial introduces torch.vmap, an autovectorizer for PyTorch operations. |
| 5 | +torch.vmap is a prototype feature and cannot handle a number of use cases; |
| 6 | +however, we would like to gather use cases for it to inform the design. If you |
| 7 | +are considering using torch.vmap or think it would be really cool for something, |
| 8 | +please contact us at https://github.com/pytorch/pytorch/issues/42368. |
| 9 | +
|
| 10 | +So, what is vmap? |
| 11 | +----------------- |
| 12 | +vmap is a higher-order function. It accepts a function `func` and returns a new |
| 13 | +function that maps `func` over some dimension of the inputs. It is highly |
| 14 | +inspired by JAX's vmap. |
| 15 | +
|
| 16 | +Semantically, vmap pushes the "map" into PyTorch operations called by `func`, |
| 17 | +effectively vectorizing those operations. |
| 18 | +""" |
| 19 | +import torch |
| 20 | +# NB: vmap is only available on nightly builds of PyTorch. |
| 21 | +# You can download one at pytorch.org if you're interested in testing it out. |
| 22 | +from torch import vmap |
| 23 | + |
| 24 | +#################################################################### |
| 25 | +# The first use case for vmap is making it easier to handle |
| 26 | +# batch dimensions in your code. One can write a function `func` |
| 27 | +# that runs on examples and then lift it to a function that can |
| 28 | +# take batches of examples with `vmap(func)`. `func` however |
| 29 | +# is subject to many restrictions: |
| 30 | +# - it must be functional (one cannot mutate a Python data structure |
| 31 | +# inside of it), with teh exception of in-place PyTorch operations. |
| 32 | +# - batches of examples must be provided as Tensors. This means that |
| 33 | +# vmap doesn't handle variable-length sequences out of the box. |
| 34 | +# |
| 35 | +# One example of using `vmap` is to compute batched dot products. PyTorch |
| 36 | +# doesn't provide a batched `torch.dot` API; instead of unsuccessfully |
| 37 | +# rummaging through docs, use `vmap` to construct a new function: |
| 38 | + |
| 39 | +torch.dot # [D], [D] -> [] |
| 40 | +batched_dot = torch.vmap(torch.dot) # [N, D], [N, D] -> [N] |
| 41 | +x, y = torch.randn(2, 5), torch.randn(2, 5) |
| 42 | +batched_dot(x, y) |
| 43 | + |
| 44 | +#################################################################### |
| 45 | +# `vmap` can be helpful in hiding batch dimensions, leading to a simpler |
| 46 | +# model authoring experience. |
| 47 | +batch_size, feature_size = 3, 5 |
| 48 | +weights = torch.randn(feature_size, requires_grad=True) |
| 49 | + |
| 50 | +# Note that model doesn't work with a batch of feature vectors because |
| 51 | +# torch.dot must take 1D tensors. It's pretty easy to rewrite this |
| 52 | +# to use `torch.matmul` instead, but if we didn't want to do that or if |
| 53 | +# the code is more complicated (e.g., does some advanced indexing |
| 54 | +# shenanigins), we can simply call `vmap`. `vmap` batches over ALL |
| 55 | +# inputs, unless otherwise specified (with the in_dims argument, |
| 56 | +# please see the documentation for more details). |
| 57 | +def model(feature_vec): |
| 58 | + # Very simple linear model with activation |
| 59 | + return feature_vec.dot(weights).relu() |
| 60 | + |
| 61 | +examples = torch.randn(batch_size, feature_size) |
| 62 | +result = torch.vmap(model)(examples) |
| 63 | +expected = torch.stack([model(example) for example in examples.unbind()]) |
| 64 | +assert torch.allclose(result, expected) |
| 65 | + |
| 66 | +#################################################################### |
| 67 | +# `vmap` can also help vectorize computations that were previously difficult |
| 68 | +# or impossible to batch. This bring us to our second use case: batched |
| 69 | +# gradient computation. |
| 70 | +# - https://github.com/pytorch/pytorch/issues/8304 |
| 71 | +# - https://github.com/pytorch/pytorch/issues/23475 |
| 72 | +# |
| 73 | +# The PyTorch autograd engine computes vjps (vector-Jacobian products). |
| 74 | +# Using vmap, we can compute (batched vector) - jacobian products. |
| 75 | +# |
| 76 | +# One example of this is computing a full Jacobian matrix (this can also be |
| 77 | +# applied to computing a full Hessian matrix). |
| 78 | +# Computing a full Jacobian matrix for some function f: R^N -> R^N usually |
| 79 | +# requires N calls to `autograd.grad`, one per Jacobian row. |
| 80 | + |
| 81 | +# Setup |
| 82 | +N = 5 |
| 83 | +def f(x): |
| 84 | + return x ** 2 |
| 85 | + |
| 86 | +x = torch.randn(N, requires_grad=True) |
| 87 | +y = f(x) |
| 88 | +basis_vectors = torch.eye(N) |
| 89 | + |
| 90 | +# Sequential approach |
| 91 | +jacobian_rows = [torch.autograd.grad(y, x, v, retain_graph=True)[0] |
| 92 | + for v in basis_vectors.unbind()] |
| 93 | +jacobian = torch.stack(jacobian_rows) |
| 94 | + |
| 95 | +# Using `vmap`, we can vectorize the whole computation, computing the |
| 96 | +# Jacobian in a single call to `autograd.grad`. |
| 97 | +def get_vjp(v): |
| 98 | + return torch.autograd.grad(y, x, v)[0] |
| 99 | + |
| 100 | +jacobian_vmap = vmap(get_vjp)(basis_vectors) |
| 101 | +assert torch.allclose(jacobian_vmap, jacobian) |
| 102 | + |
| 103 | +#################################################################### |
| 104 | +# The third main use case for vmap is computing per-sample-gradients. |
| 105 | +# This is something that the vmap prototype cannot handle performantly |
| 106 | +# right now. We're not sure what the API for computing per-sample-gradients |
| 107 | +# should be, but if you have ideas, please comment in |
| 108 | +# https://github.com/pytorch/pytorch/issues/7786. |
| 109 | + |
| 110 | +def model(sample, weight): |
| 111 | + # do something... |
| 112 | + return torch.dot(sample, weight) |
| 113 | + |
| 114 | +def grad_sample(sample): |
| 115 | + return torch.autograd.functional.vjp(lambda weight: model(sample), weight)[1] |
| 116 | + |
| 117 | +# The following doesn't actually work in the vmap prototype. But it |
| 118 | +# could be an API for computing per-sample-gradients. |
| 119 | + |
| 120 | +# batch_of_samples = torch.randn(64, 5) |
| 121 | +# vmap(grad_sample)(batch_of_samples) |
0 commit comments