Skip to content

Commit 02544de

Browse files
committed
Use iterative approach when non-numpy
1 parent 2dc1e07 commit 02544de

File tree

2 files changed

+19
-5
lines changed

2 files changed

+19
-5
lines changed

src/array_api_extra/_lib/_funcs.py

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,10 @@
1212
array_namespace,
1313
is_dask_namespace,
1414
is_jax_array,
15+
is_jax_namespace,
16+
is_numpy_namespace,
1517
is_torch_array,
18+
is_torch_namespace,
1619
)
1720
from ._utils._helpers import (
1821
asarrays,
@@ -391,22 +394,33 @@ def one_hot(
391394
) -> Array:
392395
if xp is None:
393396
xp = array_namespace(x)
394-
if is_jax_array(x):
397+
if is_jax_namespace(xp):
398+
assert is_jax_array(x)
395399
from jax.nn import one_hot
396400
if dtype is None:
397401
dtype = xp.float_
398402
return one_hot(x, num_classes, dtype=dtype, axis=axis)
399-
if is_torch_array(x):
403+
if is_torch_namespace(xp):
404+
assert is_torch_array(x)
400405
from torch.nn.functional import one_hot
401406
out = one_hot(x, num_classes)
402407
if dtype is None:
403408
dtype = xp.float
404409
out = xp.astype(out, dtype)
405410
else:
406411
if dtype is None:
407-
dtype = xp.float64
412+
dtype = xp.empty(()).dtype # Default float dtype
408413
out = xp.zeros((x.size, num_classes), dtype=dtype)
409-
at(out)[xp.arange(x.size), xp.reshape(x, (-1,))].set(1)
414+
x_flattened = xp.reshape(x, (-1,))
415+
x_size = x.size
416+
if x_size is None:
417+
msg = "x must have a concrete size."
418+
raise TypeError(msg)
419+
if is_numpy_namespace(xp):
420+
at(out)[xp.arange(x_size), x_flattened].set(1)
421+
else:
422+
for i in range(x_size):
423+
at(out)[i, int(x_flattened[i])].set(1)
410424
if x.ndim != 1:
411425
out = xp.reshape(out, (*x.shape, num_classes))
412426
if axis != -1:

tests/test_funcs.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -465,7 +465,7 @@ def test_dims_and_classes(self, xp: ModuleType, n_dim: int, num_classes: int):
465465
assert y.shape == (*x.shape, num_classes)
466466
for *i_list, j in ndindex(*shape, num_classes):
467467
i = tuple(i_list)
468-
assert y[*i, j] == (x[i] == j)
468+
assert float(y[*i, j]) == (int(x[i]) == j)
469469

470470
def test_basic(self, xp: ModuleType):
471471
actual = one_hot(xp.asarray([0, 1, 2]), 3)

0 commit comments

Comments
 (0)