4
4
*******************************************************
5
5
**Author**: `Richard Zou <https://github.com/zou3519>`_
6
6
7
- **Editor**: `Seth Weidman <https://github.com/SethHWeidman>`_
8
-
9
7
Named Tensors aim to make tensors easier to use by allowing users to associate explicit names
10
8
with tensor dimensions. In most cases, operations that take dimension parameters will accept
11
9
dimension names, avoiding the need to track dimensions by position. In addition, named tensors
103
101
print (imgs .refine_names (..., 'H' , 'W' ).names )
104
102
105
103
def catch_error (fn ):
106
- try :
107
- fn ()
108
- except RuntimeError as err :
109
- print (err )
104
+ fn ()
105
+ assert False
106
+
107
+ # Actually name 'imgs' using 'refine_names'
108
+ imgs = imgs .refine_names ('N' , 'C' , 'H' , 'W' )
110
109
111
110
# Tried to refine an existing name to a different name
112
111
catch_error (lambda : imgs .refine_names ('batch' , 'channel' , 'height' , 'width' ))
@@ -129,7 +128,7 @@ def catch_error(fn):
129
128
# advanced) has not been implemented yet but is on the roadmap. Using the ``named_imgs``
130
129
# tensor from above, we can do:
131
130
132
- output = named_imgs .sum ([ 'C' ] ) # Perform a sum over the channel dimension
131
+ output = named_imgs .sum ('C' ) # Perform a sum over the channel dimension
133
132
print (output .names )
134
133
135
134
img0 = named_imgs .select ('N' , 0 ) # get one image
@@ -172,7 +171,8 @@ def catch_error(fn):
172
171
# unexpected semantics. Let's go through a couple you're likely to encounter:
173
172
# broadcasting and matrix multiply.
174
173
#
175
- # **Broadcasting**
174
+ # Broadcasting
175
+ # ^^^^^^^^^^^^
176
176
#
177
177
# Named tensors do not change broadcasting behavior; they still broadcast by
178
178
# position. However, when checking two dimensions for if they can be
@@ -200,7 +200,8 @@ def catch_error(fn):
200
200
# Another way would be to use the new explicit broadcasting by names
201
201
# functionality, covered below.
202
202
#
203
- # **Matrix multiply**
203
+ # Matrix multiply
204
+ # ^^^^^^^^^^^^^^^
204
205
#
205
206
# ``torch.mm(A, B)`` performs a dot product between the second dim of ``A``
206
207
# and the first dim of ``B``, returning a tensor with the first dim of ``A``
@@ -219,7 +220,7 @@ def catch_error(fn):
219
220
# have the same name.
220
221
#
221
222
# Next, we'll cover two new behaviors that named tensors enable: explicit
222
- # broadcasting by names.
223
+ # broadcasting by names and flattening and unflattening dimensions by names
223
224
#
224
225
# New behavior: Explicit broadcasting by names
225
226
# --------------------------------------------
@@ -388,6 +389,10 @@ def prepare_head(tensor):
388
389
assert value is None
389
390
if self_attn :
390
391
key = value = query
392
+ elif value is None :
393
+ # key and value are the same, but query differs
394
+ key = key .refine_names (..., 'T' , 'D' )
395
+ value = key
391
396
key_len = key .size ('T' )
392
397
dim = key .size ('D' )
393
398
@@ -400,7 +405,7 @@ def prepare_head(tensor):
400
405
dot_prod .refine_names (..., 'H' , 'T' , 'T_key' ) # just a check
401
406
402
407
# (III)
403
- attn_mask = (attn_mask == 0 ).align_as (dot_prod )
408
+ attn_mask = (mask == 0 ).align_as (dot_prod )
404
409
dot_prod .masked_fill_ (attn_mask , - float (1e20 ))
405
410
406
411
attn_weights = self .attn_dropout (F .softmax (dot_prod / scale , dim = 'T_key' ))
@@ -467,7 +472,7 @@ def prepare_head(tensor):
467
472
468
473
def ignore ():
469
474
# (III)
470
- attn_mask = (attn_mask == 0 ).align_as (dot_prod )
475
+ attn_mask = (mask == 0 ).align_as (dot_prod )
471
476
472
477
dot_prod .masked_fill_ (attn_mask , - float (1e20 ))
473
478
@@ -508,8 +513,8 @@ def ignore():
508
513
######################################################################
509
514
# The above works as expected. Furthermore, note that in the code we
510
515
# did not mention the name of the batch dimension at all. In fact,
511
- # the code is agnostic to the existence of the batch dimensions, so
512
- # that we can run the following example-level code:
516
+ # our ``MultiHeadAttention`` module is agnostic to the existence of batch
517
+ # dimensions.
513
518
514
519
query = torch .randn (t , d , names = ('T' , 'D' ))
515
520
mask = torch .ones (t , names = ('T' ,))
0 commit comments