4
4
****************************************
5
5
**Author**: `Richard Zou <https://github.com/zou3519>`_
6
6
7
- `Sasha Rush <https://tech.cornell.edu/people/alexander-rush/>`_ proposed the idea of
8
- `named tensors <http://nlp.seas.harvard.edu/NamedTensor>`_ in a January 2019 blog post as a
9
- way to enable more readable code when writing with the manipulations of multidimensional
10
- arrays necessary for coding up Transformer and Convolutional architectures. With PyTorch 1.3,
11
- we begin supporting the concept of named tensors by allowing a ``Tensor`` to have **named
12
- dimensions**; this tutorial is intended as a guide to the functionality that will
7
+ Named Tensors aim to make tensors easier to use by allowing users to associate explicit names
8
+ with tensor dimensions. In most cases, operations that take dimension parameters will accept
9
+ dimension names, avoiding the need to track dimensions by position. In addition, named tensors
10
+ use names to automatically check that APIs are being used correctly at runtime, providing extra
11
+ safety. Names can also be used to rearrange dimensions, for example, to support
12
+ "broadcasting by name" rather than "broadcasting by position".
13
+
14
+ this tutorial is intended as a guide to the functionality that will
13
15
be included with the 1.3 launch. By the end of it, you will be able to:
14
16
15
17
- Initiate a ``Tensor`` with named dimensions, as well as removing or renmaing those dimensions
18
20
- Broadcasting operations
19
21
- Flattening and unflattening dimensions
20
22
21
- Finally, we'll put this into practice by coding the operations of multi-headed attention
22
- using named tensors, and see that the code is significantly more readable than it would
23
- be with regular, "unnamed" tensors!
23
+ Finally, we'll put this into practice by writing a multi-headed attention module
24
+ using named tensors.
25
+
26
+ Named tensors in PyTorch are inspired by and done in collaboration with
27
+ `Sasha Rush <https://tech.cornell.edu/people/alexander-rush/>`_.
28
+ The original idea and proof of concept were proposed in his
29
+ `January 2019 blog post <http://nlp.seas.harvard.edu/NamedTensor>`_.
24
30
"""
25
31
26
32
######################################################################
27
33
# Basics: named dimensions
28
34
# ------------------------
29
35
#
30
- # Tensors now take a new ``names`` argument that represents a name for each dimension.
31
- # Here we construct a tensor with names:
36
+ # PyTorch now allows Tensors to have named dimensions; factory functions
37
+ # now take a new `names` argument that associates a name with each dimension.
38
+ # This works with most factory functions, such as
39
+ #
40
+ # - `tensor`
41
+ # - `empty`
42
+ # - `ones`
43
+ # - `zeros`
44
+ # - `randn`
45
+ # - `rand`
32
46
#
47
+ # Here we construct a tensor with names:
33
48
34
49
import torch
35
50
imgs = torch .randn (1 , 2 , 2 , 3 , names = ('N' , 'C' , 'H' , 'W' ))
36
51
37
- ######################################################################
38
- # This works with most factory functions, such as:
39
- #
40
- # - ``tensor``
41
- # - ``empty``
42
- # - ``ones``
43
- # - ``zeros``
44
- #
45
52
# Unlike in
46
53
# `the original named tensors blogpost <http://nlp.seas.harvard.edu/NamedTensor>`_,
47
- # named dimensions are ordered. `tensor.names[i]` is the name of the `i`th dimension of `tensor`.
54
+ # named dimensions are ordered: `tensor.names[i]` is the name of the `i`th dimension of `tensor`.
48
55
#
49
- # There are two ways rename a ``Tensor``'s names :
56
+ # There are two ways rename a ``Tensor``'s dimensions :
50
57
#
51
58
52
59
print (imgs .names )
63
70
# The preferred way to remove names is to call ``tensor.rename(None)``:
64
71
65
72
imgs .rename (None )
73
+ print (imgs .names )
66
74
67
75
######################################################################
68
76
# Unnamed tensors (tensors with no named dimensions) still work as normal and do
69
- # not have names in their repr.
77
+ # not have names in their ` repr` .
70
78
71
79
unnamed = torch .randn (2 , 1 , 3 )
72
80
print (unnamed )
87
95
# - A ``None`` dim can be refined to have any name
88
96
# - A named dim can only be refined to have the same name.
89
97
90
- print (imgs .names )
91
- print (imgs .refine_names ('N' , 'C' , 'H' , 'W' ).names )
98
+ imgs = torch .randn (3 , 1 , 1 , 2 )
99
+ named_imgs = imgs .refine_names ('N' , 'C' , 'H' , 'W' )
100
+ print (named_imgs .names )
92
101
93
102
# Coerces the last two dims to 'H' and 'W'. In Python 2, use the string '...' instead of ...
94
103
print (imgs .refine_names (..., 'H' , 'W' ).names )
@@ -105,7 +114,7 @@ def catch_error(fn):
105
114
######################################################################
106
115
# Most simple operations propagate names. The ultimate goal for named tensors is
107
116
# for all operations to propagate names in a reasonable, intuitive manner. Many
108
- # common operations have been implemented at the time of the 1.3 release; here,
117
+ # common operations have been added at the time of the 1.3 release; here,
109
118
# for example, is `.abs()`:
110
119
111
120
named_imgs = imgs .refine_names ('N' , 'C' , 'H' , 'W' )
@@ -133,9 +142,9 @@ def catch_error(fn):
133
142
# Names are propagated on operations in a two step process called **name inference**. It
134
143
# works as follows:
135
144
#
136
- # - **Check names**: an operator may check that certain dimensions must match.
137
- # - **Propagate names**: name inference computes and propagates output names to
138
- # output tensors.
145
+ # 1. **Check names**: an operator may check that certain dimensions must match.
146
+ # 2. **Propagate names**: name inference computes and propagates output names to
147
+ # output tensors.
139
148
#
140
149
# Let's go through the very small example of adding 2 one-dim tensors with no
141
150
# broadcasting.
@@ -151,7 +160,7 @@ def catch_error(fn):
151
160
152
161
catch_error (lambda : x + z )
153
162
154
- # **Propagate names**: unify the two names by returning the most refined name of
163
+ # **Propagate names**: _unify_ the two names by returning the most refined name of
155
164
# the two. With ``x + y``, ``X`` is more specific than ``None``.
156
165
157
166
print ((x + y ).names )
@@ -188,11 +197,10 @@ def catch_error(fn):
188
197
# Matrix multiply
189
198
# ---------------
190
199
#
191
- # Of course, many of you may be wondering about the very special operation of
192
- # matrix multiplication. ``torch.mm(A, B)`` contracts away the second dimension
193
- # of ``A`` with the first dimension of ``B``, returning a tensor with the first
194
- # dim of ``A`` and the second dim of ``B``. (the other matmul functions,
195
- # ``torch.matmul``, ``torch.mv``, ``torch.dot``, behave similarly):
200
+ # `torch.mm(A, B)`` performs a dot product between the second dim of `A`
201
+ # and the first dim of `B`, returning a tensor with the first dim of `A`
202
+ # and the second dim of `B`. (the other matmul functions, such as `torch.matmul`,
203
+ # `torch.mv`, `torch.dot`, behave similarly).
196
204
197
205
markov_states = torch .randn (128 , 5 , names = ('batch' , 'D' ))
198
206
transition_matrix = torch .randn (5 , 5 , names = ('in' , 'out' ))
@@ -202,6 +210,9 @@ def catch_error(fn):
202
210
print (new_state .names )
203
211
204
212
######################################################################
213
+ # Inherently, matrix multiply does not check if the contracted dimensions
214
+ # have the same name.
215
+ #
205
216
# New behavior: Explicit broadcasting by names
206
217
# --------------------------------------------
207
218
#
@@ -272,6 +283,14 @@ def catch_error(fn):
272
283
273
284
print (weight .grad ) # Unnamed for now. Will be named in the future
274
285
286
+ weight .grad .zero_ ()
287
+ grad_loss = grad_loss .refine_names ('C' )
288
+ loss = (x - weight ).abs ()
289
+ # Ideally we'd check that the names of loss and grad_loss match but we don't yet.
290
+ loss .backward (grad_loss )
291
+
292
+ print (weight .grad ) # still unnamed
293
+
275
294
######################################################################
276
295
# Other supported (and unsupported) features
277
296
# ------------------------------------------
@@ -280,8 +299,8 @@ def catch_error(fn):
280
299
# supported with the 1.3 release, what is on the roadmap to be supported soon,
281
300
# and what will be supported in the future but not soon.
282
301
#
283
- # In particular, three important features that we do not have plans to support
284
- # soon are :
302
+ # In particular, we want to call out three important features that are not
303
+ # currently supported :
285
304
#
286
305
# - Retaining names when serializing or loading a serialized ``Tensor`` via
287
306
# ``torch.save``
@@ -313,7 +332,8 @@ def fn(x):
313
332
# `here <https://github.com/facebookresearch/ParlAI/blob/f7db35cba3f3faf6097b3e6b208442cd564783d9/parlai/agents/transformer/modules.py#L907>`_.
314
333
# Read through the code at that example; then, compare with the code below,
315
334
# noting that there are four places labeled (I), (II), (III), and (IV), where
316
- # using named tensors enables more readable code.
335
+ # using named tensors enables more readable code; we will dive into each of these
336
+ # after the code block.
317
337
318
338
import torch .nn as nn
319
339
import torch .nn .functional as F
@@ -393,8 +413,6 @@ def prepare_head(tensor):
393
413
return self .out_lin (attentioned ).refine_names (..., 'T' , 'D' )
394
414
395
415
######################################################################
396
- # Let's dive into each of these areas in turn:
397
- #
398
416
# **(I) Refining the input tensor dims**
399
417
400
418
def forward (self , query , key = None , value = None , mask = None ):
@@ -409,18 +427,6 @@ def forward(self, query, key=None, value=None, mask=None):
409
427
#
410
428
# **(II) Manipulating dimensions in ``prepare_head``**
411
429
412
- def prepare_head (tensor ):
413
- # (II)
414
- tensor = tensor .refine_names ('N' , 'T' , 'D' )
415
- return (tensor .unflatten ('D' , [('H' , n_heads ), ('D_head' , dim_per_head )])
416
- .align_to ('N' , 'H' , 'T' , 'D_head' ).contiguous ())
417
-
418
- ######################################################################
419
- # Next, multihead attention takes the key, query, and value and splits their
420
- # feature dimensions into multiple heads and rearranges the dim order to be
421
- # ``['N', 'H', 'T', 'D_head']``. We can achieve something similar using view
422
- # and transpose operations like the following:
423
-
424
430
# (II)
425
431
def prepare_head (tensor ):
426
432
tensor = tensor .refine_names (..., 'T' , 'D' )
@@ -470,13 +476,13 @@ def ignore():
470
476
dot_prod .masked_fill_ (attn_mask , - float (1e20 ))
471
477
472
478
######################################################################
473
- # ``mask`` usually has dims ``[N, T]`` (in the case of self- attention) or
479
+ # ``mask`` usually has dims ``[N, T]`` (in the case of self attention) or
474
480
# ``[N, T, T_key]`` (in the case of encoder attention) while ``dot_prod``
475
481
# has dims ``[N, H, T, T_key]``. To make ``mask`` broadcast correctly with
476
482
# ``dot_prod``, we would usually `unsqueeze` dims `1` and `-1` in the case of self
477
483
# attention or `unsqueeze` dim `1` in the case of encoder attention. Using
478
484
# named tensors, we can simply align the two tensors and stop worrying about
479
- # where to `` unsqueeze` ` dims. Using named tensors, we simply align `attn_mask`
485
+ # where to unsqueeze` dims. Using named tensors, we simply align `attn_mask`
480
486
# to `dot_prod` using `align_as` and stop worrying about where to `unsqueeze` dims.
481
487
#
482
488
# **(IV) More dimension manipulation using ``align_to`` and ``flatten``**
@@ -490,8 +496,8 @@ def ignore():
490
496
)
491
497
492
498
######################################################################
493
- # (IV): Like (II), using `` align_to`` and `` flatten` ` are more semantically
494
- # meaningful than `view`.
499
+ # (IV): Like (II), ` align_to` and `flatten` are more semantically
500
+ # meaningful than `view` (despite being more verbose) .
495
501
#
496
502
# Running the example
497
503
# -------------------
0 commit comments