|
| 1 | +# -*- coding: utf-8 -*- |
| 2 | +""" |
| 3 | +Forward-mode Auto Differentiation |
| 4 | +================================= |
| 5 | +
|
| 6 | +This tutorial demonstrates how to compute directional derivatives |
| 7 | +(or, equivalently, Jacobian-vector products) with forward-mode AD. |
| 8 | +
|
| 9 | +""" |
| 10 | + |
| 11 | +###################################################################### |
| 12 | +# Basic Usage |
| 13 | +# -------------------------------------------------------------------- |
| 14 | +# Unlike reverse-mode AD, forward-mode AD computes gradients eagerly |
| 15 | +# alongside the forward pass. We can compute a directional derivative |
| 16 | +# with forward-mode AD by performing the forward pass as we usually do, |
| 17 | +# except, prior to calling the function, we first associate with our |
| 18 | +# input with another tensor representing the direction of the directional |
| 19 | +# derivative, or equivalently, the ``v`` in a Jacobian-vector product. |
| 20 | +# We also call this "direction" tensor a tangent tensor. |
| 21 | +# |
| 22 | +# As the forward pass is performed (if any input tensors have associated |
| 23 | +# tangents) extra computation is performed to propogate this "sensitivity" |
| 24 | +# of the function. |
| 25 | +# |
| 26 | +# [0] https://en.wikipedia.org/wiki/Dual_number |
| 27 | + |
| 28 | +import torch |
| 29 | +import torch.autograd.forward_ad as fwAD |
| 30 | + |
| 31 | +primal = torch.randn(10, 10) |
| 32 | +tangent = torch.randn(10, 10) |
| 33 | + |
| 34 | +def fn(x, y): |
| 35 | + return x ** 2 + y ** 2 |
| 36 | + |
| 37 | +# All forward AD computation must be performed in the context of |
| 38 | +# the a ``dual_level`` context. All dual tensors created in a |
| 39 | +# context will have their tangents destoryed upon exit. This is to ensure that |
| 40 | +# if the output or intermediate results of this computation are reused |
| 41 | +# in a future forward AD computation, their tangents (which are associated |
| 42 | +# with this computation) won't be confused with tangents from later computation. |
| 43 | +with fwAD.dual_level(): |
| 44 | + # To create a dual tensor we associate a tensor, which we call the |
| 45 | + # primal with another tensor of the same size, which call the tangent. |
| 46 | + # If the layout of the tangent is different from that of the primal, |
| 47 | + # The values of the tangent are copied into a new tensor with the same |
| 48 | + # metadata as the primal. Otherwise, the tangent itself is used as-is. |
| 49 | + # |
| 50 | + # It is important to take note that the dual tensor created by |
| 51 | + # ``make_dual``` is a view of the primal. |
| 52 | + dual_input = fwAD.make_dual(primal, tangent) |
| 53 | + assert dual_input._base is primal |
| 54 | + assert fwAD.unpack_dual(dual_input).tangent is tangent |
| 55 | + |
| 56 | + # Any tensor involved in the computation that do not have an associated tangent, |
| 57 | + # are automatically considered to have a zero-filled tangent. |
| 58 | + plain_tensor = torch.randn(10, 10) |
| 59 | + dual_output = fn(dual_input, plain_tensor) |
| 60 | + |
| 61 | + # Unpacking the dual returns a namedtuple, with primal and tangent as its |
| 62 | + # attributes |
| 63 | + jvp = fwAD.unpack_dual(dual_output).tangent |
| 64 | + |
| 65 | +assert fwAD.unpack_dual(dual_output).tangent is None |
| 66 | +output = fwAD.unpack_dual(dual_output) |
| 67 | + |
| 68 | +###################################################################### |
| 69 | +# Usage with Modules |
| 70 | +# -------------------------------------------------------------------- |
| 71 | +# To use ``nn.Module``s with forward AD, replace the parameters of your |
| 72 | +# model with dual tensors before performing the forward pass. |
| 73 | + |
| 74 | +import torch.nn as nn |
| 75 | + |
| 76 | +model = nn.Linear(10, 10) |
| 77 | +input = torch.randn(64, 10) |
| 78 | + |
| 79 | +with fwAD.dual_level(): |
| 80 | + for name, p in model.named_parameters(): |
| 81 | + # detach to avoid the extra overhead of creating the backward graph |
| 82 | + # print(p) |
| 83 | + # Oh no! This doesn't quite work yet because make_dua |
| 84 | + # I don't think subclassing works with forward AD because... |
| 85 | + # |
| 86 | + # dual_param = fwAD.make_dual(p.detach(), torch.randn_like(p)) |
| 87 | + dual_param = fwAD.make_dual(p, torch.rand_like(p)) |
| 88 | + setattr(model, name, dual_param) |
| 89 | + print(fwAD.unpack_dual(getattr(model, "weight"))) |
| 90 | + out = model(input) |
| 91 | + |
| 92 | + # print(fwAD.unpack_dual(next(model.parameters())).tangent) |
| 93 | + |
| 94 | + jvp = fwAD.unpack_dual(out).tangent |
| 95 | + |
| 96 | + print("2", jvp) |
| 97 | + |
| 98 | +###################################################################### |
| 99 | +# Using Modules stateless API |
| 100 | +# -------------------------------------------------------------------- |
| 101 | +# Another way to use ``nn.Module``s with forward AD is to utilize |
| 102 | +# the stateless API: |
| 103 | + |
| 104 | +from torch.nn.utils._stateless import functional_call |
| 105 | + |
| 106 | +params = {} |
| 107 | +with fwAD.dual_level(): |
| 108 | + for name, p in model.named_parameters(): |
| 109 | + params[name] = fwAD.make_dual(p, torch.randn_like(p)) |
| 110 | + out = functional_call(model, params, input) |
| 111 | + jvp = fwAD.unpack_dual(out).tangent |
| 112 | + |
| 113 | + |
| 114 | +###################################################################### |
| 115 | +# Custom autograd Function |
| 116 | +# -------------------------------------------------------------------- |
| 117 | +# Hello world |
| 118 | + |
| 119 | +class Fn(torch.autograd.Function): |
| 120 | + @staticmethod |
| 121 | + def forward(ctx, foo): |
| 122 | + return foo * 2 |
| 123 | + |
| 124 | + @staticmethod |
| 125 | + def jvp(ctx, gI): |
| 126 | + torch.randn_like(gI) |
| 127 | + return gI * 2 |
0 commit comments