Skip to content

Commit 3a458e7

Browse files
authored
Update torchmetrics callsite to fix error due to BC-breaking changes (#2155)
torchmetrics changed the signature of the `accuracy` function in a BC-breaking fashion. This updates the callsite to work with torchmetrics 0.11 and uses `multiclass_accuracy` instead.
1 parent a585992 commit 3a458e7

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

intermediate_source/mnist_train_nas.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
from torch import nn
1717
from torch.nn import functional as F
1818
from torch.utils.data import DataLoader
19-
from torchmetrics.functional.classification.accuracy import accuracy
19+
from torchmetrics.functional.classification.accuracy import multiclass_accuracy
2020
from torchvision import transforms
2121
from torchvision.datasets import MNIST
2222

@@ -106,7 +106,7 @@ def validation_step(self, batch, batch_idx):
106106
logits = self(x)
107107
loss = F.nll_loss(logits, y)
108108
preds = torch.argmax(logits, dim=1)
109-
acc = accuracy(preds, y)
109+
acc = multiclass_accuracy(preds, y, num_classes=self.num_classes)
110110
self.log("val_acc", acc, prog_bar=False)
111111
return loss
112112

0 commit comments

Comments
 (0)