diff --git a/intermediate_source/mnist_train_nas.py b/intermediate_source/mnist_train_nas.py index f6f58968a31..54dd1ee7629 100644 --- a/intermediate_source/mnist_train_nas.py +++ b/intermediate_source/mnist_train_nas.py @@ -16,7 +16,7 @@ from torch import nn from torch.nn import functional as F from torch.utils.data import DataLoader -from torchmetrics.functional.classification.accuracy import accuracy +from torchmetrics.functional.classification.accuracy import multiclass_accuracy from torchvision import transforms from torchvision.datasets import MNIST @@ -106,7 +106,7 @@ def validation_step(self, batch, batch_idx): logits = self(x) loss = F.nll_loss(logits, y) preds = torch.argmax(logits, dim=1) - acc = accuracy(preds, y) + acc = multiclass_accuracy(preds, y, num_classes=self.num_classes) self.log("val_acc", acc, prog_bar=False) return loss