Skip to content

Commit c264833

Browse files
committed
[maskedtensor] Overview tutorial [1/4]
1 parent 04e1ba9 commit c264833

File tree

2 files changed

+320
-0
lines changed

2 files changed

+320
-0
lines changed
Lines changed: 310 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,310 @@
1+
# -*- coding: utf-8 -*-
2+
3+
"""
4+
(Prototype) MaskedTensor Overview
5+
=================================
6+
"""
7+
8+
######################################################################
9+
# This tutorial is designed to serve as a starting point for using MaskedTensors
10+
# and discuss its masking semantics.
11+
#
12+
# MaskedTensor serves as an extension to :class:`torch.Tensor` that provides the user with the ability to:
13+
#
14+
# * use any masked semantics (for example, variable length tensors, nan* operators, etc.)
15+
# * differentiation between 0 and NaN gradients
16+
# * various sparse applications (see tutorial below)
17+
#
18+
# For a more detailed introduction on what MaskedTensors are, please find the
19+
# `torch.masked documentation <https://pytorch.org/docs/master/masked.html>`__.
20+
#
21+
# Using MaskedTensor
22+
# ++++++++++++++++++
23+
#
24+
# Construction
25+
# ------------
26+
#
27+
# There are a few different ways to construct a MaskedTensor:
28+
#
29+
# * The first way is to directly invoke the MaskedTensor class
30+
# * The second (and our recommended way) is to use :func:`masked.masked_tensor` and :func:`masked.as_masked_tensor`
31+
# factory functions, which are analogous to :func:`torch.tensor` and :func:`torch.as_tensor`
32+
#
33+
# Throughout this tutorial, we will be assuming the import line: `from torch.masked import masked_tensor`.
34+
#
35+
# Accessing the data and mask
36+
# ---------------------------
37+
#
38+
# The underlying fields in a MaskedTensor can be accessed through:
39+
#
40+
# * the :meth:`MaskedTensor.get_data` function
41+
# * the :meth:`MaskedTensor.get_mask` function. Recall that ``True`` indicates "specified" or "valid"
42+
# while ``False`` indicates "unspecified" or "invalid".
43+
#
44+
# In general, the underlying data that is returned may not be valid in the unspecified entries, so we recommend that
45+
# when users require a Tensor without any masked entries, that they use :meth:`MaskedTensor.to_tensor` (as shown above) to
46+
# return a Tensor with filled values.
47+
#
48+
# Indexing and slicing
49+
# --------------------
50+
#
51+
# :class:`MaskedTensor` is a Tensor subclass, which means that it inherits the same semantics for indexing and slicing
52+
# as :class:`torch.Tensor`. Below are some examples of common indexing and slicing patterns:
53+
#
54+
55+
import torch
56+
from torch.masked import masked_tensor
57+
58+
data = torch.arange(24).reshape(2, 3, 4)
59+
mask = data % 2 == 0
60+
61+
print("data\n", data)
62+
print("mask\n", mask)
63+
64+
# float is used for cleaner visualization when being printed
65+
mt = masked_tensor(data.float(), mask)
66+
67+
print ("mt[0]:\n", mt[0])
68+
print ("mt[:, :, 2:4]", mt[:, :, 2:4])
69+
70+
######################################################################
71+
# Why is MaskedTensor useful?
72+
# +++++++++++++++++++++++++++
73+
#
74+
# Because of :class:`MaskedTensor`'s treatment of specified and unspecified values as a first-class citizen
75+
# instead of an afterthought (with filled values, nans, etc.), it is able to solve for several of the shortcomings
76+
# that regular Tensors are unable to; indeed, :class:`MaskedTensor` was born in a large part due to these recurring issues.
77+
#
78+
# Below, we will discuss some of the most common issues that are still unresolved in PyTorch today
79+
# and illustrate how :class:`MaskedTensor` can solve these problems.
80+
#
81+
# Distinguishing between 0 and NaN gradient
82+
# -----------------------------------------
83+
#
84+
# One issue that :class:`torch.Tensor` runs into is the inability to distinguish between gradients that are
85+
# undefined (NaN) vs. gradients that are actually 0. Because PyTorch does not have a way of marking a value
86+
# as specified/valid vs. unspecified/invalid, it is forced to rely on NaN or 0 (depending on the use case), leading
87+
# to unreliable semantics since many operations aren't meant to handle NaN values properly. What is even more confusing
88+
# is that sometimes depending on the order of operations, the gradient could vary (for example, depending on how early
89+
# in the chain of operations a NaN value manifests).
90+
#
91+
# :class:`MaskedTensor` is the perfect solution for this!
92+
#
93+
# :func:`torch.where`
94+
# ^^^^^^^^^^^^^^^^^^^
95+
#
96+
# In `Issue 10729 <https://github.com/pytorch/pytorch/issues/10729>`__, we notice a case where the order of operations
97+
# can matter when using :func:`torch.where` because we have trouble differentiating between if the 0 is a real 0
98+
# or one from undefined gradients. Therefore, we remain consistent and mask out the results:
99+
#
100+
# Current result:
101+
#
102+
103+
x = torch.tensor([-10., -5, 0, 5, 10, 50, 60, 70, 80, 90, 100], requires_grad=True, dtype=torch.float)
104+
y = torch.where(x < 0, torch.exp(x), torch.ones_like(x))
105+
y.sum().backward()
106+
x.grad
107+
108+
######################################################################
109+
# :class:`MaskedTensor` result:
110+
#
111+
112+
x = torch.tensor([-10., -5, 0, 5, 10, 50, 60, 70, 80, 90, 100])
113+
mask = x < 0
114+
mx = masked_tensor(x, mask, requires_grad=True)
115+
my = masked_tensor(torch.ones_like(x), ~mask, requires_grad=True)
116+
y = torch.where(mask, torch.exp(mx), my)
117+
y.sum().backward()
118+
mx.grad
119+
120+
######################################################################
121+
# The gradient here is only provided to the selected subset. Effectively, this changes the gradient of `where`
122+
# to mask out elements instead of setting them to zero.
123+
#
124+
# Another :func:`torch.where`
125+
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^
126+
#
127+
# `Issue 52248 <https://github.com/pytorch/pytorch/issues/52248>`__ is another example.
128+
#
129+
# Current result:
130+
#
131+
132+
a = torch.randn((), requires_grad=True)
133+
b = torch.tensor(False)
134+
c = torch.ones(())
135+
print("torch.where(b, a/0, c):\n", torch.where(b, a/0, c))
136+
print("torch.autograd.grad(torch.where(b, a/0, c), a):\n", torch.autograd.grad(torch.where(b, a/0, c), a))
137+
138+
######################################################################
139+
# :class:`MaskedTensor` result:
140+
#
141+
142+
a = masked_tensor(torch.randn(()), torch.tensor(True), requires_grad=True)
143+
b = torch.tensor(False)
144+
c = torch.ones(())
145+
print("torch.where(b, a/0, c):\n", torch.where(b, a/0, c))
146+
print("torch.autograd.grad(torch.where(b, a/0, c), a):\n", torch.autograd.grad(torch.where(b, a/0, c), a))
147+
148+
######################################################################
149+
# This issue is similar (and even links to the next issue below) in that it expresses frustration with
150+
# unexpected behavior because of the inability to differentiate "no gradient" vs "zero gradient",
151+
# which in turn makes working with other ops difficult to reason about.
152+
#
153+
# When using mask, x/0 yields NaN grad
154+
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
155+
#
156+
# In `Issue 4132 <https://github.com/pytorch/pytorch/issues/4132>`__, the user proposes that
157+
# `x.grad` should be `[0, 1]` instead of the `[nan, 1]`,
158+
# whereas :class:`MaskedTensor` makes this very clear by masking out the gradient altogether.
159+
#
160+
# Current result:
161+
#
162+
163+
x = torch.tensor([1., 1.], requires_grad=True)
164+
div = torch.tensor([0., 1.])
165+
y = x/div # => y is [inf, 1]
166+
mask = (div != 0) # => mask is [0, 1]
167+
y[mask].backward()
168+
x.grad
169+
170+
######################################################################
171+
# :class:`MaskedTensor` result:
172+
#
173+
174+
x = torch.tensor([1., 1.], requires_grad=True)
175+
div = torch.tensor([0., 1.])
176+
y = x/div # => y is [inf, 1]
177+
>>>
178+
mask = (div != 0) # => mask is [0, 1]
179+
loss = as_masked_tensor(y, mask)
180+
loss.sum().backward()
181+
x.grad
182+
183+
######################################################################
184+
# :func:`torch.nansum` and :func:`torch.nanmean`
185+
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
186+
#
187+
# In `Issue 67180 <https://github.com/pytorch/pytorch/issues/67180>`__,
188+
# the gradient isn't calculate properly (a longstanding issue), whereas :class:`MaskedTensor` handles it correctly.
189+
#
190+
# Current result:
191+
#
192+
193+
a = torch.tensor([1., 2., float('nan')])
194+
b = torch.tensor(1.0, requires_grad=True)
195+
c = a * b
196+
c1 = torch.nansum(c)
197+
bgrad1, = torch.autograd.grad(c1, b, retain_graph=True)
198+
bgrad1
199+
200+
######################################################################
201+
# :class:`MaskedTensor` result:
202+
#
203+
204+
a = torch.tensor([1., 2., float('nan')])
205+
b = torch.tensor(1.0, requires_grad=True)
206+
mt = masked_tensor(a, ~torch.isnan(a))
207+
c = mt * b
208+
c1 = torch.sum(c)
209+
bgrad1, = torch.autograd.grad(c1, b, retain_graph=True)
210+
bgrad1
211+
212+
######################################################################
213+
# Safe Softmax
214+
# ------------
215+
#
216+
# Safe softmax is another great example of `an issue <https://github.com/pytorch/pytorch/issues/55056>`_
217+
# that arises frequently. In a nutshell, if there is an entire batch that is "masked out"
218+
# or consists entirely of padding (which, in the softmax case, translates to being set `-inf`),
219+
# then this will result in NaNs, which can lead to training divergence.
220+
#
221+
# Luckily, :class:`MaskedTensor` has solved this issue. Consider this setup:
222+
#
223+
224+
data = torch.randn(3, 3)
225+
mask = torch.tensor([[True, False, False], [True, False, True], [False, False, False]])
226+
x = data.masked_fill(~mask, float('-inf'))
227+
mt = masked_tensor(data, mask)
228+
print("x:\n", x)
229+
print("mt:\n", mt)
230+
231+
######################################################################
232+
# For example, we want to calculate the softmax along `dim=0`. Note that the second column is "unsafe" (i.e. entirely
233+
# masked out), so when the softmax is calculated, the result will yield `0/0 = nan` since `exp(-inf) = 0`.
234+
# However, what we would really like is for the gradients to be masked out since they are unspecified and would be
235+
# invalid for training.
236+
#
237+
# PyTorch result:
238+
#
239+
240+
x.softmax(0)
241+
242+
######################################################################
243+
# :class:`MaskedTensor` result:
244+
#
245+
246+
mt.softmax(0)
247+
248+
######################################################################
249+
# Implementing missing torch.nan* operators
250+
# --------------------------------------------------------------------------------------------------------------
251+
#
252+
# In `Issue 61474 <<https://github.com/pytorch/pytorch/issues/61474>`__,
253+
# there is a request to add additional operators to cover the various `torch.nan*` applications,
254+
# such as ``torch.nanmax``, ``torch.nanmin``, etc.
255+
#
256+
# In general, these problems lend themselves more naturally to masked semantics, so instead of introducing additional
257+
# operators, we propose using :class:`MaskedTensor`s instead. Since
258+
# `nanmean has already landed <https://github.com/pytorch/pytorch/issues/21987>`_, we can use it as a comparison point:
259+
#
260+
261+
x = torch.arange(16).float()
262+
y = x * x.fmod(4)
263+
z = y.masked_fill(y == 0, float('nan')) # we want to get the mean of y when ignoring the zeros
264+
265+
print("y:\n, y")
266+
# z is just y with the zeros replaced with nan's
267+
print("z:\n", z)
268+
print("y.mean():\n", y.mean())
269+
print("z.nanmean():\n", z.nanmean())
270+
# MaskedTensor successfully ignores the 0's
271+
print("torch.mean(masked_tensor(y, y != 0)):\n", torch.mean(masked_tensor(y, y != 0)))
272+
273+
######################################################################
274+
# In the above example, we've constructed a `y` and would like to calculate the mean of the series while ignoring
275+
# the zeros. `torch.nanmean` can be used to do this, but we don't have implementations for the rest of the
276+
# `torch.nan*` operations. :class:`MaskedTensor` solves this issue by being able to use the base operation,
277+
# and we already have support for the other operations listed in the issue. For example:
278+
#
279+
280+
torch.argmin(masked_tensor(y, y != 0))
281+
282+
######################################################################
283+
# Indeed, the index of the minimum argument when ignoring the 0's is the 1 in index 1.
284+
#
285+
# :class:`MaskedTensor` can also support reductions when the data is fully masked out, which is equivalent
286+
# to the case above when the data Tensor is completely ``nan``. ``nanmean`` would return ``nan``
287+
# (an ambiguous return value), while MaskedTensor would more accurately indicate a masked out result.
288+
#
289+
290+
x = torch.empty(16).fill_(float('nan'))
291+
print("x:\n", x)
292+
print("torch.nanmean(x):\n", torch.nanmean(x))
293+
print("torch.nanmean via maskedtensor:\n", torch.mean(masked_tensor(x, ~torch.isnan(x))))
294+
295+
######################################################################
296+
# This is a similar problem to safe softmax where `0/0 = nan` when what we really want is an undefined value.
297+
#
298+
# Conclusion
299+
# ++++++++++
300+
#
301+
# In this tutorial, we've introduced what MaskedTensors are, demonstrated how to use them, and motivated their
302+
# value through a series of examples and issues that they've helped resolve.
303+
#
304+
# Further Reading
305+
# +++++++++++++++
306+
#
307+
# To continue learning more, you can find our
308+
# `MaskedTensor Sparsity tutorial <https://pytorch.org/tutorials/prototype/maskedtensor_sparsity.html>`__
309+
# to see how MaskedTensor enables sparsity and the different storage formats we currently support.
310+
#

prototype_source/prototype_index.rst

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -141,6 +141,15 @@ Prototype features are not available as part of binary distributions like PyPI o
141141
:link: ../prototype/nestedtensor.html
142142
:tags: NestedTensor
143143

144+
.. MaskedTensor
145+
146+
.. customcarditem::
147+
:header: MaskedTensor Overview
148+
:card_description: Learn about masked tensors, the source of truth for specified and unspecified values
149+
:image: ../_static/img/thumbnails/cropped/generic-pytorch-logo.png
150+
:link: ../prototype/maskedtensor_overview.html
151+
:tags: MaskedTensor
152+
144153
.. End of tutorial card section
145154
146155
.. raw:: html
@@ -172,3 +181,4 @@ Prototype features are not available as part of binary distributions like PyPI o
172181
prototype/vmap_recipe.html
173182
prototype/vulkan_workflow.html
174183
prototype/nestedtensor.html
184+
prototype/maskedtensor_overview.html

0 commit comments

Comments
 (0)