diff --git a/docs/api-reference.md b/docs/api-reference.md index 8e9375d0..38d0d26e 100644 --- a/docs/api-reference.md +++ b/docs/api-reference.md @@ -17,6 +17,7 @@ isclose kron nunique + one_hot pad setdiff1d sinc diff --git a/src/array_api_extra/__init__.py b/src/array_api_extra/__init__.py index 5cfe8594..ba9de3b4 100644 --- a/src/array_api_extra/__init__.py +++ b/src/array_api_extra/__init__.py @@ -1,6 +1,6 @@ """Extra array functions built on top of the array API standard.""" -from ._delegation import isclose, pad +from ._delegation import isclose, one_hot, pad from ._lib._at import at from ._lib._funcs import ( apply_where, @@ -34,6 +34,7 @@ "kron", "lazy_apply", "nunique", + "one_hot", "pad", "setdiff1d", "sinc", diff --git a/src/array_api_extra/_delegation.py b/src/array_api_extra/_delegation.py index b52c23ae..756841c8 100644 --- a/src/array_api_extra/_delegation.py +++ b/src/array_api_extra/_delegation.py @@ -14,10 +14,11 @@ is_pydata_sparse_namespace, is_torch_namespace, ) +from ._lib._utils._compat import device as get_device from ._lib._utils._helpers import asarrays -from ._lib._utils._typing import Array +from ._lib._utils._typing import Array, DType -__all__ = ["isclose", "pad"] +__all__ = ["isclose", "one_hot", "pad"] def isclose( @@ -112,6 +113,83 @@ def isclose( return _funcs.isclose(a, b, rtol=rtol, atol=atol, equal_nan=equal_nan, xp=xp) +def one_hot( + x: Array, + /, + num_classes: int, + *, + dtype: DType | None = None, + axis: int = -1, + xp: ModuleType | None = None, +) -> Array: + """ + One-hot encode the given indices. + + Each index in the input `x` is encoded as a vector of zeros of length `num_classes` + with the element at the given index set to one. + + Parameters + ---------- + x : array + An array with integral dtype whose values are between `0` and `num_classes - 1`. + num_classes : int + Number of classes in the one-hot dimension. + dtype : DType, optional + The dtype of the return value. Defaults to the default float dtype (usually + float64). + axis : int, optional + Position in the expanded axes where the new axis is placed. Default: -1. + xp : array_namespace, optional + The standard-compatible namespace for `x`. Default: infer. + + Returns + ------- + array + An array having the same shape as `x` except for a new axis at the position + given by `axis` having size `num_classes`. If `axis` is unspecified, it + defaults to -1, which appends a new axis. + + If ``x < 0`` or ``x >= num_classes``, then the result is undefined, may raise + an exception, or may even cause a bad state. `x` is not checked. + + Examples + -------- + >>> import array_api_extra as xpx + >>> import array_api_strict as xp + >>> xpx.one_hot(xp.asarray([1, 2, 0]), 3) + Array([[0., 1., 0.], + [0., 0., 1.], + [1., 0., 0.]], dtype=array_api_strict.float64) + """ + # Validate inputs. + if xp is None: + xp = array_namespace(x) + if not xp.isdtype(x.dtype, "integral"): + msg = "x must have an integral dtype." + raise TypeError(msg) + if dtype is None: + dtype = _funcs.default_dtype(xp, device=get_device(x)) + # Delegate where possible. + if is_jax_namespace(xp): + from jax.nn import one_hot as jax_one_hot + + return jax_one_hot(x, num_classes, dtype=dtype, axis=axis) + if is_torch_namespace(xp): + from torch.nn.functional import one_hot as torch_one_hot + + x = xp.astype(x, xp.int64) # PyTorch only supports int64 here. + try: + out = torch_one_hot(x, num_classes) + except RuntimeError as e: + raise IndexError from e + else: + out = _funcs.one_hot(x, num_classes, xp=xp) + out = xp.astype(out, dtype, copy=False) + if axis != -1: + out = xp.moveaxis(out, -1, axis) + return out + + def pad( x: Array, pad_width: int | tuple[int, int] | Sequence[tuple[int, int]], diff --git a/src/array_api_extra/_lib/_funcs.py b/src/array_api_extra/_lib/_funcs.py index be703fb5..69dfe6a4 100644 --- a/src/array_api_extra/_lib/_funcs.py +++ b/src/array_api_extra/_lib/_funcs.py @@ -375,6 +375,23 @@ def cov(m: Array, /, *, xp: ModuleType | None = None) -> Array: return xp.squeeze(c, axis=axes) +def one_hot( + x: Array, + /, + num_classes: int, + *, + xp: ModuleType, +) -> Array: # numpydoc ignore=PR01,RT01 + """See docstring in `array_api_extra._delegation.py`.""" + # TODO: Benchmark whether this is faster on the NumPy backend: + # if is_numpy_array(x): + # out = xp.zeros((x.size, num_classes), dtype=dtype) + # out[xp.arange(x.size), xp.reshape(x, (-1,))] = 1 + # return xp.reshape(out, (*x.shape, num_classes)) + range_num_classes = xp.arange(num_classes, dtype=x.dtype, device=_compat.device(x)) + return x[..., xp.newaxis] == range_num_classes + + def create_diagonal( x: Array, /, *, offset: int = 0, xp: ModuleType | None = None ) -> Array: diff --git a/tests/test_funcs.py b/tests/test_funcs.py index b89c7441..c8bea859 100644 --- a/tests/test_funcs.py +++ b/tests/test_funcs.py @@ -22,6 +22,7 @@ isclose, kron, nunique, + one_hot, pad, setdiff1d, sinc, @@ -45,6 +46,7 @@ lazy_xp_function(expand_dims) lazy_xp_function(kron) lazy_xp_function(nunique) +lazy_xp_function(one_hot) lazy_xp_function(pad) # FIXME calls in1d which calls xp.unique_values without size lazy_xp_function(setdiff1d, jax_jit=False) @@ -449,6 +451,98 @@ def test_xp(self, xp: ModuleType): ) +@pytest.mark.skip_xp_backend(Backend.SPARSE, reason="backend doesn't have arange") +class TestOneHot: + @pytest.mark.parametrize("n_dim", range(4)) + @pytest.mark.parametrize("num_classes", [1, 3, 10]) + def test_dims_and_classes(self, xp: ModuleType, n_dim: int, num_classes: int): + shape = tuple(range(2, 2 + n_dim)) + rng = np.random.default_rng(2347823) + np_x = rng.integers(num_classes, size=shape) + x = xp.asarray(np_x) + y = one_hot(x, num_classes) + assert y.shape == (*x.shape, num_classes) + for *i_list, j in ndindex(*shape, num_classes): + i = tuple(i_list) + assert float(y[(*i, j)]) == (int(x[i]) == j) + + def test_basic(self, xp: ModuleType): + actual = one_hot(xp.asarray([0, 1, 2]), 3) + expected = xp.asarray([[1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 1.0]]) + xp_assert_equal(actual, expected) + + actual = one_hot(xp.asarray([1, 2, 0]), 3) + expected = xp.asarray([[0.0, 1.0, 0.0], [0.0, 0.0, 1.0], [1.0, 0.0, 0.0]]) + xp_assert_equal(actual, expected) + + def test_2d(self, xp: ModuleType): + actual = one_hot(xp.asarray([[2, 1, 0], [1, 0, 2]]), 3, axis=1) + expected = xp.asarray( + [ + [[0.0, 0.0, 1.0], [0.0, 1.0, 0.0], [1.0, 0.0, 0.0]], + [[0.0, 1.0, 0.0], [1.0, 0.0, 0.0], [0.0, 0.0, 1.0]], + ] + ) + xp_assert_equal(actual, expected) + + @pytest.mark.skip_xp_backend( + Backend.ARRAY_API_STRICTEST, reason="backend doesn't support Boolean indexing" + ) + def test_abstract_size(self, xp: ModuleType): + x = xp.arange(5) + x = x[x > 2] + actual = one_hot(x, 5) + expected = xp.asarray([[0.0, 0.0, 0.0, 1.0, 0.0], [0.0, 0.0, 0.0, 0.0, 1.0]]) + xp_assert_equal(actual, expected) + + @pytest.mark.skip_xp_backend( + Backend.TORCH_GPU, reason="Puts Pytorch into a bad state." + ) + def test_out_of_bound(self, xp: ModuleType): + # Undefined behavior. Either return zero, or raise. + try: + actual = one_hot(xp.asarray([-1, 3]), 3) + except IndexError: + return + expected = xp.asarray([[0.0, 0.0, 0.0], [0.0, 0.0, 0.0]]) + xp_assert_equal(actual, expected) + + @pytest.mark.parametrize( + "int_dtype", + ["int8", "int16", "int32", "int64", "uint8", "uint16", "uint32", "uint64"], + ) + def test_int_types(self, xp: ModuleType, int_dtype: str): + dtype = getattr(xp, int_dtype) + x = xp.asarray([0, 1, 2], dtype=dtype) + actual = one_hot(x, 3) + expected = xp.asarray([[1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 1.0]]) + xp_assert_equal(actual, expected) + + def test_custom_dtype(self, xp: ModuleType): + actual = one_hot(xp.asarray([0, 1, 2], dtype=xp.int32), 3, dtype=xp.bool) + expected = xp.asarray( + [[True, False, False], [False, True, False], [False, False, True]] + ) + xp_assert_equal(actual, expected) + + def test_axis(self, xp: ModuleType): + expected = xp.asarray([[0.0, 1.0, 0.0], [0.0, 0.0, 1.0], [1.0, 0.0, 0.0]]).T + actual = one_hot(xp.asarray([1, 2, 0]), 3, axis=0) + xp_assert_equal(actual, expected) + + actual = one_hot(xp.asarray([1, 2, 0]), 3, axis=-2) + xp_assert_equal(actual, expected) + + def test_non_integer(self, xp: ModuleType): + with pytest.raises(TypeError): + _ = one_hot(xp.asarray([1.0]), 3) + + def test_device(self, xp: ModuleType, device: Device): + x = xp.asarray([0, 1, 2], device=device) + y = one_hot(x, 3) + assert get_device(y) == device + + @pytest.mark.skip_xp_backend( Backend.SPARSE, reason="read-only backend without .at support" )