-
Notifications
You must be signed in to change notification settings - Fork 4.2k
Add forward AD tutorial #1746
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add forward AD tutorial #1746
Conversation
✔️ Deploy Preview for pytorch-tutorials-preview ready! 🔨 Explore the source changes: 111f2d5 🔍 Inspect the deploy log: https://app.netlify.com/sites/pytorch-tutorials-preview/deploys/61a956aa5fb4d1000846cd65 😎 Browse the preview: https://deploy-preview-1746--pytorch-tutorials-preview.netlify.app |
|
||
fn = Fn.apply | ||
|
||
primal = torch.randn(10, 10, dtype=torch.double, requires_grad=True) # Fix this? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Gradcheck should not require inputs to require grad if we are not checking backward AD. See: pytorch/pytorch#69004
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not sure this is true.
We do use requires_grad to know which input/outputs are differentiable.
Maybe we should not, but today we do require this, and just removing the check doesn't sound like the right solution.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You're correct about the inputs, but looks like gradcheck currently doesn't actually need the outputs to have requires grad:
import torch
import torch.autograd.forward_ad as fwAD
def detach_backward(x):
_, tangent = fwAD.unpack_dual(x)
if not tangent: # non-dual-level friendly path
return x.detach()
return fwAD.make_dual(x.detach(), tangent)
class Func(torch.autograd.Function):
@staticmethod
def forward(ctx, x):
return x * 2
@staticmethod
def jvp(ctx, gI):
return gI * 3 # Wrong, lets see if gradcheck catches this
t = torch.tensor(1., requires_grad=True, dtype=torch.double)
def fn(x):
out1 = detach_backward(Func.apply(x))
out2 = x.clone() # only necessary to bypass the no differnetiable outputs check
return out1, out2
with fwAD.dual_level():
dual = fwAD.make_dual(t, torch.rand_like(t))
out = fn(dual)
assert not out[0].requires_grad
torch.autograd.gradcheck(fn, (t,), check_forward_ad=True,
check_backward_ad=False, check_batched_grad=False, check_undefined_grad=False)
I don't know whether the fact non-require-grad outputs are checked is useful though. At least for custom functions, it looks like the output still requires_grad even if only jvp
is implemented, i..e, does a function that behaves like detach_backward
actually exist in nature?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
does a function that behaves like detach_backward actually exist in nature?
Ignoring things like detach
, I think the answer is no. If a function is differentiable, it will be for both forward and backward.
Note that this needs an API that is only in master so the build will fail. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
sgtm
No description provided.