|
| 1 | +# -*- coding: utf-8 -*- |
| 2 | +""" |
| 3 | +Forward-mode Automatic Differentiation |
| 4 | +====================================== |
| 5 | +
|
| 6 | +This tutorial demonstrates how to use forward-mode AD to compute |
| 7 | +directional derivatives (or equivalently, Jacobian-vector products). |
| 8 | +
|
| 9 | +Basic Usage |
| 10 | +-------------------------------------------------------------------- |
| 11 | +Unlike reverse-mode AD, forward-mode AD computes gradients eagerly |
| 12 | +alongside the forward pass. We can use forward-mode AD to compute a |
| 13 | +directional derivative by performing the forward pass as before, |
| 14 | +except we first associate our input with another tensor representing |
| 15 | +the direction of the directional derivative (or equivalently, the ``v`` |
| 16 | +in a Jacobian-vector product). When an input, which we call "primal", is |
| 17 | +associated with a "direction" tensor, which we call "tangent", the |
| 18 | +resultant new tensor object is called a "dual tensor" for its connection |
| 19 | +to dual numbers[0]. |
| 20 | +
|
| 21 | +As the forward pass is performed, if any input tensors are dual tensors, |
| 22 | +extra computation is performed to propogate this "sensitivity" of the |
| 23 | +function. |
| 24 | +
|
| 25 | +""" |
| 26 | + |
| 27 | +import torch |
| 28 | +import torch.autograd.forward_ad as fwAD |
| 29 | + |
| 30 | +primal = torch.randn(10, 10) |
| 31 | +tangent = torch.randn(10, 10) |
| 32 | + |
| 33 | +def fn(x, y): |
| 34 | + return x ** 2 + y ** 2 |
| 35 | + |
| 36 | +# All forward AD computation must be performed in the context of |
| 37 | +# a ``dual_level`` context. All dual tensors created in such a context |
| 38 | +# will have their tangents destroyed upon exit. This is to ensure that |
| 39 | +# if the output or intermediate results of this computation are reused |
| 40 | +# in a future forward AD computation, their tangents (which are associated |
| 41 | +# with this computation) won't be confused with tangents from the later |
| 42 | +# computation. |
| 43 | +with fwAD.dual_level(): |
| 44 | + # To create a dual tensor we associate a tensor, which we call the |
| 45 | + # primal with another tensor of the same size, which we call the tangent. |
| 46 | + # If the layout of the tangent is different from that of the primal, |
| 47 | + # The values of the tangent are copied into a new tensor with the same |
| 48 | + # metadata as the primal. Otherwise, the tangent itself is used as-is. |
| 49 | + # |
| 50 | + # It is also important to note that the dual tensor created by |
| 51 | + # ``make_dual`` is a view of the primal. |
| 52 | + dual_input = fwAD.make_dual(primal, tangent) |
| 53 | + assert fwAD.unpack_dual(dual_input).tangent is tangent |
| 54 | + |
| 55 | + # To demonstrate the case where the copy of the tangent happens, |
| 56 | + # we pass in a tangent with a layout different from that of the primal |
| 57 | + dual_input_alt = fwAD.make_dual(primal, tangent.T) |
| 58 | + assert fwAD.unpack_dual(dual_input_alt).tangent is not tangent |
| 59 | + |
| 60 | + # Tensors that do not not have an associated tangent are automatically |
| 61 | + # considered to have a zero-filled tangent of the same shape. |
| 62 | + plain_tensor = torch.randn(10, 10) |
| 63 | + dual_output = fn(dual_input, plain_tensor) |
| 64 | + |
| 65 | + # Unpacking the dual returns a namedtuple with ``primal`` and ``tangent`` |
| 66 | + # as attributes |
| 67 | + jvp = fwAD.unpack_dual(dual_output).tangent |
| 68 | + |
| 69 | +assert fwAD.unpack_dual(dual_output).tangent is None |
| 70 | + |
| 71 | +###################################################################### |
| 72 | +# Usage with Modules |
| 73 | +# -------------------------------------------------------------------- |
| 74 | +# To use ``nn.Module`` with forward AD, replace the parameters of your |
| 75 | +# model with dual tensors before performing the forward pass. At the |
| 76 | +# time of writing, it is not possible to create dual tensor |
| 77 | +# `nn.Parameter`s. As a workaround, one must register the dual tensor |
| 78 | +# as a non-parameter attribute of the module. |
| 79 | + |
| 80 | +import torch.nn as nn |
| 81 | + |
| 82 | +model = nn.Linear(5, 5) |
| 83 | +input = torch.randn(16, 5) |
| 84 | + |
| 85 | +params = {name: p for name, p in model.named_parameters()} |
| 86 | +tangents = {name: torch.rand_like(p) for name, p in params.items()} |
| 87 | + |
| 88 | +with fwAD.dual_level(): |
| 89 | + for name, p in params.items(): |
| 90 | + delattr(model, name) |
| 91 | + setattr(model, name, fwAD.make_dual(p, tangents[name])) |
| 92 | + |
| 93 | + out = model(input) |
| 94 | + jvp = fwAD.unpack_dual(out).tangent |
| 95 | + |
| 96 | +###################################################################### |
| 97 | +# Using Modules stateless API (experimental) |
| 98 | +# -------------------------------------------------------------------- |
| 99 | +# Another way to use ``nn.Module`` with forward AD is to utilize |
| 100 | +# the stateless API. NB: At the time of writing the stateless API is still |
| 101 | +# experimental and may be subject to change. |
| 102 | + |
| 103 | +from torch.nn.utils._stateless import functional_call |
| 104 | + |
| 105 | +# We need a fresh module because the functional call requires the |
| 106 | +# the model to have parameters registered. |
| 107 | +model = nn.Linear(5, 5) |
| 108 | + |
| 109 | +dual_params = {} |
| 110 | +with fwAD.dual_level(): |
| 111 | + for name, p in params.items(): |
| 112 | + # Using the same ``tangents`` from the above section |
| 113 | + dual_params[name] = fwAD.make_dual(p, tangents[name]) |
| 114 | + out = functional_call(model, dual_params, input) |
| 115 | + jvp2 = fwAD.unpack_dual(out).tangent |
| 116 | + |
| 117 | +# Check our results |
| 118 | +assert torch.allclose(jvp, jvp2) |
| 119 | + |
| 120 | +###################################################################### |
| 121 | +# Custom autograd Function |
| 122 | +# -------------------------------------------------------------------- |
| 123 | +# Custom Functions also support forward-mode AD. To create custom Function |
| 124 | +# supporting forward-mode AD, register the ``jvp()`` static method. It is |
| 125 | +# possible, but not mandatory for custom Functions to support both forward |
| 126 | +# and backward AD. See the |
| 127 | +# `documentation <https://pytorch.org/docs/master/notes/extending.html#forward-mode-ad>`_ |
| 128 | +# for more information. |
| 129 | + |
| 130 | +class Fn(torch.autograd.Function): |
| 131 | + @staticmethod |
| 132 | + def forward(ctx, foo): |
| 133 | + result = torch.exp(foo) |
| 134 | + # Tensors stored in ctx can be used in the subsequent forward grad |
| 135 | + # computation. |
| 136 | + ctx.result = result |
| 137 | + return result |
| 138 | + |
| 139 | + @staticmethod |
| 140 | + def jvp(ctx, gI): |
| 141 | + gO = gI * ctx.result |
| 142 | + # If the tensor stored in ctx will not also be used in the backward pass, |
| 143 | + # one can manually free it using ``del`` |
| 144 | + del ctx.result |
| 145 | + return gO |
| 146 | + |
| 147 | +fn = Fn.apply |
| 148 | + |
| 149 | +primal = torch.randn(10, 10, dtype=torch.double, requires_grad=True) |
| 150 | +tangent = torch.randn(10, 10) |
| 151 | + |
| 152 | +with fwAD.dual_level(): |
| 153 | + dual_input = fwAD.make_dual(primal, tangent) |
| 154 | + dual_output = fn(dual_input) |
| 155 | + jvp = fwAD.unpack_dual(dual_output).tangent |
| 156 | + |
| 157 | +# It is important to use ``autograd.gradcheck`` to verify that your |
| 158 | +# custom autograd Function computes the gradients correctly. By default, |
| 159 | +# gradcheck only checks the backward-mode (reverse-mode) AD gradients. Specify |
| 160 | +# ``check_forward_ad=True`` to also check forward grads. If you did not |
| 161 | +# implement the backward formula for your function, you can also tell gradcheck |
| 162 | +# to skip the tests that require backward-mode AD by specifying |
| 163 | +# ``check_backward_ad=False``, ``check_undefined_grad=False``, and |
| 164 | +# ``check_batched_grad=False``. |
| 165 | +torch.autograd.gradcheck(Fn.apply, (primal,), check_forward_ad=True, |
| 166 | + check_backward_ad=False, check_undefined_grad=False, |
| 167 | + check_batched_grad=False) |
| 168 | + |
| 169 | +###################################################################### |
| 170 | +# [0] https://en.wikipedia.org/wiki/Dual_number |
0 commit comments