Skip to content

Commit 3102426

Browse files
committed
Address comments
1 parent 9687d01 commit 3102426

File tree

1 file changed

+14
-10
lines changed

1 file changed

+14
-10
lines changed

intermediate_source/forward_ad_tutorial.py

Lines changed: 14 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
# -*- coding: utf-8 -*-
22
"""
3-
Forward-mode Auto Differentiation
4-
=================================
3+
Forward-mode Automatic Differentiation
4+
======================================
55
66
This tutorial demonstrates how to use forward-mode AD to compute
77
directional derivatives (or equivalently, Jacobian-vector products).
@@ -11,9 +11,9 @@
1111
Unlike reverse-mode AD, forward-mode AD computes gradients eagerly
1212
alongside the forward pass. We can use forward-mode AD to compute a
1313
directional derivative by performing the forward pass as before,
14-
except we first associate with our input with another tensor representing
14+
except we first associate our input with another tensor representing
1515
the direction of the directional derivative (or equivalently, the ``v``
16-
in a Jacobian-vector product). When a input, which we call "primal", is
16+
in a Jacobian-vector product). When an input, which we call "primal", is
1717
associated with a "direction" tensor, which we call "tangent", the
1818
resultant new tensor object is called a "dual tensor" for its connection
1919
to dual numbers[0].
@@ -22,8 +22,6 @@
2222
extra computation is performed to propogate this "sensitivity" of the
2323
function.
2424
25-
[0] https://en.wikipedia.org/wiki/Dual_number
26-
2725
"""
2826

2927
import torch
@@ -52,9 +50,13 @@ def fn(x, y):
5250
# It is also important to note that the dual tensor created by
5351
# ``make_dual`` is a view of the primal.
5452
dual_input = fwAD.make_dual(primal, tangent)
55-
assert dual_input._base is primal
5653
assert fwAD.unpack_dual(dual_input).tangent is tangent
5754

55+
# To demonstrate the case where the copy of the tangent happens,
56+
# we pass in a tangent with a layout different from that of the primal
57+
dual_input_alt = fwAD.make_dual(primal, tangent.T)
58+
assert fwAD.unpack_dual(dual_input_alt).tangent is not tangent
59+
5860
# Tensors that do not not have an associated tangent are automatically
5961
# considered to have a zero-filled tangent of the same shape.
6062
plain_tensor = torch.randn(10, 10)
@@ -65,7 +67,6 @@ def fn(x, y):
6567
jvp = fwAD.unpack_dual(dual_output).tangent
6668

6769
assert fwAD.unpack_dual(dual_output).tangent is None
68-
output = fwAD.unpack_dual(dual_output)
6970

7071
######################################################################
7172
# Usage with Modules
@@ -145,7 +146,7 @@ def jvp(ctx, gI):
145146

146147
fn = Fn.apply
147148

148-
primal = torch.randn(10, 10, dtype=torch.double, requires_grad=True) # Fix this?
149+
primal = torch.randn(10, 10, dtype=torch.double, requires_grad=True)
149150
tangent = torch.randn(10, 10)
150151

151152
with fwAD.dual_level():
@@ -163,4 +164,7 @@ def jvp(ctx, gI):
163164
# ``check_batched_grad=False``.
164165
torch.autograd.gradcheck(Fn.apply, (primal,), check_forward_ad=True,
165166
check_backward_ad=False, check_undefined_grad=False,
166-
check_batched_grad=False)
167+
check_batched_grad=False)
168+
169+
######################################################################
170+
# [0] https://en.wikipedia.org/wiki/Dual_number

0 commit comments

Comments
 (0)