From 4eaaf669bc46cadbde4e1e4a2306e95e3d9fe0a5 Mon Sep 17 00:00:00 2001 From: George Qi Date: Tue, 20 Sep 2022 00:41:02 +0000 Subject: [PATCH 1/6] [maskedtensor] Add overview tutorial [ghstack-poisoned] --- beginner_source/maskedtensor_overview.rst | 244 ++++++++++++++++++++++ index.rst | 9 + 2 files changed, 253 insertions(+) create mode 100644 beginner_source/maskedtensor_overview.rst diff --git a/beginner_source/maskedtensor_overview.rst b/beginner_source/maskedtensor_overview.rst new file mode 100644 index 00000000000..068cd1668cb --- /dev/null +++ b/beginner_source/maskedtensor_overview.rst @@ -0,0 +1,244 @@ +MaskedTensor Overview +===================== + +This tutorial is designed to serve as a starting point for using MaskedTensors +and discuss its masking semantics. + +Using MaskedTensor +++++++++++++++++++ + +Construction +------------ + +There are a few different ways to construct a MaskedTensor: + +* The first way is to directly invoke the MaskedTensor class +* The second (and our recommended way) is to use :func:`masked.masked_tensor` and :func:`masked.as_masked_tensor` factory functions, + which are analogous to :func:`torch.tensor` and :func:`torch.as_tensor` + + .. autosummary:: + :toctree: generated + :nosignatures: + + masked.masked_tensor + masked.as_masked_tensor + +Accessing the data and mask +--------------------------- + +The underlying fields in a MaskedTensor can be accessed through: + +* the :meth:`MaskedTensor.get_data` function +* the :meth:`MaskedTensor.get_mask` function. Recall that ``True`` indicates "specified" or "valid" while ``False`` indicates + "unspecified" or "invalid". + +In general, the underlying data that is returned may not be valid in the unspecified entries, so we recommend that +when users require a Tensor without any masked entries, that they use :meth:`MaskedTensor.to_tensor` (as shown above) to +return a Tensor with filled values. + +Indexing and slicing +-------------------- + +:class:`MaskedTensor` is a Tensor subclass, which means that it inherits the same semantics for indexing and slicing +as :class:`torch.Tensor`. Below are some examples of common indexing and slicing patterns: + + >>> data = torch.arange(60).reshape(3, 4, 5) + >>> mask = data % 2 == 0 + >>> mt = masked_tensor(data.float(), mask) + >>> mt[0] + MaskedTensor( + [ + [ 0.0000, --, 2.0000, --, 4.0000], + [ --, 6.0000, --, 8.0000, --], + [ 10.0000, --, 12.0000, --, 14.0000], + [ --, 16.0000, --, 18.0000, --] + ] + ) + >>> mt[[0,2]] + MaskedTensor( + [ + [ + [ 0.0000, --, 2.0000, --, 4.0000], + [ --, 6.0000, --, 8.0000, --], + [ 10.0000, --, 12.0000, --, 14.0000], + [ --, 16.0000, --, 18.0000, --] + ], + [ + [ 40.0000, --, 42.0000, --, 44.0000], + [ --, 46.0000, --, 48.0000, --], + [ 50.0000, --, 52.0000, --, 54.0000], + [ --, 56.0000, --, 58.0000, --] + ] + ] + ) + >>> mt[:, :2] + MaskedTensor( + [ + [ + [ 0.0000, --, 2.0000, --, 4.0000], + [ --, 6.0000, --, 8.0000, --] + ], + [ + [ 20.0000, --, 22.0000, --, 24.0000], + [ --, 26.0000, --, 28.0000, --] + ], + [ + [ 40.0000, --, 42.0000, --, 44.0000], + [ --, 46.0000, --, 48.0000, --] + ] + ] + ) + +Semantics ++++++++++ + +MaskedTensor vs NumPy's MaskedArray +----------------------------------- + +NumPy's ``MaskedArray`` has a few fundamental semantics differences from MaskedTensor. + +1. Their factory function and basic definition inverts the mask (similar to ``torch.nn.MHA``); that is, MaskedTensor +uses ``True`` to denote "specified" and ``False`` to denote "unspecified", or "valid"/"invalid", whereas NumPy does the +opposite. +2. Intersection semantics. In NumPy, if one of two elements are masked out, the resulting element will be +masked out as well -- in practice, they +`apply the logical_or operator `__. + + >>> data = torch.arange(5.) + >>> mask = torch.tensor([True, True, False, True, False]) + >>> npm0 = np.ma.masked_array(data.numpy(), (~mask).numpy()) + >>> npm1 = np.ma.masked_array(data.numpy(), (mask).numpy()) + >>> npm0 + masked_array(data=[0.0, 1.0, --, 3.0, --], + mask=[False, False, True, False, True], + fill_value=1e+20, + dtype=float32) + >>> npm1 + masked_array(data=[--, --, 2.0, --, 4.0], + mask=[ True, True, False, True, False], + fill_value=1e+20, + dtype=float32) + >>> npm0 + npm1 + masked_array(data=[--, --, --, --, --], + mask=[ True, True, True, True, True], + fill_value=1e+20, + dtype=float32) + +Meanwhile, MaskedTensor does not support addition or binary operators with masks that don't match -- to understand why, +please find the section on reductions. + + >>> mt0 = masked_tensor(data, mask) + >>> mt1 = masked_tensor(data, ~mask) + >>> m0 + MaskedTensor( + [ 0.0000, 1.0000, --, 3.0000, --] + ) + >>> mt0 = masked_tensor(data, mask) + >>> mt1 = masked_tensor(data, ~mask) + >>> mt0 + MaskedTensor( + [ 0.0000, 1.0000, --, 3.0000, --] + ) + >>> mt1 + MaskedTensor( + [ --, --, 2.0000, --, 4.0000] + ) + >>> mt0 + mt1 + ValueError: Input masks must match. If you need support for this, please open an issue on Github. + +However, if this behavior is desired, MaskedTensor does support these semantics by giving access to the data and masks +and conveniently converting a MaskedTensor to a Tensor with masked values filled in using :func:`to_tensor`. + + >>> t0 = mt0.to_tensor(0) + >>> t1 = mt1.to_tensor(0) + >>> mt2 = masked_tensor(t0 + t1, mt0.get_mask() & mt1.get_mask()) + >>> t0 + tensor([0., 1., 0., 3., 0.]) + >>> t1 + tensor([0., 0., 2., 0., 4.]) + >>> mt2 + MaskedTensor( + [ --, --, --, --, --] + +.. _reduction-semantics: + +Reduction semantics +------------------- + +The basis for reduction semantics `has been documented and discussed at length `__, +but again, by way of example: + + >>> data = torch.arange(12, dtype=torch.float).reshape(3, 4) + >>> mask = torch.randint(2, (3, 4), dtype=torch.bool) + >>> mt = masked_tensor(data, mask) + >>> mt + MaskedTensor( + [ + [ --, 1.0000, --, --], + [ --, 5.0000, 6.0000, 7.0000], + [ 8.0000, 9.0000, --, 11.0000] + ] + ) + + >>> torch.sum(mt, 1) + MaskedTensor( + [ 1.0000, 18.0000, 28.0000] + ) + >>> torch.mean(mt, 1) + MaskedTensor( + [ 1.0000, 6.0000, 9.3333] + ) + >>> torch.prod(mt, 1) + MaskedTensor( + [ 1.0000, 210.0000, 792.0000] + ) + >>> torch.amin(mt, 1) + MaskedTensor( + [ 1.0000, 5.0000, 8.0000] + ) + >>> torch.amax(mt, 1) + MaskedTensor( + [ 1.0000, 7.0000, 11.0000] + ) + +Now we can revisit the question: why do we enforce the invariant that masks must match for binary operators? +In other words, why don't we use the same semantics as ``np.ma.masked_array``? Consider the following example: + + >>> data0 = torch.arange(10.).reshape(2, 5) + >>> data1 = torch.arange(10.).reshape(2, 5) + 10 + >>> mask0 = torch.tensor([[True, True, False, False, False], [False, False, False, True, True]]) + >>> mask1 = torch.tensor([[False, False, False, True, True], [True, True, False, False, False]]) + + >>> npm0 = np.ma.masked_array(data0.numpy(), (mask0).numpy()) + >>> npm1 = np.ma.masked_array(data1.numpy(), (mask1).numpy()) + >>> npm0 + masked_array( + data=[[--, --, 2.0, 3.0, 4.0], + [5.0, 6.0, 7.0, --, --]], + mask=[[ True, True, False, False, False], + [False, False, False, True, True]], + fill_value=1e+20, + dtype=float32) + >>> npm1 + masked_array( + data=[[10.0, 11.0, 12.0, --, --], + [--, --, 17.0, 18.0, 19.0]], + mask=[[False, False, False, True, True], + [ True, True, False, False, False]], + fill_value=1e+20, + dtype=float32) + >>> (npm0 + npm1).sum(0) + masked_array(data=[--, --, 38.0, --, --], + mask=[ True, True, False, True, True], + fill_value=1e+20, + dtype=float32) + >>> npm0.sum(0) + npm1.sum(0) + masked_array(data=[15.0, 17.0, 38.0, 21.0, 23.0], + mask=[False, False, False, False, False], + fill_value=1e+20, + dtype=float32) + +Sum and addition should clearly be associative, but with NumPy's semantics, they are allowed to not be, +which can certainly be confusing for the user. That being said, if the user wishes, there are ways around this +(e.g. filling in the MaskedTensor's undefined elements with 0 values using :func:`to_tensor` as shown in a previous +example), but the user must now be more explicit with their intentions. diff --git a/index.rst b/index.rst index 89f04219d87..29e4d62fe04 100644 --- a/index.rst +++ b/index.rst @@ -804,6 +804,15 @@ Additional Resources beginner/translation_transformer +.. toctree:: + :maxdepth: 2 + :includehidden: + :hidden: + :caption: MaskedTensor + + beginner/maskedtensor_overview + + .. toctree:: :maxdepth: 2 :includehidden: From f523304162921751b325f7c101b81e2dd213c258 Mon Sep 17 00:00:00 2001 From: George Qi Date: Tue, 20 Sep 2022 00:41:19 +0000 Subject: [PATCH 2/6] [maskedtensor] Add sparsity tutorial [ghstack-poisoned] --- beginner_source/maskedtensor_sparsity.rst | 218 ++++++++++++++++++++++ 1 file changed, 218 insertions(+) create mode 100644 beginner_source/maskedtensor_sparsity.rst diff --git a/beginner_source/maskedtensor_sparsity.rst b/beginner_source/maskedtensor_sparsity.rst new file mode 100644 index 00000000000..f374e8edca7 --- /dev/null +++ b/beginner_source/maskedtensor_sparsity.rst @@ -0,0 +1,218 @@ +Sparsity +++++++++ + +Sparsity has been an area of rapid growth and importance within PyTorch; if any sparsity terms are confusing below, +please refer to the `sparsity tutorial `__ for additional details. + +Sparse storage formats have been proven to be powerful in a variety of ways. As a primer, the first use case +most practitioners think about is when the majority of elements are equal to zero (a high degree of sparsity), +but even in cases of lower sparsity, certain formats (e.g. BSR) can take advantage of substructures within a matrix. + +.. note:: + + At the moment, MaskedTensor supports COO and CSR tensors with plans to support additional formats + (e.g. BSR and CSC) in the future. If you have any requests for additional formats, please file a feature request! + +Principles +---------- + +When creating a :class:`MaskedTensor` with sparse tensors, there are a few principles that must be observed: + +1. ``data`` and ``mask`` must have the same storage format, whether that's :attr:`torch.strided`, :attr:`torch.sparse_coo`, or :attr:`torch.sparse_csr` +2. ``data`` and ``mask`` must have the same size, indicated by :func:`size()` + +Sparse COO tensors +------------------ + +In accordance with Principle #1, a sparse COO MaskedTensor is created by passing in two sparse COO tensors, +which can be initialized by any of its constructors, e.g. :func:`torch.sparse_coo_tensor`. + +As a recap of `sparse COO tensors `__, the COO format +stands for "coordinate format", where the specified elements are stored as tuples of their indices and the +corresponding values. That is, the following are provided: + +* ``indices``: array of size ``(ndim, nse)`` and dtype ``torch.int64`` +* ``values``: array of size `(nse,)` with any integer or floating point dtype + +where ``ndim`` is the dimensionality of the tensor and ``nse`` is the number of specified elements + +For both sparse COO and CSR tensors, you can construct a :class:`MaskedTensor` by doing either: + +1. ``masked_tensor(sparse_tensor_data, sparse_tensor_mask)`` +2. ``dense_masked_tensor.to_sparse_coo()`` or ``dense_masked_tensor.to_sparse_csr()`` + +The second method is easier to illustrate so we've shown that below, but for more on the first and the nuances behind +the approach, please read the :ref:`sparse-coo-appendix`. + + >>> values = torch.tensor([[0, 0, 3], [4, 0, 5]]) + >>> mask = torch.tensor([[False, False, True], [False, False, True]]) + >>> mt = masked_tensor(values, mask) + >>> sparse_coo_mt = mt.to_sparse_coo() + >>> mt + MaskedTensor( + [ + [ --, --, 3], + [ --, --, 5] + ] + ) + >>> sparse_coo_mt + MaskedTensor( + [ + [ --, --, 3], + [ --, --, 5] + ] + ) + >>> sparse_coo_mt.get_data() + tensor(indices=tensor([[0, 1], + [2, 2]]), + values=tensor([3, 5]), + size=(2, 3), nnz=2, layout=torch.sparse_coo) + +Sparse CSR tensors +------------------ + +Similarly, :class:`MaskedTensor` also supports the +`CSR (Compressed Sparse Row) `__ +sparse tensor format. Instead of storing the tuples of the indices like sparse COO tensors, sparse CSR tensors +aim to decrease the memory requirements by storing compressed row indices. +In particular, a CSR sparse tensor consists of three 1-D tensors: + +* ``crow_indices``: array of compressed row indices with size ``(size[0] + 1,)``. This array indicates which row + a given entry in values lives in. The last element is the number of specified elements, + while crow_indices[i+1] - crow_indices[i] indicates the number of specified elements in row i. +* ``col_indices``: array of size ``(nnz,)``. Indicates the column indices for each value. +* ``values``: array of size ``(nnz,)``. Contains the values of the CSR tensor. + +Of note, both sparse COO and CSR tensors are in a `beta `__ state. + +By way of example: + + >>> mt_sparse_csr = mt.to_sparse_csr() + >>> mt_sparse_csr + MaskedTensor( + [ + [ --, --, 3], + [ --, --, 5] + ] + ) + >>> mt_sparse_csr.get_data() + tensor(crow_indices=tensor([0, 1, 2]), + col_indices=tensor([2, 2]), + values=tensor([3, 5]), size=(2, 3), nnz=2, layout=torch.sparse_csr) + +Appendix +++++++++ + +.. _sparse-coo-appendix: + +Sparse COO construction +----------------------- + +Recall in our original example, we created a :class:`MaskedTensor` and then converted it to a sparse COO MaskedTensor +with :meth:`MaskedTensor.to_sparse_coo`. + +Alternatively, we can also construct a sparse COO MaskedTensor directly by passing in two sparse COO tensors: + + >>> values = torch.tensor([[0, 0, 3], [4, 0, 5]]).to_sparse() + >>> mask = torch.tensor([[False, False, True], [False, False, True]]).to_sparse() + >>> mt = masked_tensor(values, mask) + >>> values + tensor(indices=tensor([[0, 1, 1], + [2, 0, 2]]), + values=tensor([3, 4, 5]), + size=(2, 3), nnz=3, layout=torch.sparse_coo) + >>> mask + tensor(indices=tensor([[0, 1], + [2, 2]]), + values=tensor([True, True]), + size=(2, 3), nnz=2, layout=torch.sparse_coo) + >>> mt + MaskedTensor( + [ + [ --, --, 3], + [ --, --, 5] + ] + ) + +Instead of using :meth:`torch.Tensor.to_sparse`, we can also create the sparse COO tensors directly, which brings us to a warning: + +.. warning:: + + When using a function like :meth:`MaskedTensor.to_sparse_coo`, if the user does not specify the indices like in the above + example, then the 0 values will be "unspecified" by default. + +Below, we explicitly specify the 0's: + + >>> values = torch.sparse_coo_tensor(i, v, (2, 3)) + >>> mask = torch.sparse_coo_tensor(i, m, (2, 3)) + >>> mt2 = masked_tensor(values, mask) + >>> values + tensor(indices=tensor([[0, 1, 1], + [2, 0, 2]]), + values=tensor([3, 4, 5]), + size=(2, 3), nnz=3, layout=torch.sparse_coo) + >>> mask + tensor(indices=tensor([[0, 1, 1], + [2, 0, 2]]), + values=tensor([ True, False, True]), + size=(2, 3), nnz=3, layout=torch.sparse_coo) + >>> mt2 + MaskedTensor( + [ + [ --, --, 3], + [ --, --, 5] + ] + ) + +Note that ``mt`` and ``mt2`` look identical on the surface, and in the vast majority of operations, will yield the same +result. But this brings us to a detail on the implementation: + +``data`` and ``mask`` -- only for sparse MaskedTensors -- can have a different number of elements (:func:`nnz`) +**at creation**, but the indices of ``mask`` must then be a subset of the indices of ``data``. In this case, +``data`` will assume the shape of ``mask`` by ``data = data.sparse_mask(mask)``; in other words, any of the elements +in ``data`` that are not ``True`` in ``mask`` (i.e. not specified) will be thrown away. + +Therefore, under the hood, the data looks slightly different; ``mt2`` has the "4" value masked out and ``mt`` is completely +without it. Their underlying data has different shapes, which would make operations like ``mt + mt2`` invalid. + + >>> mt.get_data() + tensor(indices=tensor([[0, 1], + [2, 2]]), + values=tensor([3, 5]), + size=(2, 3), nnz=2, layout=torch.sparse_coo) + >>> mt2.get_data() + tensor(indices=tensor([[0, 1, 1], + [2, 0, 2]]), + values=tensor([3, 4, 5]), + size=(2, 3), nnz=3, layout=torch.sparse_coo) + +.. _sparse-csr-appendix: + +Sparse CSR construction +----------------------- + +We can also construct a sparse CSR MaskedTensor using sparse CSR tensors, +and like the example above, this results in a similar treatment under the hood. + + >>> crow_indices = torch.tensor([0, 2, 4]) + >>> col_indices = torch.tensor([0, 1, 0, 1]) + >>> values = torch.tensor([1, 2, 3, 4]) + >>> mask_values = torch.tensor([True, False, False, True]) + >>> + >>> csr = torch.sparse_csr_tensor(crow_indices, col_indices, values, dtype=torch.double) + >>> mask = torch.sparse_csr_tensor(crow_indices, col_indices, mask_values, dtype=torch.bool) + >>> + >>> mt = masked_tensor(csr, mask) + >>> mt + MaskedTensor( + [ + [ 1.0000, --], + [ --, 4.0000] + ] + ) + >>> mt.get_data() + tensor(crow_indices=tensor([0, 2, 4]), + col_indices=tensor([0, 1, 0, 1]), + values=tensor([1., 2., 3., 4.]), size=(2, 2), nnz=4, + dtype=torch.float64, layout=torch.sparse_csr) + From 3a1754c11e6cc1d399a93a9f2e34b700d95d1d88 Mon Sep 17 00:00:00 2001 From: George Qi Date: Tue, 20 Sep 2022 00:41:20 +0000 Subject: [PATCH 3/6] [maskedtensor] Distinguish between 0 and NaN gradient [ghstack-poisoned] --- .../maskedtensor_distinguish_gradient.rst | 108 ++++++++++++++++++ index.rst | 2 + 2 files changed, 110 insertions(+) create mode 100644 beginner_source/maskedtensor_distinguish_gradient.rst diff --git a/beginner_source/maskedtensor_distinguish_gradient.rst b/beginner_source/maskedtensor_distinguish_gradient.rst new file mode 100644 index 00000000000..6118f572548 --- /dev/null +++ b/beginner_source/maskedtensor_distinguish_gradient.rst @@ -0,0 +1,108 @@ +Distinguishing between 0 and NaN gradient +----------------------------------------- + +One issue that :class:`torch.Tensor` runs into is the inability to distinguish between gradients that are not +defined (NaN) vs. gradients that are actually 0. By way of example, below are several different issues where +:class:`MaskedTensor` can resolve and/or work around the NaN gradient problem. + +`Issue 10729 `__ -- torch.where +-------------------------------------------------------------------------------- + +Current result: + + >>> x = torch.tensor([-10., -5, 0, 5, 10, 50, 60, 70, 80, 90, 100], requires_grad=True, dtype=torch.float) + >>> y = torch.where(x < 0, torch.exp(x), torch.ones_like(x)) + >>> y.sum().backward() + >>> x.grad + tensor([4.5400e-05, 6.7379e-03, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, + 0.0000e+00, 0.0000e+00, 0.0000e+00, nan, nan]) + +:class:`MaskedTensor` result: + + >>> x = torch.tensor([-10., -5, 0, 5, 10, 50, 60, 70, 80, 90, 100]) + >>> mask = x < 0 + >>> mx = masked_tensor(x, mask, requires_grad=True) + >>> my = masked_tensor(torch.ones_like(x), ~mask, requires_grad=True) + >>> y = torch.where(mask, torch.exp(mx), my) + >>> y.sum().backward() + >>> mx.grad + MaskedTensor( + [ 0.0000, 0.0067, --, --, --, --, --, --, --, --, --] + ) + +The gradient here is only provided to the selected subset. Effectively, this changes the gradient of `where` +to mask out elements instead of setting them to zero. + +`Issue 52248 `__ -- another torch.where +---------------------------------------------------------------------------------------- + +Current result: + + >>> a = torch.randn((), requires_grad=True) + >>> b = torch.tensor(False) + >>> c = torch.ones(()) + >>> torch.where(b, a/0, c) + tensor(1., grad_fn=) + >>> torch.autograd.grad(torch.where(b, a/0, c), a) + (tensor(nan),) + +:class:`MaskedTensor` result: + + >>> a = masked_tensor(torch.randn(()), torch.tensor(True), requires_grad=True) + >>> b = torch.tensor(False) + >>> c = torch.ones(()) + >>> torch.where(b, a/0, c) + MaskedTensor( 1.0000, True) + >>> torch.autograd.grad(torch.where(b, a/0, c), a) + (MaskedTensor(--, False),) + +`Issue 67180 `__ -- :func:`torch.nansum` and :func:`torch.nanmean` +------------------------------------------------------------------------------------------------------------------- + +Current result: + + >>> a = torch.tensor([1., 2., float('nan')]) + >>> b = torch.tensor(1.0, requires_grad=True) + >>> c = a * b + >>> c1 = torch.nansum(c) + >>> bgrad1, = torch.autograd.grad(c1, b, retain_graph=True) + >>> bgrad1 + tensor(nan) + +:class:`MaskedTensor` result: + + >>> a = torch.tensor([1., 2., float('nan')]) + >>> b = torch.tensor(1.0, requires_grad=True) + >>> mt = masked_tensor(a, ~torch.isnan(a)) + >>> c = mt * b + >>> c1 = torch.sum(c) + >>> bgrad1, = torch.autograd.grad(c1, b, retain_graph=True) + >>> bgrad1 + MaskedTensor( 3.0000, True) + +`Issue 4132 `__ -- when using mask, x/0 yields NaN grad +------------------------------------------------------------------------------------------------------- + +Current result: + + >>> x = torch.tensor([1., 1.], requires_grad=True) + >>> div = torch.tensor([0., 1.]) + >>> y = x/div # => y is [inf, 1] + >>> mask = (div != 0) # => mask is [0, 1] + >>> y[mask].backward() + >>> x.grad # grad is [nan, 1], but expected [0, 1] + tensor([nan, 1.]) + +:class:`MaskedTensor` result: + + >>> x = torch.tensor([1., 1.], requires_grad=True) + >>> div = torch.tensor([0., 1.]) + >>> y = x/div # => y is [inf, 1] + >>> + >>> mask = (div != 0) # => mask is [0, 1] + >>> loss = as_masked_tensor(y, mask) + >>> loss.sum().backward() + >>> x.grad + MaskedTensor( + [ --, 1.0000] + ) \ No newline at end of file diff --git a/index.rst b/index.rst index 29e4d62fe04..df01394eded 100644 --- a/index.rst +++ b/index.rst @@ -811,6 +811,8 @@ Additional Resources :caption: MaskedTensor beginner/maskedtensor_overview + beginner/maskedtensor_sparsity + beginner/maskedtensor_distinguish_gradient .. toctree:: From 8c7372556b82c2bfeab7f443575eac0c94704869 Mon Sep 17 00:00:00 2001 From: George Qi Date: Tue, 20 Sep 2022 00:41:22 +0000 Subject: [PATCH 4/6] [maskedtensor] Add safe softmax tutorial [ghstack-poisoned] --- beginner_source/maskedtensor_safe_softmax.rst | 33 +++++++++++++++++++ index.rst | 1 + 2 files changed, 34 insertions(+) create mode 100644 beginner_source/maskedtensor_safe_softmax.rst diff --git a/beginner_source/maskedtensor_safe_softmax.rst b/beginner_source/maskedtensor_safe_softmax.rst new file mode 100644 index 00000000000..0f34b479caf --- /dev/null +++ b/beginner_source/maskedtensor_safe_softmax.rst @@ -0,0 +1,33 @@ +Safe Softmax +------------ + +One of the issues that frequently comes up is the necessity for a safe softmax -- that is, if there is an entire +batch that is "masked out" or consists entirely of padding (which, in the softmax case, translates to being set `-inf`), +then this will result in NaNs, which can leading to training divergence. For more detail on why this functionality +is necessary, please find refer to +`Issue 55056 - Feature Request for Safe Softmax `__. + +Luckily, :class:`MaskedTensor` has solved this issue: + + >>> data = torch.randn(3, 3) + >>> mask = torch.tensor([[True, False, False], [True, False, True], [False, False, False]]) + >>> x = data.masked_fill(~mask, float('-inf')) + >>> mt = masked_tensor(data, mask) + +PyTorch result: + + >>> x.softmax(0) + tensor([[0.3548, nan, 0.0000], + [0.6452, nan, 1.0000], + [0.0000, nan, 0.0000]]) + +:class:`MaskedTensor` result: + + >>> mt.softmax(0) + MaskedTensor( + [ + [ 0.3548, --, --], + [ 0.6452, --, 1.0000], + [ --, --, --] + ] + ) diff --git a/index.rst b/index.rst index df01394eded..bb7a019c83b 100644 --- a/index.rst +++ b/index.rst @@ -813,6 +813,7 @@ Additional Resources beginner/maskedtensor_overview beginner/maskedtensor_sparsity beginner/maskedtensor_distinguish_gradient + beginner/maskedtensor_safe_softmax .. toctree:: From 46f2fbe45efd7fcd7ae84fdd076db97643df1d66 Mon Sep 17 00:00:00 2001 From: George Qi Date: Tue, 20 Sep 2022 00:41:24 +0000 Subject: [PATCH 5/6] [maskedtensor] Add missing nan ops tutorial [ghstack-poisoned] --- .../maskedtensor_missing_nan_ops.rst | 32 +++++++++++++++++++ 1 file changed, 32 insertions(+) create mode 100644 beginner_source/maskedtensor_missing_nan_ops.rst diff --git a/beginner_source/maskedtensor_missing_nan_ops.rst b/beginner_source/maskedtensor_missing_nan_ops.rst new file mode 100644 index 00000000000..5fbcbabbc10 --- /dev/null +++ b/beginner_source/maskedtensor_missing_nan_ops.rst @@ -0,0 +1,32 @@ +Implementing missing torch.nan* operators +----------------------------------------- + +In the above issue, there is a request to add additional operators to cover the various `torch.nan*` applications, +such as ``torch.nanmax``, ``torch.nanmin``, etc. + +In general, these problems lend themselves more naturally to masked semantics, so instead of introducing additional +operators, we propose using MaskedTensors instead. Since +`nanmean has already landed `__, we can use it as a comparison point: + + >>> x = torch.arange(16).float() + >>> y = x * x.fmod(4) + >>> y = y.masked_fill(y ==0, float('nan')) + >>> y + tensor([nan, 1., 4., 9., nan, 5., 12., 21., nan, 9., 20., 33., nan, 13., + 28., 45.]) + >>> y.nanmean() + tensor(16.6667) + >>> torch.mean(masked_tensor(y, ~torch.isnan(y))) + MaskedTensor( 16.6667, True) + +:class:`MaskedTensor` can also support reductions when the data is fully masked out, which is equivalent +to the case above when the data Tensor is completely ``nan``. ``nanmean`` would return ``nan`` +(an ambiguous return value), while MaskedTensor would more accurately indicate a masked out result. + + >>> x = torch.empty(16).fill_(float('nan')) + >>> x + tensor([nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan]) + >>> torch.nanmean(x) + tensor(nan) + >>> torch.mean(masked_tensor(x, ~torch.isnan(x))) + MaskedTensor(--, False) From 6f2478f4feb8193608daf09ba69da71e514cfc86 Mon Sep 17 00:00:00 2001 From: George Qi Date: Tue, 20 Sep 2022 00:41:25 +0000 Subject: [PATCH 6/6] [maskedtensor] Add adagrad sparse semantics tutorial [ghstack-poisoned] --- .../maskedtensor_adagrad_sparse_semantics.py | 180 ++++++++++++++++++ index.rst | 2 + 2 files changed, 182 insertions(+) create mode 100644 beginner_source/maskedtensor_adagrad_sparse_semantics.py diff --git a/beginner_source/maskedtensor_adagrad_sparse_semantics.py b/beginner_source/maskedtensor_adagrad_sparse_semantics.py new file mode 100644 index 00000000000..d6fae685306 --- /dev/null +++ b/beginner_source/maskedtensor_adagrad_sparse_semantics.py @@ -0,0 +1,180 @@ +# -*- coding: utf-8 -*- + +""" +Efficiency of writing "sparse" semantics for Adagrad +==================================================== + +`Issue 1369 `__ discussed the additional lines of code +that were introduce while writing "sparse" semantics for Adagrad. +But really the code doesn't use sparsity as a compression and optimization technique, +it wants to use masked semantics. We worked around this by introducing one-off semantics and operators +that encode this behavior while forcing users to be aware of storage details such as indices and values. + +In particular we'll point out when sparsity is used as a semantic extension, i.e. unspecified values are not zero +and when it is just used to compress zeros. +We'll also compare and contrast this with equivalent code written using MaskedTensor. +In the end the code snippets are repeat without additional comments to show the difference in brevity. + +"""" + +import torch +from torch.masked.maskedtensor import masked_tensor + +###################################################################### +# Original sparse implementation +# ------------------------------ +# +# First, let's look at the current implementation of +# `Adagrad (functional) `__ +# + +def _make_sparse(grad, grad_indices, values): + size = grad.size() + if grad_indices.numel() == 0 or values.numel() == 0: + return torch.empty_like(grad) + return torch.sparse_coo_tensor(grad_indices, values, size) + +# Some hyperparameters +eps = 1e-10 +clr = 0.1 + +# We don't support sparse gradients +param = torch.arange(8).reshape(2, 4).float() +i = torch.tensor([[0, 1, 1], + [2, 0, 2]]) +v = torch.tensor([3, 4, 5], dtype=torch.float32) +grad = torch.sparse_coo_tensor(i, v, [2, 4]) +state_sum = torch.full_like(param, 0.5) # initial value for state sum + +print("param:\n", param) +print("grad:\n", grad.to_dense()) +print("state_sum:\n", state_sum) + +###################################################################### +# + +state_sum = torch.full_like(param, 0.5) # initial value for state sum +print(state_sum) + +grad = grad.coalesce() # the update is non-linear so indices must be unique +grad_indices = grad._indices() +grad_values = grad._values() + +# pow(2) has the same semantics for both sparse and dense memory layouts since 0^2 is zero +state_sum.add_(_make_sparse(grad, grad_indices, grad_values.pow(2))) +# We take care to make std sparse, even though state_sum clearly is not. +# This means that we're only applying the gradient to parts of the state_sum +# for which it is specified. This even drives the point home a lot more that +# the passed gradient is not sparse, but masked. +std = state_sum.sparse_mask(grad) +print("state_sum:\n", state_sum) +print("std:\n", std.to_dense()) + +###################################################################### +# This is where we have a very important divergence. +# The addition of eps should technically be applied to all values, but instead is only applied to specified values. +# Here we're using sparsity as a semantic extension and to enforce a certain pattern of defined and undefined values. +# If parts of the values of the gradient are zero they are still included if materialized. +# Even though they could be compressed by other sparse storage layouts. +# This is technically quite brittle even though someone could argue that eps is always very small. +# +# Moreover, an implementation add_ for sparsity as a storage layout and compression scheme should cause densification, +# but we force it not to. +# For this one-off case it is fine until we want to introduce new compression schemes +# such as CSR, BSR or 2:4 block sparsity. We'll then need to introduce separate Tensor types for each +# and write variations for gradients compressed using different storage formats. +# + +# We currently dodge all these concerns using the private method values. +std_values = std._values().sqrt_().add_(eps) + +# We currently don't support div for sparse Tensors because zero / zero is +# not well defined. For a MaskedTensor undefined / undefined is undefined. +param.add_(_make_sparse(grad, grad_indices, grad_values / std_values), alpha=-clr) +print("param:\n", param) + +###################################################################### +# MaskedTensor sparse implementation +# ---------------------------------- +# +# We've been conflating sparsity as an optimization with sparsity as a semantic extension to PyTorch. +# MaskedTensor proposes to call the semantic extension through sparsity masked. +# Currently we can't have dense semantics with sparse storage or masked semantics with dense storage, while +# MaskedTensor fixes that because it separates the storage from the semantics. +# Consider the above example using a masked gradient: +# + +# Create an entirely new set of parameters to avoid errors +param2 = torch.arange(8).reshape(2, 4).float() +state_sum2 = torch.full_like(param, 0.5) # initial value for state sum + +mask = (grad.to_dense() != 0).to_sparse() +masked_grad = masked_tensor(grad, mask) +print("masked_grad:\n", masked_grad) + +###################################################################### +# + +state_sum2 = state_sum2 + masked_grad.pow(2).data() +std2 = masked_tensor(state_sum2.to_sparse(), mask) + +# Let's print both this version and the regular version for easier comparison +print("state_sum:\n", state_sum) +print("std:\n", std) +print("state_sum2:\n", state_sum2) +print("std2:\n", std2) + +###################################################################### +# + +# We can add support for in-place operations later. Notice how this doesn't +# need to access any storage internals and is in general a lot shorter +std2 = std2.sqrt().add(eps) + +print("std:\n", std) +print("std2:\n", std2) + +# .data() indeed returns a sparse tensor +param2 = param2.add((masked_grad / std2).data(), alpha=-clr) +print("param2:\n", param2) + +###################################################################### +# Conclusion: Difference in code +# ------------------------------ +# +# For reference, this is the regular, dense code path without masked gradients or sparsity: +# :: +# +# state_sum.addcmul_(grad, grad, value=1) +# std = state_sum.sqrt().add_(eps) +# param.addcdiv_(grad, std, value=-clr) +# +# The vanilla tensor implementation for sparse is: +# + +grad = grad.coalesce() # the update is non-linear so indices must be unique +grad_indices = grad._indices() +grad_values = grad._values() +size = grad.size() + +state_sum.add_(_make_sparse(grad, grad_indices, grad_values.pow(2))) +std = state_sum.sparse_mask(grad) +std_values = std._values().sqrt_().add_(eps) +param.add_(_make_sparse(grad, grad_indices, grad_values / std_values), alpha=-clr) + +###################################################################### +# while MaskedTensor minimizes the code to the snippet: +# + +state_sum2 = state_sum2 + masked_grad.pow(2).data() +std2 = masked_tensor(state_sum2.to_sparse(), mask) +std2 = std2.sqrt().add(eps) +param2 = param2.add((masked_grad / std2).data(), alpha=-clr) + +###################################################################### +# And for good measure, let's make sure the parameters match: +# + +print("param:\n", param) +print("param2:\n", param2) + \ No newline at end of file diff --git a/index.rst b/index.rst index bb7a019c83b..3eed6fa5b07 100644 --- a/index.rst +++ b/index.rst @@ -814,6 +814,8 @@ Additional Resources beginner/maskedtensor_sparsity beginner/maskedtensor_distinguish_gradient beginner/maskedtensor_safe_softmax + beginner/maskedtensor_missing_nan_ops + beginner/maskedtensor_adagrad_sparse_semantics .. toctree::