Skip to content

Commit 9d9be8f

Browse files
zou3519Svetlana Karslioglu
and
Svetlana Karslioglu
authored
Add torch.func tutorials for PyTorch 2.0 (#2171)
As the final step of integrating functorch into PyTorch, we have move the functorch APIs from under functorch.* to torch.func.* and made some adjustments to them. This PR moves the relevant functorch tutorials from the functorch docs (https://pytorch.org/functorch/stable/) to pytorch/tutorials. We moved four tutorials: - Jacobians, Hessians, hvp, vhp, and more - Model ensembling - per-sample-gradients - Neural Tangent Kernels. We also rewrite the tutorials to use the torch.func.* APIs instead of the functorch APIs, and excised mentions of functorch where appropriate. Test Plan: - view preview (is that possible for tutorials?) Co-authored-by: Svetlana Karslioglu <svekars@fb.com>
1 parent 1fa0cd0 commit 9d9be8f

File tree

5 files changed

+1013
-0
lines changed

5 files changed

+1013
-0
lines changed

index.rst

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -417,6 +417,34 @@ What's new in PyTorch tutorials?
417417
:link: intermediate/forward_ad_usage.html
418418
:tags: Frontend-APIs
419419

420+
.. customcarditem::
421+
:header: Jacobians, Hessians, hvp, vhp, and more
422+
:card_description: Learn how to compute advanced autodiff quantities using torch.func
423+
:image: _static/img/thumbnails/cropped/generic-pytorch-logo.png
424+
:link: intermediate/jacobians_hessians.html
425+
:tags: Frontend-APIs
426+
427+
.. customcarditem::
428+
:header: Model Ensembling
429+
:card_description: Learn how to ensemble models using torch.vmap
430+
:image: _static/img/thumbnails/cropped/generic-pytorch-logo.png
431+
:link: intermediate/ensembling.html
432+
:tags: Frontend-APIs
433+
434+
.. customcarditem::
435+
:header: Per-Sample-Gradients
436+
:card_description: Learn how to compute per-sample-gradients using torch.func
437+
:image: _static/img/thumbnails/cropped/generic-pytorch-logo.png
438+
:link: intermediate/per_sample_grads.html
439+
:tags: Frontend-APIs
440+
441+
.. customcarditem::
442+
:header: Neural Tangent Kernels
443+
:card_description: Learn how to compute neural tangent kernels using torch.func
444+
:image: _static/img/thumbnails/cropped/generic-pytorch-logo.png
445+
:link: intermediate/neural_tangent_kernels.html
446+
:tags: Frontend-APIs
447+
420448
.. Model Optimization
421449
422450
.. customcarditem::
@@ -877,6 +905,10 @@ Additional Resources
877905

878906
intermediate/memory_format_tutorial
879907
intermediate/forward_ad_usage
908+
intermediate/jacobians_hessians
909+
intermediate/ensembling
910+
intermediate/per_sample_grads
911+
intermediate/neural_tangent_kernels.py
880912
advanced/cpp_frontend
881913
advanced/torch-script-parallelism
882914
advanced/cpp_autograd

intermediate_source/ensembling.py

Lines changed: 171 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,171 @@
1+
# -*- coding: utf-8 -*-
2+
"""
3+
Model ensembling
4+
================
5+
6+
This tutorial illustrates how to vectorize model ensembling using ``torch.vmap``.
7+
8+
What is model ensembling?
9+
-------------------------
10+
Model ensembling combines the predictions from multiple models together.
11+
Traditionally this is done by running each model on some inputs separately
12+
and then combining the predictions. However, if you're running models with
13+
the same architecture, then it may be possible to combine them together
14+
using ``torch.vmap``. ``vmap`` is a function transform that maps functions across
15+
dimensions of the input tensors. One of its use cases is eliminating
16+
for-loops and speeding them up through vectorization.
17+
18+
Let's demonstrate how to do this using an ensemble of simple MLPs.
19+
"""
20+
21+
import torch
22+
import torch.nn as nn
23+
import torch.nn.functional as F
24+
torch.manual_seed(0)
25+
26+
# Here's a simple MLP
27+
class SimpleMLP(nn.Module):
28+
def __init__(self):
29+
super(SimpleMLP, self).__init__()
30+
self.fc1 = nn.Linear(784, 128)
31+
self.fc2 = nn.Linear(128, 128)
32+
self.fc3 = nn.Linear(128, 10)
33+
34+
def forward(self, x):
35+
x = x.flatten(1)
36+
x = self.fc1(x)
37+
x = F.relu(x)
38+
x = self.fc2(x)
39+
x = F.relu(x)
40+
x = self.fc3(x)
41+
return x
42+
43+
######################################################################
44+
# Let’s generate a batch of dummy data and pretend that we’re working with
45+
# an MNIST dataset. Thus, the dummy images are 28 by 28, and we have a
46+
# minibatch of size 64. Furthermore, lets say we want to combine the predictions
47+
# from 10 different models.
48+
49+
device = 'cuda'
50+
num_models = 10
51+
52+
data = torch.randn(100, 64, 1, 28, 28, device=device)
53+
targets = torch.randint(10, (6400,), device=device)
54+
55+
models = [SimpleMLP().to(device) for _ in range(num_models)]
56+
57+
######################################################################
58+
# We have a couple of options for generating predictions. Maybe we want to
59+
# give each model a different randomized minibatch of data. Alternatively,
60+
# maybe we want to run the same minibatch of data through each model (e.g.
61+
# if we were testing the effect of different model initializations).
62+
63+
######################################################################
64+
# Option 1: different minibatch for each model
65+
66+
minibatches = data[:num_models]
67+
predictions_diff_minibatch_loop = [model(minibatch) for model, minibatch in zip(models, minibatches)]
68+
69+
######################################################################
70+
# Option 2: Same minibatch
71+
72+
minibatch = data[0]
73+
predictions2 = [model(minibatch) for model in models]
74+
75+
######################################################################
76+
# Using vmap to vectorize the ensemble
77+
# ------------------------------------
78+
#
79+
# Let's use vmap to speed up the for-loop. We must first prepare the models
80+
# for use with vmap.
81+
#
82+
# First, let’s combine the states of the model together by stacking each
83+
# parameter. For example, ``model[i].fc1.weight`` has shape ``[784, 128]``; we are
84+
# going to stack the .fc1.weight of each of the 10 models to produce a big
85+
# weight of shape ``[10, 784, 128]``.
86+
#
87+
# PyTorch offers the ``torch.func.stack_module_state`` convenience function to do
88+
# this.
89+
from torch.func import stack_module_state
90+
91+
params, buffers = stack_module_state(models)
92+
93+
######################################################################
94+
# Next, we need to define a function to vmap over. The function should,
95+
# given parameters and buffers and inputs, run the model using those
96+
# parameters, buffers, and inputs. We'll use ``torch.func.functional_call``
97+
# to help out:
98+
99+
from torch.func import functional_call
100+
import copy
101+
102+
# Construct a "stateless" version of one of the models. It is "stateless" in
103+
# the sense that the parameters are meta Tensors and do not have storage.
104+
base_model = copy.deepcopy(models[0])
105+
base_model = base_model.to('meta')
106+
107+
def fmodel(params, buffers, x):
108+
return functional_call(base_model, (params, buffers), (x,))
109+
110+
######################################################################
111+
# Option 1: get predictions using a different minibatch for each model.
112+
#
113+
# By default, vmap maps a function across the first dimension of all inputs to
114+
# the passed-in function. After using ``stack_module_state``, each of
115+
# the params and buffers have an additional dimension of size 'num_models' at
116+
# the front, and minibatches has a dimension of size 'num_models'.
117+
118+
print([p.size(0) for p in params.values()]) # show the leading 'num_models' dimension
119+
120+
assert minibatches.shape == (num_models, 64, 1, 28, 28) # verify minibatch has leading dimension of size 'num_models'
121+
122+
from torch import vmap
123+
124+
predictions1_vmap = vmap(fmodel)(params, buffers, minibatches)
125+
126+
# verify the vmap predictions match the
127+
assert torch.allclose(predictions1_vmap, torch.stack(predictions_diff_minibatch_loop), atol=1e-3, rtol=1e-5)
128+
129+
######################################################################
130+
# Option 2: get predictions using the same minibatch of data.
131+
#
132+
# vmap has an in_dims arg that specifies which dimensions to map over.
133+
# By using ``None``, we tell vmap we want the same minibatch to apply for all of
134+
# the 10 models.
135+
136+
predictions2_vmap = vmap(fmodel, in_dims=(0, 0, None))(params, buffers, minibatch)
137+
138+
assert torch.allclose(predictions2_vmap, torch.stack(predictions2), atol=1e-3, rtol=1e-5)
139+
140+
######################################################################
141+
# A quick note: there are limitations around what types of functions can be
142+
# transformed by vmap. The best functions to transform are ones that are pure
143+
# functions: a function where the outputs are only determined by the inputs
144+
# that have no side effects (e.g. mutation). vmap is unable to handle mutation
145+
# of arbitrary Python data structures, but it is able to handle many in-place
146+
# PyTorch operations.
147+
148+
######################################################################
149+
# Performance
150+
# -----------
151+
# Curious about performance numbers? Here's how the numbers look.
152+
153+
from torch.utils.benchmark import Timer
154+
without_vmap = Timer(
155+
stmt="[model(minibatch) for model, minibatch in zip(models, minibatches)]",
156+
globals=globals())
157+
with_vmap = Timer(
158+
stmt="vmap(fmodel)(params, buffers, minibatches)",
159+
globals=globals())
160+
print(f'Predictions without vmap {without_vmap.timeit(100)}')
161+
print(f'Predictions with vmap {with_vmap.timeit(100)}')
162+
163+
######################################################################
164+
# There's a large speedup using vmap!
165+
#
166+
# In general, vectorization with vmap should be faster than running a function
167+
# in a for-loop and competitive with manual batching. There are some exceptions
168+
# though, like if we haven’t implemented the vmap rule for a particular
169+
# operation or if the underlying kernels weren’t optimized for older hardware
170+
# (GPUs). If you see any of these cases, please let us know by opening an issue
171+
# on GitHub.

0 commit comments

Comments
 (0)