We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent 3bd3892 commit 6f74b76Copy full SHA for 6f74b76
torch/_subclasses/meta_utils.py
@@ -632,8 +632,9 @@ def _to_fake_tensor(t):
632
r = _add_batch_dim(ft, bdim, lvl)
633
elif is_gradtrackingtensor(t):
634
disable_functorch = torch._C._DisableFuncTorch
635
+ ft = get_unwrapped(t)
636
with disable_functorch():
- ft = _to_fake_tensor(get_unwrapped(t))
637
+ ft = _to_fake_tensor(ft)
638
lvl = torch._C._functorch.maybe_get_level(t)
639
r = torch._C._functorch._wrap_for_grad(ft, lvl)
640
0 commit comments