Skip to content

Commit 0de7040

Browse files
committed
Fix missing docstring.
1 parent ff74d1f commit 0de7040

File tree

2 files changed

+4
-1
lines changed

2 files changed

+4
-1
lines changed

src/array_api_extra/_delegation.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -135,6 +135,9 @@ def one_hot(
135135
An array with integral dtype having shape ``batch_dims``.
136136
num_classes : int
137137
Number of classes in the one-hot dimension.
138+
dtype : DType, optional
139+
The dtype of the return value. Defaults to the default float dtype (usually
140+
float64).
138141
axis : int or tuple of ints, optional
139142
Position(s) in the expanded axes where the new axis is placed.
140143
xp : array_namespace, optional
@@ -146,7 +149,6 @@ def one_hot(
146149
An array having the same shape as `x` except for a new axis at the position
147150
given by `axis` having size `num_classes`.
148151
149-
The dtype of the return value is the default float dtype (usually float64).
150152
151153
If ``x < 0`` or ``x >= num_classes``, then the result is undefined, may raise
152154
an exception, or may even cause a bad state. `x` is not checked.

src/array_api_extra/_lib/_funcs.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -389,6 +389,7 @@ def one_hot(
389389
dtype: DType,
390390
xp: ModuleType,
391391
) -> Array:
392+
"""Helper for _delegation.one_hot."""
392393
out = xp.zeros((x.size, num_classes), dtype=dtype)
393394
x_flattened = xp.reshape(x, (-1,))
394395
if is_numpy_namespace(xp):

0 commit comments

Comments
 (0)