Skip to content

Commit 8c1d5ca

Browse files
committed
Add one_hot
1 parent 27fbd9c commit 8c1d5ca

File tree

4 files changed

+108
-2
lines changed

4 files changed

+108
-2
lines changed

docs/api-reference.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
expand_dims
1616
isclose
1717
kron
18+
one_hot
1819
nunique
1920
pad
2021
setdiff1d

src/array_api_extra/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
create_diagonal,
1111
expand_dims,
1212
kron,
13+
one_hot,
1314
nunique,
1415
setdiff1d,
1516
sinc,
@@ -31,6 +32,7 @@
3132
"isclose",
3233
"kron",
3334
"lazy_apply",
35+
"one_hot",
3436
"nunique",
3537
"pad",
3638
"setdiff1d",

src/array_api_extra/_lib/_funcs.py

Lines changed: 41 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,15 +8,20 @@
88

99
from ._at import at
1010
from ._utils import _compat, _helpers
11-
from ._utils._compat import array_namespace, is_dask_namespace, is_jax_array
11+
from ._utils._compat import (
12+
array_namespace,
13+
is_dask_namespace,
14+
is_jax_array,
15+
is_torch_array,
16+
)
1217
from ._utils._helpers import (
1318
asarrays,
1419
capabilities,
1520
eager_shape,
1621
meta_namespace,
1722
ndindex,
1823
)
19-
from ._utils._typing import Array
24+
from ._utils._typing import Array, DType
2025

2126
__all__ = [
2227
"apply_where",
@@ -375,6 +380,40 @@ def cov(m: Array, /, *, xp: ModuleType | None = None) -> Array:
375380
return xp.squeeze(c, axis=axes)
376381

377382

383+
def one_hot(
384+
x: Array,
385+
/,
386+
num_classes: int,
387+
*,
388+
dtype: DType | None = None,
389+
axis: int = -1,
390+
xp: ModuleType | None = None,
391+
) -> Array:
392+
if xp is None:
393+
xp = array_namespace(x)
394+
if is_jax_array(x):
395+
from jax.nn import one_hot
396+
if dtype is None:
397+
dtype = xp.float_
398+
return one_hot(x, num_classes, dtype=dtype, axis=axis)
399+
if is_torch_array(x):
400+
from torch.nn.functional import one_hot
401+
out = one_hot(x, num_classes)
402+
if dtype is None:
403+
dtype = xp.float
404+
out = xp.astype(out, dtype)
405+
else:
406+
if dtype is None:
407+
dtype = xp.float64
408+
out = xp.zeros((x.size, num_classes), dtype=dtype)
409+
at(out)[xp.arange(x.size), xp.reshape(x, (-1,))].set(1)
410+
if x.ndim != 1:
411+
out = xp.reshape(out, (*x.shape, num_classes))
412+
if axis != -1:
413+
out = xp.moveaxis(out, -1, axis)
414+
return out
415+
416+
378417
def create_diagonal(
379418
x: Array, /, *, offset: int = 0, xp: ModuleType | None = None
380419
) -> Array:

tests/test_funcs.py

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
isclose,
2222
kron,
2323
nunique,
24+
one_hot,
2425
pad,
2526
setdiff1d,
2627
sinc,
@@ -44,6 +45,7 @@
4445
lazy_xp_function(expand_dims)
4546
lazy_xp_function(kron)
4647
lazy_xp_function(nunique)
48+
lazy_xp_function(one_hot)
4749
lazy_xp_function(pad)
4850
# FIXME calls in1d which calls xp.unique_values without size
4951
lazy_xp_function(setdiff1d, jax_jit=False)
@@ -448,6 +450,68 @@ def test_xp(self, xp: ModuleType):
448450
)
449451

450452

453+
@pytest.mark.skip_xp_backend(
454+
Backend.SPARSE, reason="read-only backend without .at support"
455+
)
456+
class TestOneHot:
457+
@pytest.mark.parametrize("n_dim", range(4))
458+
@pytest.mark.parametrize("num_classes", range(1, 5, 2))
459+
def test_dims_and_classes(self, xp: ModuleType, n_dim: int, num_classes: int):
460+
shape = tuple(range(2, 2 + n_dim))
461+
rng = np.random.default_rng(2347823)
462+
np_x = rng.integers(num_classes, size=shape)
463+
x = xp.asarray(np_x)
464+
y = one_hot(x, num_classes)
465+
assert y.shape == (*x.shape, num_classes)
466+
for *i_list, j in ndindex(*shape, num_classes):
467+
i = tuple(i_list)
468+
assert y[*i, j] == (x[i] == j)
469+
470+
def test_one_hot(self, xp: ModuleType):
471+
actual = one_hot(xp.asarray([0, 1, 2]), 3)
472+
expected = xp.asarray([[1., 0., 0.],
473+
[0., 1., 0.],
474+
[0., 0., 1.]])
475+
xp_assert_equal(actual, expected)
476+
477+
actual = one_hot(xp.asarray([1, 2, 0]), 3)
478+
expected = xp.asarray([[0., 1., 0.],
479+
[0., 0., 1.],
480+
[1., 0., 0.]])
481+
xp_assert_equal(actual, expected)
482+
483+
def test_one_hot_out_of_bound(self, xp: ModuleType):
484+
# Undefined behavior. Either return zero, or raise.
485+
try:
486+
actual = one_hot(xp.asarray([-1, 3]), 3)
487+
except (IndexError, RuntimeError):
488+
return
489+
expected = xp.asarray([[0., 0., 0.],
490+
[0., 0., 0.]])
491+
xp_assert_equal(actual, expected)
492+
493+
def test_one_hot_custom_dtype(self, xp: ModuleType):
494+
actual = one_hot(xp.asarray([0, 1, 2]), 3, dtype=xp.bool)
495+
expected = xp.asarray([[True, False, False],
496+
[False, True, False],
497+
[False, False, True]])
498+
xp_assert_equal(actual, expected)
499+
500+
def test_one_hot_axis(self, xp: ModuleType):
501+
expected = xp.asarray([[0., 1., 0.],
502+
[0., 0., 1.],
503+
[1., 0., 0.]]).T
504+
actual = one_hot(xp.asarray([1, 2, 0]), 3, axis=0)
505+
xp_assert_equal(actual, expected)
506+
507+
actual = one_hot(xp.asarray([1, 2, 0]), 3, axis=-2)
508+
xp_assert_equal(actual, expected)
509+
510+
def test_one_hot_non_integer(self, xp: ModuleType):
511+
with pytest.raises((TypeError, RuntimeError, IndexError, DeprecationWarning)):
512+
one_hot(xp.asarray([1.0]), 3)
513+
514+
451515
@pytest.mark.skip_xp_backend(
452516
Backend.SPARSE, reason="read-only backend without .at support"
453517
)

0 commit comments

Comments
 (0)