Skip to content

Commit 1a0de45

Browse files
committed
[maskedtensor] Add missing nan ops tutorial
ghstack-source-id: 69dd2dd Pull Request resolved: #2046
1 parent 33fc9d9 commit 1a0de45

File tree

1 file changed

+32
-0
lines changed

1 file changed

+32
-0
lines changed
Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
Implementing missing torch.nan* operators
2+
-----------------------------------------
3+
4+
In the above issue, there is a request to add additional operators to cover the various `torch.nan*` applications,
5+
such as ``torch.nanmax``, ``torch.nanmin``, etc.
6+
7+
In general, these problems lend themselves more naturally to masked semantics, so instead of introducing additional
8+
operators, we propose using MaskedTensors instead. Since
9+
`nanmean has already landed <https://github.com/pytorch/pytorch/issues/21987>`__, we can use it as a comparison point:
10+
11+
>>> x = torch.arange(16).float()
12+
>>> y = x * x.fmod(4)
13+
>>> y = y.masked_fill(y ==0, float('nan'))
14+
>>> y
15+
tensor([nan, 1., 4., 9., nan, 5., 12., 21., nan, 9., 20., 33., nan, 13.,
16+
28., 45.])
17+
>>> y.nanmean()
18+
tensor(16.6667)
19+
>>> torch.mean(masked_tensor(y, ~torch.isnan(y)))
20+
MaskedTensor( 16.6667, True)
21+
22+
:class:`MaskedTensor` can also support reductions when the data is fully masked out, which is equivalent
23+
to the case above when the data Tensor is completely ``nan``. ``nanmean`` would return ``nan``
24+
(an ambiguous return value), while MaskedTensor would more accurately indicate a masked out result.
25+
26+
>>> x = torch.empty(16).fill_(float('nan'))
27+
>>> x
28+
tensor([nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan])
29+
>>> torch.nanmean(x)
30+
tensor(nan)
31+
>>> torch.mean(masked_tensor(x, ~torch.isnan(x)))
32+
MaskedTensor(--, False)

0 commit comments

Comments
 (0)