Skip to content

Commit ff74d1f

Browse files
committed
Use delegation file
1 parent 27ff917 commit ff74d1f

File tree

3 files changed

+96
-87
lines changed

3 files changed

+96
-87
lines changed

src/array_api_extra/__init__.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
"""Extra array functions built on top of the array API standard."""
22

3-
from ._delegation import isclose, pad
3+
from ._delegation import isclose, one_hot, pad
44
from ._lib._at import at
55
from ._lib._funcs import (
66
apply_where,
@@ -11,7 +11,6 @@
1111
expand_dims,
1212
kron,
1313
nunique,
14-
one_hot,
1514
setdiff1d,
1615
sinc,
1716
)

src/array_api_extra/_delegation.py

Lines changed: 86 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,15 +9,17 @@
99
array_namespace,
1010
is_cupy_namespace,
1111
is_dask_namespace,
12+
is_jax_array,
1213
is_jax_namespace,
1314
is_numpy_namespace,
1415
is_pydata_sparse_namespace,
16+
is_torch_array,
1517
is_torch_namespace,
1618
)
1719
from ._lib._utils._helpers import asarrays
18-
from ._lib._utils._typing import Array
20+
from ._lib._utils._typing import Array, DType
1921

20-
__all__ = ["isclose", "pad"]
22+
__all__ = ["isclose", "one_hot", "pad"]
2123

2224

