Skip to content

ENH: add new function one_hot #306

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jun 3, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/api-reference.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
isclose
kron
nunique
one_hot
pad
setdiff1d
sinc
Expand Down
3 changes: 2 additions & 1 deletion src/array_api_extra/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""Extra array functions built on top of the array API standard."""

from ._delegation import isclose, pad
from ._delegation import isclose, one_hot, pad
from ._lib._at import at
from ._lib._funcs import (
apply_where,
Expand Down Expand Up @@ -34,6 +34,7 @@
"kron",
"lazy_apply",
"nunique",
"one_hot",
"pad",
"setdiff1d",
"sinc",
Expand Down
82 changes: 80 additions & 2 deletions src/array_api_extra/_delegation.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,11 @@
is_pydata_sparse_namespace,
is_torch_namespace,
)
from ._lib._utils._compat import device as get_device
from ._lib._utils._helpers import asarrays
from ._lib._utils._typing import Array
from ._lib._utils._typing import Array, DType

__all__ = ["isclose", "pad"]
__all__ = ["isclose", "one_hot", "pad"]


def isclose(
Expand Down Expand Up @@ -112,6 +113,83 @@ def isclose(
return _funcs.isclose(a, b, rtol=rtol, atol=atol, equal_nan=equal_nan, xp=xp)


def one_hot(
x: Array,
/,
num_classes: int,
*,
dtype: DType | None = None,
axis: int = -1,
xp: ModuleType | None = None,
) -> Array:
"""
One-hot encode the given indices.

Each index in the input `x` is encoded as a vector of zeros of length `num_classes`
with the element at the given index set to one.

Parameters
----------
x : array
An array with integral dtype whose values are between `0` and `num_classes - 1`.
num_classes : int
Number of classes in the one-hot dimension.
dtype : DType, optional
The dtype of the return value. Defaults to the default float dtype (usually
float64).
axis : int, optional
Position in the expanded axes where the new axis is placed. Default: -1.
xp : array_namespace, optional
The standard-compatible namespace for `x`. Default: infer.

Returns
-------
array
An array having the same shape as `x` except for a new axis at the position
given by `axis` having size `num_classes`. If `axis` is unspecified, it
defaults to -1, which appends a new axis.

If ``x < 0`` or ``x >= num_classes``, then the result is undefined, may raise
an exception, or may even cause a bad state. `x` is not checked.

