Skip to content

Commit b87e0aa

Browse files
committed
Run more tests on array-api-strict and sparse
1 parent beac55b commit b87e0aa

File tree

3 files changed

+16
-6
lines changed

3 files changed

+16
-6
lines changed

tests/_helpers.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,9 @@
44
import pytest
55

66
wrapped_libraries = ["numpy", "cupy", "torch", "dask.array"]
7-
all_libraries = wrapped_libraries + ["jax.numpy"]
7+
all_libraries = wrapped_libraries + [
8+
"array_api_strict", "jax.numpy", "sparse"
9+
]
810

911
# `sparse` added array API support as of Python 3.10.
1012
if sys.version_info >= (3, 10):
@@ -20,9 +22,7 @@ def import_(library, wrapper=False):
2022
jax_numpy = import_module("jax.numpy")
2123
if not hasattr(jax_numpy, "__array_api_version__"):
2224
library = 'jax.experimental.array_api'
23-
elif library.startswith('sparse'):
24-
library = 'sparse'
25-
else:
25+
elif library in wrapped_libraries:
2626
library = 'array_api_compat.' + library
2727

2828
return import_module(library)

tests/test_all.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,10 @@
1818

1919
@pytest.mark.parametrize("library", ["common"] + wrapped_libraries)
2020
def test_all(library):
21-
import_(library, wrapper=True)
21+
if library == "common":
22+
import array_api_compat.common # noqa: F401
23+
else:
24+
import_(library, wrapper=True)
2225

2326
for mod_name in sys.modules:
2427
if not mod_name.startswith('array_api_compat.' + library):

tests/test_common.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
is_dask_array, is_jax_array, is_pydata_sparse_array,
44
is_numpy_namespace, is_cupy_namespace, is_torch_namespace,
55
is_dask_namespace, is_jax_namespace, is_pydata_sparse_namespace,
6+
is_array_api_strict_namespace,
67
)
78

89
from array_api_compat import device, is_array_api_obj, is_writeable_array, to_device
@@ -30,6 +31,7 @@
3031
'dask.array': 'is_dask_namespace',
3132
'jax.numpy': 'is_jax_namespace',
3233
'sparse': 'is_pydata_sparse_namespace',
34+
'array_api_strict': 'is_array_api_strict_namespace',
3335
}
3436

3537

@@ -71,7 +73,12 @@ def test_xp_is_array_generics(library):
7173
is_func = globals()[func]
7274
if is_func(x0):
7375
matches.append(library2)
74-
assert matches in ([library], ["numpy"])
76+
77+
if library == "array_api_strict":
78+
# There is no is_array_api_strict_array() function
79+
assert matches == []
80+
else:
81+
assert matches in ([library], ["numpy"])
7582

7683

7784
@pytest.mark.parametrize("library", all_libraries)

0 commit comments

Comments
 (0)