Skip to content

Commit ec28f14

Browse files
rittik9Borda
authored andcommitted
Fix top_k in multiclass_accuracy (#3117)
* fix top_k in multiclass_accuracy * add additional top_k tests * Update test_accuracy.py --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Jirka Borovec <6035284+Borda@users.noreply.github.com> (cherry picked from commit 661c53c)
1 parent 16d9233 commit ec28f14

File tree

3 files changed

+111
-5
lines changed

3 files changed

+111
-5
lines changed

CHANGELOG.md

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,14 +10,17 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
1010

1111
### Changed
1212

13-
- Enhance: improve performance of `_rank_data` ([#3103](https://github.com/Lightning-AI/torchmetrics/pull/3103))
13+
-
1414

1515

1616
### Fixed
1717

1818
- Fixed: Ensure `WrapperMetric` Resets `wrapped_metric` State ([#3123](https://github.com/Lightning-AI/torchmetrics/pull/3123))
1919

2020

21+
- Fixed: `top_k` in `multiclass_accuracy` ([#3117](https://github.com/Lightning-AI/torchmetrics/pull/3117))
22+
23+
2124
## [1.7.2] - 2025-05-27
2225

2326
### Changed

src/torchmetrics/functional/classification/stat_scores.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -367,7 +367,7 @@ def _refine_preds_oh(preds: Tensor, preds_oh: Tensor, target: Tensor, top_k: int
367367
top_1_indices = top_k_indices[:, 0]
368368
target_in_topk = torch.any(top_k_indices == target.unsqueeze(1), dim=1)
369369
result = torch.where(target_in_topk, target, top_1_indices)
370-
return torch.zeros_like(preds_oh, dtype=torch.int32).scatter_(-1, result.unsqueeze(1), 1)
370+
return torch.zeros_like(preds_oh, dtype=torch.int32).scatter_(-1, result.unsqueeze(-1), 1)
371371

372372

373373
def _multiclass_stat_scores_update(

tests/unittests/classification/test_accuracy.py

Lines changed: 106 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -420,8 +420,14 @@ def test_multiclass_accuracy_gpu_sync_points_uptodate(
420420
_mc_k_targets4 = _mc_k_targets3[:1]
421421
_mc_k_preds4 = _mc_k_preds3[:1, :]
422422

423-
_mc_k_targets5 = torch.randint(10, (2, 50))
424-
_mc_k_preds5 = torch.rand(2, 10, 50)
423+
_mc_k_targets5 = torch.randint(10, (2, 50), generator=torch.Generator().manual_seed(42))
424+
_mc_k_preds5 = torch.rand(2, 10, 50, generator=torch.Generator().manual_seed(42))
425+
426+
_mc_k_targets6 = torch.tensor([[3, 2], [1, 0]])
427+
_mc_k_preds6 = torch.tensor([
428+
[[0.0000, 0.1000, 0.5000, 0.4000], [0.0000, 0.2000, 0.7000, 0.1000]],
429+
[[0.0000, 0.4000, 0.3000, 0.3000], [1.0000, 0.0000, 0.0000, 0.0000]],
430+
]).transpose(2, 1)
425431

426432

427433
@pytest.mark.parametrize(
@@ -435,7 +441,7 @@ def test_multiclass_accuracy_gpu_sync_points_uptodate(
435441
(5, _mc_k_preds3, _mc_k_targets3, "micro", 10, torch.tensor(0.5176)),
436442
(5, _mc_k_preds4, _mc_k_targets4, "macro", 10, torch.tensor(1.0)),
437443
(5, _mc_k_preds4, _mc_k_targets4, "micro", 10, torch.tensor(1.0)),
438-
(5, _mc_k_preds5, _mc_k_targets5, "micro", 10, torch.tensor(0.02)),
444+
(5, _mc_k_preds5, _mc_k_targets5, "micro", 10, torch.tensor(0.42)),
439445
],
440446
)
441447
def test_top_k(k, preds, target, average, num_classes, expected):
@@ -451,6 +457,103 @@ def test_top_k(k, preds, target, average, num_classes, expected):
451457
)
452458

453459

460+
@pytest.mark.parametrize(
461+
("preds", "target", "k", "expected"),
462+
[
463+
(_mc_k_preds6, _mc_k_targets6, 1, torch.tensor(0.6667)),
464+
(_mc_k_preds6, _mc_k_targets6, 2, torch.tensor(1.0)),
465+
(_mc_k_preds6, _mc_k_targets6, 3, torch.tensor(1.0)),
466+
(_mc_k_preds6, _mc_k_targets6, 4, torch.tensor(1.0)),
467+
],
468+
)
469+
def test_top_k_with_ignore_index(preds, target, k, expected):
470+
"""Issue: https://github.com/Lightning-AI/torchmetrics/issues/3068."""
471+
num_classes = 4
472+
average = "micro"
473+
ignore_index = 0
474+
475+
class_metric = Accuracy(
476+
task="multiclass",
477+
ignore_index=ignore_index,
478+
num_classes=num_classes,
479+
multidim_average="global",
480+
average=average,
481+
top_k=k,
482+
)
483+
class_metric.update(preds, target)
484+
assert torch.isclose(class_metric.compute(), expected, rtol=1e-4, atol=1e-4)
485+
assert torch.isclose(
486+
multiclass_accuracy(
487+
preds, target, num_classes=num_classes, average=average, top_k=k, ignore_index=ignore_index
488+
),
489+
expected,
490+
rtol=1e-4,
491+
atol=1e-4,
492+
)
493+
494+
495+
@pytest.mark.parametrize("num_classes", [5])
496+
@pytest.mark.parametrize("average", ["macro", "micro", "weighted"])
497+
def test_multiclass_accuracy_with_top_k(num_classes, average):
498+
"""Test that Accuracy increases monotonically with top_k and equals 1 when top_k equals num_classes.
499+
500+
Args:
501+
num_classes: Number of classes in the classification task.
502+
average: The averaging method to use (macro, micro, weighted).
503+
504+
The test verifies two properties:
505+
1. Accuracy increases or stays the same as top_k increases
506+
2. Accuracy equals 1 when top_k equals num_classes
507+
508+
"""
509+
preds = torch.randn(200, num_classes).softmax(dim=-1)
510+
target = torch.randint(num_classes, (200,))
511+
512+
previous_accuracy = 0.0
513+
for k in range(1, num_classes + 1):
514+
accuracy_score = MulticlassAccuracy(num_classes=num_classes, top_k=k, average=average)
515+
accuracy = accuracy_score(preds, target)
516+
517+
assert accuracy >= previous_accuracy, f"Accuracy did not increase for top_k={k}"
518+
previous_accuracy = accuracy
519+
520+
if k == num_classes:
521+
assert torch.isclose(accuracy, torch.tensor(1.0)), (
522+
f"Accuracy is not 1 for top_k={k} when num_classes={num_classes}"
523+
)
524+
525+
526+
@pytest.mark.parametrize(("num_classes", "k"), [(5, 3), (10, 5)])
527+
@pytest.mark.parametrize("average", ["macro", "micro", "weighted"])
528+
def test_multiclass_accuracy_top_k_equivalence(num_classes, k, average):
529+
"""Test that top-k Accuracy scores are equivalent to corrected top-1 scores.
530+
531+
Args:
532+
num_classes: Number of classes in the classification task.
533+
k: The top-k value to test.
534+
average: The averaging method to use (macro, micro, weighted).
535+
536+
"""
537+
preds = torch.randn(200, num_classes).softmax(dim=-1)
538+
target = torch.randint(num_classes, (200,))
539+
540+
accuracy_top_k = MulticlassAccuracy(num_classes=num_classes, top_k=k, average=average)
541+
accuracy_top_1 = MulticlassAccuracy(num_classes=num_classes, top_k=1, average=average)
542+
543+
pred_top_k = torch.argsort(preds, dim=1, descending=True)[:, :k]
544+
pred_top_1 = pred_top_k[:, 0]
545+
target_in_top_k = (target.unsqueeze(1) == pred_top_k).any(dim=1)
546+
pred_corrected_top_k = torch.where(target_in_top_k, target, pred_top_1)
547+
548+
accuracy_score_top_k = accuracy_top_k(preds, target)
549+
accuracy_score_corrected = accuracy_top_1(pred_corrected_top_k, target)
550+
551+
assert torch.isclose(accuracy_score_top_k, accuracy_score_corrected), (
552+
f"Top-{k} Accuracy ({accuracy_score_top_k}) does not match "
553+
f"corrected top-1 Accuracy ({accuracy_score_corrected})"
554+
)
555+
556+
454557
def _reference_sklearn_accuracy_multilabel(preds, target, ignore_index, multidim_average, average):
455558
preds = preds.numpy()
456559
target = target.numpy()

0 commit comments

Comments
 (0)