Examples
--------
>>> import array_api_extra as xpx
>>> import array_api_strict as xp
>>> xpx.one_hot(xp.asarray([1, 2, 0]), 3)
Array([[0., 1., 0.],
[0., 0., 1.],
[1., 0., 0.]], dtype=array_api_strict.float64)
"""
# Validate inputs.
if xp is None:
xp = array_namespace(x)
if not xp.isdtype(x.dtype, "integral"):
msg = "x must have an integral dtype."
raise TypeError(msg)
if dtype is None:
dtype = _funcs.default_dtype(xp, device=get_device(x))
# Delegate where possible.
if is_jax_namespace(xp):
from jax.nn import one_hot as jax_one_hot

return jax_one_hot(x, num_classes, dtype=dtype, axis=axis)
if is_torch_namespace(xp):
from torch.nn.functional import one_hot as torch_one_hot

x = xp.astype(x, xp.int64) # PyTorch only supports int64 here.
try:
out = torch_one_hot(x, num_classes)
except RuntimeError as e:
raise IndexError from e
else:
out = _funcs.one_hot(x, num_classes, xp=xp)
out = xp.astype(out, dtype, copy=False)
if axis != -1:
out = xp.moveaxis(out, -1, axis)
return out


def pad(
x: Array,
pad_width: int | tuple[int, int] | Sequence[tuple[int, int]],
Expand Down
17 changes: 17 additions & 0 deletions src/array_api_extra/_lib/_funcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -375,6 +375,23 @@ def cov(m: Array, /, *, xp: ModuleType | None = None) -> Array:
return xp.squeeze(c, axis=axes)


def one_hot(
x: Array,
/,
num_classes: int,
*,
xp: ModuleType,
) -> Array: # numpydoc ignore=PR01,RT01
"""See docstring in `array_api_extra._delegation.py`."""
# TODO: Benchmark whether this is faster on the NumPy backend:
# if is_numpy_array(x):
# out = xp.zeros((x.size, num_classes), dtype=dtype)
# out[xp.arange(x.size), xp.reshape(x, (-1,))] = 1
# return xp.reshape(out, (*x.shape, num_classes))
range_num_classes = xp.arange(num_classes, dtype=x.dtype, device=_compat.device(x))
return x[..., xp.newaxis] == range_num_classes


def create_diagonal(
x: Array, /, *, offset: int = 0, xp: ModuleType | None = None
) -> Array:
Expand Down
94 changes: 94 additions & 0 deletions tests/test_funcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
isclose,
kron,
nunique,
one_hot,
pad,
setdiff1d,
sinc,
Expand All @@ -45,6 +46,7 @@
lazy_xp_function(expand_dims)
lazy_xp_function(kron)
lazy_xp_function(nunique)
lazy_xp_function(one_hot)
lazy_xp_function(pad)
# FIXME calls in1d which calls xp.unique_values without size
lazy_xp_function(setdiff1d, jax_jit=False)
Expand Down Expand Up @@ -449,6 +451,98 @@ def test_xp(self, xp: ModuleType):
)


@pytest.mark.skip_xp_backend(Backend.SPARSE, reason="backend doesn't have arange")
class TestOneHot:
@pytest.mark.parametrize("n_dim", range(4))
@pytest.mark.parametrize("num_classes", [1, 3, 10])
def test_dims_and_classes(self, xp: ModuleType, n_dim: int, num_classes: int):
shape = tuple(range(2, 2 + n_dim))
rng = np.random.default_rng(2347823)
np_x = rng.integers(num_classes, size=shape)
x = xp.asarray(np_x)
y = one_hot(x, num_classes)
assert y.shape == (*x.shape, num_classes)
for *i_list, j in ndindex(*shape, num_classes):
i = tuple(i_list)
assert float(y[(*i, j)]) == (int(x[i]) == j)

def test_basic(self, xp: ModuleType):
actual = one_hot(xp.asarray([0, 1, 2]), 3)
expected = xp.asarray([[1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 1.0]])
xp_assert_equal(actual, expected)

actual = one_hot(xp.asarray([1, 2, 0]), 3)
expected = xp.asarray([[0.0, 1.0, 0.0], [0.0, 0.0, 1.0], [1.0, 0.0, 0.0]])
xp_assert_equal(actual, expected)

def test_2d(self, xp: ModuleType):
actual = one_hot(xp.asarray([[2, 1, 0], [1, 0, 2]]), 3, axis=1)
expected = xp.asarray(
[
[[0.0, 0.0, 1.0], [0.0, 1.0, 0.0], [1.0, 0.0, 0.0]],
[[0.0, 1.0, 0.0], [1.0, 0.0, 0.0], [0.0, 0.0, 1.0]],
]
)
xp_assert_equal(actual, expected)

@pytest.mark.skip_xp_backend(
Backend.ARRAY_API_STRICTEST, reason="backend doesn't support Boolean indexing"
)
def test_abstract_size(self, xp: ModuleType):
x = xp.arange(5)
x = x[x > 2]
actual = one_hot(x, 5)
expected = xp.asarray([[0.0, 0.0, 0.0, 1.0, 0.0], [0.0, 0.0, 0.0, 0.0, 1.0]])
xp_assert_equal(actual, expected)

@pytest.mark.skip_xp_backend(
Backend.TORCH_GPU, reason="Puts Pytorch into a bad state."
)
def test_out_of_bound(self, xp: ModuleType):
# Undefined behavior. Either return zero, or raise.
try:
actual = one_hot(xp.asarray([-1, 3]), 3)
except IndexError:
return
expected = xp.asarray([[0.0, 0.0, 0.0], [0.0, 0.0, 0.0]])
xp_assert_equal(actual, expected)

@pytest.mark.parametrize(
"int_dtype",
["int8", "int16", "int32", "int64", "uint8", "uint16", "uint32", "uint64"],
)
def test_int_types(self, xp: ModuleType, int_dtype: str):
dtype = getattr(xp, int_dtype)
x = xp.asarray([0, 1, 2], dtype=dtype)
actual = one_hot(x, 3)
expected = xp.asarray([[1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 1.0]])
xp_assert_equal(actual, expected)

def test_custom_dtype(self, xp: ModuleType):
actual = one_hot(xp.asarray([0, 1, 2], dtype=xp.int32), 3, dtype=xp.bool)
expected = xp.asarray(
[[True, False, False], [False, True, False], [False, False, True]]
)
xp_assert_equal(actual, expected)

def test_axis(self, xp: ModuleType):
expected = xp.asarray([[0.0, 1.0, 0.0], [0.0, 0.0, 1.0], [1.0, 0.0, 0.0]]).T
actual = one_hot(xp.asarray([1, 2, 0]), 3, axis=0)
xp_assert_equal(actual, expected)

actual = one_hot(xp.asarray([1, 2, 0]), 3, axis=-2)
xp_assert_equal(actual, expected)

def test_non_integer(self, xp: ModuleType):
with pytest.raises(TypeError):
_ = one_hot(xp.asarray([1.0]), 3)

def test_device(self, xp: ModuleType, device: Device):
x = xp.asarray([0, 1, 2], device=device)
y = one_hot(x, 3)
assert get_device(y) == device


@pytest.mark.skip_xp_backend(
Backend.SPARSE, reason="read-only backend without .at support"
)
Expand Down