Skip to content

Commit 674edc3

Browse files
committed
Update
1 parent 52e045e commit 674edc3

File tree

1 file changed

+90
-50
lines changed

1 file changed

+90
-50
lines changed

intermediate_source/forward_ad_tutorial.py

Lines changed: 90 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -3,25 +3,27 @@
33
Forward-mode Auto Differentiation
44
=================================
55
6-
This tutorial demonstrates how to compute directional derivatives
7-
(or, equivalently, Jacobian-vector products) with forward-mode AD.
6+
This tutorial demonstrates how to use forward-mode AD to compute
7+
directional derivatives (or equivalently, Jacobian-vector products).
88
99
"""
1010

1111
######################################################################
1212
# Basic Usage
1313
# --------------------------------------------------------------------
1414
# Unlike reverse-mode AD, forward-mode AD computes gradients eagerly
15-
# alongside the forward pass. We can compute a directional derivative
16-
# with forward-mode AD by performing the forward pass as we usually do,
17-
# except, prior to calling the function, we first associate with our
18-
# input with another tensor representing the direction of the directional
19-
# derivative, or equivalently, the ``v`` in a Jacobian-vector product.
20-
# We also call this "direction" tensor a tangent tensor.
15+
# alongside the forward pass. We can use forward-mode AD to compute a
16+
# directional derivative by performing the forward pass as before,
17+
# except we first associate with our input with another tensor representing
18+
# the direction of the directional derivative (or equivalently, the ``v``
19+
# in a Jacobian-vector product). When a input, which we call "primal", is
20+
# associated with a "direction" tensor, which we call "tangent", the
21+
# resultant new tensor object is called a "dual tensor" for its connection
22+
# to dual numbers[0].
2123
#
22-
# As the forward pass is performed (if any input tensors have associated
23-
# tangents) extra computation is performed to propogate this "sensitivity"
24-
# of the function.
24+
# As the forward pass is performed, if any input tensors are dual tensors,
25+
# extra computation is performed to propogate this "sensitivity" of the
26+
# function.
2527
#
2628
# [0] https://en.wikipedia.org/wiki/Dual_number
2729

@@ -35,31 +37,32 @@ def fn(x, y):
3537
return x ** 2 + y ** 2
3638

3739
# All forward AD computation must be performed in the context of
38-
# the a ``dual_level`` context. All dual tensors created in a
39-
# context will have their tangents destoryed upon exit. This is to ensure that
40+
# a ``dual_level`` context. All dual tensors created in such a context
41+
# will have their tangents destroyed upon exit. This is to ensure that
4042
# if the output or intermediate results of this computation are reused
4143
# in a future forward AD computation, their tangents (which are associated
42-
# with this computation) won't be confused with tangents from later computation.
44+
# with this computation) won't be confused with tangents from the later
45+
# computation.
4346
with fwAD.dual_level():
4447
# To create a dual tensor we associate a tensor, which we call the
45-
# primal with another tensor of the same size, which call the tangent.
48+
# primal with another tensor of the same size, which we call the tangent.
4649
# If the layout of the tangent is different from that of the primal,
4750
# The values of the tangent are copied into a new tensor with the same
4851
# metadata as the primal. Otherwise, the tangent itself is used as-is.
4952
#
50-
# It is important to take note that the dual tensor created by
51-
# ``make_dual``` is a view of the primal.
53+
# It is also important to note that the dual tensor created by
54+
# ``make_dual`` is a view of the primal.
5255
dual_input = fwAD.make_dual(primal, tangent)
5356
assert dual_input._base is primal
5457
assert fwAD.unpack_dual(dual_input).tangent is tangent
5558

56-
# Any tensor involved in the computation that do not have an associated tangent,
57-
# are automatically considered to have a zero-filled tangent.
59+
# Tensors that do not not have an associated tangent are automatically
60+
# considered to have a zero-filled tangent of the same shape.
5861
plain_tensor = torch.randn(10, 10)
5962
dual_output = fn(dual_input, plain_tensor)
6063

61-
# Unpacking the dual returns a namedtuple, with primal and tangent as its
62-
# attributes
64+
# Unpacking the dual returns a namedtuple with ``primal`` and ``tangent``
65+
# as attributes
6366
jvp = fwAD.unpack_dual(dual_output).tangent
6467

6568
assert fwAD.unpack_dual(dual_output).tangent is None
@@ -69,59 +72,96 @@ def fn(x, y):
6972
# Usage with Modules
7073
# --------------------------------------------------------------------
7174
# To use ``nn.Module``s with forward AD, replace the parameters of your
72-
# model with dual tensors before performing the forward pass.
75+
# model with dual tensors before performing the forward pass. At the
76+
# time of writing, it is not possible to create dual tensor
77+
# `nn.Parameter`s. As a workaround, one must register the dual tensor
78+
# as a non-parameter attribute of the module.
7379

7480
import torch.nn as nn
7581

76-
model = nn.Linear(10, 10)
77-
input = torch.randn(64, 10)
82+
model = nn.Linear(5, 5)
83+
input = torch.randn(16, 5)
7884

79-
with fwAD.dual_level():
80-
for name, p in model.named_parameters():
81-
# detach to avoid the extra overhead of creating the backward graph
82-
# print(p)
83-
# Oh no! This doesn't quite work yet because make_dua
84-
# I don't think subclassing works with forward AD because...
85-
#
86-
# dual_param = fwAD.make_dual(p.detach(), torch.randn_like(p))
87-
dual_param = fwAD.make_dual(p, torch.rand_like(p))
88-
setattr(model, name, dual_param)
89-
print(fwAD.unpack_dual(getattr(model, "weight")))
90-
out = model(input)
85+
params = {name: p for name, p in model.named_parameters()}
86+
tangents = {name: torch.rand_like(p) for name, p in params.items()}
9187

92-
# print(fwAD.unpack_dual(next(model.parameters())).tangent)
88+
with fwAD.dual_level():
89+
for name, p in params.items():
90+
delattr(model, name)
91+
setattr(model, name, fwAD.make_dual(p, tangents[name]))
9392

93+
out = model(input)
9494
jvp = fwAD.unpack_dual(out).tangent
9595

96-
print("2", jvp)
97-
9896
######################################################################
99-
# Using Modules stateless API
97+
# Using Modules stateless API (experimental)
10098
# --------------------------------------------------------------------
10199
# Another way to use ``nn.Module``s with forward AD is to utilize
102-
# the stateless API:
100+
# the stateless API. NB: At the time of writing the stateless API is still
101+
# experimental and may be subject to change.
103102

104103
from torch.nn.utils._stateless import functional_call
105104

106-
params = {}
105+
# We need a fresh module because the functional call requires the
106+
# the model to have parameters registered.
107+
model = nn.Linear(5, 5)
108+
109+
dual_params = {}
107110
with fwAD.dual_level():
108-
for name, p in model.named_parameters():
109-
params[name] = fwAD.make_dual(p, torch.randn_like(p))
110-
out = functional_call(model, params, input)
111-
jvp = fwAD.unpack_dual(out).tangent
111+
for name, p in params.items():
112+
# Using the same ``tangents`` from the above section
113+
dual_params[name] = fwAD.make_dual(p, tangents[name])
114+
out = functional_call(model, dual_params, input)
115+
jvp2 = fwAD.unpack_dual(out).tangent
112116

117+
# Check our results
118+
assert torch.allclose(jvp, jvp2)
113119

114120
######################################################################
115121
# Custom autograd Function
116122
# --------------------------------------------------------------------
117-
# Hello world
123+
# Custom Functions also support forward-mode AD. To create custom Function
124+
# supporting forward-mode AD, register the ``jvp()`` static method. It is
125+
# possible, but not mandatory for custom Functions to support both forward
126+
# and backward AD. See the
127+
# `documentation <https://pytorch.org/docs/master/notes/extending.html#forward-mode-ad>`_
128+
# for more information.
118129

119130
class Fn(torch.autograd.Function):
120131
@staticmethod
121132
def forward(ctx, foo):
122-
return foo * 2
133+
result = torch.exp(foo)
134+
# Tensors stored in ctx can be used in the subsequent forward grad
135+
# computation.
136+
ctx.result = result
137+
return result
123138

124139
@staticmethod
125140
def jvp(ctx, gI):
126-
torch.randn_like(gI)
127-
return gI * 2
141+
gO = gI * ctx.result
142+
# If the tensor stored in ctx will not also be used in the backward pass,
143+
# one can manually free it using ``del``
144+
del ctx.result
145+
return gO
146+
147+
fn = Fn.apply
148+
149+
primal = torch.randn(10, 10, dtype=torch.double, requires_grad=True) # Fix this?
150+
tangent = torch.randn(10, 10)
151+
152+
with fwAD.dual_level():
153+
dual_input = fwAD.make_dual(primal, tangent)
154+
dual_output = fn(dual_input)
155+
jvp = fwAD.unpack_dual(dual_output).tangent
156+
157+
# It is important to use ``autograd.gradcheck`` to verify that your
158+
# custom autograd Function computes the gradients correctly. By default,
159+
# gradcheck only checks the backward-mode (reverse-mode) AD gradients. Specify
160+
# ``check_forward_ad=True`` to also check forward grads. If you did not
161+
# implement the backward formula for your function, you can also tell gradcheck
162+
# to skip the tests that require backward-mode AD by specifying
163+
# ``check_backward_ad=False``, ``check_undefined_grad=False``, and
164+
# ``check_batched_grad=False``.
165+
torch.autograd.gradcheck(Fn.apply, (primal,), check_forward_ad=True,
166+
check_backward_ad=False, check_undefined_grad=False,
167+
check_batched_grad=False)

0 commit comments

Comments
 (0)