Skip to content

Commit 52e045e

Browse files
committed
Add forward AD tutorial
1 parent c08519b commit 52e045e

File tree

1 file changed

+127
-0
lines changed

1 file changed

+127
-0
lines changed
Lines changed: 127 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,127 @@
1+
# -*- coding: utf-8 -*-
2+
"""
3+
Forward-mode Auto Differentiation
4+
=================================
5+
6+
This tutorial demonstrates how to compute directional derivatives
7+
(or, equivalently, Jacobian-vector products) with forward-mode AD.
8+
9+
"""
10+
11+
######################################################################
12+
# Basic Usage
13+
# --------------------------------------------------------------------
14+
# Unlike reverse-mode AD, forward-mode AD computes gradients eagerly
15+
# alongside the forward pass. We can compute a directional derivative
16+
# with forward-mode AD by performing the forward pass as we usually do,
17+
# except, prior to calling the function, we first associate with our
18+
# input with another tensor representing the direction of the directional
19+
# derivative, or equivalently, the ``v`` in a Jacobian-vector product.
20+
# We also call this "direction" tensor a tangent tensor.
21+
#
22+
# As the forward pass is performed (if any input tensors have associated
23+
# tangents) extra computation is performed to propogate this "sensitivity"
24+
# of the function.
25+
#
26+
# [0] https://en.wikipedia.org/wiki/Dual_number
27+
28+
import torch
29+
import torch.autograd.forward_ad as fwAD
30+
31+
primal = torch.randn(10, 10)
32+
tangent = torch.randn(10, 10)
33+
34+
def fn(x, y):
35+
return x ** 2 + y ** 2
36+
37+
# All forward AD computation must be performed in the context of
38+
# the a ``dual_level`` context. All dual tensors created in a
39+
# context will have their tangents destoryed upon exit. This is to ensure that
40+
# if the output or intermediate results of this computation are reused
41+
# in a future forward AD computation, their tangents (which are associated
42+
# with this computation) won't be confused with tangents from later computation.
43+
with fwAD.dual_level():
44+
# To create a dual tensor we associate a tensor, which we call the
45+
# primal with another tensor of the same size, which call the tangent.
46+
# If the layout of the tangent is different from that of the primal,
47+
# The values of the tangent are copied into a new tensor with the same
48+
# metadata as the primal. Otherwise, the tangent itself is used as-is.
49+
#
50+
# It is important to take note that the dual tensor created by
51+
# ``make_dual``` is a view of the primal.
52+
dual_input = fwAD.make_dual(primal, tangent)
53+
assert dual_input._base is primal
54+
assert fwAD.unpack_dual(dual_input).tangent is tangent
55+
56+
# Any tensor involved in the computation that do not have an associated tangent,
57+
# are automatically considered to have a zero-filled tangent.
58+
plain_tensor = torch.randn(10, 10)
59+
dual_output = fn(dual_input, plain_tensor)
60+
61+
# Unpacking the dual returns a namedtuple, with primal and tangent as its
62+
# attributes
63+
jvp = fwAD.unpack_dual(dual_output).tangent
64+
65+
assert fwAD.unpack_dual(dual_output).tangent is None
66+
output = fwAD.unpack_dual(dual_output)
67+
68+
######################################################################
69+
# Usage with Modules
70+
# --------------------------------------------------------------------
71+
# To use ``nn.Module``s with forward AD, replace the parameters of your
72+
# model with dual tensors before performing the forward pass.
73+
74+
import torch.nn as nn
75+
76+
model = nn.Linear(10, 10)
77+
input = torch.randn(64, 10)
78+
79+
with fwAD.dual_level():
80+
for name, p in model.named_parameters():
81+
# detach to avoid the extra overhead of creating the backward graph
82+
# print(p)
83+
# Oh no! This doesn't quite work yet because make_dua
84+
# I don't think subclassing works with forward AD because...
85+
#
86+
# dual_param = fwAD.make_dual(p.detach(), torch.randn_like(p))
87+
dual_param = fwAD.make_dual(p, torch.rand_like(p))
88+
setattr(model, name, dual_param)
89+
print(fwAD.unpack_dual(getattr(model, "weight")))
90+
out = model(input)
91+
92+
# print(fwAD.unpack_dual(next(model.parameters())).tangent)
93+
94+
jvp = fwAD.unpack_dual(out).tangent
95+
96+
print("2", jvp)
97+
98+
######################################################################
99+
# Using Modules stateless API
100+
# --------------------------------------------------------------------
101+
# Another way to use ``nn.Module``s with forward AD is to utilize
102+
# the stateless API:
103+
104+
from torch.nn.utils._stateless import functional_call
105+
106+
params = {}
107+
with fwAD.dual_level():
108+
for name, p in model.named_parameters():
109+
params[name] = fwAD.make_dual(p, torch.randn_like(p))
110+
out = functional_call(model, params, input)
111+
jvp = fwAD.unpack_dual(out).tangent
112+
113+
114+
######################################################################
115+
# Custom autograd Function
116+
# --------------------------------------------------------------------
117+
# Hello world
118+
119+
class Fn(torch.autograd.Function):
120+
@staticmethod
121+
def forward(ctx, foo):
122+
return foo * 2
123+
124+
@staticmethod
125+
def jvp(ctx, gI):
126+
torch.randn_like(gI)
127+
return gI * 2

0 commit comments

Comments
 (0)