Skip to content

Commit 8dba3f8

Browse files
committed
Skip Dask, and use flags
1 parent eea5935 commit 8dba3f8

File tree

3 files changed

+19
-4
lines changed

3 files changed

+19
-4
lines changed

src/array_api_extra/_delegation.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -188,7 +188,15 @@ def one_hot(
188188
raise IndexError from e
189189
out = xp.astype(out, dtype)
190190
else:
191-
out = _funcs.one_hot(x, num_classes, x_size=x_size, dtype=dtype, xp=xp)
191+
out = _funcs.one_hot(
192+
x,
193+
num_classes,
194+
x_size=x_size,
195+
dtype=dtype,
196+
xp=xp,
197+
supports_fancy_indexing=is_numpy_namespace(xp),
198+
supports_array_indexing=is_dask_namespace(xp),
199+
)
192200

193201
if axis != -1:
194202
out = xp.moveaxis(out, -1, axis)

src/array_api_extra/_lib/_funcs.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212
array_namespace,
1313
is_dask_namespace,
1414
is_jax_array,
15-
is_numpy_namespace,
1615
)
1716
from ._utils._helpers import (
1817
asarrays,
@@ -385,17 +384,22 @@ def one_hot(
385384
/,
386385
num_classes: int,
387386
*,
387+
supports_fancy_indexing: bool = False,
388+
supports_array_indexing: bool = False,
388389
x_size: int,
389390
dtype: DType,
390391
xp: ModuleType,
391392
) -> Array: # numpydoc ignore=PR01,RT01
392393
"""See docstring in `array_api_extra._delegation.py`."""
393394
out = xp.zeros((x.size, num_classes), dtype=dtype)
394395
x_flattened = xp.reshape(x, (-1,))
395-
if is_numpy_namespace(xp):
396+
if supports_fancy_indexing:
396397
out = at(out)[xp.arange(x_size), x_flattened].set(1)
397398
for i in range(x_size):
398-
out = at(out)[i, int(x_flattened[i])].set(1)
399+
x_i = x_flattened[i]
400+
if not supports_array_indexing:
401+
x_i = int(x_i)
402+
out = at(out)[i, x_i].set(1)
399403
if x.ndim != 1:
400404
out = xp.reshape(out, (*x.shape, num_classes))
401405
return out

tests/test_funcs.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -453,6 +453,9 @@ def test_xp(self, xp: ModuleType):
453453
@pytest.mark.skip_xp_backend(
454454
Backend.SPARSE, reason="read-only backend without .at support"
455455
)
456+
@pytest.mark.skip_xp_backend(
457+
Backend.DASK, reason="backend does not yet support indexed assignment"
458+
)
456459
class TestOneHot:
457460
@pytest.mark.parametrize("n_dim", range(4))
458461
@pytest.mark.parametrize("num_classes", [1, 3, 10])

0 commit comments

Comments
 (0)