From 52e045ec9d1bbd95a4bdff311a0823f1a0920525 Mon Sep 17 00:00:00 2001 From: soulitzer Date: Thu, 18 Nov 2021 13:53:09 -0500 Subject: [PATCH 1/5] Add forward AD tutorial --- intermediate_source/forward_ad_tutorial.py | 127 +++++++++++++++++++++ 1 file changed, 127 insertions(+) create mode 100644 intermediate_source/forward_ad_tutorial.py diff --git a/intermediate_source/forward_ad_tutorial.py b/intermediate_source/forward_ad_tutorial.py new file mode 100644 index 00000000000..40c288427e7 --- /dev/null +++ b/intermediate_source/forward_ad_tutorial.py @@ -0,0 +1,127 @@ +# -*- coding: utf-8 -*- +""" +Forward-mode Auto Differentiation +================================= + +This tutorial demonstrates how to compute directional derivatives +(or, equivalently, Jacobian-vector products) with forward-mode AD. + +""" + +###################################################################### +# Basic Usage +# -------------------------------------------------------------------- +# Unlike reverse-mode AD, forward-mode AD computes gradients eagerly +# alongside the forward pass. We can compute a directional derivative +# with forward-mode AD by performing the forward pass as we usually do, +# except, prior to calling the function, we first associate with our +# input with another tensor representing the direction of the directional +# derivative, or equivalently, the ``v`` in a Jacobian-vector product. +# We also call this "direction" tensor a tangent tensor. +# +# As the forward pass is performed (if any input tensors have associated +# tangents) extra computation is performed to propogate this "sensitivity" +# of the function. +# +# [0] https://en.wikipedia.org/wiki/Dual_number + +import torch +import torch.autograd.forward_ad as fwAD + +primal = torch.randn(10, 10) +tangent = torch.randn(10, 10) + +def fn(x, y): + return x ** 2 + y ** 2 + +# All forward AD computation must be performed in the context of +# the a ``dual_level`` context. All dual tensors created in a +# context will have their tangents destoryed upon exit. This is to ensure that +# if the output or intermediate results of this computation are reused +# in a future forward AD computation, their tangents (which are associated +# with this computation) won't be confused with tangents from later computation. +with fwAD.dual_level(): + # To create a dual tensor we associate a tensor, which we call the + # primal with another tensor of the same size, which call the tangent. + # If the layout of the tangent is different from that of the primal, + # The values of the tangent are copied into a new tensor with the same + # metadata as the primal. Otherwise, the tangent itself is used as-is. + # + # It is important to take note that the dual tensor created by + # ``make_dual``` is a view of the primal. + dual_input = fwAD.make_dual(primal, tangent) + assert dual_input._base is primal + assert fwAD.unpack_dual(dual_input).tangent is tangent + + # Any tensor involved in the computation that do not have an associated tangent, + # are automatically considered to have a zero-filled tangent. + plain_tensor = torch.randn(10, 10) + dual_output = fn(dual_input, plain_tensor) + + # Unpacking the dual returns a namedtuple, with primal and tangent as its + # attributes + jvp = fwAD.unpack_dual(dual_output).tangent + +assert fwAD.unpack_dual(dual_output).tangent is None +output = fwAD.unpack_dual(dual_output) + +###################################################################### +# Usage with Modules +# -------------------------------------------------------------------- +# To use ``nn.Module``s with forward AD, replace the parameters of your +# model with dual tensors before performing the forward pass. + +import torch.nn as nn + +model = nn.Linear(10, 10) +input = torch.randn(64, 10) + +with fwAD.dual_level(): + for name, p in model.named_parameters(): + # detach to avoid the extra overhead of creating the backward graph + # print(p) + # Oh no! This doesn't quite work yet because make_dua + # I don't think subclassing works with forward AD because... + # + # dual_param = fwAD.make_dual(p.detach(), torch.randn_like(p)) + dual_param = fwAD.make_dual(p, torch.rand_like(p)) + setattr(model, name, dual_param) + print(fwAD.unpack_dual(getattr(model, "weight"))) + out = model(input) + + # print(fwAD.unpack_dual(next(model.parameters())).tangent) + + jvp = fwAD.unpack_dual(out).tangent + + print("2", jvp) + +###################################################################### +# Using Modules stateless API +# -------------------------------------------------------------------- +# Another way to use ``nn.Module``s with forward AD is to utilize +# the stateless API: + +from torch.nn.utils._stateless import functional_call + +params = {} +with fwAD.dual_level(): + for name, p in model.named_parameters(): + params[name] = fwAD.make_dual(p, torch.randn_like(p)) + out = functional_call(model, params, input) + jvp = fwAD.unpack_dual(out).tangent + + +###################################################################### +# Custom autograd Function +# -------------------------------------------------------------------- +# Hello world + +class Fn(torch.autograd.Function): + @staticmethod + def forward(ctx, foo): + return foo * 2 + + @staticmethod + def jvp(ctx, gI): + torch.randn_like(gI) + return gI * 2 From 674edc3bf0adf3484f1bb8b5b2a00c98a656357f Mon Sep 17 00:00:00 2001 From: soulitzer Date: Mon, 29 Nov 2021 11:30:44 -0500 Subject: [PATCH 2/5] Update --- intermediate_source/forward_ad_tutorial.py | 140 +++++++++++++-------- 1 file changed, 90 insertions(+), 50 deletions(-) diff --git a/intermediate_source/forward_ad_tutorial.py b/intermediate_source/forward_ad_tutorial.py index 40c288427e7..fb579f88a3c 100644 --- a/intermediate_source/forward_ad_tutorial.py +++ b/intermediate_source/forward_ad_tutorial.py @@ -3,8 +3,8 @@ Forward-mode Auto Differentiation ================================= -This tutorial demonstrates how to compute directional derivatives -(or, equivalently, Jacobian-vector products) with forward-mode AD. +This tutorial demonstrates how to use forward-mode AD to compute +directional derivatives (or equivalently, Jacobian-vector products). """ @@ -12,16 +12,18 @@ # Basic Usage # -------------------------------------------------------------------- # Unlike reverse-mode AD, forward-mode AD computes gradients eagerly -# alongside the forward pass. We can compute a directional derivative -# with forward-mode AD by performing the forward pass as we usually do, -# except, prior to calling the function, we first associate with our -# input with another tensor representing the direction of the directional -# derivative, or equivalently, the ``v`` in a Jacobian-vector product. -# We also call this "direction" tensor a tangent tensor. +# alongside the forward pass. We can use forward-mode AD to compute a +# directional derivative by performing the forward pass as before, +# except we first associate with our input with another tensor representing +# the direction of the directional derivative (or equivalently, the ``v`` +# in a Jacobian-vector product). When a input, which we call "primal", is +# associated with a "direction" tensor, which we call "tangent", the +# resultant new tensor object is called a "dual tensor" for its connection +# to dual numbers[0]. # -# As the forward pass is performed (if any input tensors have associated -# tangents) extra computation is performed to propogate this "sensitivity" -# of the function. +# As the forward pass is performed, if any input tensors are dual tensors, +# extra computation is performed to propogate this "sensitivity" of the +# function. # # [0] https://en.wikipedia.org/wiki/Dual_number @@ -35,31 +37,32 @@ def fn(x, y): return x ** 2 + y ** 2 # All forward AD computation must be performed in the context of -# the a ``dual_level`` context. All dual tensors created in a -# context will have their tangents destoryed upon exit. This is to ensure that +# a ``dual_level`` context. All dual tensors created in such a context +# will have their tangents destroyed upon exit. This is to ensure that # if the output or intermediate results of this computation are reused # in a future forward AD computation, their tangents (which are associated -# with this computation) won't be confused with tangents from later computation. +# with this computation) won't be confused with tangents from the later +# computation. with fwAD.dual_level(): # To create a dual tensor we associate a tensor, which we call the - # primal with another tensor of the same size, which call the tangent. + # primal with another tensor of the same size, which we call the tangent. # If the layout of the tangent is different from that of the primal, # The values of the tangent are copied into a new tensor with the same # metadata as the primal. Otherwise, the tangent itself is used as-is. # - # It is important to take note that the dual tensor created by - # ``make_dual``` is a view of the primal. + # It is also important to note that the dual tensor created by + # ``make_dual`` is a view of the primal. dual_input = fwAD.make_dual(primal, tangent) assert dual_input._base is primal assert fwAD.unpack_dual(dual_input).tangent is tangent - # Any tensor involved in the computation that do not have an associated tangent, - # are automatically considered to have a zero-filled tangent. + # Tensors that do not not have an associated tangent are automatically + # considered to have a zero-filled tangent of the same shape. plain_tensor = torch.randn(10, 10) dual_output = fn(dual_input, plain_tensor) - # Unpacking the dual returns a namedtuple, with primal and tangent as its - # attributes + # Unpacking the dual returns a namedtuple with ``primal`` and ``tangent`` + # as attributes jvp = fwAD.unpack_dual(dual_output).tangent assert fwAD.unpack_dual(dual_output).tangent is None @@ -69,59 +72,96 @@ def fn(x, y): # Usage with Modules # -------------------------------------------------------------------- # To use ``nn.Module``s with forward AD, replace the parameters of your -# model with dual tensors before performing the forward pass. +# model with dual tensors before performing the forward pass. At the +# time of writing, it is not possible to create dual tensor +# `nn.Parameter`s. As a workaround, one must register the dual tensor +# as a non-parameter attribute of the module. import torch.nn as nn -model = nn.Linear(10, 10) -input = torch.randn(64, 10) +model = nn.Linear(5, 5) +input = torch.randn(16, 5) -with fwAD.dual_level(): - for name, p in model.named_parameters(): - # detach to avoid the extra overhead of creating the backward graph - # print(p) - # Oh no! This doesn't quite work yet because make_dua - # I don't think subclassing works with forward AD because... - # - # dual_param = fwAD.make_dual(p.detach(), torch.randn_like(p)) - dual_param = fwAD.make_dual(p, torch.rand_like(p)) - setattr(model, name, dual_param) - print(fwAD.unpack_dual(getattr(model, "weight"))) - out = model(input) +params = {name: p for name, p in model.named_parameters()} +tangents = {name: torch.rand_like(p) for name, p in params.items()} - # print(fwAD.unpack_dual(next(model.parameters())).tangent) +with fwAD.dual_level(): + for name, p in params.items(): + delattr(model, name) + setattr(model, name, fwAD.make_dual(p, tangents[name])) + out = model(input) jvp = fwAD.unpack_dual(out).tangent - print("2", jvp) - ###################################################################### -# Using Modules stateless API +# Using Modules stateless API (experimental) # -------------------------------------------------------------------- # Another way to use ``nn.Module``s with forward AD is to utilize -# the stateless API: +# the stateless API. NB: At the time of writing the stateless API is still +# experimental and may be subject to change. from torch.nn.utils._stateless import functional_call -params = {} +# We need a fresh module because the functional call requires the +# the model to have parameters registered. +model = nn.Linear(5, 5) + +dual_params = {} with fwAD.dual_level(): - for name, p in model.named_parameters(): - params[name] = fwAD.make_dual(p, torch.randn_like(p)) - out = functional_call(model, params, input) - jvp = fwAD.unpack_dual(out).tangent + for name, p in params.items(): + # Using the same ``tangents`` from the above section + dual_params[name] = fwAD.make_dual(p, tangents[name]) + out = functional_call(model, dual_params, input) + jvp2 = fwAD.unpack_dual(out).tangent +# Check our results +assert torch.allclose(jvp, jvp2) ###################################################################### # Custom autograd Function # -------------------------------------------------------------------- -# Hello world +# Custom Functions also support forward-mode AD. To create custom Function +# supporting forward-mode AD, register the ``jvp()`` static method. It is +# possible, but not mandatory for custom Functions to support both forward +# and backward AD. See the +# `documentation `_ +# for more information. class Fn(torch.autograd.Function): @staticmethod def forward(ctx, foo): - return foo * 2 + result = torch.exp(foo) + # Tensors stored in ctx can be used in the subsequent forward grad + # computation. + ctx.result = result + return result @staticmethod def jvp(ctx, gI): - torch.randn_like(gI) - return gI * 2 + gO = gI * ctx.result + # If the tensor stored in ctx will not also be used in the backward pass, + # one can manually free it using ``del`` + del ctx.result + return gO + +fn = Fn.apply + +primal = torch.randn(10, 10, dtype=torch.double, requires_grad=True) # Fix this? +tangent = torch.randn(10, 10) + +with fwAD.dual_level(): + dual_input = fwAD.make_dual(primal, tangent) + dual_output = fn(dual_input) + jvp = fwAD.unpack_dual(dual_output).tangent + +# It is important to use ``autograd.gradcheck`` to verify that your +# custom autograd Function computes the gradients correctly. By default, +# gradcheck only checks the backward-mode (reverse-mode) AD gradients. Specify +# ``check_forward_ad=True`` to also check forward grads. If you did not +# implement the backward formula for your function, you can also tell gradcheck +# to skip the tests that require backward-mode AD by specifying +# ``check_backward_ad=False``, ``check_undefined_grad=False``, and +# ``check_batched_grad=False``. +torch.autograd.gradcheck(Fn.apply, (primal,), check_forward_ad=True, + check_backward_ad=False, check_undefined_grad=False, + check_batched_grad=False) \ No newline at end of file From 9687d015b12ec14ee21f6f93bd07310fe68a9ec0 Mon Sep 17 00:00:00 2001 From: soulitzer Date: Mon, 29 Nov 2021 11:39:18 -0500 Subject: [PATCH 3/5] Formatting --- intermediate_source/forward_ad_tutorial.py | 37 +++++++++++----------- 1 file changed, 18 insertions(+), 19 deletions(-) diff --git a/intermediate_source/forward_ad_tutorial.py b/intermediate_source/forward_ad_tutorial.py index fb579f88a3c..795a95e56d3 100644 --- a/intermediate_source/forward_ad_tutorial.py +++ b/intermediate_source/forward_ad_tutorial.py @@ -6,26 +6,25 @@ This tutorial demonstrates how to use forward-mode AD to compute directional derivatives (or equivalently, Jacobian-vector products). -""" +Basic Usage +-------------------------------------------------------------------- +Unlike reverse-mode AD, forward-mode AD computes gradients eagerly +alongside the forward pass. We can use forward-mode AD to compute a +directional derivative by performing the forward pass as before, +except we first associate with our input with another tensor representing +the direction of the directional derivative (or equivalently, the ``v`` +in a Jacobian-vector product). When a input, which we call "primal", is +associated with a "direction" tensor, which we call "tangent", the +resultant new tensor object is called a "dual tensor" for its connection +to dual numbers[0]. + +As the forward pass is performed, if any input tensors are dual tensors, +extra computation is performed to propogate this "sensitivity" of the +function. + +[0] https://en.wikipedia.org/wiki/Dual_number -###################################################################### -# Basic Usage -# -------------------------------------------------------------------- -# Unlike reverse-mode AD, forward-mode AD computes gradients eagerly -# alongside the forward pass. We can use forward-mode AD to compute a -# directional derivative by performing the forward pass as before, -# except we first associate with our input with another tensor representing -# the direction of the directional derivative (or equivalently, the ``v`` -# in a Jacobian-vector product). When a input, which we call "primal", is -# associated with a "direction" tensor, which we call "tangent", the -# resultant new tensor object is called a "dual tensor" for its connection -# to dual numbers[0]. -# -# As the forward pass is performed, if any input tensors are dual tensors, -# extra computation is performed to propogate this "sensitivity" of the -# function. -# -# [0] https://en.wikipedia.org/wiki/Dual_number +""" import torch import torch.autograd.forward_ad as fwAD From 3102426fde898f63f7c1db3c6099bb161a04ce9c Mon Sep 17 00:00:00 2001 From: soulitzer Date: Wed, 1 Dec 2021 16:47:53 -0500 Subject: [PATCH 4/5] Address comments --- intermediate_source/forward_ad_tutorial.py | 24 +++++++++++++--------- 1 file changed, 14 insertions(+), 10 deletions(-) diff --git a/intermediate_source/forward_ad_tutorial.py b/intermediate_source/forward_ad_tutorial.py index 795a95e56d3..46416693c04 100644 --- a/intermediate_source/forward_ad_tutorial.py +++ b/intermediate_source/forward_ad_tutorial.py @@ -1,7 +1,7 @@ # -*- coding: utf-8 -*- """ -Forward-mode Auto Differentiation -================================= +Forward-mode Automatic Differentiation +====================================== This tutorial demonstrates how to use forward-mode AD to compute directional derivatives (or equivalently, Jacobian-vector products). @@ -11,9 +11,9 @@ Unlike reverse-mode AD, forward-mode AD computes gradients eagerly alongside the forward pass. We can use forward-mode AD to compute a directional derivative by performing the forward pass as before, -except we first associate with our input with another tensor representing +except we first associate our input with another tensor representing the direction of the directional derivative (or equivalently, the ``v`` -in a Jacobian-vector product). When a input, which we call "primal", is +in a Jacobian-vector product). When an input, which we call "primal", is associated with a "direction" tensor, which we call "tangent", the resultant new tensor object is called a "dual tensor" for its connection to dual numbers[0]. @@ -22,8 +22,6 @@ extra computation is performed to propogate this "sensitivity" of the function. -[0] https://en.wikipedia.org/wiki/Dual_number - """ import torch @@ -52,9 +50,13 @@ def fn(x, y): # It is also important to note that the dual tensor created by # ``make_dual`` is a view of the primal. dual_input = fwAD.make_dual(primal, tangent) - assert dual_input._base is primal assert fwAD.unpack_dual(dual_input).tangent is tangent + # To demonstrate the case where the copy of the tangent happens, + # we pass in a tangent with a layout different from that of the primal + dual_input_alt = fwAD.make_dual(primal, tangent.T) + assert fwAD.unpack_dual(dual_input_alt).tangent is not tangent + # Tensors that do not not have an associated tangent are automatically # considered to have a zero-filled tangent of the same shape. plain_tensor = torch.randn(10, 10) @@ -65,7 +67,6 @@ def fn(x, y): jvp = fwAD.unpack_dual(dual_output).tangent assert fwAD.unpack_dual(dual_output).tangent is None -output = fwAD.unpack_dual(dual_output) ###################################################################### # Usage with Modules @@ -145,7 +146,7 @@ def jvp(ctx, gI): fn = Fn.apply -primal = torch.randn(10, 10, dtype=torch.double, requires_grad=True) # Fix this? +primal = torch.randn(10, 10, dtype=torch.double, requires_grad=True) tangent = torch.randn(10, 10) with fwAD.dual_level(): @@ -163,4 +164,7 @@ def jvp(ctx, gI): # ``check_batched_grad=False``. torch.autograd.gradcheck(Fn.apply, (primal,), check_forward_ad=True, check_backward_ad=False, check_undefined_grad=False, - check_batched_grad=False) \ No newline at end of file + check_batched_grad=False) + +###################################################################### +# [0] https://en.wikipedia.org/wiki/Dual_number From f345d2f22841e55426ef54de6d4bb0a08ee93582 Mon Sep 17 00:00:00 2001 From: soulitzer Date: Thu, 2 Dec 2021 11:48:31 -0500 Subject: [PATCH 5/5] Fix quoting --- .../forward_ad_tutorial.py => forward_ad_usage.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) rename intermediate_source/forward_ad_tutorial.py => forward_ad_usage.py (98%) diff --git a/intermediate_source/forward_ad_tutorial.py b/forward_ad_usage.py similarity index 98% rename from intermediate_source/forward_ad_tutorial.py rename to forward_ad_usage.py index 46416693c04..b521ebbef13 100644 --- a/intermediate_source/forward_ad_tutorial.py +++ b/forward_ad_usage.py @@ -71,7 +71,7 @@ def fn(x, y): ###################################################################### # Usage with Modules # -------------------------------------------------------------------- -# To use ``nn.Module``s with forward AD, replace the parameters of your +# To use ``nn.Module`` with forward AD, replace the parameters of your # model with dual tensors before performing the forward pass. At the # time of writing, it is not possible to create dual tensor # `nn.Parameter`s. As a workaround, one must register the dual tensor @@ -96,7 +96,7 @@ def fn(x, y): ###################################################################### # Using Modules stateless API (experimental) # -------------------------------------------------------------------- -# Another way to use ``nn.Module``s with forward AD is to utilize +# Another way to use ``nn.Module`` with forward AD is to utilize # the stateless API. NB: At the time of writing the stateless API is still # experimental and may be subject to change.