diff --git a/beginner_source/chatbot_tutorial.py b/beginner_source/chatbot_tutorial.py index 1de320fab91..67bf78b1e30 100644 --- a/beginner_source/chatbot_tutorial.py +++ b/beginner_source/chatbot_tutorial.py @@ -968,9 +968,10 @@ def train(input_variable, lengths, target_variable, mask, max_target_len, encode # Set device options input_variable = input_variable.to(device) - lengths = lengths.to(device) target_variable = target_variable.to(device) mask = mask.to(device) + # Lengths for rnn packing should always be on the cpu + lengths = lengths.to("cpu") # Initialize variables loss = 0