From 7bef25b53272ea93b31432ad1f02ea93b7f5d548 Mon Sep 17 00:00:00 2001 From: Max Balandat Date: Tue, 13 Dec 2022 15:57:37 -0800 Subject: [PATCH] Update torchmetrics callsite to fix error due to BC-breaking changes torchmetrics changed the signature of the `accuracy` function in a BC-breaking fashion. This updates the callsite to work with torchmetrics 0.11 and uses `multiclass_accuracy` instead. --- intermediate_source/mnist_train_nas.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/intermediate_source/mnist_train_nas.py b/intermediate_source/mnist_train_nas.py index f6f58968a31..54dd1ee7629 100644 --- a/intermediate_source/mnist_train_nas.py +++ b/intermediate_source/mnist_train_nas.py @@ -16,7 +16,7 @@ from torch import nn from torch.nn import functional as F from torch.utils.data import DataLoader -from torchmetrics.functional.classification.accuracy import accuracy +from torchmetrics.functional.classification.accuracy import multiclass_accuracy from torchvision import transforms from torchvision.datasets import MNIST @@ -106,7 +106,7 @@ def validation_step(self, batch, batch_idx): logits = self(x) loss = F.nll_loss(logits, y) preds = torch.argmax(logits, dim=1) - acc = accuracy(preds, y) + acc = multiclass_accuracy(preds, y, num_classes=self.num_classes) self.log("val_acc", acc, prog_bar=False) return loss