1
1
# -*- coding: utf-8 -*-
2
2
"""
3
- Introduction to Named Tensors in PyTorch
4
- ****************************************
3
+ (experimental) Introduction to Named Tensors in PyTorch
4
+ *******************************************************
5
5
**Author**: `Richard Zou <https://github.com/zou3519>`_
6
6
7
+ **Editor**: `Seth Weidman <https://github.com/SethHWeidman>`_
8
+
7
9
Named Tensors aim to make tensors easier to use by allowing users to associate explicit names
8
10
with tensor dimensions. In most cases, operations that take dimension parameters will accept
9
11
dimension names, avoiding the need to track dimensions by position. In addition, named tensors
10
12
use names to automatically check that APIs are being used correctly at runtime, providing extra
11
13
safety. Names can also be used to rearrange dimensions, for example, to support
12
14
"broadcasting by name" rather than "broadcasting by position".
13
15
14
- this tutorial is intended as a guide to the functionality that will
16
+ This tutorial is intended as a guide to the functionality that will
15
17
be included with the 1.3 launch. By the end of it, you will be able to:
16
18
17
19
- Initiate a ``Tensor`` with named dimensions, as well as removing or renmaing those dimensions
20
22
- Broadcasting operations
21
23
- Flattening and unflattening dimensions
22
24
23
- Finally, we'll put this into practice by writing a multi-headed attention module
25
+ Finally, we'll put this into practice by writing a multi-head attention module
24
26
using named tensors.
25
27
26
28
Named tensors in PyTorch are inspired by and done in collaboration with
27
29
`Sasha Rush <https://tech.cornell.edu/people/alexander-rush/>`_.
28
30
The original idea and proof of concept were proposed in his
29
31
`January 2019 blog post <http://nlp.seas.harvard.edu/NamedTensor>`_.
30
- """
31
32
32
- ######################################################################
33
- # Basics: named dimensions
34
- # ------------------------
35
- #
36
- # PyTorch now allows Tensors to have named dimensions; factory functions
37
- # now take a new `names` argument that associates a name with each dimension.
38
- # This works with most factory functions, such as
39
- #
40
- # - `tensor `
41
- # - `empty `
42
- # - `ones `
43
- # - `zeros `
44
- # - `randn `
45
- # - `rand`
46
- #
47
- # Here we construct a tensor with names:
33
+ Basics: named dimensions
34
+ ========================
35
+
36
+ PyTorch now allows Tensors to have named dimensions; factory functions
37
+ now take a new `names` argument that associates a name with each dimension.
38
+ This works with most factory functions, such as
39
+
40
+ - `tensor`
41
+ - `empty `
42
+ - `ones `
43
+ - `zeros `
44
+ - `randn `
45
+ - `rand `
46
+
47
+ Here we construct a tensor with names:
48
+ """
48
49
49
50
import torch
50
51
imgs = torch .randn (1 , 2 , 2 , 3 , names = ('N' , 'C' , 'H' , 'W' ))
52
+ print (imgs .names )
51
53
54
+ ######################################################################
52
55
# Unlike in
53
56
# `the original named tensors blogpost <http://nlp.seas.harvard.edu/NamedTensor>`_,
54
- # named dimensions are ordered: `tensor.names[i]` is the name of the `i` th dimension of `tensor`.
57
+ # named dimensions are ordered: `` tensor.names[i]`` is the name of the ``i`` th dimension of `` tensor` `.
55
58
#
56
59
# There are two ways rename a ``Tensor``'s dimensions:
57
- #
58
-
59
- print (imgs .names )
60
60
61
61
# Method #1: set .names attribute
62
62
imgs .names = ['batch' , 'channel' , 'width' , 'height' ]
63
63
print (imgs .names )
64
64
65
65
# Method #2: specify new names:
66
- imgs .rename (channel = 'C' , width = 'W' , height = 'H' )
66
+ imgs = imgs .rename (channel = 'C' , width = 'W' , height = 'H' )
67
67
print (imgs .names )
68
68
69
69
######################################################################
70
70
# The preferred way to remove names is to call ``tensor.rename(None)``:
71
71
72
- imgs .rename (None )
72
+ imgs = imgs .rename (None )
73
73
print (imgs .names )
74
74
75
75
######################################################################
76
76
# Unnamed tensors (tensors with no named dimensions) still work as normal and do
77
- # not have names in their `repr`.
77
+ # not have names in their `` repr` `.
78
78
79
79
unnamed = torch .randn (2 , 1 , 3 )
80
80
print (unnamed )
87
87
print (imgs .names )
88
88
89
89
######################################################################
90
- # Because named tensors coexist with unnamed tensors, we need a nice way to write named-tensor-aware
91
- # code that works with both named and unnamed tensors. Use ``tensor.refine_names(*names)`` to refine
92
- # dimensions and lift unnamed dims to named dims. Refining a dimension is defined as a "rename" with
93
- # the following constraints:
90
+ # Because named tensors coexist with unnamed tensors, we need a nice way to
91
+ # write named-tensor-aware code that works with both named and unnamed tensors.
92
+ # Use ``tensor.refine_names(*names)`` to refine dimensions and lift unnamed dims
93
+ # to named dims. Refining a dimension is defined as a "rename" with the following
94
+ # constraints:
94
95
#
95
96
# - A ``None`` dim can be refined to have any name
96
97
# - A named dim can only be refined to have the same name.
97
98
98
99
imgs = torch .randn (3 , 1 , 1 , 2 )
99
- named_imgs = imgs .refine_names ('N' , 'C' , 'H' , 'W' )
100
- print (named_imgs .names )
100
+ imgs = imgs .refine_names ('N' , 'C' , 'H' , 'W' )
101
+ print (imgs .names )
101
102
102
103
# Coerces the last two dims to 'H' and 'W'. In Python 2, use the string '...' instead of ...
103
104
print (imgs .refine_names (..., 'H' , 'W' ).names )
@@ -109,7 +110,7 @@ def catch_error(fn):
109
110
print (err )
110
111
111
112
# Tried to refine an existing name to a different name
112
- print ( catch_error (lambda : imgs .refine_names ('batch' , 'channel' , 'height' , 'width' ) ))
113
+ catch_error (lambda : imgs .refine_names ('batch' , 'channel' , 'height' , 'width' ))
113
114
114
115
######################################################################
115
116
# Most simple operations propagate names. The ultimate goal for named tensors is
@@ -121,12 +122,34 @@ def catch_error(fn):
121
122
print (named_imgs .abs ().names )
122
123
123
124
######################################################################
125
+ # Speaking of operations propogating names, let's quickly cover one
126
+ # of the most important operations in PyTorch.
127
+ #
128
+ # Matrix multiply
129
+ # ---------------
130
+ #
131
+ # ``torch.mm(A, B)`` performs a dot product between the second dim of `A`
132
+ # and the first dim of `B`, returning a tensor with the first dim of `A`
133
+ # and the second dim of `B`. (the other matmul functions, such as ``torch.matmul``,
134
+ # ``torch.mv``, ``torch.dot``, behave similarly).
135
+
136
+ markov_states = torch .randn (128 , 5 , names = ('batch' , 'D' ))
137
+ transition_matrix = torch .randn (5 , 5 , names = ('in' , 'out' ))
138
+
139
+ # Apply one transition
140
+ new_state = markov_states @ transition_matrix
141
+ print (new_state .names )
142
+
143
+ ######################################################################
144
+ # As you can see, matrix multiply does not check if the contracted dimensions
145
+ # have the same name.
146
+ #
124
147
# Accessors and Reduction
125
148
# -----------------------
126
149
#
127
150
# One can use dimension names to refer to dimensions instead of the positional
128
151
# dimension. These operations also propagate names. Indexing (basic and
129
- # advanced) has not been implemented yet but is on the roadmap. Using the `named_imgs`
152
+ # advanced) has not been implemented yet but is on the roadmap. Using the `` named_imgs` `
130
153
# tensor from above, we can do:
131
154
132
155
output = named_imgs .sum (['C' ]) # Perform a sum over the channel dimension
@@ -139,12 +162,12 @@ def catch_error(fn):
139
162
# Name inference
140
163
# --------------
141
164
#
142
- # Names are propagated on operations in a two step process called **name inference**. It
143
- # works as follows :
165
+ # Names are propagated on operations in a two step process called **name inference**.
166
+ # The two steps are :
144
167
#
145
- # 1. **Check names**: an operator may check that certain dimensions must match.
146
- # 2. **Propagate names**: name inference computes and propagates output names to
147
- # output tensors.
168
+ # 1. **Check names**: an operator may perform automatic checks at runtime that
169
+ # check that certain dimension names must match.
170
+ # 2. **Propagate names**: name inference propagates output names to output tensors.
148
171
#
149
172
# Let's go through the very small example of adding 2 one-dim tensors with no
150
173
# broadcasting.
@@ -153,65 +176,53 @@ def catch_error(fn):
153
176
y = torch .randn (3 )
154
177
z = torch .randn (3 , names = ('Z' ,))
155
178
179
+ ######################################################################
156
180
# **Check names**: first, we will check whether the names of these two tensors
157
181
# match. Two names match if and only if they are equal (string equality) or at
158
182
# least one is ``None`` (``None``s are essentially a special wildcard name).
159
183
# The only one of these three that will error, therefore, is ``x+z``:
160
184
161
185
catch_error (lambda : x + z )
162
186
187
+ ######################################################################
163
188
# **Propagate names**: _unify_ the two names by returning the most refined name of
164
- # the two. With `` x + y`` , ``X`` is more specific than `` None` `.
189
+ # the two. With `x + y, `X` is more refined than `None`.
165
190
166
191
print ((x + y ).names )
167
192
168
193
######################################################################
169
- # Most name inference rules are straightforward but some of them (the dot
170
- # product ones) can have unexpected semantics. Let's go through a few more of
171
- # them.
194
+ # Most name inference rules are straightforward but some of them can have
195
+ # unexpected semantics. Let's go through a few more of them.
172
196
#
173
197
# Broadcasting
174
198
# ------------
175
199
#
176
200
# Named tensors do not change broadcasting behavior; they still broadcast by
177
201
# position. However, when checking two dimensions for if they can be
178
- # broadcasted, the names of those dimensions must match. Two names match if and
179
- # only if they are equal (string equality), or if one is None.
202
+ # broadcasted, the names of those dimensions must match.
180
203
#
181
- # We do not support **automatic broadcasting** by names because the output
182
- # ordering is ambiguous and does not work well with unnamed dimensions. However,
183
- # we support **explicit broadcasting** by names, which is introduced in a later
184
- # section. The two examples below help clarify this.
204
+ # Furthermore, broadcasting with named tensors can prevent incorrect behavior.
205
+ # The following code will error, whereas without `names` it would add
206
+ # `per_batch_scale` to the last dimension of `imgs`.
185
207
186
208
# Automatic broadcasting: expected to fail
187
209
imgs = torch .randn (6 , 6 , 6 , 6 , names = ('N' , 'C' , 'H' , 'W' ))
188
210
per_batch_scale = torch .rand (6 , names = ('N' ,))
189
211
catch_error (lambda : imgs * per_batch_scale )
190
212
191
- # Explicit broadcasting: the names check out and the more refined names are propagated.
213
+ ######################################################################
214
+ # How `should` we perform this broadcasting operation along the first
215
+ # dimension? One way, involving names, would be to explicitly initialize
216
+ # the ``per_batch_scale`` tensor as four dimensional, and give it names
217
+ # (such as ``('N', None, None, None)`` below) so that name inference will
218
+ # work.
219
+
192
220
imgs = torch .randn (6 , 6 , 6 , 6 , names = ('N' , 'C' , 'H' , 'W' ))
193
221
per_batch_scale_4d = torch .rand (6 , 1 , 1 , 1 , names = ('N' , None , None , None ))
194
222
print ((imgs * per_batch_scale_4d ).names )
195
223
196
224
######################################################################
197
- # Matrix multiply
198
- # ---------------
199
- #
200
- # `torch.mm(A, B)`` performs a dot product between the second dim of `A`
201
- # and the first dim of `B`, returning a tensor with the first dim of `A`
202
- # and the second dim of `B`. (the other matmul functions, such as `torch.matmul`,
203
- # `torch.mv`, `torch.dot`, behave similarly).
204
-
205
- markov_states = torch .randn (128 , 5 , names = ('batch' , 'D' ))
206
- transition_matrix = torch .randn (5 , 5 , names = ('in' , 'out' ))
207
-
208
- # Apply one transition
209
- new_state = markov_states @ transition_matrix
210
- print (new_state .names )
211
-
212
- ######################################################################
213
- # Inherently, matrix multiply does not check if the contracted dimensions
214
- # have the same name.
225
+ # However, named tensors enable an even better way, which we'll cover next.
215
226
#
216
227
# New behavior: Explicit broadcasting by names
217
228
# --------------------------------------------
@@ -316,13 +327,13 @@ def fn(x):
316
327
catch_error (lambda : fn (imgs_named ))
317
328
318
329
######################################################################
319
- # As a workaround, please drop names via `tensor = tensor.rename(None)`
330
+ # As a workaround, please drop names via `` tensor = tensor.rename(None)` `
320
331
# before using anything that does not yet support named tensors.
321
332
#
322
333
# Longer example: Multi-headed attention
323
334
# --------------------------------------
324
335
#
325
- # Now we'll go through a complete example of implementing a common advanced
336
+ # Now we'll go through a complete example of implementing a common
326
337
# PyTorch ``nn.Module``: multi-headed attention. We assume the reader is already
327
338
# familiar with multi-headed attention; for a refresher, check out
328
339
# `this explanation <http://jalammar.github.io/illustrated-transformer/>`_.
@@ -378,12 +389,9 @@ def prepare_head(tensor):
378
389
return (tensor .unflatten ('D' , [('H' , n_heads ), ('D_head' , dim_per_head )])
379
390
.align_to (..., 'H' , 'T' , 'D_head' ))
380
391
392
+ assert value is None
381
393
if self_attn :
382
394
key = value = query
383
- elif value is None :
384
- # key and value are the same, but query differs
385
- key = key .refine_names (..., 'T' , 'D' )
386
- value = key
387
395
key_len = key .size ('T' )
388
396
dim = key .size ('D' )
389
397
@@ -396,9 +404,7 @@ def prepare_head(tensor):
396
404
dot_prod .refine_names (..., 'H' , 'T' , 'T_key' ) # just a check
397
405
398
406
# (III)
399
- # Named tensors doesn't support `==` yet; the following is a workaround.
400
- attn_mask = (mask .rename (None ) == 0 ).refine_names (* mask .names )
401
- attn_mask = attn_mask .align_as (dot_prod )
407
+ attn_mask = (attn_mask == 0 ).align_as (dot_prod )
402
408
dot_prod .masked_fill_ (attn_mask , - float (1e20 ))
403
409
404
410
attn_weights = self .attn_dropout (F .softmax (dot_prod / scale , dim = 'T_key' ))
@@ -422,7 +428,7 @@ def forward(self, query, key=None, value=None, mask=None):
422
428
######################################################################
423
429
# The ``query = query.refine_names(..., 'T', 'D')`` serves as enforcable documentation
424
430
# and lifts input dimensions to being named. It checks that the last two dimensions
425
- # can be refined to `['T', 'D']`, preventing potentially silent or confusing size
431
+ # can be refined to `` ['T', 'D']` `, preventing potentially silent or confusing size
426
432
# mismatch errors later down the line.
427
433
#
428
434
# **(II) Manipulating dimensions in ``prepare_head``**
@@ -435,13 +441,13 @@ def prepare_head(tensor):
435
441
436
442
######################################################################
437
443
# The first thing to note is how the code clearly states the input and
438
- # output dimensions: the input tensor must end with the `T` and `D ` dims
439
- # and the output tensor ends in `H` , `T` , and `D_head` dims.
444
+ # output dimensions: the input tensor must end with the ``T`` and ``D` ` dims
445
+ # and the output tensor ends in ``H`` , ``T`` , and `` D_head` ` dims.
440
446
#
441
447
# The second thing to note is how clearly the code describes what is going on.
442
448
# prepare_head takes the key, query, and value and splits the embedding dim into
443
- # multiple heads, finally rearranging the dim order to be `[..., 'H', 'T', 'D_head']`.
444
- # ParlAI implements prepare_head as the following, using `view` and `transpose`
449
+ # multiple heads, finally rearranging the dim order to be `` [..., 'H', 'T', 'D_head']` `.
450
+ # ParlAI implements `` prepare_head`` as the following, using `` view`` and `` transpose` `
445
451
# operations:
446
452
#
447
453
# **(III) Explicit broadcasting by names**
@@ -460,30 +466,25 @@ def prepare_head(tensor):
460
466
461
467
######################################################################
462
468
# Our named tensor variant uses ops that, though more verbose, also have
463
- # more semantic meaning than `view` and `transpose` and include enforcable
469
+ # more semantic meaning than `` view`` and `` transpose` ` and include enforcable
464
470
# documentation in the form of names.
465
471
#
466
472
# **(III) Explicit broadcasting by names**
467
473
468
474
def ignore ():
469
475
# (III)
470
- # Named tensors doesn't support == yet; the following is a workaround.
471
- attn_mask = (mask .renamed (None ) == 0 ).refine_names (* mask .names )
472
-
473
- # recall that we had dot_prod.refine_names(..., 'H', 'T', 'T_key')
474
- attn_mask = attn_mask .align_as (dot_prod )
476
+ attn_mask = (attn_mask == 0 ).align_as (dot_prod )
475
477
476
478
dot_prod .masked_fill_ (attn_mask , - float (1e20 ))
477
479
478
480
######################################################################
479
481
# ``mask`` usually has dims ``[N, T]`` (in the case of self attention) or
480
482
# ``[N, T, T_key]`` (in the case of encoder attention) while ``dot_prod``
481
483
# has dims ``[N, H, T, T_key]``. To make ``mask`` broadcast correctly with
482
- # ``dot_prod``, we would usually `unsqueeze` dims `1` and `-1` in the case of self
483
- # attention or `unsqueeze` dim `1` in the case of encoder attention. Using
484
- # named tensors, we can simply align the two tensors and stop worrying about
485
- # where to unsqueeze` dims. Using named tensors, we simply align `attn_mask`
486
- # to `dot_prod` using `align_as` and stop worrying about where to `unsqueeze` dims.
484
+ # ``dot_prod``, we would usually `unsqueeze` dims ``1`` and ``-1`` in the case of self
485
+ # attention or ``unsqueeze`` dim ``1`` in the case of encoder attention. Using
486
+ # named tensors, we simply align ``attn_mask`` to ``dot_prod`` using ``align_as``
487
+ # and stop worrying about where to ``unsqueeze`` dims.
487
488
#
488
489
# **(IV) More dimension manipulation using ``align_to`` and ``flatten``**
489
490
@@ -496,7 +497,7 @@ def ignore():
496
497
)
497
498
498
499
######################################################################
499
- # (IV): Like (II), `align_to` and `flatten` are more semantically
500
+ # (IV): Like (II), `` align_to`` and `` flatten` ` are more semantically
500
501
# meaningful than `view` (despite being more verbose).
501
502
#
502
503
# Running the example
0 commit comments