@@ -469,15 +469,11 @@ def test_dims_and_classes(self, xp: ModuleType, n_dim: int, num_classes: int):
469
469
470
470
def test_basic (self , xp : ModuleType ):
471
471
actual = one_hot (xp .asarray ([0 , 1 , 2 ]), 3 )
472
- expected = xp .asarray ([[1. , 0. , 0. ],
473
- [0. , 1. , 0. ],
474
- [0. , 0. , 1. ]])
472
+ expected = xp .asarray ([[1.0 , 0.0 , 0.0 ], [0.0 , 1.0 , 0.0 ], [0.0 , 0.0 , 1.0 ]])
475
473
xp_assert_equal (actual , expected )
476
474
477
475
actual = one_hot (xp .asarray ([1 , 2 , 0 ]), 3 )
478
- expected = xp .asarray ([[0. , 1. , 0. ],
479
- [0. , 0. , 1. ],
480
- [1. , 0. , 0. ]])
476
+ expected = xp .asarray ([[0.0 , 1.0 , 0.0 ], [0.0 , 0.0 , 1.0 ], [1.0 , 0.0 , 0.0 ]])
481
477
xp_assert_equal (actual , expected )
482
478
483
479
@pytest .mark .skip_xp_backend (
@@ -489,32 +485,29 @@ def test_out_of_bound(self, xp: ModuleType):
489
485
actual = one_hot (xp .asarray ([- 1 , 3 ]), 3 )
490
486
except IndexError :
491
487
return
492
- expected = xp .asarray ([[0. , 0. , 0. ],
493
- [0. , 0. , 0. ]])
488
+ expected = xp .asarray ([[0.0 , 0.0 , 0.0 ], [0.0 , 0.0 , 0.0 ]])
494
489
xp_assert_equal (actual , expected )
495
490
496
- @pytest .mark .parametrize ("int_dtype" , ['int8' , 'int16' , 'int32' , 'int64' , 'uint8' ,
497
- 'uint16' , 'uint32' , 'uint64' ])
491
+ @pytest .mark .parametrize (
492
+ "int_dtype" ,
493
+ ["int8" , "int16" , "int32" , "int64" , "uint8" , "uint16" , "uint32" , "uint64" ],
494
+ )
498
495
def test_int_types (self , xp : ModuleType , int_dtype : str ):
499
496
dtype = getattr (xp , int_dtype )
500
497
x = xp .asarray ([0 , 1 , 2 ], dtype = dtype )
501
498
actual = one_hot (x , 3 )
502
- expected = xp .asarray ([[1. , 0. , 0. ],
503
- [0. , 1. , 0. ],
504
- [0. , 0. , 1. ]])
499
+ expected = xp .asarray ([[1.0 , 0.0 , 0.0 ], [0.0 , 1.0 , 0.0 ], [0.0 , 0.0 , 1.0 ]])
505
500
xp_assert_equal (actual , expected )
506
501
507
502
def test_custom_dtype (self , xp : ModuleType ):
508
503
actual = one_hot (xp .asarray ([0 , 1 , 2 ], dtype = xp .int32 ), 3 , dtype = xp .bool )
509
- expected = xp .asarray ([[ True , False , False ],
510
- [False , True , False ],
511
- [ False , False , True ]] )
504
+ expected = xp .asarray (
505
+ [[ True , False , False ], [False , True , False ], [ False , False , True ]]
506
+ )
512
507
xp_assert_equal (actual , expected )
513
508
514
509
def test_axis (self , xp : ModuleType ):
515
- expected = xp .asarray ([[0. , 1. , 0. ],
516
- [0. , 0. , 1. ],
517
- [1. , 0. , 0. ]]).T
510
+ expected = xp .asarray ([[0.0 , 1.0 , 0.0 ], [0.0 , 0.0 , 1.0 ], [1.0 , 0.0 , 0.0 ]]).T
518
511
actual = one_hot (xp .asarray ([1 , 2 , 0 ]), 3 , axis = 0 )
519
512
xp_assert_equal (actual , expected )
520
513
0 commit comments