We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent fe33b54 commit ba6070eCopy full SHA for ba6070e
beginner_source/transformer_tutorial.py
@@ -77,9 +77,9 @@ def init_weights(self):
77
self.decoder.weight.data.uniform_(-initrange, initrange)
78
79
def forward(self, src):
80
- if self.src_mask is None or self.src_mask.size(0) != len(src):
+ if self.src_mask is None or self.src_mask.size(0) != src.size(0):
81
device = src.device
82
- mask = self._generate_square_subsequent_mask(len(src)).to(device)
+ mask = self._generate_square_subsequent_mask(src.size(0)).to(device)
83
self.src_mask = mask
84
85
src = self.encoder(src) * math.sqrt(self.ninp)
0 commit comments