|
2 | 2 |
|
3 | 3 | """
|
4 | 4 | (Prototype) MaskedTensor Overview
|
5 |
| -================================= |
| 5 | +********************************* |
6 | 6 | """
|
7 | 7 |
|
8 | 8 | ######################################################################
|
|
14 | 14 | # * use any masked semantics (for example, variable length tensors, nan* operators, etc.)
|
15 | 15 | # * differentiation between 0 and NaN gradients
|
16 | 16 | # * various sparse applications (see tutorial below)
|
17 |
| -# |
| 17 | +# |
18 | 18 | # For a more detailed introduction on what MaskedTensors are, please find the
|
19 | 19 | # `torch.masked documentation <https://pytorch.org/docs/master/masked.html>`__.
|
20 |
| -# |
| 20 | +# |
21 | 21 | # Using MaskedTensor
|
22 |
| -# ++++++++++++++++++ |
| 22 | +# ================== |
| 23 | +# |
| 24 | +# In this section we discuss how to use MaskedTensor including how to construct, access, the data |
| 25 | +# and mask, as well as indexing and slicing. |
| 26 | +# |
| 27 | +# Preparation |
| 28 | +# ----------- |
23 | 29 | #
|
| 30 | +# We'll begin by doing the necessary setup for the tutorial: |
| 31 | +# |
| 32 | + |
| 33 | +import torch |
| 34 | +from torch.masked import masked_tensor, as_masked_tensor |
| 35 | +import warnings |
| 36 | + |
| 37 | +# Disable prototype warnings and such |
| 38 | +warnings.filterwarnings(action='ignore', category=UserWarning) |
| 39 | + |
| 40 | +###################################################################### |
24 | 41 | # Construction
|
25 | 42 | # ------------
|
26 |
| -# |
| 43 | +# |
27 | 44 | # There are a few different ways to construct a MaskedTensor:
|
28 | 45 | #
|
29 | 46 | # * The first way is to directly invoke the MaskedTensor class
|
|
52 | 69 | # as :class:`torch.Tensor`. Below are some examples of common indexing and slicing patterns:
|
53 | 70 | #
|
54 | 71 |
|
55 |
| -import torch |
56 |
| -from torch.masked import masked_tensor, as_masked_tensor |
57 |
| - |
58 | 72 | data = torch.arange(24).reshape(2, 3, 4)
|
59 | 73 | mask = data % 2 == 0
|
60 | 74 |
|
61 |
| -print("data\n", data) |
62 |
| -print("mask\n", mask) |
| 75 | +print("data:\n", data) |
| 76 | +print("mask:\n", mask) |
| 77 | + |
| 78 | +###################################################################### |
| 79 | +# |
63 | 80 |
|
64 | 81 | # float is used for cleaner visualization when being printed
|
65 | 82 | mt = masked_tensor(data.float(), mask)
|
66 | 83 |
|
67 |
| -print ("mt[0]:\n", mt[0]) |
68 |
| -print ("mt[:, :, 2:4]", mt[:, :, 2:4]) |
| 84 | +print("mt[0]:\n", mt[0]) |
| 85 | +print("mt[:, :, 2:4]:\n", mt[:, :, 2:4]) |
69 | 86 |
|
70 | 87 | ######################################################################
|
71 | 88 | # Why is MaskedTensor useful?
|
72 |
| -# +++++++++++++++++++++++++++ |
| 89 | +# =========================== |
73 | 90 | #
|
74 | 91 | # Because of :class:`MaskedTensor`'s treatment of specified and unspecified values as a first-class citizen
|
75 | 92 | # instead of an afterthought (with filled values, nans, etc.), it is able to solve for several of the shortcomings
|
|
90 | 107 | #
|
91 | 108 | # :class:`MaskedTensor` is the perfect solution for this!
|
92 | 109 | #
|
93 |
| -# :func:`torch.where` |
94 |
| -# ^^^^^^^^^^^^^^^^^^^ |
| 110 | +# torch.where |
| 111 | +# ^^^^^^^^^^^ |
95 | 112 | #
|
96 | 113 | # In `Issue 10729 <https://github.com/pytorch/pytorch/issues/10729>`__, we notice a case where the order of operations
|
97 | 114 | # can matter when using :func:`torch.where` because we have trouble differentiating between if the 0 is a real 0
|
|
121 | 138 | # The gradient here is only provided to the selected subset. Effectively, this changes the gradient of `where`
|
122 | 139 | # to mask out elements instead of setting them to zero.
|
123 | 140 | #
|
124 |
| -# Another :func:`torch.where` |
125 |
| -# ^^^^^^^^^^^^^^^^^^^^^^^^^^^ |
| 141 | +# Another torch.where |
| 142 | +# ^^^^^^^^^^^^^^^^^^^ |
126 | 143 | #
|
127 | 144 | # `Issue 52248 <https://github.com/pytorch/pytorch/issues/52248>`__ is another example.
|
128 | 145 | #
|
|
174 | 191 | x = torch.tensor([1., 1.], requires_grad=True)
|
175 | 192 | div = torch.tensor([0., 1.])
|
176 | 193 | y = x/div # => y is [inf, 1]
|
177 |
| - >>> |
178 | 194 | mask = (div != 0) # => mask is [0, 1]
|
179 | 195 | loss = as_masked_tensor(y, mask)
|
180 | 196 | loss.sum().backward()
|
181 | 197 | x.grad
|
182 | 198 |
|
183 | 199 | ######################################################################
|
184 | 200 | # :func:`torch.nansum` and :func:`torch.nanmean`
|
185 |
| -# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ |
| 201 | +# ---------------------------------------------- |
186 | 202 | #
|
187 | 203 | # In `Issue 67180 <https://github.com/pytorch/pytorch/issues/67180>`__,
|
188 | 204 | # the gradient isn't calculate properly (a longstanding issue), whereas :class:`MaskedTensor` handles it correctly.
|
|
213 | 229 | # Safe Softmax
|
214 | 230 | # ------------
|
215 | 231 | #
|
216 |
| -# Safe softmax is another great example of `an issue <https://github.com/pytorch/pytorch/issues/55056>`_ |
| 232 | +# Safe softmax is another great example of `an issue <https://github.com/pytorch/pytorch/issues/55056>`__ |
217 | 233 | # that arises frequently. In a nutshell, if there is an entire batch that is "masked out"
|
218 | 234 | # or consists entirely of padding (which, in the softmax case, translates to being set `-inf`),
|
219 | 235 | # then this will result in NaNs, which can lead to training divergence.
|
|
247 | 263 |
|
248 | 264 | ######################################################################
|
249 | 265 | # Implementing missing torch.nan* operators
|
250 |
| -# -------------------------------------------------------------------------------------------------------------- |
| 266 | +# ----------------------------------------- |
251 | 267 | #
|
252 |
| -# In `Issue 61474 <<https://github.com/pytorch/pytorch/issues/61474>`__, |
| 268 | +# In `Issue 61474 <https://github.com/pytorch/pytorch/issues/61474>`__, |
253 | 269 | # there is a request to add additional operators to cover the various `torch.nan*` applications,
|
254 | 270 | # such as ``torch.nanmax``, ``torch.nanmin``, etc.
|
255 | 271 | #
|
256 | 272 | # In general, these problems lend themselves more naturally to masked semantics, so instead of introducing additional
|
257 |
| -# operators, we propose using :class:`MaskedTensor`s instead. Since |
258 |
| -# `nanmean has already landed <https://github.com/pytorch/pytorch/issues/21987>`_, we can use it as a comparison point: |
| 273 | +# operators, we propose using :class:`MaskedTensor` instead. |
| 274 | +# Since `nanmean has already landed <https://github.com/pytorch/pytorch/issues/21987>`__, |
| 275 | +# we can use it as a comparison point: |
259 | 276 | #
|
260 | 277 |
|
261 | 278 | x = torch.arange(16).float()
|
262 | 279 | y = x * x.fmod(4)
|
263 | 280 | z = y.masked_fill(y == 0, float('nan')) # we want to get the mean of y when ignoring the zeros
|
264 | 281 |
|
265 |
| -print("y:\n, y") |
| 282 | +###################################################################### |
| 283 | +# |
| 284 | +print("y:\n", y) |
266 | 285 | # z is just y with the zeros replaced with nan's
|
267 | 286 | print("z:\n", z)
|
| 287 | + |
| 288 | +###################################################################### |
| 289 | +# |
| 290 | + |
268 | 291 | print("y.mean():\n", y.mean())
|
269 | 292 | print("z.nanmean():\n", z.nanmean())
|
270 | 293 | # MaskedTensor successfully ignores the 0's
|
|
296 | 319 | # This is a similar problem to safe softmax where `0/0 = nan` when what we really want is an undefined value.
|
297 | 320 | #
|
298 | 321 | # Conclusion
|
299 |
| -# ++++++++++ |
| 322 | +# ========== |
300 | 323 | #
|
301 | 324 | # In this tutorial, we've introduced what MaskedTensors are, demonstrated how to use them, and motivated their
|
302 | 325 | # value through a series of examples and issues that they've helped resolve.
|
303 | 326 | #
|
304 | 327 | # Further Reading
|
305 |
| -# +++++++++++++++ |
| 328 | +# =============== |
306 | 329 | #
|
307 | 330 | # To continue learning more, you can find our
|
308 | 331 | # `MaskedTensor Sparsity tutorial <https://pytorch.org/tutorials/prototype/maskedtensor_sparsity.html>`__
|
|
0 commit comments