Skip to content

Commit 7c643ad

Browse files
committed
[maskedtensor] Distinguish between 0 and NaN gradient
ghstack-source-id: d587709 Pull Request resolved: #2044
1 parent 4b10b89 commit 7c643ad

File tree

2 files changed

+110
-0
lines changed

2 files changed

+110
-0
lines changed
Lines changed: 108 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,108 @@
1+
Distinguishing between 0 and NaN gradient
2+
-----------------------------------------
3+
4+
One issue that :class:`torch.Tensor` runs into is the inability to distinguish between gradients that are not
5+
defined (NaN) vs. gradients that are actually 0. By way of example, below are several different issues where
6+
:class:`MaskedTensor` can resolve and/or work around the NaN gradient problem.
7+
8+
`Issue 10729 <https://github.com/pytorch/pytorch/issues/10729>`__ -- torch.where
9+
--------------------------------------------------------------------------------
10+
11+
Current result:
12+
13+
>>> x = torch.tensor([-10., -5, 0, 5, 10, 50, 60, 70, 80, 90, 100], requires_grad=True, dtype=torch.float)
14+
>>> y = torch.where(x < 0, torch.exp(x), torch.ones_like(x))
15+
>>> y.sum().backward()
16+
>>> x.grad
17+
tensor([4.5400e-05, 6.7379e-03, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
18+
0.0000e+00, 0.0000e+00, 0.0000e+00, nan, nan])
19+
20+
:class:`MaskedTensor` result:
21+
22+
>>> x = torch.tensor([-10., -5, 0, 5, 10, 50, 60, 70, 80, 90, 100])
23+
>>> mask = x < 0
24+
>>> mx = masked_tensor(x, mask, requires_grad=True)
25+
>>> my = masked_tensor(torch.ones_like(x), ~mask, requires_grad=True)
26+
>>> y = torch.where(mask, torch.exp(mx), my)
27+
>>> y.sum().backward()
28+
>>> mx.grad
29+
MaskedTensor(
30+
[ 0.0000, 0.0067, --, --, --, --, --, --, --, --, --]
31+
)
32+
33+
The gradient here is only provided to the selected subset. Effectively, this changes the gradient of `where`
34+
to mask out elements instead of setting them to zero.
35+
36+
`Issue 52248 <https://github.com/pytorch/pytorch/issues/52248>`__ -- another torch.where
37+
----------------------------------------------------------------------------------------
38+
39+
Current result:
40+
41+
>>> a = torch.randn((), requires_grad=True)
42+
>>> b = torch.tensor(False)
43+
>>> c = torch.ones(())
44+
>>> torch.where(b, a/0, c)
45+
tensor(1., grad_fn=<WhereBackward0>)
46+
>>> torch.autograd.grad(torch.where(b, a/0, c), a)
47+
(tensor(nan),)
48+
49+
:class:`MaskedTensor` result:
50+
51+
>>> a = masked_tensor(torch.randn(()), torch.tensor(True), requires_grad=True)
52+
>>> b = torch.tensor(False)
53+
>>> c = torch.ones(())
54+
>>> torch.where(b, a/0, c)
55+
MaskedTensor( 1.0000, True)
56+
>>> torch.autograd.grad(torch.where(b, a/0, c), a)
57+
(MaskedTensor(--, False),)
58+
59+
`Issue 67180 <https://github.com/pytorch/pytorch/issues/67180>`__ -- :func:`torch.nansum` and :func:`torch.nanmean`
60+
-------------------------------------------------------------------------------------------------------------------
61+
62+
Current result:
63+
64+
>>> a = torch.tensor([1., 2., float('nan')])
65+
>>> b = torch.tensor(1.0, requires_grad=True)
66+
>>> c = a * b
67+
>>> c1 = torch.nansum(c)
68+
>>> bgrad1, = torch.autograd.grad(c1, b, retain_graph=True)
69+
>>> bgrad1
70+
tensor(nan)
71+
72+
:class:`MaskedTensor` result:
73+
74+
>>> a = torch.tensor([1., 2., float('nan')])
75+
>>> b = torch.tensor(1.0, requires_grad=True)
76+
>>> mt = masked_tensor(a, ~torch.isnan(a))
77+
>>> c = mt * b
78+
>>> c1 = torch.sum(c)
79+
>>> bgrad1, = torch.autograd.grad(c1, b, retain_graph=True)
80+
>>> bgrad1
81+
MaskedTensor( 3.0000, True)
82+
83+
`Issue 4132 <https://github.com/pytorch/pytorch/issues/4132>`__ -- when using mask, x/0 yields NaN grad
84+
-------------------------------------------------------------------------------------------------------
85+
86+
Current result:
87+
88+
>>> x = torch.tensor([1., 1.], requires_grad=True)
89+
>>> div = torch.tensor([0., 1.])
90+
>>> y = x/div # => y is [inf, 1]
91+
>>> mask = (div != 0) # => mask is [0, 1]
92+
>>> y[mask].backward()
93+
>>> x.grad # grad is [nan, 1], but expected [0, 1]
94+
tensor([nan, 1.])
95+
96+
:class:`MaskedTensor` result:
97+
98+
>>> x = torch.tensor([1., 1.], requires_grad=True)
99+
>>> div = torch.tensor([0., 1.])
100+
>>> y = x/div # => y is [inf, 1]
101+
>>>
102+
>>> mask = (div != 0) # => mask is [0, 1]
103+
>>> loss = as_masked_tensor(y, mask)
104+
>>> loss.sum().backward()
105+
>>> x.grad
106+
MaskedTensor(
107+
[ --, 1.0000]
108+
)

index.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -811,6 +811,8 @@ Additional Resources
811811
:caption: MaskedTensor
812812

813813
beginner/maskedtensor_overview
814+
beginner/maskedtensor_sparsity
815+
beginner/maskedtensor_distinguish_gradient
814816

815817

816818
.. toctree::

0 commit comments

Comments
 (0)