62
62
imgs .names = ['batch' , 'channel' , 'width' , 'height' ]
63
63
print (imgs .names )
64
64
65
- # Method #2: specify new names:
65
+ # Method #2: specify new names (note: this changes names out-of-place)
66
66
imgs = imgs .rename (channel = 'C' , width = 'W' , height = 'H' )
67
67
print (imgs .names )
68
68
88
88
89
89
######################################################################
90
90
# Because named tensors coexist with unnamed tensors, we need a nice way to
91
- # write named- tensor-aware code that works with both named and unnamed tensors.
91
+ # write named tensor-aware code that works with both named and unnamed tensors.
92
92
# Use ``tensor.refine_names(*names)`` to refine dimensions and lift unnamed dims
93
93
# to named dims. Refining a dimension is defined as a "rename" with the following
94
94
# constraints:
97
97
# - A named dim can only be refined to have the same name.
98
98
99
99
imgs = torch .randn (3 , 1 , 1 , 2 )
100
- imgs = imgs .refine_names ('N' , 'C' , 'H' , 'W' )
101
- print (imgs .names )
100
+ print (imgs .refine_names ('N' , 'C' , 'H' , 'W' ))
102
101
103
102
# Coerces the last two dims to 'H' and 'W'. In Python 2, use the string '...' instead of ...
104
103
print (imgs .refine_names (..., 'H' , 'W' ).names )
@@ -116,34 +115,12 @@ def catch_error(fn):
116
115
# Most simple operations propagate names. The ultimate goal for named tensors is
117
116
# for all operations to propagate names in a reasonable, intuitive manner. Many
118
117
# common operations have been added at the time of the 1.3 release; here,
119
- # for example, is `.abs()`:
118
+ # for example, is `` .abs()` `:
120
119
121
120
named_imgs = imgs .refine_names ('N' , 'C' , 'H' , 'W' )
122
121
print (named_imgs .abs ().names )
123
122
124
123
######################################################################
125
- # Speaking of operations propogating names, let's quickly cover one
126
- # of the most important operations in PyTorch.
127
- #
128
- # Matrix multiply
129
- # ---------------
130
- #
131
- # ``torch.mm(A, B)`` performs a dot product between the second dim of `A`
132
- # and the first dim of `B`, returning a tensor with the first dim of `A`
133
- # and the second dim of `B`. (the other matmul functions, such as ``torch.matmul``,
134
- # ``torch.mv``, ``torch.dot``, behave similarly).
135
-
136
- markov_states = torch .randn (128 , 5 , names = ('batch' , 'D' ))
137
- transition_matrix = torch .randn (5 , 5 , names = ('in' , 'out' ))
138
-
139
- # Apply one transition
140
- new_state = markov_states @ transition_matrix
141
- print (new_state .names )
142
-
143
- ######################################################################
144
- # As you can see, matrix multiply does not check if the contracted dimensions
145
- # have the same name.
146
- #
147
124
# Accessors and Reduction
148
125
# -----------------------
149
126
#
@@ -179,31 +156,31 @@ def catch_error(fn):
179
156
######################################################################
180
157
# **Check names**: first, we will check whether the names of these two tensors
181
158
# match. Two names match if and only if they are equal (string equality) or at
182
- # least one is ``None`` (``None``s are essentially a special wildcard name).
183
- # The only one of these three that will error, therefore, is ``x+ z``:
159
+ # least one is ``None`` (``None`` is essentially a special wildcard name).
160
+ # The only one of these three that will error, therefore, is ``x + z``:
184
161
185
162
catch_error (lambda : x + z )
186
163
187
164
######################################################################
188
- # **Propagate names**: _unify_ the two names by returning the most refined name of
189
- # the two. With `x + y, `X` is more refined than `None`.
165
+ # **Propagate names**: `unify` the two names by returning the most refined name of
166
+ # the two. With `` x + y`` , ``X`` is more refined than `` None` `.
190
167
191
168
print ((x + y ).names )
192
169
193
170
######################################################################
194
171
# Most name inference rules are straightforward but some of them can have
195
- # unexpected semantics. Let's go through a few more of them.
172
+ # unexpected semantics. Let's go through a couple you're likely to encounter:
173
+ # broadcasting and matrix multiply.
196
174
#
197
- # Broadcasting
198
- # ------------
175
+ # **Broadcasting**
199
176
#
200
177
# Named tensors do not change broadcasting behavior; they still broadcast by
201
178
# position. However, when checking two dimensions for if they can be
202
179
# broadcasted, the names of those dimensions must match.
203
180
#
204
181
# Furthermore, broadcasting with named tensors can prevent incorrect behavior.
205
182
# The following code will error, whereas without `names` it would add
206
- # `per_batch_scale` to the last dimension of `imgs`.
183
+ # `` per_batch_scale`` to the last dimension of `` imgs` `.
207
184
208
185
# Automatic broadcasting: expected to fail
209
186
imgs = torch .randn (6 , 6 , 6 , 6 , names = ('N' , 'C' , 'H' , 'W' ))
@@ -212,17 +189,37 @@ def catch_error(fn):
212
189
213
190
######################################################################
214
191
# How `should` we perform this broadcasting operation along the first
215
- # dimension? One way, involving names, would be to explicitly initialize
216
- # the ``per_batch_scale`` tensor as four dimensional, and give it names
217
- # (such as ``('N', None, None, None)`` below) so that name inference will
218
- # work.
192
+ # dimension? One way, involving names, would be to name the ``per_batch_scale``
193
+ # tensor such that it matches ``imgs.names``, as shown below.
219
194
220
195
imgs = torch .randn (6 , 6 , 6 , 6 , names = ('N' , 'C' , 'H' , 'W' ))
221
- per_batch_scale_4d = torch .rand (6 , 1 , 1 , 1 , names = ('N' , None , None , None ))
196
+ per_batch_scale_4d = torch .rand (6 , 1 , 1 , 1 , names = ('N' , 'C' , 'H' , 'W' ))
222
197
print ((imgs * per_batch_scale_4d ).names )
223
198
224
199
######################################################################
225
- # However, named tensors enable an even better way, which we'll cover next.
200
+ # Another way would be to use the new explicit broadcasting by names
201
+ # functionality, covered below.
202
+ #
203
+ # **Matrix multiply**
204
+ #
205
+ # ``torch.mm(A, B)`` performs a dot product between the second dim of ``A``
206
+ # and the first dim of ``B``, returning a tensor with the first dim of ``A``
207
+ # and the second dim of ``B``. (the other matmul functions, such as ``torch.matmul``,
208
+ # ``torch.mv``, ``torch.dot``, behave similarly).
209
+
210
+ markov_states = torch .randn (128 , 5 , names = ('batch' , 'D' ))
211
+ transition_matrix = torch .randn (5 , 5 , names = ('in' , 'out' ))
212
+
213
+ # Apply one transition
214
+ new_state = markov_states @ transition_matrix
215
+ print (new_state .names )
216
+
217
+ ######################################################################
218
+ # As you can see, matrix multiply does not check if the contracted dimensions
219
+ # have the same name.
220
+ #
221
+ # Next, we'll cover two new behaviors that named tensors enable: explicit
222
+ # broadcasting by names.
226
223
#
227
224
# New behavior: Explicit broadcasting by names
228
225
# --------------------------------------------
@@ -250,6 +247,7 @@ def catch_error(fn):
250
247
per_batch_scale = per_batch_scale .refine_names ('N' )
251
248
252
249
named_result = imgs * per_batch_scale .align_as (imgs )
250
+ # note: named tensors do not yet work with allclose
253
251
assert torch .allclose (named_result .rename (None ), correct_result )
254
252
255
253
######################################################################
@@ -307,14 +305,12 @@ def catch_error(fn):
307
305
# ------------------------------------------
308
306
#
309
307
# See here (link to be included) for a detailed breakdown of what is
310
- # supported with the 1.3 release, what is on the roadmap to be supported soon,
311
- # and what will be supported in the future but not soon.
308
+ # supported with the 1.3 release.
312
309
#
313
310
# In particular, we want to call out three important features that are not
314
311
# currently supported:
315
312
#
316
- # - Retaining names when serializing or loading a serialized ``Tensor`` via
317
- # ``torch.save``
313
+ # - Saving or loading named tensors via ``torch.save`` or ``torch.load``
318
314
# - Multi-processing via ``torch.multiprocessing``
319
315
# - JIT support; for example, the following will error
320
316
@@ -330,15 +326,15 @@ def fn(x):
330
326
# As a workaround, please drop names via ``tensor = tensor.rename(None)``
331
327
# before using anything that does not yet support named tensors.
332
328
#
333
- # Longer example: Multi-headed attention
329
+ # Longer example: Multi-head attention
334
330
# --------------------------------------
335
331
#
336
332
# Now we'll go through a complete example of implementing a common
337
- # PyTorch ``nn.Module``: multi-headed attention. We assume the reader is already
338
- # familiar with multi-headed attention; for a refresher, check out
333
+ # PyTorch ``nn.Module``: multi-head attention. We assume the reader is already
334
+ # familiar with multi-head attention; for a refresher, check out
339
335
# `this explanation <http://jalammar.github.io/illustrated-transformer/>`_.
340
336
#
341
- # We adapt the implementation of multi-headed attention from
337
+ # We adapt the implementation of multi-head attention from
342
338
# `ParlAI <https://github.com/facebookresearch/ParlAI>`_; specifically
343
339
# `here <https://github.com/facebookresearch/ParlAI/blob/f7db35cba3f3faf6097b3e6b208442cd564783d9/parlai/agents/transformer/modules.py#L907>`_.
344
340
# Read through the code at that example; then, compare with the code below,
@@ -431,7 +427,7 @@ def forward(self, query, key=None, value=None, mask=None):
431
427
# can be refined to ``['T', 'D']``, preventing potentially silent or confusing size
432
428
# mismatch errors later down the line.
433
429
#
434
- # **(II) Manipulating dimensions in `` prepare_head`` **
430
+ # **(II) Manipulating dimensions in prepare_head**
435
431
436
432
# (II)
437
433
def prepare_head (tensor ):
@@ -449,8 +445,6 @@ def prepare_head(tensor):
449
445
# multiple heads, finally rearranging the dim order to be ``[..., 'H', 'T', 'D_head']``.
450
446
# ParlAI implements ``prepare_head`` as the following, using ``view`` and ``transpose``
451
447
# operations:
452
- #
453
- # **(III) Explicit broadcasting by names**
454
448
455
449
def prepare_head (tensor ):
456
450
# input is [batch_size, seq_len, n_heads * dim_per_head]
@@ -486,7 +480,7 @@ def ignore():
486
480
# named tensors, we simply align ``attn_mask`` to ``dot_prod`` using ``align_as``
487
481
# and stop worrying about where to ``unsqueeze`` dims.
488
482
#
489
- # **(IV) More dimension manipulation using `` align_to`` and `` flatten`` **
483
+ # **(IV) More dimension manipulation using align_to and flatten**
490
484
491
485
def ignore ():
492
486
# (IV)
@@ -497,8 +491,8 @@ def ignore():
497
491
)
498
492
499
493
######################################################################
500
- # (IV): Like (II), ``align_to`` and ``flatten`` are more semantically
501
- # meaningful than `view` (despite being more verbose).
494
+ # Here, as in (II), ``align_to`` and ``flatten`` are more semantically
495
+ # meaningful than `` view` ` (despite being more verbose).
502
496
#
503
497
# Running the example
504
498
# -------------------
@@ -508,6 +502,18 @@ def ignore():
508
502
mask = torch .ones (n , t , names = ('N' , 'T' ))
509
503
attn = MultiHeadAttention (h , d )
510
504
output = attn (query , mask = mask )
505
+ # works as expected!
506
+ print (output .names )
507
+
508
+ ######################################################################
509
+ # The above works as expected. Furthermore, note that in the code we
510
+ # did not mention the name of the batch dimension at all. In fact,
511
+ # the code is agnostic to the existence of the batch dimensions, so
512
+ # that we can run the following example-level code:
513
+
514
+ query = torch .randn (t , d , names = ('T' , 'D' ))
515
+ mask = torch .ones (t , names = ('T' ,))
516
+ output = attn (query , mask = mask )
511
517
print (output .names )
512
518
513
519
######################################################################
0 commit comments