1
1
# -*- coding: utf-8 -*-
2
2
"""
3
- Forward-mode Auto Differentiation
4
- =================================
3
+ Forward-mode Automatic Differentiation
4
+ ======================================
5
5
6
6
This tutorial demonstrates how to use forward-mode AD to compute
7
7
directional derivatives (or equivalently, Jacobian-vector products).
11
11
Unlike reverse-mode AD, forward-mode AD computes gradients eagerly
12
12
alongside the forward pass. We can use forward-mode AD to compute a
13
13
directional derivative by performing the forward pass as before,
14
- except we first associate with our input with another tensor representing
14
+ except we first associate our input with another tensor representing
15
15
the direction of the directional derivative (or equivalently, the ``v``
16
- in a Jacobian-vector product). When a input, which we call "primal", is
16
+ in a Jacobian-vector product). When an input, which we call "primal", is
17
17
associated with a "direction" tensor, which we call "tangent", the
18
18
resultant new tensor object is called a "dual tensor" for its connection
19
19
to dual numbers[0].
22
22
extra computation is performed to propogate this "sensitivity" of the
23
23
function.
24
24
25
- [0] https://en.wikipedia.org/wiki/Dual_number
26
-
27
25
"""
28
26
29
27
import torch
@@ -52,9 +50,13 @@ def fn(x, y):
52
50
# It is also important to note that the dual tensor created by
53
51
# ``make_dual`` is a view of the primal.
54
52
dual_input = fwAD .make_dual (primal , tangent )
55
- assert dual_input ._base is primal
56
53
assert fwAD .unpack_dual (dual_input ).tangent is tangent
57
54
55
+ # To demonstrate the case where the copy of the tangent happens,
56
+ # we pass in a tangent with a layout different from that of the primal
57
+ dual_input_alt = fwAD .make_dual (primal , tangent .T )
58
+ assert fwAD .unpack_dual (dual_input_alt ).tangent is not tangent
59
+
58
60
# Tensors that do not not have an associated tangent are automatically
59
61
# considered to have a zero-filled tangent of the same shape.
60
62
plain_tensor = torch .randn (10 , 10 )
@@ -65,7 +67,6 @@ def fn(x, y):
65
67
jvp = fwAD .unpack_dual (dual_output ).tangent
66
68
67
69
assert fwAD .unpack_dual (dual_output ).tangent is None
68
- output = fwAD .unpack_dual (dual_output )
69
70
70
71
######################################################################
71
72
# Usage with Modules
@@ -145,7 +146,7 @@ def jvp(ctx, gI):
145
146
146
147
fn = Fn .apply
147
148
148
- primal = torch .randn (10 , 10 , dtype = torch .double , requires_grad = True ) # Fix this?
149
+ primal = torch .randn (10 , 10 , dtype = torch .double , requires_grad = True )
149
150
tangent = torch .randn (10 , 10 )
150
151
151
152
with fwAD .dual_level ():
@@ -163,4 +164,7 @@ def jvp(ctx, gI):
163
164
# ``check_batched_grad=False``.
164
165
torch .autograd .gradcheck (Fn .apply , (primal ,), check_forward_ad = True ,
165
166
check_backward_ad = False , check_undefined_grad = False ,
166
- check_batched_grad = False )
167
+ check_batched_grad = False )
168
+
169
+ ######################################################################
170
+ # [0] https://en.wikipedia.org/wiki/Dual_number
0 commit comments