Skip to content

Commit 027b60a

Browse files
committed
Test abstract array
1 parent 8dba3f8 commit 027b60a

File tree

1 file changed

+11
-0
lines changed

1 file changed

+11
-0
lines changed

tests/test_funcs.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import hypothesis.extra.numpy as npst
88
import numpy as np
99
import pytest
10+
from array_api_compat import is_jax_namespace
1011
from hypothesis import given
1112
from hypothesis import strategies as st
1213

@@ -470,6 +471,16 @@ def test_dims_and_classes(self, xp: ModuleType, n_dim: int, num_classes: int):
470471
i = tuple(i_list)
471472
assert float(y[(*i, j)]) == (int(x[i]) == j)
472473

474+
def test_empty_shape(self, xp: ModuleType):
475+
if not is_jax_namespace(xp):
476+
pytest.skip("backend does not support abstract arrays")
477+
import jax
478+
import jax.numpy as jnp
479+
480+
abstract_input = jax.ShapeDtypeStruct(shape=(None, 784), dtype=jnp.float32)
481+
with pytest.raises(TypeError):
482+
_ = one_hot(abstract_input, 3)
483+
473484
def test_basic(self, xp: ModuleType):
474485
actual = one_hot(xp.asarray([0, 1, 2]), 3)
475486
expected = xp.asarray([[1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 1.0]])

0 commit comments

Comments
 (0)