File tree Expand file tree Collapse file tree 2 files changed +4
-1
lines changed Expand file tree Collapse file tree 2 files changed +4
-1
lines changed Original file line number Diff line number Diff line change @@ -135,6 +135,9 @@ def one_hot(
135
135
An array with integral dtype having shape ``batch_dims``.
136
136
num_classes : int
137
137
Number of classes in the one-hot dimension.
138
+ dtype : DType, optional
139
+ The dtype of the return value. Defaults to the default float dtype (usually
140
+ float64).
138
141
axis : int or tuple of ints, optional
139
142
Position(s) in the expanded axes where the new axis is placed.
140
143
xp : array_namespace, optional
@@ -146,7 +149,6 @@ def one_hot(
146
149
An array having the same shape as `x` except for a new axis at the position
147
150
given by `axis` having size `num_classes`.
148
151
149
- The dtype of the return value is the default float dtype (usually float64).
150
152
151
153
If ``x < 0`` or ``x >= num_classes``, then the result is undefined, may raise
152
154
an exception, or may even cause a bad state. `x` is not checked.
Original file line number Diff line number Diff line change @@ -389,6 +389,7 @@ def one_hot(
389
389
dtype : DType ,
390
390
xp : ModuleType ,
391
391
) -> Array :
392
+ """Helper for _delegation.one_hot."""
392
393
out = xp .zeros ((x .size , num_classes ), dtype = dtype )
393
394
x_flattened = xp .reshape (x , (- 1 ,))
394
395
if is_numpy_namespace (xp ):
You can’t perform that action at this time.
0 commit comments