Skip to content

Commit 5c15b66

Browse files
author
Seth Weidman
committed
Formatting fixes, reordering for flow and clarity
1 parent 1761555 commit 5c15b66

File tree

1 file changed

+97
-96
lines changed

1 file changed

+97
-96
lines changed

intermediate_source/named_tensor_tutorial.py

Lines changed: 97 additions & 96 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,19 @@
11
# -*- coding: utf-8 -*-
22
"""
3-
Introduction to Named Tensors in PyTorch
4-
****************************************
3+
(experimental) Introduction to Named Tensors in PyTorch
4+
*******************************************************
55
**Author**: `Richard Zou <https://github.com/zou3519>`_
66
7+
**Editor**: `Seth Weidman <https://github.com/SethHWeidman>`_
8+
79
Named Tensors aim to make tensors easier to use by allowing users to associate explicit names
810
with tensor dimensions. In most cases, operations that take dimension parameters will accept
911
dimension names, avoiding the need to track dimensions by position. In addition, named tensors
1012
use names to automatically check that APIs are being used correctly at runtime, providing extra
1113
safety. Names can also be used to rearrange dimensions, for example, to support
1214
"broadcasting by name" rather than "broadcasting by position".
1315
14-
this tutorial is intended as a guide to the functionality that will
16+
This tutorial is intended as a guide to the functionality that will
1517
be included with the 1.3 launch. By the end of it, you will be able to:
1618
1719
- Initiate a ``Tensor`` with named dimensions, as well as removing or renmaing those dimensions
@@ -20,61 +22,59 @@
2022
- Broadcasting operations
2123
- Flattening and unflattening dimensions
2224
23-
Finally, we'll put this into practice by writing a multi-headed attention module
25+
Finally, we'll put this into practice by writing a multi-head attention module
2426
using named tensors.
2527
2628
Named tensors in PyTorch are inspired by and done in collaboration with
2729
`Sasha Rush <https://tech.cornell.edu/people/alexander-rush/>`_.
2830
The original idea and proof of concept were proposed in his
2931
`January 2019 blog post <http://nlp.seas.harvard.edu/NamedTensor>`_.
30-
"""
3132
32-
######################################################################
33-
# Basics: named dimensions
34-
# ------------------------
35-
#
36-
# PyTorch now allows Tensors to have named dimensions; factory functions
37-
# now take a new `names` argument that associates a name with each dimension.
38-
# This works with most factory functions, such as
39-
#
40-
# - `tensor`
41-
# - `empty`
42-
# - `ones`
43-
# - `zeros`
44-
# - `randn`
45-
# - `rand`
46-
#
47-
# Here we construct a tensor with names:
33+
Basics: named dimensions
34+
========================
35+
36+
PyTorch now allows Tensors to have named dimensions; factory functions
37+
now take a new `names` argument that associates a name with each dimension.
38+
This works with most factory functions, such as
39+
40+
- `tensor`
41+
- `empty`
42+
- `ones`
43+
- `zeros`
44+
- `randn`
45+
- `rand`
46+
47+
Here we construct a tensor with names:
48+
"""
4849

4950
import torch
5051
imgs = torch.randn(1, 2, 2, 3 , names=('N', 'C', 'H', 'W'))
52+
print(imgs.names)
5153

54+
######################################################################
5255
# Unlike in
5356
# `the original named tensors blogpost <http://nlp.seas.harvard.edu/NamedTensor>`_,
54-
# named dimensions are ordered: `tensor.names[i]` is the name of the `i`th dimension of `tensor`.
57+
# named dimensions are ordered: ``tensor.names[i]`` is the name of the ``i`` th dimension of ``tensor``.
5558
#
5659
# There are two ways rename a ``Tensor``'s dimensions:
57-
#
58-
59-
print(imgs.names)
6060

6161
# Method #1: set .names attribute
6262
imgs.names = ['batch', 'channel', 'width', 'height']
6363
print(imgs.names)
6464

6565
# Method #2: specify new names:
66-
imgs.rename(channel='C', width='W', height='H')
66+
imgs = imgs.rename(channel='C', width='W', height='H')
6767
print(imgs.names)
6868

6969
######################################################################
7070
# The preferred way to remove names is to call ``tensor.rename(None)``:
7171

72-
imgs.rename(None)
72+
imgs = imgs.rename(None)
7373
print(imgs.names)
7474

7575
######################################################################
7676
# Unnamed tensors (tensors with no named dimensions) still work as normal and do
77-
# not have names in their `repr`.
77+
# not have names in their ``repr``.
7878

7979
unnamed = torch.randn(2, 1, 3)
8080
print(unnamed)
@@ -87,17 +87,18 @@
8787
print(imgs.names)
8888

8989
######################################################################
90-
# Because named tensors coexist with unnamed tensors, we need a nice way to write named-tensor-aware
91-
# code that works with both named and unnamed tensors. Use ``tensor.refine_names(*names)`` to refine
92-
# dimensions and lift unnamed dims to named dims. Refining a dimension is defined as a "rename" with
93-
# the following constraints:
90+
# Because named tensors coexist with unnamed tensors, we need a nice way to
91+
# write named-tensor-aware code that works with both named and unnamed tensors.
92+
# Use ``tensor.refine_names(*names)`` to refine dimensions and lift unnamed dims
93+
# to named dims. Refining a dimension is defined as a "rename" with the following
94+
# constraints:
9495
#
9596
# - A ``None`` dim can be refined to have any name
9697
# - A named dim can only be refined to have the same name.
9798

9899
imgs = torch.randn(3, 1, 1, 2)
99-
named_imgs= imgs.refine_names('N', 'C', 'H', 'W')
100-
print(named_imgs.names)
100+
imgs = imgs.refine_names('N', 'C', 'H', 'W')
101+
print(imgs.names)
101102

102103
# Coerces the last two dims to 'H' and 'W'. In Python 2, use the string '...' instead of ...
103104
print(imgs.refine_names(..., 'H', 'W').names)
@@ -109,7 +110,7 @@ def catch_error(fn):
109110
print(err)
110111

111112
# Tried to refine an existing name to a different name
112-
print(catch_error(lambda: imgs.refine_names('batch', 'channel', 'height', 'width')))
113+
catch_error(lambda: imgs.refine_names('batch', 'channel', 'height', 'width'))
113114

114115
######################################################################
115116
# Most simple operations propagate names. The ultimate goal for named tensors is
@@ -121,12 +122,34 @@ def catch_error(fn):
121122
print(named_imgs.abs().names)
122123

123124
######################################################################
125+
# Speaking of operations propogating names, let's quickly cover one
126+
# of the most important operations in PyTorch.
127+
#
128+
# Matrix multiply
129+
# ---------------
130+
#
131+
# ``torch.mm(A, B)`` performs a dot product between the second dim of `A`
132+
# and the first dim of `B`, returning a tensor with the first dim of `A`
133+
# and the second dim of `B`. (the other matmul functions, such as ``torch.matmul``,
134+
# ``torch.mv``, ``torch.dot``, behave similarly).
135+
136+
markov_states = torch.randn(128, 5, names=('batch', 'D'))
137+
transition_matrix = torch.randn(5, 5, names=('in', 'out'))
138+
139+
# Apply one transition
140+
new_state = markov_states @ transition_matrix
141+
print(new_state.names)
142+
143+
######################################################################
144+
# As you can see, matrix multiply does not check if the contracted dimensions
145+
# have the same name.
146+
#
124147
# Accessors and Reduction
125148
# -----------------------
126149
#
127150
# One can use dimension names to refer to dimensions instead of the positional
128151
# dimension. These operations also propagate names. Indexing (basic and
129-
# advanced) has not been implemented yet but is on the roadmap. Using the `named_imgs`
152+
# advanced) has not been implemented yet but is on the roadmap. Using the ``named_imgs``
130153
# tensor from above, we can do:
131154

132155
output = named_imgs.sum(['C']) # Perform a sum over the channel dimension
@@ -139,12 +162,12 @@ def catch_error(fn):
139162
# Name inference
140163
# --------------
141164
#
142-
# Names are propagated on operations in a two step process called **name inference**. It
143-
# works as follows:
165+
# Names are propagated on operations in a two step process called **name inference**.
166+
# The two steps are:
144167
#
145-
# 1. **Check names**: an operator may check that certain dimensions must match.
146-
# 2. **Propagate names**: name inference computes and propagates output names to
147-
# output tensors.
168+
# 1. **Check names**: an operator may perform automatic checks at runtime that
169+
# check that certain dimension names must match.
170+
# 2. **Propagate names**: name inference propagates output names to output tensors.
148171
#
149172
# Let's go through the very small example of adding 2 one-dim tensors with no
150173
# broadcasting.
@@ -153,65 +176,53 @@ def catch_error(fn):
153176
y = torch.randn(3)
154177
z = torch.randn(3, names=('Z',))
155178

179+
######################################################################
156180
# **Check names**: first, we will check whether the names of these two tensors
157181
# match. Two names match if and only if they are equal (string equality) or at
158182
# least one is ``None`` (``None``s are essentially a special wildcard name).
159183
# The only one of these three that will error, therefore, is ``x+z``:
160184

161185
catch_error(lambda: x + z)
162186

187+
######################################################################
163188
# **Propagate names**: _unify_ the two names by returning the most refined name of
164-
# the two. With ``x + y``, ``X`` is more specific than ``None``.
189+
# the two. With `x + y, `X` is more refined than `None`.
165190

166191
print((x + y).names)
167192

168193
######################################################################
169-
# Most name inference rules are straightforward but some of them (the dot
170-
# product ones) can have unexpected semantics. Let's go through a few more of
171-
# them.
194+
# Most name inference rules are straightforward but some of them can have
195+
# unexpected semantics. Let's go through a few more of them.
172196
#
173197
# Broadcasting
174198
# ------------
175199
#
176200
# Named tensors do not change broadcasting behavior; they still broadcast by
177201
# position. However, when checking two dimensions for if they can be
178-
# broadcasted, the names of those dimensions must match. Two names match if and
179-
# only if they are equal (string equality), or if one is None.
202+
# broadcasted, the names of those dimensions must match.
180203
#
181-
# We do not support **automatic broadcasting** by names because the output
182-
# ordering is ambiguous and does not work well with unnamed dimensions. However,
183-
# we support **explicit broadcasting** by names, which is introduced in a later
184-
# section. The two examples below help clarify this.
204+
# Furthermore, broadcasting with named tensors can prevent incorrect behavior.
205+
# The following code will error, whereas without `names` it would add
206+
# `per_batch_scale` to the last dimension of `imgs`.
185207

186208
# Automatic broadcasting: expected to fail
187209
imgs = torch.randn(6, 6, 6, 6, names=('N', 'C', 'H', 'W'))
188210
per_batch_scale = torch.rand(6, names=('N',))
189211
catch_error(lambda: imgs * per_batch_scale)
190212

191-
# Explicit broadcasting: the names check out and the more refined names are propagated.
213+
######################################################################
214+
# How `should` we perform this broadcasting operation along the first
215+
# dimension? One way, involving names, would be to explicitly initialize
216+
# the ``per_batch_scale`` tensor as four dimensional, and give it names
217+
# (such as ``('N', None, None, None)`` below) so that name inference will
218+
# work.
219+
192220
imgs = torch.randn(6, 6, 6, 6, names=('N', 'C', 'H', 'W'))
193221
per_batch_scale_4d = torch.rand(6, 1, 1, 1, names=('N', None, None, None))
194222
print((imgs * per_batch_scale_4d).names)
195223

196224
######################################################################
197-
# Matrix multiply
198-
# ---------------
199-
#
200-
# `torch.mm(A, B)`` performs a dot product between the second dim of `A`
201-
# and the first dim of `B`, returning a tensor with the first dim of `A`
202-
# and the second dim of `B`. (the other matmul functions, such as `torch.matmul`,
203-
# `torch.mv`, `torch.dot`, behave similarly).
204-
205-
markov_states = torch.randn(128, 5, names=('batch', 'D'))
206-
transition_matrix = torch.randn(5, 5, names=('in', 'out'))
207-
208-
# Apply one transition
209-
new_state = markov_states @ transition_matrix
210-
print(new_state.names)
211-
212-
######################################################################
213-
# Inherently, matrix multiply does not check if the contracted dimensions
214-
# have the same name.
225+
# However, named tensors enable an even better way, which we'll cover next.
215226
#
216227
# New behavior: Explicit broadcasting by names
217228
# --------------------------------------------
@@ -316,13 +327,13 @@ def fn(x):
316327
catch_error(lambda: fn(imgs_named))
317328

318329
######################################################################
319-
# As a workaround, please drop names via `tensor = tensor.rename(None)`
330+
# As a workaround, please drop names via ``tensor = tensor.rename(None)``
320331
# before using anything that does not yet support named tensors.
321332
#
322333
# Longer example: Multi-headed attention
323334
# --------------------------------------
324335
#
325-
# Now we'll go through a complete example of implementing a common advanced
336+
# Now we'll go through a complete example of implementing a common
326337
# PyTorch ``nn.Module``: multi-headed attention. We assume the reader is already
327338
# familiar with multi-headed attention; for a refresher, check out
328339
# `this explanation <http://jalammar.github.io/illustrated-transformer/>`_.
@@ -378,12 +389,9 @@ def prepare_head(tensor):
378389
return (tensor.unflatten('D', [('H', n_heads), ('D_head', dim_per_head)])
379390
.align_to(..., 'H', 'T', 'D_head'))
380391

392+
assert value is None
381393
if self_attn:
382394
key = value = query
383-
elif value is None:
384-
# key and value are the same, but query differs
385-
key = key.refine_names(..., 'T', 'D')
386-
value = key
387395
key_len = key.size('T')
388396
dim = key.size('D')
389397

@@ -396,9 +404,7 @@ def prepare_head(tensor):
396404
dot_prod.refine_names(..., 'H', 'T', 'T_key') # just a check
397405

398406
# (III)
399-
# Named tensors doesn't support `==` yet; the following is a workaround.
400-
attn_mask = (mask.rename(None) == 0).refine_names(*mask.names)
401-
attn_mask = attn_mask.align_as(dot_prod)
407+
attn_mask = (attn_mask == 0).align_as(dot_prod)
402408
dot_prod.masked_fill_(attn_mask, -float(1e20))
403409

404410
attn_weights = self.attn_dropout(F.softmax(dot_prod / scale, dim='T_key'))
@@ -422,7 +428,7 @@ def forward(self, query, key=None, value=None, mask=None):
422428
######################################################################
423429
# The ``query = query.refine_names(..., 'T', 'D')`` serves as enforcable documentation
424430
# and lifts input dimensions to being named. It checks that the last two dimensions
425-
# can be refined to `['T', 'D']`, preventing potentially silent or confusing size
431+
# can be refined to ``['T', 'D']``, preventing potentially silent or confusing size
426432
# mismatch errors later down the line.
427433
#
428434
# **(II) Manipulating dimensions in ``prepare_head``**
@@ -435,13 +441,13 @@ def prepare_head(tensor):
435441

436442
######################################################################
437443
# The first thing to note is how the code clearly states the input and
438-
# output dimensions: the input tensor must end with the `T` and `D` dims
439-
# and the output tensor ends in `H`, `T`, and `D_head` dims.
444+
# output dimensions: the input tensor must end with the ``T`` and ``D`` dims
445+
# and the output tensor ends in ``H``, ``T``, and ``D_head`` dims.
440446
#
441447
# The second thing to note is how clearly the code describes what is going on.
442448
# prepare_head takes the key, query, and value and splits the embedding dim into
443-
# multiple heads, finally rearranging the dim order to be `[..., 'H', 'T', 'D_head']`.
444-
# ParlAI implements prepare_head as the following, using `view` and `transpose`
449+
# multiple heads, finally rearranging the dim order to be ``[..., 'H', 'T', 'D_head']``.
450+
# ParlAI implements ``prepare_head`` as the following, using ``view`` and ``transpose``
445451
# operations:
446452
#
447453
# **(III) Explicit broadcasting by names**
@@ -460,30 +466,25 @@ def prepare_head(tensor):
460466

461467
######################################################################
462468
# Our named tensor variant uses ops that, though more verbose, also have
463-
# more semantic meaning than `view` and `transpose` and include enforcable
469+
# more semantic meaning than ``view`` and ``transpose`` and include enforcable
464470
# documentation in the form of names.
465471
#
466472
# **(III) Explicit broadcasting by names**
467473

468474
def ignore():
469475
# (III)
470-
# Named tensors doesn't support == yet; the following is a workaround.
471-
attn_mask = (mask.renamed(None) == 0).refine_names(*mask.names)
472-
473-
# recall that we had dot_prod.refine_names(..., 'H', 'T', 'T_key')
474-
attn_mask = attn_mask.align_as(dot_prod)
476+
attn_mask = (attn_mask == 0).align_as(dot_prod)
475477

476478
dot_prod.masked_fill_(attn_mask, -float(1e20))
477479

478480
######################################################################
479481
# ``mask`` usually has dims ``[N, T]`` (in the case of self attention) or
480482
# ``[N, T, T_key]`` (in the case of encoder attention) while ``dot_prod``
481483
# has dims ``[N, H, T, T_key]``. To make ``mask`` broadcast correctly with
482-
# ``dot_prod``, we would usually `unsqueeze` dims `1` and `-1` in the case of self
483-
# attention or `unsqueeze` dim `1` in the case of encoder attention. Using
484-
# named tensors, we can simply align the two tensors and stop worrying about
485-
# where to unsqueeze` dims. Using named tensors, we simply align `attn_mask`
486-
# to `dot_prod` using `align_as` and stop worrying about where to `unsqueeze` dims.
484+
# ``dot_prod``, we would usually `unsqueeze` dims ``1`` and ``-1`` in the case of self
485+
# attention or ``unsqueeze`` dim ``1`` in the case of encoder attention. Using
486+
# named tensors, we simply align ``attn_mask`` to ``dot_prod`` using ``align_as``
487+
# and stop worrying about where to ``unsqueeze`` dims.
487488
#
488489
# **(IV) More dimension manipulation using ``align_to`` and ``flatten``**
489490

@@ -496,7 +497,7 @@ def ignore():
496497
)
497498

498499
######################################################################
499-
# (IV): Like (II), `align_to` and `flatten` are more semantically
500+
# (IV): Like (II), ``align_to`` and ``flatten`` are more semantically
500501
# meaningful than `view` (despite being more verbose).
501502
#
502503
# Running the example

0 commit comments

Comments
 (0)