@@ -480,18 +480,32 @@ def test_basic(self, xp: ModuleType):
480
480
[1. , 0. , 0. ]])
481
481
xp_assert_equal (actual , expected )
482
482
483
+ @pytest .mark .skip_xp_backend (
484
+ Backend .TORCH_GPU , reason = "Puts Pytorch into a bad state."
485
+ )
483
486
def test_out_of_bound (self , xp : ModuleType ):
484
487
# Undefined behavior. Either return zero, or raise.
485
488
try :
486
489
actual = one_hot (xp .asarray ([- 1 , 3 ]), 3 )
487
- except ( IndexError , RuntimeError ) :
490
+ except IndexError :
488
491
return
489
492
expected = xp .asarray ([[0. , 0. , 0. ],
490
493
[0. , 0. , 0. ]])
491
494
xp_assert_equal (actual , expected )
492
495
496
+ @pytest .mark .parametrize ("int_dtype" , ['int8' , 'int16' , 'int32' , 'int64' , 'uint8' ,
497
+ 'uint16' , 'uint32' , 'uint64' ])
498
+ def test_int_types (self , xp : ModuleType , int_dtype : str ):
499
+ dtype = getattr (xp , int_dtype )
500
+ x = xp .asarray ([0 , 1 , 2 ], dtype = dtype )
501
+ actual = one_hot (x , 3 )
502
+ expected = xp .asarray ([[1. , 0. , 0. ],
503
+ [0. , 1. , 0. ],
504
+ [0. , 0. , 1. ]])
505
+ xp_assert_equal (actual , expected )
506
+
493
507
def test_custom_dtype (self , xp : ModuleType ):
494
- actual = one_hot (xp .asarray ([0 , 1 , 2 ]), 3 , dtype = xp .bool )
508
+ actual = one_hot (xp .asarray ([0 , 1 , 2 ], dtype = xp . int32 ), 3 , dtype = xp .bool )
495
509
expected = xp .asarray ([[True , False , False ],
496
510
[False , True , False ],
497
511
[False , False , True ]])
0 commit comments