3
3
Forward-mode Auto Differentiation
4
4
=================================
5
5
6
- This tutorial demonstrates how to compute directional derivatives
7
- (or, equivalently, Jacobian-vector products) with forward-mode AD .
6
+ This tutorial demonstrates how to use forward-mode AD to compute
7
+ directional derivatives (or equivalently, Jacobian-vector products).
8
8
9
9
"""
10
10
11
11
######################################################################
12
12
# Basic Usage
13
13
# --------------------------------------------------------------------
14
14
# Unlike reverse-mode AD, forward-mode AD computes gradients eagerly
15
- # alongside the forward pass. We can compute a directional derivative
16
- # with forward-mode AD by performing the forward pass as we usually do,
17
- # except, prior to calling the function, we first associate with our
18
- # input with another tensor representing the direction of the directional
19
- # derivative, or equivalently, the ``v`` in a Jacobian-vector product.
20
- # We also call this "direction" tensor a tangent tensor.
15
+ # alongside the forward pass. We can use forward-mode AD to compute a
16
+ # directional derivative by performing the forward pass as before,
17
+ # except we first associate with our input with another tensor representing
18
+ # the direction of the directional derivative (or equivalently, the ``v``
19
+ # in a Jacobian-vector product). When a input, which we call "primal", is
20
+ # associated with a "direction" tensor, which we call "tangent", the
21
+ # resultant new tensor object is called a "dual tensor" for its connection
22
+ # to dual numbers[0].
21
23
#
22
- # As the forward pass is performed ( if any input tensors have associated
23
- # tangents) extra computation is performed to propogate this "sensitivity"
24
- # of the function.
24
+ # As the forward pass is performed, if any input tensors are dual tensors,
25
+ # extra computation is performed to propogate this "sensitivity" of the
26
+ # function.
25
27
#
26
28
# [0] https://en.wikipedia.org/wiki/Dual_number
27
29
@@ -35,31 +37,32 @@ def fn(x, y):
35
37
return x ** 2 + y ** 2
36
38
37
39
# All forward AD computation must be performed in the context of
38
- # the a ``dual_level`` context. All dual tensors created in a
39
- # context will have their tangents destoryed upon exit. This is to ensure that
40
+ # a ``dual_level`` context. All dual tensors created in such a context
41
+ # will have their tangents destroyed upon exit. This is to ensure that
40
42
# if the output or intermediate results of this computation are reused
41
43
# in a future forward AD computation, their tangents (which are associated
42
- # with this computation) won't be confused with tangents from later computation.
44
+ # with this computation) won't be confused with tangents from the later
45
+ # computation.
43
46
with fwAD .dual_level ():
44
47
# To create a dual tensor we associate a tensor, which we call the
45
- # primal with another tensor of the same size, which call the tangent.
48
+ # primal with another tensor of the same size, which we call the tangent.
46
49
# If the layout of the tangent is different from that of the primal,
47
50
# The values of the tangent are copied into a new tensor with the same
48
51
# metadata as the primal. Otherwise, the tangent itself is used as-is.
49
52
#
50
- # It is important to take note that the dual tensor created by
51
- # ``make_dual``` is a view of the primal.
53
+ # It is also important to note that the dual tensor created by
54
+ # ``make_dual`` is a view of the primal.
52
55
dual_input = fwAD .make_dual (primal , tangent )
53
56
assert dual_input ._base is primal
54
57
assert fwAD .unpack_dual (dual_input ).tangent is tangent
55
58
56
- # Any tensor involved in the computation that do not have an associated tangent,
57
- # are automatically considered to have a zero-filled tangent.
59
+ # Tensors that do not not have an associated tangent are automatically
60
+ # considered to have a zero-filled tangent of the same shape .
58
61
plain_tensor = torch .randn (10 , 10 )
59
62
dual_output = fn (dual_input , plain_tensor )
60
63
61
- # Unpacking the dual returns a namedtuple, with primal and tangent as its
62
- # attributes
64
+ # Unpacking the dual returns a namedtuple with `` primal`` and `` tangent``
65
+ # as attributes
63
66
jvp = fwAD .unpack_dual (dual_output ).tangent
64
67
65
68
assert fwAD .unpack_dual (dual_output ).tangent is None
@@ -69,59 +72,96 @@ def fn(x, y):
69
72
# Usage with Modules
70
73
# --------------------------------------------------------------------
71
74
# To use ``nn.Module``s with forward AD, replace the parameters of your
72
- # model with dual tensors before performing the forward pass.
75
+ # model with dual tensors before performing the forward pass. At the
76
+ # time of writing, it is not possible to create dual tensor
77
+ # `nn.Parameter`s. As a workaround, one must register the dual tensor
78
+ # as a non-parameter attribute of the module.
73
79
74
80
import torch .nn as nn
75
81
76
- model = nn .Linear (10 , 10 )
77
- input = torch .randn (64 , 10 )
82
+ model = nn .Linear (5 , 5 )
83
+ input = torch .randn (16 , 5 )
78
84
79
- with fwAD .dual_level ():
80
- for name , p in model .named_parameters ():
81
- # detach to avoid the extra overhead of creating the backward graph
82
- # print(p)
83
- # Oh no! This doesn't quite work yet because make_dua
84
- # I don't think subclassing works with forward AD because...
85
- #
86
- # dual_param = fwAD.make_dual(p.detach(), torch.randn_like(p))
87
- dual_param = fwAD .make_dual (p , torch .rand_like (p ))
88
- setattr (model , name , dual_param )
89
- print (fwAD .unpack_dual (getattr (model , "weight" )))
90
- out = model (input )
85
+ params = {name : p for name , p in model .named_parameters ()}
86
+ tangents = {name : torch .rand_like (p ) for name , p in params .items ()}
91
87
92
- # print(fwAD.unpack_dual(next(model.parameters())).tangent)
88
+ with fwAD .dual_level ():
89
+ for name , p in params .items ():
90
+ delattr (model , name )
91
+ setattr (model , name , fwAD .make_dual (p , tangents [name ]))
93
92
93
+ out = model (input )
94
94
jvp = fwAD .unpack_dual (out ).tangent
95
95
96
- print ("2" , jvp )
97
-
98
96
######################################################################
99
- # Using Modules stateless API
97
+ # Using Modules stateless API (experimental)
100
98
# --------------------------------------------------------------------
101
99
# Another way to use ``nn.Module``s with forward AD is to utilize
102
- # the stateless API:
100
+ # the stateless API. NB: At the time of writing the stateless API is still
101
+ # experimental and may be subject to change.
103
102
104
103
from torch .nn .utils ._stateless import functional_call
105
104
106
- params = {}
105
+ # We need a fresh module because the functional call requires the
106
+ # the model to have parameters registered.
107
+ model = nn .Linear (5 , 5 )
108
+
109
+ dual_params = {}
107
110
with fwAD .dual_level ():
108
- for name , p in model .named_parameters ():
109
- params [name ] = fwAD .make_dual (p , torch .randn_like (p ))
110
- out = functional_call (model , params , input )
111
- jvp = fwAD .unpack_dual (out ).tangent
111
+ for name , p in params .items ():
112
+ # Using the same ``tangents`` from the above section
113
+ dual_params [name ] = fwAD .make_dual (p , tangents [name ])
114
+ out = functional_call (model , dual_params , input )
115
+ jvp2 = fwAD .unpack_dual (out ).tangent
112
116
117
+ # Check our results
118
+ assert torch .allclose (jvp , jvp2 )
113
119
114
120
######################################################################
115
121
# Custom autograd Function
116
122
# --------------------------------------------------------------------
117
- # Hello world
123
+ # Custom Functions also support forward-mode AD. To create custom Function
124
+ # supporting forward-mode AD, register the ``jvp()`` static method. It is
125
+ # possible, but not mandatory for custom Functions to support both forward
126
+ # and backward AD. See the
127
+ # `documentation <https://pytorch.org/docs/master/notes/extending.html#forward-mode-ad>`_
128
+ # for more information.
118
129
119
130
class Fn (torch .autograd .Function ):
120
131
@staticmethod
121
132
def forward (ctx , foo ):
122
- return foo * 2
133
+ result = torch .exp (foo )
134
+ # Tensors stored in ctx can be used in the subsequent forward grad
135
+ # computation.
136
+ ctx .result = result
137
+ return result
123
138
124
139
@staticmethod
125
140
def jvp (ctx , gI ):
126
- torch .randn_like (gI )
127
- return gI * 2
141
+ gO = gI * ctx .result
142
+ # If the tensor stored in ctx will not also be used in the backward pass,
143
+ # one can manually free it using ``del``
144
+ del ctx .result
145
+ return gO
146
+
147
+ fn = Fn .apply
148
+
149
+ primal = torch .randn (10 , 10 , dtype = torch .double , requires_grad = True ) # Fix this?
150
+ tangent = torch .randn (10 , 10 )
151
+
152
+ with fwAD .dual_level ():
153
+ dual_input = fwAD .make_dual (primal , tangent )
154
+ dual_output = fn (dual_input )
155
+ jvp = fwAD .unpack_dual (dual_output ).tangent
156
+
157
+ # It is important to use ``autograd.gradcheck`` to verify that your
158
+ # custom autograd Function computes the gradients correctly. By default,
159
+ # gradcheck only checks the backward-mode (reverse-mode) AD gradients. Specify
160
+ # ``check_forward_ad=True`` to also check forward grads. If you did not
161
+ # implement the backward formula for your function, you can also tell gradcheck
162
+ # to skip the tests that require backward-mode AD by specifying
163
+ # ``check_backward_ad=False``, ``check_undefined_grad=False``, and
164
+ # ``check_batched_grad=False``.
165
+ torch .autograd .gradcheck (Fn .apply , (primal ,), check_forward_ad = True ,
166
+ check_backward_ad = False , check_undefined_grad = False ,
167
+ check_batched_grad = False )
0 commit comments