2325
def isclose(
@@ -112,6 +114,88 @@ def isclose(
112114
return _funcs.isclose(a, b, rtol=rtol, atol=atol, equal_nan=equal_nan, xp=xp)
113115

114116

117+
def one_hot(
118+
x: Array,
119+
/,
120+
num_classes: int,
121+
*,
122+
dtype: DType | None = None,
123+
axis: int = -1,
124+
xp: ModuleType | None = None,
125+
) -> Array:
126+
"""
127+
One-hot encode the given indices.
128+
129+
Each index in the input ``x`` is encoded as a vector of zeros of length
130+
``num_classes`` with the element at the given index set to one.
131+
132+
Parameters
133+
----------
134+
x : array
135+
An array with integral dtype having shape ``batch_dims``.
136+
num_classes : int
137+
Number of classes in the one-hot dimension.
138+
axis : int or tuple of ints, optional
139+
Position(s) in the expanded axes where the new axis is placed.
140+
xp : array_namespace, optional
141+
The standard-compatible namespace for `x`. Default: infer.
142+
143+
Returns
144+
-------
145+
array
146+
An array having the same shape as `x` except for a new axis at the position
147+
given by `axis` having size `num_classes`.
148+
149+
The dtype of the return value is the default float dtype (usually float64).
150+
151+
If ``x < 0`` or ``x >= num_classes``, then the result is undefined, may raise
152+
an exception, or may even cause a bad state. `x` is not checked.
153+
154+
Examples
155+
--------
156+
>>> xp.one_hot(jnp.array([1, 2, 0]), 3)
157+
Array([[0., 1., 0.],
158+
[0., 0., 1.],
159+
[1., 0., 0.]], dtype=float64)
160+
"""
161+
# Validate inputs.
162+
if xp is None:
163+
xp = array_namespace(x)
164+
x_size = x.size
165+
if x_size is None:
166+
msg = "x must have a concrete size."
167+
raise TypeError(msg)
168+
if not xp.isdtype(x.dtype, "integral"):
169+
msg = "x must have an integral dtype."
170+
raise TypeError(msg)
171+
if dtype is None:
172+
dtype = xp.empty(()).dtype # Default float dtype
173+
# Delegate where possible.
174+
if is_jax_namespace(xp):
175+
assert is_jax_array(x)
176+
from jax.nn import one_hot
177+
178+
return one_hot(x, num_classes, dtype=dtype, axis=axis)
179+
if is_torch_namespace(xp):
180+
assert is_torch_array(x)
181+
from torch.nn.functional import one_hot
182+
183+
x = xp.astype(x, xp.int64) # PyTorch only supports int64 here.
184+
try:
185+
out = one_hot(x, num_classes)
186+
except RuntimeError as e:
187+
raise IndexError from e
188+
out = xp.astype(out, dtype)
189+
else:
190+
out = _funcs.one_hot(x, num_classes, x_size=x_size, dtype=dtype, xp=xp)
191+
192+
if x.ndim != 1:
193+
out = xp.reshape(out, (*x.shape, num_classes))
194+
if axis != -1:
195+
out = xp.moveaxis(out, -1, axis)
196+
return out
197+
198+
115199
def pad(
116200
x: Array,
117201
pad_width: int | tuple[int, int] | Sequence[tuple[int, int]],

src/array_api_extra/_lib/_funcs.py

Lines changed: 9 additions & 83 deletions
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,7 @@
1212
array_namespace,
1313
is_dask_namespace,
1414
is_jax_array,
15-
is_jax_namespace,
1615
is_numpy_namespace,
17-
is_torch_array,
18-
is_torch_namespace,
1916
)
2017
from ._utils._helpers import (
2118
asarrays,
@@ -388,87 +385,16 @@ def one_hot(
388385
/,
389386
num_classes: int,
390387
*,
391-
dtype: DType | None = None,
392-
axis: int = -1,
393-
xp: ModuleType | None = None,
388+
x_size: int,
389+
dtype: DType,
390+
xp: ModuleType,
394391
) -> Array:
395-
"""
396-
One-hot encode the given indices.
397-
398-
Each index in the input ``x`` is encoded as a vector of zeros of length
399-
``num_classes`` with the element at the given index set to one.
400-
401-
Parameters
402-
----------
403-
x : array
404-
An array with integral dtype having shape ``batch_dims``.
405-
num_classes : int
406-
Number of classes in the one-hot dimension.
407-
axis : int or tuple of ints, optional
408-
Position(s) in the expanded axes where the new axis is placed.
409-
xp : array_namespace, optional
410-
The standard-compatible namespace for `x`. Default: infer.
411-
412-
Returns
413-
-------
414-
array
415-
An array having the same shape as `x` except for a new axis at the position
416-
given by `axis` having size `num_classes`.
417-
418-
The dtype of the return value is the default float dtype (usually float64).
419-
420-
If ``x < 0`` or ``x >= num_classes``, then the result is undefined, may raise
421-
an exception, or may even cause a bad state. `x` is not checked.
422-
423-
Examples
424-
--------
425-
>>> xp.one_hot(jnp.array([1, 2, 0]), 3)
426-
Array([[0., 1., 0.],
427-
[0., 0., 1.],
428-
[1., 0., 0.]], dtype=float64)
429-
"""
430-
if xp is None:
431-
xp = array_namespace(x)
432-
x_size = x.size
433-
if x_size is None:
434-
msg = "x must have a concrete size."
435-
raise TypeError(msg)
436-
if not xp.isdtype(x.dtype, "integral"):
437-
msg = "x must have an integral dtype."
438-
raise TypeError(msg)
439-
if is_jax_namespace(xp):
440-
assert is_jax_array(x)
441-
from jax.nn import one_hot
442-
443-
if dtype is None:
444-
dtype = xp.float_
445-
return one_hot(x, num_classes, dtype=dtype, axis=axis)
446-
if is_torch_namespace(xp):
447-
assert is_torch_array(x)
448-
from torch.nn.functional import one_hot
449-
450-
x = xp.astype(x, xp.int64) # PyTorch only supports int64 here.
451-
try:
452-
out = one_hot(x, num_classes)
453-
except RuntimeError as e:
454-
raise IndexError from e
455-
if dtype is None:
456-
dtype = xp.float
457-
out = xp.astype(out, dtype)
458-
else:
459-
if dtype is None:
460-
dtype = xp.empty(()).dtype # Default float dtype
461-
out = xp.zeros((x.size, num_classes), dtype=dtype)
462-
x_flattened = xp.reshape(x, (-1,))
463-
if is_numpy_namespace(xp):
464-
out = at(out)[xp.arange(x_size), x_flattened].set(1)
465-
else:
466-
for i in range(x_size):
467-
out = at(out)[i, int(x_flattened[i])].set(1)
468-
if x.ndim != 1:
469-
out = xp.reshape(out, (*x.shape, num_classes))
470-
if axis != -1:
471-
out = xp.moveaxis(out, -1, axis)
392+
out = xp.zeros((x.size, num_classes), dtype=dtype)
393+
x_flattened = xp.reshape(x, (-1,))
394+
if is_numpy_namespace(xp):
395+
return at(out)[xp.arange(x_size), x_flattened].set(1)
396+
for i in range(x_size):
397+
out = at(out)[i, int(x_flattened[i])].set(1)
472398
return out
473399

474400

0 commit comments

Comments
 (0)