diff --git a/beginner_source/transformer_tutorial.py b/beginner_source/transformer_tutorial.py index 0ba9711ed67..90c8b902d37 100644 --- a/beginner_source/transformer_tutorial.py +++ b/beginner_source/transformer_tutorial.py @@ -77,9 +77,9 @@ def init_weights(self): self.decoder.weight.data.uniform_(-initrange, initrange) def forward(self, src): - if self.src_mask is None or self.src_mask.size(0) != len(src): + if self.src_mask is None or self.src_mask.size(0) != src.size(0): device = src.device - mask = self._generate_square_subsequent_mask(len(src)).to(device) + mask = self._generate_square_subsequent_mask(src.size(0)).to(device) self.src_mask = mask src = self.encoder(src) * math.sqrt(self.ninp)