@@ -420,8 +420,14 @@ def test_multiclass_accuracy_gpu_sync_points_uptodate(
420
420
_mc_k_targets4 = _mc_k_targets3 [:1 ]
421
421
_mc_k_preds4 = _mc_k_preds3 [:1 , :]
422
422
423
- _mc_k_targets5 = torch .randint (10 , (2 , 50 ))
424
- _mc_k_preds5 = torch .rand (2 , 10 , 50 )
423
+ _mc_k_targets5 = torch .randint (10 , (2 , 50 ), generator = torch .Generator ().manual_seed (42 ))
424
+ _mc_k_preds5 = torch .rand (2 , 10 , 50 , generator = torch .Generator ().manual_seed (42 ))
425
+
426
+ _mc_k_targets6 = torch .tensor ([[3 , 2 ], [1 , 0 ]])
427
+ _mc_k_preds6 = torch .tensor ([
428
+ [[0.0000 , 0.1000 , 0.5000 , 0.4000 ], [0.0000 , 0.2000 , 0.7000 , 0.1000 ]],
429
+ [[0.0000 , 0.4000 , 0.3000 , 0.3000 ], [1.0000 , 0.0000 , 0.0000 , 0.0000 ]],
430
+ ]).transpose (2 , 1 )
425
431
426
432
427
433
@pytest .mark .parametrize (
@@ -435,7 +441,7 @@ def test_multiclass_accuracy_gpu_sync_points_uptodate(
435
441
(5 , _mc_k_preds3 , _mc_k_targets3 , "micro" , 10 , torch .tensor (0.5176 )),
436
442
(5 , _mc_k_preds4 , _mc_k_targets4 , "macro" , 10 , torch .tensor (1.0 )),
437
443
(5 , _mc_k_preds4 , _mc_k_targets4 , "micro" , 10 , torch .tensor (1.0 )),
438
- (5 , _mc_k_preds5 , _mc_k_targets5 , "micro" , 10 , torch .tensor (0.02 )),
444
+ (5 , _mc_k_preds5 , _mc_k_targets5 , "micro" , 10 , torch .tensor (0.42 )),
439
445
],
440
446
)
441
447
def test_top_k (k , preds , target , average , num_classes , expected ):
@@ -451,6 +457,103 @@ def test_top_k(k, preds, target, average, num_classes, expected):
451
457
)
452
458
453
459
460
+ @pytest .mark .parametrize (
461
+ ("preds" , "target" , "k" , "expected" ),
462
+ [
463
+ (_mc_k_preds6 , _mc_k_targets6 , 1 , torch .tensor (0.6667 )),
464
+ (_mc_k_preds6 , _mc_k_targets6 , 2 , torch .tensor (1.0 )),
465
+ (_mc_k_preds6 , _mc_k_targets6 , 3 , torch .tensor (1.0 )),
466
+ (_mc_k_preds6 , _mc_k_targets6 , 4 , torch .tensor (1.0 )),
467
+ ],
468
+ )
469
+ def test_top_k_with_ignore_index (preds , target , k , expected ):
470
+ """Issue: https://github.com/Lightning-AI/torchmetrics/issues/3068."""
471
+ num_classes = 4
472
+ average = "micro"
473
+ ignore_index = 0
474
+
475
+ class_metric = Accuracy (
476
+ task = "multiclass" ,
477
+ ignore_index = ignore_index ,
478
+ num_classes = num_classes ,
479
+ multidim_average = "global" ,
480
+ average = average ,
481
+ top_k = k ,
482
+ )
483
+ class_metric .update (preds , target )
484
+ assert torch .isclose (class_metric .compute (), expected , rtol = 1e-4 , atol = 1e-4 )
485
+ assert torch .isclose (
486
+ multiclass_accuracy (
487
+ preds , target , num_classes = num_classes , average = average , top_k = k , ignore_index = ignore_index
488
+ ),
489
+ expected ,
490
+ rtol = 1e-4 ,
491
+ atol = 1e-4 ,
492
+ )
493
+
494
+
495
+ @pytest .mark .parametrize ("num_classes" , [5 ])
496
+ @pytest .mark .parametrize ("average" , ["macro" , "micro" , "weighted" ])
497
+ def test_multiclass_accuracy_with_top_k (num_classes , average ):
498
+ """Test that Accuracy increases monotonically with top_k and equals 1 when top_k equals num_classes.
499
+
500
+ Args:
501
+ num_classes: Number of classes in the classification task.
502
+ average: The averaging method to use (macro, micro, weighted).
503
+
504
+ The test verifies two properties:
505
+ 1. Accuracy increases or stays the same as top_k increases
506
+ 2. Accuracy equals 1 when top_k equals num_classes
507
+
508
+ """
509
+ preds = torch .randn (200 , num_classes ).softmax (dim = - 1 )
510
+ target = torch .randint (num_classes , (200 ,))
511
+
512
+ previous_accuracy = 0.0
513
+ for k in range (1 , num_classes + 1 ):
514
+ accuracy_score = MulticlassAccuracy (num_classes = num_classes , top_k = k , average = average )
515
+ accuracy = accuracy_score (preds , target )
516
+
517
+ assert accuracy >= previous_accuracy , f"Accuracy did not increase for top_k={ k } "
518
+ previous_accuracy = accuracy
519
+
520
+ if k == num_classes :
521
+ assert torch .isclose (accuracy , torch .tensor (1.0 )), (
522
+ f"Accuracy is not 1 for top_k={ k } when num_classes={ num_classes } "
523
+ )
524
+
525
+
526
+ @pytest .mark .parametrize (("num_classes" , "k" ), [(5 , 3 ), (10 , 5 )])
527
+ @pytest .mark .parametrize ("average" , ["macro" , "micro" , "weighted" ])
528
+ def test_multiclass_accuracy_top_k_equivalence (num_classes , k , average ):
529
+ """Test that top-k Accuracy scores are equivalent to corrected top-1 scores.
530
+
531
+ Args:
532
+ num_classes: Number of classes in the classification task.
533
+ k: The top-k value to test.
534
+ average: The averaging method to use (macro, micro, weighted).
535
+
536
+ """
537
+ preds = torch .randn (200 , num_classes ).softmax (dim = - 1 )
538
+ target = torch .randint (num_classes , (200 ,))
539
+
540
+ accuracy_top_k = MulticlassAccuracy (num_classes = num_classes , top_k = k , average = average )
541
+ accuracy_top_1 = MulticlassAccuracy (num_classes = num_classes , top_k = 1 , average = average )
542
+
543
+ pred_top_k = torch .argsort (preds , dim = 1 , descending = True )[:, :k ]
544
+ pred_top_1 = pred_top_k [:, 0 ]
545
+ target_in_top_k = (target .unsqueeze (1 ) == pred_top_k ).any (dim = 1 )
546
+ pred_corrected_top_k = torch .where (target_in_top_k , target , pred_top_1 )
547
+
548
+ accuracy_score_top_k = accuracy_top_k (preds , target )
549
+ accuracy_score_corrected = accuracy_top_1 (pred_corrected_top_k , target )
550
+
551
+ assert torch .isclose (accuracy_score_top_k , accuracy_score_corrected ), (
552
+ f"Top-{ k } Accuracy ({ accuracy_score_top_k } ) does not match "
553
+ f"corrected top-1 Accuracy ({ accuracy_score_corrected } )"
554
+ )
555
+
556
+
454
557
def _reference_sklearn_accuracy_multilabel (preds , target , ignore_index , multidim_average , average ):
455
558
preds = preds .numpy ()
456
559
target = target .numpy ()
0 commit comments