Skip to content

Commit 796013c

Browse files
authored
Add forward AD tutorial (#1746)
* Add forward AD tutorial * Update * Formatting * Address comments * Fix quoting
1 parent c2115df commit 796013c

File tree

1 file changed

+170
-0
lines changed

1 file changed

+170
-0
lines changed

forward_ad_usage.py

Lines changed: 170 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,170 @@
1+
# -*- coding: utf-8 -*-
2+
"""
3+
Forward-mode Automatic Differentiation
4+
======================================
5+
6+
This tutorial demonstrates how to use forward-mode AD to compute
7+
directional derivatives (or equivalently, Jacobian-vector products).
8+
9+
Basic Usage
10+
--------------------------------------------------------------------
11+
Unlike reverse-mode AD, forward-mode AD computes gradients eagerly
12+
alongside the forward pass. We can use forward-mode AD to compute a
13+
directional derivative by performing the forward pass as before,
14+
except we first associate our input with another tensor representing
15+
the direction of the directional derivative (or equivalently, the ``v``
16+
in a Jacobian-vector product). When an input, which we call "primal", is
17+
associated with a "direction" tensor, which we call "tangent", the
18+
resultant new tensor object is called a "dual tensor" for its connection
19+
to dual numbers[0].
20+
21+
As the forward pass is performed, if any input tensors are dual tensors,
22+
extra computation is performed to propogate this "sensitivity" of the
23+
function.
24+
25+
"""
26+
27+
import torch
28+
import torch.autograd.forward_ad as fwAD
29+
30+
primal = torch.randn(10, 10)
31+
tangent = torch.randn(10, 10)
32+
33+
def fn(x, y):
34+
return x ** 2 + y ** 2
35+
36+
# All forward AD computation must be performed in the context of
37+
# a ``dual_level`` context. All dual tensors created in such a context
38+
# will have their tangents destroyed upon exit. This is to ensure that
39+
# if the output or intermediate results of this computation are reused
40+
# in a future forward AD computation, their tangents (which are associated
41+
# with this computation) won't be confused with tangents from the later
42+
# computation.
43+
with fwAD.dual_level():
44+
# To create a dual tensor we associate a tensor, which we call the
45+
# primal with another tensor of the same size, which we call the tangent.
46+
# If the layout of the tangent is different from that of the primal,
47+
# The values of the tangent are copied into a new tensor with the same
48+
# metadata as the primal. Otherwise, the tangent itself is used as-is.
49+
#
50+
# It is also important to note that the dual tensor created by
51+
# ``make_dual`` is a view of the primal.
52+
dual_input = fwAD.make_dual(primal, tangent)
53+
assert fwAD.unpack_dual(dual_input).tangent is tangent
54+
55+
# To demonstrate the case where the copy of the tangent happens,
56+
# we pass in a tangent with a layout different from that of the primal
57+
dual_input_alt = fwAD.make_dual(primal, tangent.T)
58+
assert fwAD.unpack_dual(dual_input_alt).tangent is not tangent
59+
60+
# Tensors that do not not have an associated tangent are automatically
61+
# considered to have a zero-filled tangent of the same shape.
62+
plain_tensor = torch.randn(10, 10)
63+
dual_output = fn(dual_input, plain_tensor)
64+
65+
# Unpacking the dual returns a namedtuple with ``primal`` and ``tangent``
66+
# as attributes
67+
jvp = fwAD.unpack_dual(dual_output).tangent
68+
69+
assert fwAD.unpack_dual(dual_output).tangent is None
70+
71+
######################################################################
72+
# Usage with Modules
73+
# --------------------------------------------------------------------
74+
# To use ``nn.Module`` with forward AD, replace the parameters of your
75+
# model with dual tensors before performing the forward pass. At the
76+
# time of writing, it is not possible to create dual tensor
77+
# `nn.Parameter`s. As a workaround, one must register the dual tensor
78+
# as a non-parameter attribute of the module.
79+
80+
import torch.nn as nn
81+
82+
model = nn.Linear(5, 5)
83+
input = torch.randn(16, 5)
84+
85+
params = {name: p for name, p in model.named_parameters()}
86+
tangents = {name: torch.rand_like(p) for name, p in params.items()}
87+
88+
with fwAD.dual_level():
89+
for name, p in params.items():
90+
delattr(model, name)
91+
setattr(model, name, fwAD.make_dual(p, tangents[name]))
92+
93+
out = model(input)
94+
jvp = fwAD.unpack_dual(out).tangent
95+
96+
######################################################################
97+
# Using Modules stateless API (experimental)
98+
# --------------------------------------------------------------------
99+
# Another way to use ``nn.Module`` with forward AD is to utilize
100+
# the stateless API. NB: At the time of writing the stateless API is still
101+
# experimental and may be subject to change.
102+
103+
from torch.nn.utils._stateless import functional_call
104+
105+
# We need a fresh module because the functional call requires the
106+
# the model to have parameters registered.
107+
model = nn.Linear(5, 5)
108+
109+
dual_params = {}
110+
with fwAD.dual_level():
111+
for name, p in params.items():
112+
# Using the same ``tangents`` from the above section
113+
dual_params[name] = fwAD.make_dual(p, tangents[name])
114+
out = functional_call(model, dual_params, input)
115+
jvp2 = fwAD.unpack_dual(out).tangent
116+
117+
# Check our results
118+
assert torch.allclose(jvp, jvp2)
119+
120+
######################################################################
121+
# Custom autograd Function
122+
# --------------------------------------------------------------------
123+
# Custom Functions also support forward-mode AD. To create custom Function
124+
# supporting forward-mode AD, register the ``jvp()`` static method. It is
125+
# possible, but not mandatory for custom Functions to support both forward
126+
# and backward AD. See the
127+
# `documentation <https://pytorch.org/docs/master/notes/extending.html#forward-mode-ad>`_
128+
# for more information.
129+
130+
class Fn(torch.autograd.Function):
131+
@staticmethod
132+
def forward(ctx, foo):
133+
result = torch.exp(foo)
134+
# Tensors stored in ctx can be used in the subsequent forward grad
135+
# computation.
136+
ctx.result = result
137+
return result
138+
139+
@staticmethod
140+
def jvp(ctx, gI):
141+
gO = gI * ctx.result
142+
# If the tensor stored in ctx will not also be used in the backward pass,
143+
# one can manually free it using ``del``
144+
del ctx.result
145+
return gO
146+
147+
fn = Fn.apply
148+
149+
primal = torch.randn(10, 10, dtype=torch.double, requires_grad=True)
150+
tangent = torch.randn(10, 10)
151+
152+
with fwAD.dual_level():
153+
dual_input = fwAD.make_dual(primal, tangent)
154+
dual_output = fn(dual_input)
155+
jvp = fwAD.unpack_dual(dual_output).tangent
156+
157+
# It is important to use ``autograd.gradcheck`` to verify that your
158+
# custom autograd Function computes the gradients correctly. By default,
159+
# gradcheck only checks the backward-mode (reverse-mode) AD gradients. Specify
160+
# ``check_forward_ad=True`` to also check forward grads. If you did not
161+
# implement the backward formula for your function, you can also tell gradcheck
162+
# to skip the tests that require backward-mode AD by specifying
163+
# ``check_backward_ad=False``, ``check_undefined_grad=False``, and
164+
# ``check_batched_grad=False``.
165+
torch.autograd.gradcheck(Fn.apply, (primal,), check_forward_ad=True,
166+
check_backward_ad=False, check_undefined_grad=False,
167+
check_batched_grad=False)
168+
169+
######################################################################
170+
# [0] https://en.wikipedia.org/wiki/Dual_number

0 commit comments

Comments
 (0)