Skip to content

Commit 91ccc85

Browse files
author
Seth Weidman
committed
Final formatting fixes, addressing comments on flow and wording
1 parent 5c15b66 commit 91ccc85

File tree

1 file changed

+61
-55
lines changed

1 file changed

+61
-55
lines changed

intermediate_source/named_tensor_tutorial.py

Lines changed: 61 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@
6262
imgs.names = ['batch', 'channel', 'width', 'height']
6363
print(imgs.names)
6464

65-
# Method #2: specify new names:
65+
# Method #2: specify new names (note: this changes names out-of-place)
6666
imgs = imgs.rename(channel='C', width='W', height='H')
6767
print(imgs.names)
6868

@@ -88,7 +88,7 @@
8888

8989
######################################################################
9090
# Because named tensors coexist with unnamed tensors, we need a nice way to
91-
# write named-tensor-aware code that works with both named and unnamed tensors.
91+
# write named tensor-aware code that works with both named and unnamed tensors.
9292
# Use ``tensor.refine_names(*names)`` to refine dimensions and lift unnamed dims
9393
# to named dims. Refining a dimension is defined as a "rename" with the following
9494
# constraints:
@@ -97,8 +97,7 @@
9797
# - A named dim can only be refined to have the same name.
9898

9999
imgs = torch.randn(3, 1, 1, 2)
100-
imgs = imgs.refine_names('N', 'C', 'H', 'W')
101-
print(imgs.names)
100+
print(imgs.refine_names('N', 'C', 'H', 'W'))
102101

103102
# Coerces the last two dims to 'H' and 'W'. In Python 2, use the string '...' instead of ...
104103
print(imgs.refine_names(..., 'H', 'W').names)
@@ -116,34 +115,12 @@ def catch_error(fn):
116115
# Most simple operations propagate names. The ultimate goal for named tensors is
117116
# for all operations to propagate names in a reasonable, intuitive manner. Many
118117
# common operations have been added at the time of the 1.3 release; here,
119-
# for example, is `.abs()`:
118+
# for example, is ``.abs()``:
120119

121120
named_imgs = imgs.refine_names('N', 'C', 'H', 'W')
122121
print(named_imgs.abs().names)
123122

124123
######################################################################
125-
# Speaking of operations propogating names, let's quickly cover one
126-
# of the most important operations in PyTorch.
127-
#
128-
# Matrix multiply
129-
# ---------------
130-
#
131-
# ``torch.mm(A, B)`` performs a dot product between the second dim of `A`
132-
# and the first dim of `B`, returning a tensor with the first dim of `A`
133-
# and the second dim of `B`. (the other matmul functions, such as ``torch.matmul``,
134-
# ``torch.mv``, ``torch.dot``, behave similarly).
135-
136-
markov_states = torch.randn(128, 5, names=('batch', 'D'))
137-
transition_matrix = torch.randn(5, 5, names=('in', 'out'))
138-
139-
# Apply one transition
140-
new_state = markov_states @ transition_matrix
141-
print(new_state.names)
142-
143-
######################################################################
144-
# As you can see, matrix multiply does not check if the contracted dimensions
145-
# have the same name.
146-
#
147124
# Accessors and Reduction
148125
# -----------------------
149126
#
@@ -179,31 +156,31 @@ def catch_error(fn):
179156
######################################################################
180157
# **Check names**: first, we will check whether the names of these two tensors
181158
# match. Two names match if and only if they are equal (string equality) or at
182-
# least one is ``None`` (``None``s are essentially a special wildcard name).
183-
# The only one of these three that will error, therefore, is ``x+z``:
159+
# least one is ``None`` (``None`` is essentially a special wildcard name).
160+
# The only one of these three that will error, therefore, is ``x + z``:
184161

185162
catch_error(lambda: x + z)
186163

187164
######################################################################
188-
# **Propagate names**: _unify_ the two names by returning the most refined name of
189-
# the two. With `x + y, `X` is more refined than `None`.
165+
# **Propagate names**: `unify` the two names by returning the most refined name of
166+
# the two. With ``x + y``, ``X`` is more refined than ``None``.
190167

