|
| 1 | +Distinguishing between 0 and NaN gradient |
| 2 | +----------------------------------------- |
| 3 | + |
| 4 | +One issue that :class:`torch.Tensor` runs into is the inability to distinguish between gradients that are not |
| 5 | +defined (NaN) vs. gradients that are actually 0. By way of example, below are several different issues where |
| 6 | +:class:`MaskedTensor` can resolve and/or work around the NaN gradient problem. |
| 7 | + |
| 8 | +`Issue 10729 <https://github.com/pytorch/pytorch/issues/10729>`__ -- torch.where |
| 9 | +-------------------------------------------------------------------------------- |
| 10 | + |
| 11 | +Current result: |
| 12 | + |
| 13 | + >>> x = torch.tensor([-10., -5, 0, 5, 10, 50, 60, 70, 80, 90, 100], requires_grad=True, dtype=torch.float) |
| 14 | + >>> y = torch.where(x < 0, torch.exp(x), torch.ones_like(x)) |
| 15 | + >>> y.sum().backward() |
| 16 | + >>> x.grad |
| 17 | + tensor([4.5400e-05, 6.7379e-03, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, |
| 18 | + 0.0000e+00, 0.0000e+00, 0.0000e+00, nan, nan]) |
| 19 | + |
| 20 | +:class:`MaskedTensor` result: |
| 21 | + |
| 22 | + >>> x = torch.tensor([-10., -5, 0, 5, 10, 50, 60, 70, 80, 90, 100]) |
| 23 | + >>> mask = x < 0 |
| 24 | + >>> mx = masked_tensor(x, mask, requires_grad=True) |
| 25 | + >>> my = masked_tensor(torch.ones_like(x), ~mask, requires_grad=True) |
| 26 | + >>> y = torch.where(mask, torch.exp(mx), my) |
| 27 | + >>> y.sum().backward() |
| 28 | + >>> mx.grad |
| 29 | + MaskedTensor( |
| 30 | + [ 0.0000, 0.0067, --, --, --, --, --, --, --, --, --] |
| 31 | + ) |
| 32 | + |
| 33 | +The gradient here is only provided to the selected subset. Effectively, this changes the gradient of `where` |
| 34 | +to mask out elements instead of setting them to zero. |
| 35 | + |
| 36 | +`Issue 52248 <https://github.com/pytorch/pytorch/issues/52248>`__ -- another torch.where |
| 37 | +---------------------------------------------------------------------------------------- |
| 38 | + |
| 39 | +Current result: |
| 40 | + |
| 41 | + >>> a = torch.randn((), requires_grad=True) |
| 42 | + >>> b = torch.tensor(False) |
| 43 | + >>> c = torch.ones(()) |
| 44 | + >>> torch.where(b, a/0, c) |
| 45 | + tensor(1., grad_fn=<WhereBackward0>) |
| 46 | + >>> torch.autograd.grad(torch.where(b, a/0, c), a) |
| 47 | + (tensor(nan),) |
| 48 | + |
| 49 | +:class:`MaskedTensor` result: |
| 50 | + |
| 51 | + >>> a = masked_tensor(torch.randn(()), torch.tensor(True), requires_grad=True) |
| 52 | + >>> b = torch.tensor(False) |
| 53 | + >>> c = torch.ones(()) |
| 54 | + >>> torch.where(b, a/0, c) |
| 55 | + MaskedTensor( 1.0000, True) |
| 56 | + >>> torch.autograd.grad(torch.where(b, a/0, c), a) |
| 57 | + (MaskedTensor(--, False),) |
| 58 | + |
| 59 | +`Issue 67180 <https://github.com/pytorch/pytorch/issues/67180>`__ -- :func:`torch.nansum` and :func:`torch.nanmean` |
| 60 | +------------------------------------------------------------------------------------------------------------------- |
| 61 | + |
| 62 | +Current result: |
| 63 | + |
| 64 | + >>> a = torch.tensor([1., 2., float('nan')]) |
| 65 | + >>> b = torch.tensor(1.0, requires_grad=True) |
| 66 | + >>> c = a * b |
| 67 | + >>> c1 = torch.nansum(c) |
| 68 | + >>> bgrad1, = torch.autograd.grad(c1, b, retain_graph=True) |
| 69 | + >>> bgrad1 |
| 70 | + tensor(nan) |
| 71 | + |
| 72 | +:class:`MaskedTensor` result: |
| 73 | + |
| 74 | + >>> a = torch.tensor([1., 2., float('nan')]) |
| 75 | + >>> b = torch.tensor(1.0, requires_grad=True) |
| 76 | + >>> mt = masked_tensor(a, ~torch.isnan(a)) |
| 77 | + >>> c = mt * b |
| 78 | + >>> c1 = torch.sum(c) |
| 79 | + >>> bgrad1, = torch.autograd.grad(c1, b, retain_graph=True) |
| 80 | + >>> bgrad1 |
| 81 | + MaskedTensor( 3.0000, True) |
| 82 | + |
| 83 | +`Issue 4132 <https://github.com/pytorch/pytorch/issues/4132>`__ -- when using mask, x/0 yields NaN grad |
| 84 | +------------------------------------------------------------------------------------------------------- |
| 85 | + |
| 86 | +Current result: |
| 87 | + |
| 88 | + >>> x = torch.tensor([1., 1.], requires_grad=True) |
| 89 | + >>> div = torch.tensor([0., 1.]) |
| 90 | + >>> y = x/div # => y is [inf, 1] |
| 91 | + >>> mask = (div != 0) # => mask is [0, 1] |
| 92 | + >>> y[mask].backward() |
| 93 | + >>> x.grad # grad is [nan, 1], but expected [0, 1] |
| 94 | + tensor([nan, 1.]) |
| 95 | + |
| 96 | +:class:`MaskedTensor` result: |
| 97 | + |
| 98 | + >>> x = torch.tensor([1., 1.], requires_grad=True) |
| 99 | + >>> div = torch.tensor([0., 1.]) |
| 100 | + >>> y = x/div # => y is [inf, 1] |
| 101 | + >>> |
| 102 | + >>> mask = (div != 0) # => mask is [0, 1] |
| 103 | + >>> loss = as_masked_tensor(y, mask) |
| 104 | + >>> loss.sum().backward() |
| 105 | + >>> x.grad |
| 106 | + MaskedTensor( |
| 107 | + [ --, 1.0000] |
| 108 | + ) |
0 commit comments