|
| 1 | +Implementing missing torch.nan* operators |
| 2 | +----------------------------------------- |
| 3 | + |
| 4 | +In the above issue, there is a request to add additional operators to cover the various `torch.nan*` applications, |
| 5 | +such as ``torch.nanmax``, ``torch.nanmin``, etc. |
| 6 | + |
| 7 | +In general, these problems lend themselves more naturally to masked semantics, so instead of introducing additional |
| 8 | +operators, we propose using MaskedTensors instead. Since |
| 9 | +`nanmean has already landed <https://github.com/pytorch/pytorch/issues/21987>`__, we can use it as a comparison point: |
| 10 | + |
| 11 | + >>> x = torch.arange(16).float() |
| 12 | + >>> y = x * x.fmod(4) |
| 13 | + >>> y = y.masked_fill(y ==0, float('nan')) |
| 14 | + >>> y |
| 15 | + tensor([nan, 1., 4., 9., nan, 5., 12., 21., nan, 9., 20., 33., nan, 13., |
| 16 | + 28., 45.]) |
| 17 | + >>> y.nanmean() |
| 18 | + tensor(16.6667) |
| 19 | + >>> torch.mean(masked_tensor(y, ~torch.isnan(y))) |
| 20 | + MaskedTensor( 16.6667, True) |
| 21 | + |
| 22 | +:class:`MaskedTensor` can also support reductions when the data is fully masked out, which is equivalent |
| 23 | +to the case above when the data Tensor is completely ``nan``. ``nanmean`` would return ``nan`` |
| 24 | +(an ambiguous return value), while MaskedTensor would more accurately indicate a masked out result. |
| 25 | + |
| 26 | + >>> x = torch.empty(16).fill_(float('nan')) |
| 27 | + >>> x |
| 28 | + tensor([nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan]) |
| 29 | + >>> torch.nanmean(x) |
| 30 | + tensor(nan) |
| 31 | + >>> torch.mean(masked_tensor(x, ~torch.isnan(x))) |
| 32 | + MaskedTensor(--, False) |
0 commit comments