From 23943186545a4856b99df7203434bd6aa4442baf Mon Sep 17 00:00:00 2001 From: Thiago Crepaldi Date: Fri, 28 Aug 2020 15:24:34 -0700 Subject: [PATCH] Fix model to be properly exported to ONNX --- beginner_source/transformer_tutorial.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/beginner_source/transformer_tutorial.py b/beginner_source/transformer_tutorial.py index 0ba9711ed67..90c8b902d37 100644 --- a/beginner_source/transformer_tutorial.py +++ b/beginner_source/transformer_tutorial.py @@ -77,9 +77,9 @@ def init_weights(self): self.decoder.weight.data.uniform_(-initrange, initrange) def forward(self, src): - if self.src_mask is None or self.src_mask.size(0) != len(src): + if self.src_mask is None or self.src_mask.size(0) != src.size(0): device = src.device - mask = self._generate_square_subsequent_mask(len(src)).to(device) + mask = self._generate_square_subsequent_mask(src.size(0)).to(device) self.src_mask = mask src = self.encoder(src) * math.sqrt(self.ninp)