File tree Expand file tree Collapse file tree 1 file changed +35
-0
lines changed Expand file tree Collapse file tree 1 file changed +35
-0
lines changed Original file line number Diff line number Diff line change @@ -392,6 +392,41 @@ def one_hot(
392
392
axis : int = - 1 ,
393
393
xp : ModuleType | None = None ,
394
394
) -> Array :
395
+ """
396
+ One-hot encode the given indices.
397
+
398
+ Each index in the input ``x`` is encoded as a vector of zeros of length
399
+ ``num_classes`` with the element at the given index set to one.
400
+
401
+ Parameters
402
+ ----------
403
+ x : array
404
+ An array with integral dtype having shape ``batch_dims``.
405
+ num_classes : int
406
+ Number of classes in the one-hot dimension.
407
+ axis : int or tuple of ints, optional
408
+ Position(s) in the expanded axes where the new axis is placed.
409
+ xp : array_namespace, optional
410
+ The standard-compatible namespace for `x`. Default: infer.
411
+
412
+ Returns
413
+ -------
414
+ array
415
+ An array having the same shape as `x` except for a new axis at the position
416
+ given by `axis` having size `num_classes`.
417
+
418
+ The dtype of the return value is the default float dtype (usually float64).
419
+
420
+ If ``x < 0`` or ``x >= num_classes``, then the result is undefined, may raise
421
+ an exception, or may even cause a bad state. `x` is not checked.
422
+
423
+ Examples
424
+ --------
425
+ >>> xp.one_hot(jnp.array([1, 2, 0]), 3)
426
+ Array([[0., 1., 0.],
427
+ [0., 0., 1.],
428
+ [1., 0., 0.]], dtype=float64)
429
+ """
395
430
if xp is None :
396
431
xp = array_namespace (x )
397
432
x_size = x .size
You can’t perform that action at this time.
0 commit comments