From 2cb5217c75fae4abffd9f9be8a651ba4c984f68d Mon Sep 17 00:00:00 2001 From: Richard Zou Date: Fri, 13 Jan 2023 10:00:28 -0800 Subject: [PATCH] Add torch.func tutorials for PyTorch 2.0 As the final step of integrating functorch into PyTorch, we have move the functorch APIs from under functorch.* to torch.func.* and made some adjustments to them. This PR moves the relevant functorch tutorials from the functorch docs (https://pytorch.org/functorch/stable/) to pytorch/tutorials. We moved four tutorials: - Jacobians, Hessians, hvp, vhp, and more - Model ensembling - per-sample-gradients - Neural Tangent Kernels. We also rewrite the tutorials to use the torch.func.* APIs instead of the functorch APIs, and excised mentions of functorch where appropriate. Test Plan: - view preview (is that possible for tutorials?) --- index.rst | 32 ++ intermediate_source/ensembling.py | 171 +++++++++ intermediate_source/jacobians_hessians.py | 345 ++++++++++++++++++ intermediate_source/neural_tangent_kernels.py | 244 +++++++++++++ intermediate_source/per_sample_grads.py | 221 +++++++++++ 5 files changed, 1013 insertions(+) create mode 100644 intermediate_source/ensembling.py create mode 100644 intermediate_source/jacobians_hessians.py create mode 100644 intermediate_source/neural_tangent_kernels.py create mode 100644 intermediate_source/per_sample_grads.py diff --git a/index.rst b/index.rst index 0e2f1eaeaa6..0a1a772da5f 100644 --- a/index.rst +++ b/index.rst @@ -417,6 +417,34 @@ What's new in PyTorch tutorials? :link: intermediate/forward_ad_usage.html :tags: Frontend-APIs +.. customcarditem:: + :header: Jacobians, Hessians, hvp, vhp, and more + :card_description: Learn how to compute advanced autodiff quantities using torch.func + :image: _static/img/thumbnails/cropped/generic-pytorch-logo.png + :link: intermediate/jacobians_hessians.html + :tags: Frontend-APIs + +.. customcarditem:: + :header: Model Ensembling + :card_description: Learn how to ensemble models using torch.vmap + :image: _static/img/thumbnails/cropped/generic-pytorch-logo.png + :link: intermediate/ensembling.html + :tags: Frontend-APIs + +.. customcarditem:: + :header: Per-Sample-Gradients + :card_description: Learn how to compute per-sample-gradients using torch.func + :image: _static/img/thumbnails/cropped/generic-pytorch-logo.png + :link: intermediate/per_sample_grads.html + :tags: Frontend-APIs + +.. customcarditem:: + :header: Neural Tangent Kernels + :card_description: Learn how to compute neural tangent kernels using torch.func + :image: _static/img/thumbnails/cropped/generic-pytorch-logo.png + :link: intermediate/neural_tangent_kernels.html + :tags: Frontend-APIs + .. Model Optimization .. customcarditem:: @@ -870,6 +898,10 @@ Additional Resources intermediate/memory_format_tutorial intermediate/forward_ad_usage + intermediate/jacobians_hessians + intermediate/ensembling + intermediate/per_sample_grads + intermediate/neural_tangent_kernels.py advanced/cpp_frontend advanced/torch-script-parallelism advanced/cpp_autograd diff --git a/intermediate_source/ensembling.py b/intermediate_source/ensembling.py new file mode 100644 index 00000000000..f44706d481e --- /dev/null +++ b/intermediate_source/ensembling.py @@ -0,0 +1,171 @@ +# -*- coding: utf-8 -*- +""" +Model ensembling +================ + +This tutorial illustrates how to vectorize model ensembling using ``torch.vmap``. + +What is model ensembling? +------------------------- +Model ensembling combines the predictions from multiple models together. +Traditionally this is done by running each model on some inputs separately +and then combining the predictions. However, if you're running models with +the same architecture, then it may be possible to combine them together +using ``torch.vmap``. ``vmap`` is a function transform that maps functions across +dimensions of the input tensors. One of its use cases is eliminating +for-loops and speeding them up through vectorization. + +Let's demonstrate how to do this using an ensemble of simple MLPs. +""" + +import torch +import torch.nn as nn +import torch.nn.functional as F +torch.manual_seed(0) + +# Here's a simple MLP +class SimpleMLP(nn.Module): + def __init__(self): + super(SimpleMLP, self).__init__() + self.fc1 = nn.Linear(784, 128) + self.fc2 = nn.Linear(128, 128) + self.fc3 = nn.Linear(128, 10) + + def forward(self, x): + x = x.flatten(1) + x = self.fc1(x) + x = F.relu(x) + x = self.fc2(x) + x = F.relu(x) + x = self.fc3(x) + return x + +###################################################################### +# Let’s generate a batch of dummy data and pretend that we’re working with +# an MNIST dataset. Thus, the dummy images are 28 by 28, and we have a +# minibatch of size 64. Furthermore, lets say we want to combine the predictions +# from 10 different models. + +device = 'cuda' +num_models = 10 + +data = torch.randn(100, 64, 1, 28, 28, device=device) +targets = torch.randint(10, (6400,), device=device) + +models = [SimpleMLP().to(device) for _ in range(num_models)] + +###################################################################### +# We have a couple of options for generating predictions. Maybe we want to +# give each model a different randomized minibatch of data. Alternatively, +# maybe we want to run the same minibatch of data through each model (e.g. +# if we were testing the effect of different model initializations). + +###################################################################### +# Option 1: different minibatch for each model + +minibatches = data[:num_models] +predictions_diff_minibatch_loop = [model(minibatch) for model, minibatch in zip(models, minibatches)] + +###################################################################### +# Option 2: Same minibatch + +minibatch = data[0] +predictions2 = [model(minibatch) for model in models] + +###################################################################### +# Using vmap to vectorize the ensemble +# ------------------------------------ +# +# Let's use vmap to speed up the for-loop. We must first prepare the models +# for use with vmap. +# +# First, let’s combine the states of the model together by stacking each +# parameter. For example, ``model[i].fc1.weight`` has shape ``[784, 128]``; we are +# going to stack the .fc1.weight of each of the 10 models to produce a big +# weight of shape ``[10, 784, 128]``. +# +# PyTorch offers the ``torch.func.stack_module_state`` convenience function to do +# this. +from torch.func import stack_module_state + +params, buffers = stack_module_state(models) + +###################################################################### +# Next, we need to define a function to vmap over. The function should, +# given parameters and buffers and inputs, run the model using those +# parameters, buffers, and inputs. We'll use ``torch.func.functional_call`` +# to help out: + +from torch.func import functional_call +import copy + +# Construct a "stateless" version of one of the models. It is "stateless" in +# the sense that the parameters are meta Tensors and do not have storage. +base_model = copy.deepcopy(models[0]) +base_model = base_model.to('meta') + +def fmodel(params, buffers, x): + return functional_call(base_model, (params, buffers), (x,)) + +###################################################################### +# Option 1: get predictions using a different minibatch for each model. +# +# By default, vmap maps a function across the first dimension of all inputs to +# the passed-in function. After using ``stack_module_state``, each of +# the params and buffers have an additional dimension of size 'num_models' at +# the front, and minibatches has a dimension of size 'num_models'. + +print([p.size(0) for p in params.values()]) # show the leading 'num_models' dimension + +assert minibatches.shape == (num_models, 64, 1, 28, 28) # verify minibatch has leading dimension of size 'num_models' + +from torch import vmap + +predictions1_vmap = vmap(fmodel)(params, buffers, minibatches) + +# verify the vmap predictions match the +assert torch.allclose(predictions1_vmap, torch.stack(predictions_diff_minibatch_loop), atol=1e-3, rtol=1e-5) + +###################################################################### +# Option 2: get predictions using the same minibatch of data. +# +# vmap has an in_dims arg that specifies which dimensions to map over. +# By using ``None``, we tell vmap we want the same minibatch to apply for all of +# the 10 models. + +predictions2_vmap = vmap(fmodel, in_dims=(0, 0, None))(params, buffers, minibatch) + +assert torch.allclose(predictions2_vmap, torch.stack(predictions2), atol=1e-3, rtol=1e-5) + +###################################################################### +# A quick note: there are limitations around what types of functions can be +# transformed by vmap. The best functions to transform are ones that are pure +# functions: a function where the outputs are only determined by the inputs +# that have no side effects (e.g. mutation). vmap is unable to handle mutation +# of arbitrary Python data structures, but it is able to handle many in-place +# PyTorch operations. + +###################################################################### +# Performance +# ----------- +# Curious about performance numbers? Here's how the numbers look. + +from torch.utils.benchmark import Timer +without_vmap = Timer( + stmt="[model(minibatch) for model, minibatch in zip(models, minibatches)]", + globals=globals()) +with_vmap = Timer( + stmt="vmap(fmodel)(params, buffers, minibatches)", + globals=globals()) +print(f'Predictions without vmap {without_vmap.timeit(100)}') +print(f'Predictions with vmap {with_vmap.timeit(100)}') + +###################################################################### +# There's a large speedup using vmap! +# +# In general, vectorization with vmap should be faster than running a function +# in a for-loop and competitive with manual batching. There are some exceptions +# though, like if we haven’t implemented the vmap rule for a particular +# operation or if the underlying kernels weren’t optimized for older hardware +# (GPUs). If you see any of these cases, please let us know by opening an issue +# on GitHub. diff --git a/intermediate_source/jacobians_hessians.py b/intermediate_source/jacobians_hessians.py new file mode 100644 index 00000000000..071321d6bba --- /dev/null +++ b/intermediate_source/jacobians_hessians.py @@ -0,0 +1,345 @@ +# -*- coding: utf-8 -*- +""" +Jacobians, Hessians, hvp, vhp, and more: composing function transforms +====================================================================== + +Computing jacobians or hessians are useful in a number of non-traditional +deep learning models. It is difficult (or annoying) to compute these quantities +efficiently using PyTorch's regular autodiff APIs +(``Tensor.backward()``, ``torch.autograd.grad``). PyTorch's +`JAX-inspired `_ +`function transforms API `_ +provides ways of computing various higher-order autodiff quantities +efficiently. + +Computing the Jacobian +---------------------- +""" + +import torch +import torch.nn.functional as F +from functools import partial +_ = torch.manual_seed(0) + +###################################################################### +# Let's start with a function that we'd like to compute the jacobian of. +# This is a simple linear function with non-linear activation. + +def predict(weight, bias, x): + return F.linear(x, weight, bias).tanh() + +###################################################################### +# Let's add some dummy data: a weight, a bias, and a feature vector x. + +D = 16 +weight = torch.randn(D, D) +bias = torch.randn(D) +x = torch.randn(D) # feature vector + +###################################################################### +# Let's think of ``predict`` as a function that maps the input ``x`` from :math:`R^D \to R^D`. +# PyTorch Autograd computes vector-Jacobian products. In order to compute the full +# Jacobian of this :math:`R^D \to R^D` function, we would have to compute it row-by-row +# by using a different unit vector each time. + +def compute_jac(xp): + jacobian_rows = [torch.autograd.grad(predict(weight, bias, xp), xp, vec)[0] + for vec in unit_vectors] + return torch.stack(jacobian_rows) + +xp = x.clone().requires_grad_() +unit_vectors = torch.eye(D) + +jacobian = compute_jac(xp) + +print(jacobian.shape) +print(jacobian[0]) # show first row + +###################################################################### +# Instead of computing the jacobian row-by-row, we can use PyTorch's +# ``torch.vmap`` function transform to get rid of the for-loop and vectorize the +# computation. We can’t directly apply vmap to ``torch.autograd.grad``; +# instead, PyTorch provides a ``torch.func.vjp`` transform that composes with +# ``torch.vmap``: + +from torch.func import vmap, vjp + +_, vjp_fn = vjp(partial(predict, weight, bias), x) + +ft_jacobian, = vmap(vjp_fn)(unit_vectors) + +# let's confirm both methods compute the same result +assert torch.allclose(ft_jacobian, jacobian) + +###################################################################### +# In a later tutorial a composition of reverse-mode AD and vmap will give us +# per-sample-gradients. +# In this tutorial, composing reverse-mode AD and vmap gives us Jacobian +# computation! +# Various compositions of vmap and autodiff transforms can give us different +# interesting quantities. +# +# PyTorch provides ``torch.func.jacrev`` as a convenience function that performs +# the vmap-vjp composition to compute jacobians. ``jacrev`` accepts an argnums +# argument that says which argument we would like to compute Jacobians with +# respect to. + +from torch.func import jacrev + +ft_jacobian = jacrev(predict, argnums=2)(weight, bias, x) + +# confirm +assert torch.allclose(ft_jacobian, jacobian) + +###################################################################### +# Let's compare the performance of the two ways to compute the jacobian. +# The function transform version is much faster (and becomes even faster the +# more outputs there are). +# +# In general, we expect that vectorization via vmap can help eliminate overhead +# and give better utilization of your hardware. +# +# vmap does this magic by pushing the outer loop down into the function's +# primitive operations in order to obtain better performance. +# +# Let's make a quick function to evaluate performance and deal with +# microseconds and milliseconds measurements: + +def get_perf(first, first_descriptor, second, second_descriptor): + """takes torch.benchmark objects and compares delta of second vs first.""" + faster = second.times[0] + slower = first.times[0] + gain = (slower-faster)/slower + if gain < 0: gain *=-1 + final_gain = gain*100 + print(f" Performance delta: {final_gain:.4f} percent improvement with {second_descriptor} ") + +###################################################################### +# And then run the performance comparison: + +from torch.utils.benchmark import Timer + +without_vmap = Timer(stmt="compute_jac(xp)", globals=globals()) +with_vmap = Timer(stmt="jacrev(predict, argnums=2)(weight, bias, x)", globals=globals()) + +no_vmap_timer = without_vmap.timeit(500) +with_vmap_timer = with_vmap.timeit(500) + +print(no_vmap_timer) +print(with_vmap_timer) + +###################################################################### +# Let's do a relative performance comparison of the above with our get_perf function: + +get_perf(no_vmap_timer, "without vmap", with_vmap_timer, "vmap") + +###################################################################### +# Furthemore, it’s pretty easy to flip the problem around and say we want to +# compute Jacobians of the parameters to our model (weight, bias) instead of the input + +# note the change in input via argnums params of 0,1 to map to weight and bias +ft_jac_weight, ft_jac_bias = jacrev(predict, argnums=(0, 1))(weight, bias, x) + +###################################################################### +# reverse-mode Jacobian (jacrev) vs forward-mode Jacobian (jacfwd) +# -------------------------------------------------------------------- +# +# We offer two APIs to compute jacobians: ``jacrev`` and ``jacfwd``: +# +# - jacrev uses reverse-mode AD. As you saw above it is a composition of our +# vjp and vmap transforms. +# - jacfwd uses forward-mode AD. It is implemented as a composition of our +# jvp and vmap transforms. +# +# jacfwd and jacrev can be substituted for each other but they have different +# performance characteristics. +# +# As a general rule of thumb, if you’re computing the jacobian of an :math:`R^N \to R^M` +# function, and there are many more outputs than inputs (i.e. :math:`M > N`) then +# jacfwd is preferred, otherwise use jacrev. There are exceptions to this rule, +# but a non-rigorous argument for this follows: +# +# In reverse-mode AD, we are computing the jacobian row-by-row, while in +# forward-mode AD (which computes Jacobian-vector products), we are computing +# it column-by-column. The Jacobian matrix has M rows and N columns, so if it +# is taller or wider one way we may prefer the method that deals with fewer +# rows or columns. + +from torch.func import jacrev, jacfwd + +###################################################################### +# First, let's benchmark with more inputs than outputs: + +Din = 32 +Dout = 2048 +weight = torch.randn(Dout, Din) + +bias = torch.randn(Dout) +x = torch.randn(Din) + +# remember the general rule about taller vs wider... here we have a taller matrix: +print(weight.shape) + +using_fwd = Timer(stmt="jacfwd(predict, argnums=2)(weight, bias, x)", globals=globals()) +using_bwd = Timer(stmt="jacrev(predict, argnums=2)(weight, bias, x)", globals=globals()) + +jacfwd_timing = using_fwd.timeit(500) +jacrev_timing = using_bwd.timeit(500) + +print(f'jacfwd time: {jacfwd_timing}') +print(f'jacrev time: {jacrev_timing}') + +###################################################################### +# and then do a relative benchmark: + +get_perf(jacfwd_timing, "jacfwd", jacrev_timing, "jacrev", ); + +####################################################################### +# and now the reverse - more outputs (M) than inputs (N): + +Din = 2048 +Dout = 32 +weight = torch.randn(Dout, Din) +bias = torch.randn(Dout) +x = torch.randn(Din) + +using_fwd = Timer(stmt="jacfwd(predict, argnums=2)(weight, bias, x)", globals=globals()) +using_bwd = Timer(stmt="jacrev(predict, argnums=2)(weight, bias, x)", globals=globals()) + +jacfwd_timing = using_fwd.timeit(500) +jacrev_timing = using_bwd.timeit(500) + +print(f'jacfwd time: {jacfwd_timing}') +print(f'jacrev time: {jacrev_timing}') + +####################################################################### +# and a relative perf comparison: + +get_perf(jacrev_timing, "jacrev", jacfwd_timing, "jacfwd") + +####################################################################### +# Hessian computation with functorch.hessian +# ------------------------------------------ +# We offer a convenience API to compute hessians: ``torch.func.hessiani``. +# Hessians are the jacobian of the jacobian (or the partial derivative of +# the partial derivative, aka second order). +# +# This suggests that one can just compose functorch’s jacobian transforms to +# compute the Hessian. +# Indeed, under the hood, ``hessian(f)`` is simply ``jacfwd(jacrev(f))``. +# +# Note: to boost performance: depending on your model, you may also want to +# use ``jacfwd(jacfwd(f))`` or ``jacrev(jacrev(f))`` instead to compute hessians +# leveraging the rule of thumb above regarding wider vs taller matrices. + +from torch.func import hessian + +# lets reduce the size in order not to blow out colab. Hessians require +# significant memory: +Din = 512 +Dout = 32 +weight = torch.randn(Dout, Din) +bias = torch.randn(Dout) +x = torch.randn(Din) + +hess_api = hessian(predict, argnums=2)(weight, bias, x) +hess_fwdfwd = jacfwd(jacfwd(predict, argnums=2), argnums=2)(weight, bias, x) +hess_revrev = jacrev(jacrev(predict, argnums=2), argnums=2)(weight, bias, x) + +####################################################################### +# Let's verify we have the same result regardless of using hessian api or +# using jacfwd(jacfwd()) + +torch.allclose(hess_api, hess_fwdfwd) + +####################################################################### +# Batch Jacobian and Batch Hessian +# -------------------------------- +# In the above examples we’ve been operating with a single feature vector. +# In some cases you might want to take the Jacobian of a batch of outputs +# with respect to a batch of inputs. That is, given a batch of inputs of +# shape ``(B, N)`` and a function that goes from :math:`R^N \to R^M`, we would like +# a Jacobian of shape ``(B, M, N)``. +# +# The easiest way to do this is to use vmap: + +batch_size = 64 +Din = 31 +Dout = 33 + +weight = torch.randn(Dout, Din) +print(f"weight shape = {weight.shape}") + +bias = torch.randn(Dout) + +x = torch.randn(batch_size, Din) + +compute_batch_jacobian = vmap(jacrev(predict, argnums=2), in_dims=(None, None, 0)) +batch_jacobian0 = compute_batch_jacobian(weight, bias, x) + +####################################################################### +# If you have a function that goes from (B, N) -> (B, M) instead and are +# certain that each input produces an independent output, then it's also +# sometimes possible to do this without using vmap by summing the outputs +# and then computing the Jacobian of that function: + +def predict_with_output_summed(weight, bias, x): + return predict(weight, bias, x).sum(0) + +batch_jacobian1 = jacrev(predict_with_output_summed, argnums=2)(weight, bias, x).movedim(1, 0) +assert torch.allclose(batch_jacobian0, batch_jacobian1) + +####################################################################### +# If you instead have a function that goes from :math:`R^N \to R^M` but inputs that +# are batched, you compose vmap with jacrev to compute batched jacobians: +# +# Finally, batch hessians can be computed similarly. It's easiest to think +# about them by using vmap to batch over hessian computation, but in some +# cases the sum trick also works. + +compute_batch_hessian = vmap(hessian(predict, argnums=2), in_dims=(None, None, 0)) + +batch_hess = compute_batch_hessian(weight, bias, x) +batch_hess.shape + +####################################################################### +# Computing Hessian-vector products +# --------------------------------- +# The naive way to compute a Hessian-vector product (hvp) is to materialize +# the full Hessian and perform a dot-product with a vector. We can do better: +# it turns out we don't need to materialize the full Hessian to do this. We'll +# go through two (of many) different strategies to compute Hessian-vector products: +# - composing reverse-mode AD with reverse-mode AD +# - composing reverse-mode AD with forward-mode AD +# +# Composing reverse-mode AD with forward-mode AD (as opposed to reverse-mode +# with reverse-mode) is generally the more memory efficient way to compute a +# hvp because forward-mode AD doesn't need to construct an Autograd graph and +# save intermediates for backward: + +from torch.func import jvp, grad, vjp + +def hvp(f, primals, tangents): + return jvp(grad(f), primals, tangents)[1] + +####################################################################### +# Here's some sample usage. + +def f(x): + return x.sin().sum() + +x = torch.randn(2048) +tangent = torch.randn(2048) + +result = hvp(f, (x,), (tangent,)) + +####################################################################### +# If PyTorch forward-AD does not have coverage for your operations, then we can +# instead compose reverse-mode AD with reverse-mode AD: + +def hvp_revrev(f, primals, tangents): + _, vjp_fn = vjp(grad(f), *primals) + return vjp_fn(*tangents) + +result_hvp_revrev = hvp_revrev(f, (x,), (tangent,)) +assert torch.allclose(result, result_hvp_revrev[0]) diff --git a/intermediate_source/neural_tangent_kernels.py b/intermediate_source/neural_tangent_kernels.py new file mode 100644 index 00000000000..37d804b883b --- /dev/null +++ b/intermediate_source/neural_tangent_kernels.py @@ -0,0 +1,244 @@ +# -*- coding: utf-8 -*- +""" +Neural Tangent Kernels +====================== + +The neural tangent kernel (NTK) is a kernel that describes +`how a neural network evolves during training `_. +There has been a lot of research around it `in recent years `_. +This tutorial, inspired by the implementation of `NTKs in JAX `_ +(see `Fast Finite Width Neural Tangent Kernel `_ for details), +demonstrates how to easily compute this quantity using ``torch.func``, +composable function transforms for PyTorch. + +Setup +----- + +First, some setup. Let's define a simple CNN that we wish to compute the NTK of. +""" + +import torch +import torch.nn as nn +from torch.func import functional_call, vmap, vjp, jvp, jacrev +device = 'cuda' + +class CNN(nn.Module): + def __init__(self): + super(CNN, self).__init__() + self.conv1 = nn.Conv2d(3, 32, (3, 3)) + self.conv2 = nn.Conv2d(32, 32, (3, 3)) + self.conv3 = nn.Conv2d(32, 32, (3, 3)) + self.fc = nn.Linear(21632, 10) + + def forward(self, x): + x = self.conv1(x) + x = x.relu() + x = self.conv2(x) + x = x.relu() + x = self.conv3(x) + x = x.flatten(1) + x = self.fc(x) + return x + +###################################################################### +# And let's generate some random data + +x_train = torch.randn(20, 3, 32, 32, device=device) +x_test = torch.randn(5, 3, 32, 32, device=device) + +###################################################################### +# Create a function version of the model +# -------------------------------------- +# +# ``torch.func`` transforms operate on functions. In particular, to compute the NTK, +# we will need a function that accepts the parameters of the model and a single +# input (as opposed to a batch of inputs!) and returns a single output. +# +# We'll use ``torch.func.functional_call``, which allows us to call an nn.Module +# using different parameters/buffers, to help accomplish the first step. +# +# Keep in mind that the model was originally written to accept a batch of input +# data points. In our CNN example, there are no inter-batch operations. That +# is, each data point in the batch is independent of other data points. With +# this assumption in mind, we can easily generate a function that evaluates the +# model on a single data point: + + +net = CNN().to(device) + +# Detaching the parameters because we won't be calling Tensor.backward(). +params = {k: v.detach() for k, v in net.named_parameters()} + +def fnet_single(params, x): + return functional_call(net, params, (x.unsqueeze(0),)).squeeze(0) + +###################################################################### +# Compute the NTK: method 1 (Jacobian contraction) +# ------------------------------------------------ +# We're ready to compute the empirical NTK. The empirical NTK for two data +# points :math:`x_1` and :math:`x_2` is defined as the matrix product between the Jacobian +# of the model evaluated at :math:`x_1` and the Jacobian of the model evaluated at +# :math:`x_2`: +# +# .. math:: +# +# J_{net}(x_1) J_{net}^T(x_2) +# +# In the batched case where :math:`x_1` is a batch of data points and :math:`x_2` is a +# batch of data points, then we want the matrix product between the Jacobians +# of all combinations of data points from :math:`x_1` and :math:`x_2`. +# +# The first method consists of doing just that - computing the two Jacobians, +# and contracting them. Here's how to compute the NTK in the batched case: + +def empirical_ntk_jacobian_contraction(fnet_single, params, x1, x2): + # Compute J(x1) + jac1 = vmap(jacrev(fnet_single), (None, 0))(params, x1) + jac1 = jac1.values() + jac1 = [j.flatten(2) for j in jac1] + + # Compute J(x2) + jac2 = vmap(jacrev(fnet_single), (None, 0))(params, x2) + jac2 = jac2.values() + jac2 = [j.flatten(2) for j in jac2] + + # Compute J(x1) @ J(x2).T + result = torch.stack([torch.einsum('Naf,Mbf->NMab', j1, j2) for j1, j2 in zip(jac1, jac2)]) + result = result.sum(0) + return result + +result = empirical_ntk_jacobian_contraction(fnet_single, params, x_train, x_test) +print(result.shape) + +###################################################################### +# In some cases, you may only want the diagonal or the trace of this quantity, +# especially if you know beforehand that the network architecture results in an +# NTK where the non-diagonal elements can be approximated by zero. It's easy to +# adjust the above function to do that: + +def empirical_ntk_jacobian_contraction(fnet_single, params, x1, x2, compute='full'): + # Compute J(x1) + jac1 = vmap(jacrev(fnet_single), (None, 0))(params, x1) + jac1 = jac1.values() + jac1 = [j.flatten(2) for j in jac1] + + # Compute J(x2) + jac2 = vmap(jacrev(fnet_single), (None, 0))(params, x2) + jac2 = jac2.values() + jac2 = [j.flatten(2) for j in jac2] + + # Compute J(x1) @ J(x2).T + einsum_expr = None + if compute == 'full': + einsum_expr = 'Naf,Mbf->NMab' + elif compute == 'trace': + einsum_expr = 'Naf,Maf->NM' + elif compute == 'diagonal': + einsum_expr = 'Naf,Maf->NMa' + else: + assert False + + result = torch.stack([torch.einsum(einsum_expr, j1, j2) for j1, j2 in zip(jac1, jac2)]) + result = result.sum(0) + return result + +result = empirical_ntk_jacobian_contraction(fnet_single, params, x_train, x_test, 'trace') +print(result.shape) + +###################################################################### +# The asymptotic time complexity of this method is :math:`N O [FP]` (time to +# compute the Jacobians) + :math:`N^2 O^2 P` (time to contract the Jacobians), +# where :math:`N` is the batch size of :math:`x_1` and :math:`x_2`, :math:`O` +# is the model's output size, :math:`P` is the total number of parameters, and +# :math:`[FP]` is the cost of a single forward pass through the model. See +# section 3.2 in +# `Fast Finite Width Neural Tangent Kernel `_ +# for details. +# +# Compute the NTK: method 2 (NTK-vector products) +# ----------------------------------------------- +# +# The next method we will discuss is a way to compute the NTK using NTK-vector +# products. +# +# This method reformulates NTK as a stack of NTK-vector products applied to +# columns of an identity matrix :math:`I_O` of size :math:`O\times O` +# (where :math:`O` is the output size of the model): +# +# .. math:: +# +# J_{net}(x_1) J_{net}^T(x_2) = J_{net}(x_1) J_{net}^T(x_2) I_{O} = \left[J_{net}(x_1) \left[J_{net}^T(x_2) e_o\right]\right]_{o=1}^{O}, +# +# where :math:`e_o\in \mathbb{R}^O` are column vectors of the identity matrix +# :math:`I_O`. +# +# - Let :math:`\textrm{vjp}_o = J_{net}^T(x_2) e_o`. We can use +# a vector-Jacobian product to compute this. +# - Now, consider :math:`J_{net}(x_1) \textrm{vjp}_o`. This is a +# Jacobian-vector product! +# - Finally, we can run the above computation in parallel over all +# columns :math:`e_o` of :math:`I_O` using ``vmap``. +# +# This suggests that we can use a combination of reverse-mode AD (to compute +# the vector-Jacobian product) and forward-mode AD (to compute the +# Jacobian-vector product) to compute the NTK. +# +# Let's code that up: + +def empirical_ntk_ntk_vps(func, params, x1, x2, compute='full'): + def get_ntk(x1, x2): + def func_x1(params): + return func(params, x1) + + def func_x2(params): + return func(params, x2) + + output, vjp_fn = vjp(func_x1, params) + + def get_ntk_slice(vec): + # This computes vec @ J(x2).T + # `vec` is some unit vector (a single slice of the Identity matrix) + vjps = vjp_fn(vec) + # This computes J(X1) @ vjps + _, jvps = jvp(func_x2, (params,), vjps) + return jvps + + # Here's our identity matrix + basis = torch.eye(output.numel(), dtype=output.dtype, device=output.device).view(output.numel(), -1) + return vmap(get_ntk_slice)(basis) + + # get_ntk(x1, x2) computes the NTK for a single data point x1, x2 + # Since the x1, x2 inputs to empirical_ntk_ntk_vps are batched, + # we actually wish to compute the NTK between every pair of data points + # between {x1} and {x2}. That's what the vmaps here do. + result = vmap(vmap(get_ntk, (None, 0)), (0, None))(x1, x2) + + if compute == 'full': + return result + if compute == 'trace': + return torch.einsum('NMKK->NM', result) + if compute == 'diagonal': + return torch.einsum('NMKK->NMK', result) + +result_from_jacobian_contraction = empirical_ntk_jacobian_contraction(fnet_single, params, x_test, x_train) +result_from_ntk_vps = empirical_ntk_ntk_vps(fnet_single, params, x_test, x_train) +assert torch.allclose(result_from_jacobian_contraction, result_from_ntk_vps, atol=1e-5) + +###################################################################### +# Our code for ``empirical_ntk_ntk_vps`` looks like a direct translation from +# the math above! This showcases the power of function transforms: good luck +# trying to write an efficient version of the above by only using +# ``torch.autograd.grad``. +# +# The asymptotic time complexity of this method is :math:`N^2 O [FP]`, where +# :math:`N` is the batch size of :math:`x_1` and :math:`x_2`, :math:`O` is the +# model's output size, and :math:`[FP]` is the cost of a single forward pass +# through the model. Hence this method performs more forward passes through the +# network than method 1, Jacobian contraction (:math:`N^2 O` instead of +# :math:`N O`), but avoids the contraction cost altogether (no :math:`N^2 O^2 P` +# term, where :math:`P` is the total number of model's parameters). Therefore, +# this method is preferable when :math:`O P` is large relative to :math:`[FP]`, +# such as fully-connected (not convolutional) models with many outputs :math:`O`. +# Memory-wise, both methods should be comparable. See section 3.3 in +# `Fast Finite Width Neural Tangent Kernel `_ +# for details. diff --git a/intermediate_source/per_sample_grads.py b/intermediate_source/per_sample_grads.py new file mode 100644 index 00000000000..0dbdf8c94f1 --- /dev/null +++ b/intermediate_source/per_sample_grads.py @@ -0,0 +1,221 @@ +# -*- coding: utf-8 -*- +""" +Per-sample-gradients +==================== + +What is it? +----------- + +Per-sample-gradient computation is computing the gradient for each and every +sample in a batch of data. It is a useful quantity in differential privacy, +meta-learning, and optimization research. +""" + +import torch +import torch.nn as nn +import torch.nn.functional as F +torch.manual_seed(0) + +# Here's a simple CNN and loss function: + +class SimpleCNN(nn.Module): + def __init__(self): + super(SimpleCNN, self).__init__() + self.conv1 = nn.Conv2d(1, 32, 3, 1) + self.conv2 = nn.Conv2d(32, 64, 3, 1) + self.fc1 = nn.Linear(9216, 128) + self.fc2 = nn.Linear(128, 10) + + def forward(self, x): + x = self.conv1(x) + x = F.relu(x) + x = self.conv2(x) + x = F.relu(x) + x = F.max_pool2d(x, 2) + x = torch.flatten(x, 1) + x = self.fc1(x) + x = F.relu(x) + x = self.fc2(x) + output = F.log_softmax(x, dim=1) + output = x + return output + +def loss_fn(predictions, targets): + return F.nll_loss(predictions, targets) + + +###################################################################### +# Let’s generate a batch of dummy data and pretend that we’re working with an MNIST dataset. +# The dummy images are 28 by 28 and we use a minibatch of size 64. + +device = 'cuda' + +num_models = 10 +batch_size = 64 +data = torch.randn(batch_size, 1, 28, 28, device=device) + +targets = torch.randint(10, (64,), device=device) + +###################################################################### +# In regular model training, one would forward the minibatch through the model, +# and then call .backward() to compute gradients. This would generate an +# 'average' gradient of the entire mini-batch: + +model = SimpleCNN().to(device=device) +predictions = model(data) # move the entire mini-batch through the model + +loss = loss_fn(predictions, targets) +loss.backward() # back propogate the 'average' gradient of this mini-batch + +###################################################################### +# In contrast to the above approach, per-sample-gradient computation is +# equivalent to: +# +# - for each individual sample of the data, perform a forward and a backward +# pass to get an individual (per-sample) gradient. + +def compute_grad(sample, target): + sample = sample.unsqueeze(0) # prepend batch dimension for processing + target = target.unsqueeze(0) + + prediction = model(sample) + loss = loss_fn(prediction, target) + + return torch.autograd.grad(loss, list(model.parameters())) + + +def compute_sample_grads(data, targets): + """ manually process each sample with per sample gradient """ + sample_grads = [compute_grad(data[i], targets[i]) for i in range(batch_size)] + sample_grads = zip(*sample_grads) + sample_grads = [torch.stack(shards) for shards in sample_grads] + return sample_grads + +per_sample_grads = compute_sample_grads(data, targets) + +###################################################################### +# ``sample_grads[0]`` is the per-sample-grad for model.conv1.weight. +# ``model.conv1.weight.shape`` is ``[32, 1, 3, 3]``; notice how there is one +# gradient, per sample, in the batch for a total of 64. + +print(per_sample_grads[0].shape) + +###################################################################### +# Per-sample-grads, *the efficient way*, using function transforms +# ---------------------------------------------------------------- +# We can compute per-sample-gradients efficiently by using function transforms. +# +# The ``torch.func`` function transform API transforms over functions. +# Our strategy is to define a function that computes the loss and then apply +# transforms to construct a function that computes per-sample-gradients. +# +# We'll use the ``torch.func.functional_call`` function to treat an nn.Module +# like a function. +# +# First, let’s extract the state from ``model`` into two dictionaries, +# parameters and buffers. We'll be detaching them because we won't use +# regular PyTorch autograd (e.g. Tensor.backward(), torch.autograd.grad). + +from torch.func import functional_call, vmap, grad + +params = {k: v.detach() for k, v in model.named_parameters()} +buffers = {k: v.detach() for k, v in model.named_buffers()} + +###################################################################### +# Next, let's define a function to compute the loss of the model given a +# single input rather than a batch of inputs. It is important that this +# function accepts the parameters, the input, and the target, because we will +# be transforming over them. +# +# Note - because the model was originally written to handle batches, we’ll +# use ``torch.unsqueeze`` to add a batch dimension. + +def compute_loss(params, buffers, sample, target): + batch = sample.unsqueeze(0) + targets = target.unsqueeze(0) + + predictions = functional_call(model, (params, buffers), (batch,)) + loss = loss_fn(predictions, targets) + return loss + +###################################################################### +# Now, let’s use the ``grad`` transform to create a new function that computes +# the gradient with respect to the first argument of ``compute_loss`` +# (i.e. the params). + +ft_compute_grad = grad(compute_loss) + +###################################################################### +# The ``ft_compute_grad`` function computes the gradient for a single +# (sample, target) pair. We can use vmap to get it to compute the gradient +# over an entire batch of samples and targets. Note that +# ``in_dims=(None, None, 0, 0)`` because we wish to map ``ft_compute_grad`` over +# the 0th dimension of the data and targets, and use the same params and +# buffers for each. + +ft_compute_sample_grad = vmap(ft_compute_grad, in_dims=(None, None, 0, 0)) + +###################################################################### +# Finally, let's used our transformed function to compute per-sample-gradients: + +ft_per_sample_grads = ft_compute_sample_grad(params, buffers, data, targets) + +###################################################################### +# we can double check that the results using ``grad`` and ``vmap`` match the +# results of hand processing each one individually: + +for per_sample_grad, ft_per_sample_grad in zip(per_sample_grads, ft_per_sample_grads.values()): + assert torch.allclose(per_sample_grad, ft_per_sample_grad, atol=3e-3, rtol=1e-5) + +###################################################################### +# A quick note: there are limitations around what types of functions can be +# transformed by vmap. The best functions to transform are ones that are pure +# functions: a function where the outputs are only determined by the inputs, +# and that have no side effects (e.g. mutation). vmap is unable to handle +# mutation of arbitrary Python data structures, but it is able to handle many +# in-place PyTorch operations. +# +# Performance comparison +# ---------------------- +# +# Curious about how the performance of vmap compares? +# +# Currently the best results are obtained on newer GPU's such as the A100 +# (Ampere) where we've seen up to 25x speedups on this example, but here are +# some results on our build machines: + +def get_perf(first, first_descriptor, second, second_descriptor): + """takes torch.benchmark objects and compares delta of second vs first.""" + second_res = second.times[0] + first_res = first.times[0] + + gain = (first_res-second_res)/first_res + if gain < 0: gain *=-1 + final_gain = gain*100 + + print(f"Performance delta: {final_gain:.4f} percent improvement with {first_descriptor} ") + +from torch.utils.benchmark import Timer + +without_vmap = Timer(stmt="compute_sample_grads(data, targets)", globals=globals()) +with_vmap = Timer(stmt="ft_compute_sample_grad(params, buffers, data, targets)",globals=globals()) +no_vmap_timing = without_vmap.timeit(100) +with_vmap_timing = with_vmap.timeit(100) + +print(f'Per-sample-grads without vmap {no_vmap_timing}') +print(f'Per-sample-grads with vmap {with_vmap_timing}') + +get_perf(with_vmap_timing, "vmap", no_vmap_timing, "no vmap") + +###################################################################### +# There are other optimized solutions (like in https://github.com/pytorch/opacus) +# to computing per-sample-gradients in PyTorch that also perform better than +# the naive method. But it’s cool that composing ``vmap`` and ``grad`` give us a +# nice speedup. +# +# In general, vectorization with vmap should be faster than running a function +# in a for-loop and competitive with manual batching. There are some exceptions +# though, like if we haven’t implemented the vmap rule for a particular +# operation or if the underlying kernels weren’t optimized for older hardware +# (GPUs). If you see any of these cases, please let us know by opening an issue +# at on GitHub.