Skip to content

Commit 291dfbb

Browse files
author
Andrew Hoblitzell
committed
generate_square_subsequent_mask
1 parent 789fc09 commit 291dfbb

File tree

1 file changed

+2
-1
lines changed

1 file changed

+2
-1
lines changed

beginner_source/transformer_tutorial.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -91,11 +91,12 @@ def forward(self, src: Tensor, src_mask: Tensor = None) -> Tensor:
9191
"""
9292
src = self.embedding(src) * math.sqrt(self.d_model)
9393
src = self.pos_encoder(src)
94+
if src_mask is None:
95+
src_mask = nn.Transformer.generate_square_subsequent_mask(len(src)).to(device)
9496
output = self.transformer_encoder(src, src_mask)
9597
output = self.linear(output)
9698
return output
9799

98-
99100
######################################################################
100101
# ``PositionalEncoding`` module injects some information about the
101102
# relative or absolute position of the tokens in the sequence. The

0 commit comments

Comments
 (0)