Skip to content

Commit 1761555

Browse files
author
Seth Weidman
committed
Final draft of named tensor tutorial
1 parent bdf3c13 commit 1761555

File tree

1 file changed

+61
-55
lines changed

1 file changed

+61
-55
lines changed

intermediate_source/named_tensor_tutorial.py

Lines changed: 61 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,14 @@
44
****************************************
55
**Author**: `Richard Zou <https://github.com/zou3519>`_
66
7-
`Sasha Rush <https://tech.cornell.edu/people/alexander-rush/>`_ proposed the idea of
8-
`named tensors <http://nlp.seas.harvard.edu/NamedTensor>`_ in a January 2019 blog post as a
9-
way to enable more readable code when writing with the manipulations of multidimensional
10-
arrays necessary for coding up Transformer and Convolutional architectures. With PyTorch 1.3,
11-
we begin supporting the concept of named tensors by allowing a ``Tensor`` to have **named
12-
dimensions**; this tutorial is intended as a guide to the functionality that will
7+
Named Tensors aim to make tensors easier to use by allowing users to associate explicit names
8+
with tensor dimensions. In most cases, operations that take dimension parameters will accept
9+
dimension names, avoiding the need to track dimensions by position. In addition, named tensors
10+
use names to automatically check that APIs are being used correctly at runtime, providing extra
11+
safety. Names can also be used to rearrange dimensions, for example, to support
12+
"broadcasting by name" rather than "broadcasting by position".
13+
14+
this tutorial is intended as a guide to the functionality that will
1315
be included with the 1.3 launch. By the end of it, you will be able to:
1416
1517
- Initiate a ``Tensor`` with named dimensions, as well as removing or renmaing those dimensions
@@ -18,35 +20,40 @@
1820
- Broadcasting operations
1921
- Flattening and unflattening dimensions
2022
21-
Finally, we'll put this into practice by coding the operations of multi-headed attention
22-
using named tensors, and see that the code is significantly more readable than it would
23-
be with regular, "unnamed" tensors!
23+
Finally, we'll put this into practice by writing a multi-headed attention module
24+
using named tensors.
25+
26+
Named tensors in PyTorch are inspired by and done in collaboration with
27+
`Sasha Rush <https://tech.cornell.edu/people/alexander-rush/>`_.
28+
The original idea and proof of concept were proposed in his
29+
`January 2019 blog post <http://nlp.seas.harvard.edu/NamedTensor>`_.
2430
"""
2531

2632
######################################################################
2733
# Basics: named dimensions
2834
# ------------------------
2935
#
30-
# Tensors now take a new ``names`` argument that represents a name for each dimension.
31-
# Here we construct a tensor with names:
36+
# PyTorch now allows Tensors to have named dimensions; factory functions
37+
# now take a new `names` argument that associates a name with each dimension.
38+
# This works with most factory functions, such as
39+
#
40+
# - `tensor`
41+
# - `empty`
42+
# - `ones`
43+
# - `zeros`
44+
# - `randn`
45+
# - `rand`
3246
#
47+
# Here we construct a tensor with names:
3348

3449
import torch
3550
imgs = torch.randn(1, 2, 2, 3 , names=('N', 'C', 'H', 'W'))
3651

37-
######################################################################
38-
# This works with most factory functions, such as:
39-
#
40-
# - ``tensor``
41-
# - ``empty``
42-
# - ``ones``
43-
# - ``zeros``
44-
#
4552
# Unlike in
4653
# `the original named tensors blogpost <http://nlp.seas.harvard.edu/NamedTensor>`_,
47-
# named dimensions are ordered. `tensor.names[i]` is the name of the `i`th dimension of `tensor`.
54+
# named dimensions are ordered: `tensor.names[i]` is the name of the `i`th dimension of `tensor`.
4855
#
49-
# There are two ways rename a ``Tensor``'s names:
56+
# There are two ways rename a ``Tensor``'s dimensions:
5057
#
5158

5259
print(imgs.names)
@@ -63,10 +70,11 @@
6370
# The preferred way to remove names is to call ``tensor.rename(None)``:
6471

6572
imgs.rename(None)
73+
print(imgs.names)
6674

6775
######################################################################
6876
# Unnamed tensors (tensors with no named dimensions) still work as normal and do
69-
# not have names in their repr.
77+
# not have names in their `repr`.
7078

7179
unnamed = torch.randn(2, 1, 3)
7280
print(unnamed)
@@ -87,8 +95,9 @@
8795
# - A ``None`` dim can be refined to have any name
8896
# - A named dim can only be refined to have the same name.
8997

90-
print(imgs.names)
91-
print(imgs.refine_names('N', 'C', 'H', 'W').names)
98+
imgs = torch.randn(3, 1, 1, 2)
99+
named_imgs= imgs.refine_names('N', 'C', 'H', 'W')
100+
print(named_imgs.names)
92101

93102
# Coerces the last two dims to 'H' and 'W'. In Python 2, use the string '...' instead of ...
94103
print(imgs.refine_names(..., 'H', 'W').names)
@@ -105,7 +114,7 @@ def catch_error(fn):
105114
######################################################################
106115
# Most simple operations propagate names. The ultimate goal for named tensors is
107116
# for all operations to propagate names in a reasonable, intuitive manner. Many
108-
# common operations have been implemented at the time of the 1.3 release; here,
117+
# common operations have been added at the time of the 1.3 release; here,
109118
# for example, is `.abs()`:
110119

111120
named_imgs = imgs.refine_names('N', 'C', 'H', 'W')
@@ -133,9 +142,9 @@ def catch_error(fn):
133142
# Names are propagated on operations in a two step process called **name inference**. It
134143
# works as follows:
135144
#
136-
# - **Check names**: an operator may check that certain dimensions must match.
137-
# - **Propagate names**: name inference computes and propagates output names to
138-
# output tensors.
145+
# 1. **Check names**: an operator may check that certain dimensions must match.
146+
# 2. **Propagate names**: name inference computes and propagates output names to
147+
# output tensors.
139148
#
140149
# Let's go through the very small example of adding 2 one-dim tensors with no
141150
# broadcasting.
@@ -151,7 +160,7 @@ def catch_error(fn):
151160

152161
catch_error(lambda: x + z)
153162

154-
# **Propagate names**: unify the two names by returning the most refined name of
163+
# **Propagate names**: _unify_ the two names by returning the most refined name of
155164
# the two. With ``x + y``, ``X`` is more specific than ``None``.
156165

157166
print((x + y).names)
@@ -188,11 +197,10 @@ def catch_error(fn):
188197
# Matrix multiply
189198
# ---------------
190199
#
191-
# Of course, many of you may be wondering about the very special operation of
192-
# matrix multiplication. ``torch.mm(A, B)`` contracts away the second dimension
193-
# of ``A`` with the first dimension of ``B``, returning a tensor with the first
194-
# dim of ``A`` and the second dim of ``B``. (the other matmul functions,
195-
# ``torch.matmul``, ``torch.mv``, ``torch.dot``, behave similarly):
200+
# `torch.mm(A, B)`` performs a dot product between the second dim of `A`
201+
# and the first dim of `B`, returning a tensor with the first dim of `A`
202+
# and the second dim of `B`. (the other matmul functions, such as `torch.matmul`,
203+
# `torch.mv`, `torch.dot`, behave similarly).
196204

197205
markov_states = torch.randn(128, 5, names=('batch', 'D'))
198206
transition_matrix = torch.randn(5, 5, names=('in', 'out'))
@@ -202,6 +210,9 @@ def catch_error(fn):
202210
print(new_state.names)
203211

204212
######################################################################
213+
# Inherently, matrix multiply does not check if the contracted dimensions
214+
# have the same name.
215+
#
205216
# New behavior: Explicit broadcasting by names
206217
# --------------------------------------------
207218
#
@@ -272,6 +283,14 @@ def catch_error(fn):
272283

273284
print(weight.grad) # Unnamed for now. Will be named in the future
274285

286+
weight.grad.zero_()
287+
grad_loss = grad_loss.refine_names('C')
288+
loss = (x - weight).abs()
289+
# Ideally we'd check that the names of loss and grad_loss match but we don't yet.
290+
loss.backward(grad_loss)
291+
292+
print(weight.grad) # still unnamed
293+
275294
######################################################################
276295
# Other supported (and unsupported) features
277296
# ------------------------------------------
@@ -280,8 +299,8 @@ def catch_error(fn):
280299
# supported with the 1.3 release, what is on the roadmap to be supported soon,
281300
# and what will be supported in the future but not soon.
282301
#
283-
# In particular, three important features that we do not have plans to support
284-
# soon are:
302+
# In particular, we want to call out three important features that are not
303+
# currently supported:
285304
#
286305
# - Retaining names when serializing or loading a serialized ``Tensor`` via
287306
# ``torch.save``
@@ -313,7 +332,8 @@ def fn(x):
313332
# `here <https://github.com/facebookresearch/ParlAI/blob/f7db35cba3f3faf6097b3e6b208442cd564783d9/parlai/agents/transformer/modules.py#L907>`_.
314333
# Read through the code at that example; then, compare with the code below,
315334
# noting that there are four places labeled (I), (II), (III), and (IV), where
316-
# using named tensors enables more readable code.
335+
# using named tensors enables more readable code; we will dive into each of these
336+
# after the code block.
317337

318338
import torch.nn as nn
319339
import torch.nn.functional as F
@@ -393,8 +413,6 @@ def prepare_head(tensor):
393413
return self.out_lin(attentioned).refine_names(..., 'T', 'D')
394414

395415
######################################################################
396-
# Let's dive into each of these areas in turn:
397-
#
398416
# **(I) Refining the input tensor dims**
399417

400418
def forward(self, query, key=None, value=None, mask=None):
@@ -409,18 +427,6 @@ def forward(self, query, key=None, value=None, mask=None):
409427
#
410428
# **(II) Manipulating dimensions in ``prepare_head``**
411429

412-
def prepare_head(tensor):
413-
# (II)
414-
tensor = tensor.refine_names('N', 'T', 'D')
415-
return (tensor.unflatten('D', [('H', n_heads), ('D_head', dim_per_head)])
416-
.align_to('N', 'H', 'T', 'D_head').contiguous())
417-
418-
######################################################################
419-
# Next, multihead attention takes the key, query, and value and splits their
420-
# feature dimensions into multiple heads and rearranges the dim order to be
421-
# ``['N', 'H', 'T', 'D_head']``. We can achieve something similar using view
422-
# and transpose operations like the following:
423-
424430
# (II)
425431
def prepare_head(tensor):
426432
tensor = tensor.refine_names(..., 'T', 'D')
@@ -470,13 +476,13 @@ def ignore():
470476
dot_prod.masked_fill_(attn_mask, -float(1e20))
471477

472478
######################################################################
473-
# ``mask`` usually has dims ``[N, T]`` (in the case of self-attention) or
479+
# ``mask`` usually has dims ``[N, T]`` (in the case of self attention) or
474480
# ``[N, T, T_key]`` (in the case of encoder attention) while ``dot_prod``
475481
# has dims ``[N, H, T, T_key]``. To make ``mask`` broadcast correctly with
476482
# ``dot_prod``, we would usually `unsqueeze` dims `1` and `-1` in the case of self
477483
# attention or `unsqueeze` dim `1` in the case of encoder attention. Using
478484
# named tensors, we can simply align the two tensors and stop worrying about
479-
# where to ``unsqueeze`` dims. Using named tensors, we simply align `attn_mask`
485+
# where to unsqueeze` dims. Using named tensors, we simply align `attn_mask`
480486
# to `dot_prod` using `align_as` and stop worrying about where to `unsqueeze` dims.
481487
#
482488
# **(IV) More dimension manipulation using ``align_to`` and ``flatten``**
@@ -490,8 +496,8 @@ def ignore():
490496
)
491497

492498
######################################################################
493-
# (IV): Like (II), using ``align_to`` and ``flatten`` are more semantically
494-
# meaningful than `view`.
499+
# (IV): Like (II), `align_to` and `flatten` are more semantically
500+
# meaningful than `view` (despite being more verbose).
495501
#
496502
# Running the example
497503
# -------------------

0 commit comments

Comments
 (0)