191168
print((x + y).names)
192169

193170
######################################################################
194171
# Most name inference rules are straightforward but some of them can have
195-
# unexpected semantics. Let's go through a few more of them.
172+
# unexpected semantics. Let's go through a couple you're likely to encounter:
173+
# broadcasting and matrix multiply.
196174
#
197-
# Broadcasting
198-
# ------------
175+
# **Broadcasting**
199176
#
200177
# Named tensors do not change broadcasting behavior; they still broadcast by
201178
# position. However, when checking two dimensions for if they can be
202179
# broadcasted, the names of those dimensions must match.
203180
#
204181
# Furthermore, broadcasting with named tensors can prevent incorrect behavior.
205182
# The following code will error, whereas without `names` it would add
206-
# `per_batch_scale` to the last dimension of `imgs`.
183+
# ``per_batch_scale`` to the last dimension of ``imgs``.
207184

208185
# Automatic broadcasting: expected to fail
209186
imgs = torch.randn(6, 6, 6, 6, names=('N', 'C', 'H', 'W'))
@@ -212,17 +189,37 @@ def catch_error(fn):
212189

213190
######################################################################
214191
# How `should` we perform this broadcasting operation along the first
215-
# dimension? One way, involving names, would be to explicitly initialize
216-
# the ``per_batch_scale`` tensor as four dimensional, and give it names
217-
# (such as ``('N', None, None, None)`` below) so that name inference will
218-
# work.
192+
# dimension? One way, involving names, would be to name the ``per_batch_scale``
193+
# tensor such that it matches ``imgs.names``, as shown below.
219194

220195
imgs = torch.randn(6, 6, 6, 6, names=('N', 'C', 'H', 'W'))
221-
per_batch_scale_4d = torch.rand(6, 1, 1, 1, names=('N', None, None, None))
196+
per_batch_scale_4d = torch.rand(6, 1, 1, 1, names=('N', 'C', 'H', 'W'))
222197
print((imgs * per_batch_scale_4d).names)
223198

224199
######################################################################
225-
# However, named tensors enable an even better way, which we'll cover next.
200+
# Another way would be to use the new explicit broadcasting by names
201+
# functionality, covered below.
202+
#
203+
# **Matrix multiply**
204+
#
205+
# ``torch.mm(A, B)`` performs a dot product between the second dim of ``A``
206+
# and the first dim of ``B``, returning a tensor with the first dim of ``A``
207+
# and the second dim of ``B``. (the other matmul functions, such as ``torch.matmul``,
208+
# ``torch.mv``, ``torch.dot``, behave similarly).
209+
210+
markov_states = torch.randn(128, 5, names=('batch', 'D'))
211+
transition_matrix = torch.randn(5, 5, names=('in', 'out'))
212+
213+
# Apply one transition
214+
new_state = markov_states @ transition_matrix
215+
print(new_state.names)
216+
217+
######################################################################
218+
# As you can see, matrix multiply does not check if the contracted dimensions
219+
# have the same name.
220+
#
221+
# Next, we'll cover two new behaviors that named tensors enable: explicit
222+
# broadcasting by names.
226223
#
227224
# New behavior: Explicit broadcasting by names
228225
# --------------------------------------------
@@ -250,6 +247,7 @@ def catch_error(fn):
250247
per_batch_scale = per_batch_scale.refine_names('N')
251248

252249
named_result = imgs * per_batch_scale.align_as(imgs)
250+
# note: named tensors do not yet work with allclose
253251
assert torch.allclose(named_result.rename(None), correct_result)
254252

255253
######################################################################
@@ -307,14 +305,12 @@ def catch_error(fn):
307305
# ------------------------------------------
308306
#
309307
# See here (link to be included) for a detailed breakdown of what is
310-
# supported with the 1.3 release, what is on the roadmap to be supported soon,
311-
# and what will be supported in the future but not soon.
308+
# supported with the 1.3 release.
312309
#
313310
# In particular, we want to call out three important features that are not
314311
# currently supported:
315312
#
316-
# - Retaining names when serializing or loading a serialized ``Tensor`` via
317-
# ``torch.save``
313+
# - Saving or loading named tensors via ``torch.save`` or ``torch.load``
318314
# - Multi-processing via ``torch.multiprocessing``
319315
# - JIT support; for example, the following will error
320316

