Skip to content

Commit d09dde1

Browse files
committed
Always check types in one-hot
1 parent 02544de commit d09dde1

File tree

2 files changed

+9
-6
lines changed

2 files changed

+9
-6
lines changed

src/array_api_extra/_lib/_funcs.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -394,6 +394,13 @@ def one_hot(
394394
) -> Array:
395395
if xp is None:
396396
xp = array_namespace(x)
397+
x_size = x.size
398+
if x_size is None:
399+
msg = "x must have a concrete size."
400+
raise TypeError(msg)
401+
if not xp.isdtype(x.dtype, "integral"):
402+
msg = "x must have an integral dtype."
403+
raise TypeError(msg)
397404
if is_jax_namespace(xp):
398405
assert is_jax_array(x)
399406
from jax.nn import one_hot
@@ -412,10 +419,6 @@ def one_hot(
412419
dtype = xp.empty(()).dtype # Default float dtype
413420
out = xp.zeros((x.size, num_classes), dtype=dtype)
414421
x_flattened = xp.reshape(x, (-1,))
415-
x_size = x.size
416-
if x_size is None:
417-
msg = "x must have a concrete size."
418-
raise TypeError(msg)
419422
if is_numpy_namespace(xp):
420423
at(out)[xp.arange(x_size), x_flattened].set(1)
421424
else:

tests/test_funcs.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -455,7 +455,7 @@ def test_xp(self, xp: ModuleType):
455455
)
456456
class TestOneHot:
457457
@pytest.mark.parametrize("n_dim", range(4))
458-
@pytest.mark.parametrize("num_classes", range(1, 5, 2))
458+
@pytest.mark.parametrize("num_classes", [1, 3, 10])
459459
def test_dims_and_classes(self, xp: ModuleType, n_dim: int, num_classes: int):
460460
shape = tuple(range(2, 2 + n_dim))
461461
rng = np.random.default_rng(2347823)
@@ -508,7 +508,7 @@ def test_axis(self, xp: ModuleType):
508508
xp_assert_equal(actual, expected)
509509

510510
def test_non_integer(self, xp: ModuleType):
511-
with pytest.raises((TypeError, RuntimeError, IndexError, DeprecationWarning)):
511+
with pytest.raises(TypeError):
512512
one_hot(xp.asarray([1.0]), 3)
513513

514514

0 commit comments

Comments
 (0)