Skip to content

Commit ba6070e

Browse files
Thiago Crepaldibrianjo
Thiago Crepaldi
andauthored
Fix model to be properly exported to ONNX (#1144)
Co-authored-by: Brian Johnson <brianjo@fb.com>
1 parent fe33b54 commit ba6070e

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

beginner_source/transformer_tutorial.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -77,9 +77,9 @@ def init_weights(self):
7777
self.decoder.weight.data.uniform_(-initrange, initrange)
7878

7979
def forward(self, src):
80-
if self.src_mask is None or self.src_mask.size(0) != len(src):
80+
if self.src_mask is None or self.src_mask.size(0) != src.size(0):
8181
device = src.device
82-
mask = self._generate_square_subsequent_mask(len(src)).to(device)
82+
mask = self._generate_square_subsequent_mask(src.size(0)).to(device)
8383
self.src_mask = mask
8484

8585
src = self.encoder(src) * math.sqrt(self.ninp)

0 commit comments

Comments
 (0)