@@ -467,7 +467,7 @@ def test_dims_and_classes(self, xp: ModuleType, n_dim: int, num_classes: int):
467
467
i = tuple (i_list )
468
468
assert y [* i , j ] == (x [i ] == j )
469
469
470
- def test_one_hot (self , xp : ModuleType ):
470
+ def test_basic (self , xp : ModuleType ):
471
471
actual = one_hot (xp .asarray ([0 , 1 , 2 ]), 3 )
472
472
expected = xp .asarray ([[1. , 0. , 0. ],
473
473
[0. , 1. , 0. ],
@@ -480,7 +480,7 @@ def test_one_hot(self, xp: ModuleType):
480
480
[1. , 0. , 0. ]])
481
481
xp_assert_equal (actual , expected )
482
482
483
- def test_one_hot_out_of_bound (self , xp : ModuleType ):
483
+ def test_out_of_bound (self , xp : ModuleType ):
484
484
# Undefined behavior. Either return zero, or raise.
485
485
try :
486
486
actual = one_hot (xp .asarray ([- 1 , 3 ]), 3 )
@@ -490,14 +490,14 @@ def test_one_hot_out_of_bound(self, xp: ModuleType):
490
490
[0. , 0. , 0. ]])
491
491
xp_assert_equal (actual , expected )
492
492
493
- def test_one_hot_custom_dtype (self , xp : ModuleType ):
493
+ def test_custom_dtype (self , xp : ModuleType ):
494
494
actual = one_hot (xp .asarray ([0 , 1 , 2 ]), 3 , dtype = xp .bool )
495
495
expected = xp .asarray ([[True , False , False ],
496
496
[False , True , False ],
497
497
[False , False , True ]])
498
498
xp_assert_equal (actual , expected )
499
499
500
- def test_one_hot_axis (self , xp : ModuleType ):
500
+ def test_axis (self , xp : ModuleType ):
501
501
expected = xp .asarray ([[0. , 1. , 0. ],
502
502
[0. , 0. , 1. ],
503
503
[1. , 0. , 0. ]]).T
@@ -507,7 +507,7 @@ def test_one_hot_axis(self, xp: ModuleType):
507
507
actual = one_hot (xp .asarray ([1 , 2 , 0 ]), 3 , axis = - 2 )
508
508
xp_assert_equal (actual , expected )
509
509
510
- def test_one_hot_non_integer (self , xp : ModuleType ):
510
+ def test_non_integer (self , xp : ModuleType ):
511
511
with pytest .raises ((TypeError , RuntimeError , IndexError , DeprecationWarning )):
512
512
one_hot (xp .asarray ([1.0 ]), 3 )
513
513
0 commit comments