Skip to content

Commit 27ff917

Browse files
committed
Add docstring
1 parent 1faccde commit 27ff917

File tree

1 file changed

+35
-0
lines changed

1 file changed

+35
-0
lines changed

src/array_api_extra/_lib/_funcs.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -392,6 +392,41 @@ def one_hot(
392392
axis: int = -1,
393393
xp: ModuleType | None = None,
394394
) -> Array:
395+
"""
396+
One-hot encode the given indices.
397+
398+
Each index in the input ``x`` is encoded as a vector of zeros of length
399+
``num_classes`` with the element at the given index set to one.
400+
401+
Parameters
402+
----------
403+
x : array
404+
An array with integral dtype having shape ``batch_dims``.
405+
num_classes : int
406+
Number of classes in the one-hot dimension.
407+
axis : int or tuple of ints, optional
408+
Position(s) in the expanded axes where the new axis is placed.
409+
xp : array_namespace, optional
410+
The standard-compatible namespace for `x`. Default: infer.
411+
412+
Returns
413+
-------
414+
array
415+
An array having the same shape as `x` except for a new axis at the position
416+
given by `axis` having size `num_classes`.
417+
418+
The dtype of the return value is the default float dtype (usually float64).
419+
420+
If ``x < 0`` or ``x >= num_classes``, then the result is undefined, may raise
421+
an exception, or may even cause a bad state. `x` is not checked.
422+
423+
Examples
424+
--------
425+
>>> xp.one_hot(jnp.array([1, 2, 0]), 3)
426+
Array([[0., 1., 0.],
427+
[0., 0., 1.],
428+
[1., 0., 0.]], dtype=float64)
429+
"""
395430
if xp is None:
396431
xp = array_namespace(x)
397432
x_size = x.size

0 commit comments

Comments
 (0)