@@ -330,15 +326,15 @@ def fn(x):
330326
# As a workaround, please drop names via ``tensor = tensor.rename(None)``
331327
# before using anything that does not yet support named tensors.
332328
#
333-
# Longer example: Multi-headed attention
329+
# Longer example: Multi-head attention
334330
# --------------------------------------
335331
#
336332
# Now we'll go through a complete example of implementing a common
337-
# PyTorch ``nn.Module``: multi-headed attention. We assume the reader is already
338-
# familiar with multi-headed attention; for a refresher, check out
333+
# PyTorch ``nn.Module``: multi-head attention. We assume the reader is already
334+
# familiar with multi-head attention; for a refresher, check out
339335
# `this explanation <http://jalammar.github.io/illustrated-transformer/>`_.
340336
#
341-
# We adapt the implementation of multi-headed attention from
337+
# We adapt the implementation of multi-head attention from
342338
# `ParlAI <https://github.com/facebookresearch/ParlAI>`_; specifically
343339
# `here <https://github.com/facebookresearch/ParlAI/blob/f7db35cba3f3faf6097b3e6b208442cd564783d9/parlai/agents/transformer/modules.py#L907>`_.
344340
# Read through the code at that example; then, compare with the code below,
@@ -431,7 +427,7 @@ def forward(self, query, key=None, value=None, mask=None):
431427
# can be refined to ``['T', 'D']``, preventing potentially silent or confusing size
432428
# mismatch errors later down the line.
433429
#
434-
# **(II) Manipulating dimensions in ``prepare_head``**
430+
# **(II) Manipulating dimensions in prepare_head**
435431

436432
# (II)
437433
def prepare_head(tensor):
@@ -449,8 +445,6 @@ def prepare_head(tensor):
449445
# multiple heads, finally rearranging the dim order to be ``[..., 'H', 'T', 'D_head']``.
450446
# ParlAI implements ``prepare_head`` as the following, using ``view`` and ``transpose``
451447
# operations:
452-
#
453-
# **(III) Explicit broadcasting by names**
454448

455449
def prepare_head(tensor):
456450
# input is [batch_size, seq_len, n_heads * dim_per_head]
@@ -486,7 +480,7 @@ def ignore():
486480
# named tensors, we simply align ``attn_mask`` to ``dot_prod`` using ``align_as``
487481
# and stop worrying about where to ``unsqueeze`` dims.
488482
#
489-
# **(IV) More dimension manipulation using ``align_to`` and ``flatten``**
483+
# **(IV) More dimension manipulation using align_to and flatten**
490484

491485
def ignore():
492486
# (IV)
@@ -497,8 +491,8 @@ def ignore():
497491
)
498492

499493
######################################################################
500-
# (IV): Like (II), ``align_to`` and ``flatten`` are more semantically
501-
# meaningful than `view` (despite being more verbose).
494+
# Here, as in (II), ``align_to`` and ``flatten`` are more semantically
495+
# meaningful than ``view`` (despite being more verbose).
502496
#
503497
# Running the example
504498
# -------------------
@@ -508,6 +502,18 @@ def ignore():
508502
mask = torch.ones(n, t, names=('N', 'T'))
509503
attn = MultiHeadAttention(h, d)
510504
output = attn(query, mask=mask)
505+
# works as expected!
506+
print(output.names)
507+
508+
######################################################################
509+
# The above works as expected. Furthermore, note that in the code we
510+
# did not mention the name of the batch dimension at all. In fact,
511+
# the code is agnostic to the existence of the batch dimensions, so
512+
# that we can run the following example-level code:
513+
514+
query = torch.randn(t, d, names=('T', 'D'))
515+
mask = torch.ones(t, names=('T',))
516+
output = attn(query, mask=mask)
511517
print(output.names)
512518

513519
######################################################################

0 commit comments

Comments
 (0)