Skip to content

Commit 9d53696

Browse files
authored
Update chatbot_tutorial.py (#1212)
1 parent e5f60c6 commit 9d53696

File tree

1 file changed

+2
-1
lines changed

1 file changed

+2
-1
lines changed

beginner_source/chatbot_tutorial.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -968,9 +968,10 @@ def train(input_variable, lengths, target_variable, mask, max_target_len, encode
968968

969969
# Set device options
970970
input_variable = input_variable.to(device)
971-
lengths = lengths.to(device)
972971
target_variable = target_variable.to(device)
973972
mask = mask.to(device)
973+
# Lengths for rnn packing should always be on the cpu
974+
lengths = lengths.to("cpu")
974975

975976
# Initialize variables
976977
loss = 0

0 commit comments

Comments
 (0)