File tree Expand file tree Collapse file tree 2 files changed +9
-6
lines changed Expand file tree Collapse file tree 2 files changed +9
-6
lines changed Original file line number Diff line number Diff line change @@ -394,6 +394,13 @@ def one_hot(
394
394
) -> Array :
395
395
if xp is None :
396
396
xp = array_namespace (x )
397
+ x_size = x .size
398
+ if x_size is None :
399
+ msg = "x must have a concrete size."
400
+ raise TypeError (msg )
401
+ if not xp .isdtype (x .dtype , "integral" ):
402
+ msg = "x must have an integral dtype."
403
+ raise TypeError (msg )
397
404
if is_jax_namespace (xp ):
398
405
assert is_jax_array (x )
399
406
from jax .nn import one_hot
@@ -412,10 +419,6 @@ def one_hot(
412
419
dtype = xp .empty (()).dtype # Default float dtype
413
420
out = xp .zeros ((x .size , num_classes ), dtype = dtype )
414
421
x_flattened = xp .reshape (x , (- 1 ,))
415
- x_size = x .size
416
- if x_size is None :
417
- msg = "x must have a concrete size."
418
- raise TypeError (msg )
419
422
if is_numpy_namespace (xp ):
420
423
at (out )[xp .arange (x_size ), x_flattened ].set (1 )
421
424
else :
Original file line number Diff line number Diff line change @@ -455,7 +455,7 @@ def test_xp(self, xp: ModuleType):
455
455
)
456
456
class TestOneHot :
457
457
@pytest .mark .parametrize ("n_dim" , range (4 ))
458
- @pytest .mark .parametrize ("num_classes" , range ( 1 , 5 , 2 ) )
458
+ @pytest .mark .parametrize ("num_classes" , [ 1 , 3 , 10 ] )
459
459
def test_dims_and_classes (self , xp : ModuleType , n_dim : int , num_classes : int ):
460
460
shape = tuple (range (2 , 2 + n_dim ))
461
461
rng = np .random .default_rng (2347823 )
@@ -508,7 +508,7 @@ def test_axis(self, xp: ModuleType):
508
508
xp_assert_equal (actual , expected )
509
509
510
510
def test_non_integer (self , xp : ModuleType ):
511
- with pytest .raises (( TypeError , RuntimeError , IndexError , DeprecationWarning ) ):
511
+ with pytest .raises (TypeError ):
512
512
one_hot (xp .asarray ([1.0 ]), 3 )
513
513
514
514
You can’t perform that action at this time.
0 commit comments