Skip to content

Commit 972efd8

Browse files
author
Seth Weidman
committed
Formatting, wording fixes
1 parent 91ccc85 commit 972efd8

File tree

2 files changed

+20
-15
lines changed

2 files changed

+20
-15
lines changed

index.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,7 @@ Audio
9898

9999
<div style='clear:both'></div>
100100

101-
Named Tensor
101+
(experimental) Named Tensor
102102
----------------------
103103

104104
.. customgalleryitem::

intermediate_source/named_tensor_tutorial.py

Lines changed: 19 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,6 @@
44
*******************************************************
55
**Author**: `Richard Zou <https://github.com/zou3519>`_
66
7-
**Editor**: `Seth Weidman <https://github.com/SethHWeidman>`_
8-
97
Named Tensors aim to make tensors easier to use by allowing users to associate explicit names
108
with tensor dimensions. In most cases, operations that take dimension parameters will accept
119
dimension names, avoiding the need to track dimensions by position. In addition, named tensors
@@ -103,10 +101,11 @@
103101
print(imgs.refine_names(..., 'H', 'W').names)
104102

105103
def catch_error(fn):
106-
try:
107-
fn()
108-
except RuntimeError as err:
109-
print(err)
104+
fn()
105+
assert False
106+
107+
# Actually name 'imgs' using 'refine_names'
108+
imgs = imgs.refine_names('N', 'C', 'H', 'W')
110109

111110
# Tried to refine an existing name to a different name
112111
catch_error(lambda: imgs.refine_names('batch', 'channel', 'height', 'width'))
@@ -129,7 +128,7 @@ def catch_error(fn):
129128
# advanced) has not been implemented yet but is on the roadmap. Using the ``named_imgs``
130129
# tensor from above, we can do:
131130

132-
output = named_imgs.sum(['C']) # Perform a sum over the channel dimension
131+
output = named_imgs.sum('C') # Perform a sum over the channel dimension
133132
print(output.names)
134133

135134
img0 = named_imgs.select('N', 0) # get one image
@@ -172,7 +171,8 @@ def catch_error(fn):
172171
# unexpected semantics. Let's go through a couple you're likely to encounter:
173172
# broadcasting and matrix multiply.
174173
#
175-
# **Broadcasting**
174+
# Broadcasting
175+
# ^^^^^^^^^^^^
176176
#
177177
# Named tensors do not change broadcasting behavior; they still broadcast by
178178
# position. However, when checking two dimensions for if they can be
@@ -200,7 +200,8 @@ def catch_error(fn):
200200
# Another way would be to use the new explicit broadcasting by names
201201
# functionality, covered below.
202202
#
203-
# **Matrix multiply**
203+
# Matrix multiply
204+
# ^^^^^^^^^^^^^^^
204205
#
205206
# ``torch.mm(A, B)`` performs a dot product between the second dim of ``A``
206207
# and the first dim of ``B``, returning a tensor with the first dim of ``A``
@@ -219,7 +220,7 @@ def catch_error(fn):
219220
# have the same name.
220221
#
221222
# Next, we'll cover two new behaviors that named tensors enable: explicit
222-
# broadcasting by names.
223+
# broadcasting by names and flattening and unflattening dimensions by names
223224
#
224225
# New behavior: Explicit broadcasting by names
225226
# --------------------------------------------
@@ -388,6 +389,10 @@ def prepare_head(tensor):
388389
assert value is None
389390
if self_attn:
390391
key = value = query
392+
elif value is None:
393+
# key and value are the same, but query differs
394+
key = key.refine_names(..., 'T', 'D')
395+
value = key
391396
key_len = key.size('T')
392397
dim = key.size('D')
393398

@@ -400,7 +405,7 @@ def prepare_head(tensor):
400405
dot_prod.refine_names(..., 'H', 'T', 'T_key') # just a check
401406

402407
# (III)
403-
attn_mask = (attn_mask == 0).align_as(dot_prod)
408+
attn_mask = (mask == 0).align_as(dot_prod)
404409
dot_prod.masked_fill_(attn_mask, -float(1e20))
405410

406411
attn_weights = self.attn_dropout(F.softmax(dot_prod / scale, dim='T_key'))
@@ -467,7 +472,7 @@ def prepare_head(tensor):
467472

468473
def ignore():
469474
# (III)
470-
attn_mask = (attn_mask == 0).align_as(dot_prod)
475+
attn_mask = (mask == 0).align_as(dot_prod)
471476

472477
dot_prod.masked_fill_(attn_mask, -float(1e20))
473478

@@ -508,8 +513,8 @@ def ignore():
508513
######################################################################
509514
# The above works as expected. Furthermore, note that in the code we
510515
# did not mention the name of the batch dimension at all. In fact,
511-
# the code is agnostic to the existence of the batch dimensions, so
512-
# that we can run the following example-level code:
516+
# our ``MultiHeadAttention`` module is agnostic to the existence of batch
517+
# dimensions.
513518

514519
query = torch.randn(t, d, names=('T', 'D'))
515520
mask = torch.ones(t, names=('T',))

0 commit comments

Comments
 (0)