diff --git a/beginner_source/translation_transformer.py b/beginner_source/translation_transformer.py index 8028de25dc4..4c9d559af5e 100644 --- a/beginner_source/translation_transformer.py +++ b/beginner_source/translation_transformer.py @@ -309,7 +309,7 @@ def train_epoch(model, optimizer): optimizer.step() losses += loss.item() - return losses / len(train_dataloader) + return losses / len(list(train_dataloader)) def evaluate(model): @@ -333,7 +333,7 @@ def evaluate(model): loss = loss_fn(logits.reshape(-1, logits.shape[-1]), tgt_out.reshape(-1)) losses += loss.item() - return losses / len(val_dataloader) + return losses / len(list(val_dataloader)) ###################################################################### # Now we have all the ingredients to train our model. Let's do it!