Skip to content

Commit bf986cf

Browse files
committed
Run tests on ndonnx
1 parent b87e0aa commit bf986cf

File tree

5 files changed

+45
-24
lines changed

5 files changed

+45
-24
lines changed

array_api_compat/common/_helpers.py

Lines changed: 5 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -627,17 +627,13 @@ def device(x: Array, /) -> Device:
627627
to_device : Move array data to a different device.
628628
629629
"""
630-
if is_numpy_array(x):
630+
if is_numpy_array(x) or is_ndonnx_array(x):
631631
return "cpu"
632632
elif is_dask_array(x):
633633
# Peek at the metadata of the jax array to determine type
634-
try:
635-
import numpy as np
636-
if isinstance(x._meta, np.ndarray):
637-
# Must be on CPU since backed by numpy
638-
return "cpu"
639-
except ImportError:
640-
pass
634+
if is_numpy_array(x._meta):
635+
# Must be on CPU since backed by numpy
636+
return "cpu"
641637
return _DASK_DEVICE
642638
elif is_jax_array(x):
643639
# JAX has .device() as a method, but it is being deprecated so that it
@@ -758,7 +754,7 @@ def to_device(x: Array, device: Device, /, *, stream: Optional[Union[int, Any]]
758754
device : Hardware device the array data resides on.
759755
760756
"""
761-
if is_numpy_array(x):
757+
if is_numpy_array(x) or is_ndonnx_array(x):
762758
if stream is not None:
763759
raise ValueError("The stream argument to to_device() is not supported")
764760
if device == 'cpu':
@@ -780,7 +776,6 @@ def to_device(x: Array, device: Device, /, *, stream: Optional[Union[int, Any]]
780776
if not hasattr(x, "__array_namespace__"):
781777
# In JAX v0.4.31 and older, this import adds to_device method to x.
782778
import jax.experimental.array_api # noqa: F401
783-
return x.to_device(device, stream=stream)
784779
elif is_pydata_sparse_array(x) and device == _device(x):
785780
# Perform trivial check to return the same array if
786781
# device is same instead of err-ing.

docs/supported-array-libraries.md

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -137,3 +137,14 @@ The minimum supported Dask version is 2023.12.0.
137137
## [Sparse](https://sparse.pydata.org/en/stable/)
138138

139139
Similar to JAX, `sparse` Array API support is contained directly in `sparse`.
140+
141+
(ndonnx-support)=
142+
## [ndonnx](https://github.com/quantco/ndonnx)
143+
144+
Similar to JAX, `ndonnx` Array API support is contained directly in `ndonnx`.
145+
146+
(array-api-strict-support)=
147+
## [array-api-strict](https://data-apis.org/array-api-strict/)
148+
149+
array-api-strict exists only to test support for the Array API, so obviously it
150+
does not need any wrappers.

tests/_helpers.py

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,14 @@
11
from importlib import import_module
2-
import sys
32

43
import pytest
54

65
wrapped_libraries = ["numpy", "cupy", "torch", "dask.array"]
76
all_libraries = wrapped_libraries + [
8-
"array_api_strict", "jax.numpy", "sparse"
7+
"array_api_strict", "jax.numpy", "ndonnx", "sparse"
98
]
109

11-
# `sparse` added array API support as of Python 3.10.
12-
if sys.version_info >= (3, 10):
13-
all_libraries.append('sparse')
14-
1510
def import_(library, wrapper=False):
16-
if library == 'cupy':
11+
if library in ('cupy', 'ndonnx'):
1712
pytest.importorskip(library)
1813
if wrapper:
1914
if 'jax' in library:

tests/test_array_namespace.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,14 +14,17 @@
1414

1515
@pytest.mark.parametrize("use_compat", [True, False, None])
1616
@pytest.mark.parametrize("api_version", [None, "2021.12", "2022.12", "2023.12"])
17-
@pytest.mark.parametrize("library", all_libraries + ['array_api_strict'])
17+
@pytest.mark.parametrize("library", all_libraries)
1818
def test_array_namespace(library, api_version, use_compat):
1919
xp = import_(library)
2020

2121
array = xp.asarray([1.0, 2.0, 3.0])
22-
if use_compat is True and library in {'array_api_strict', 'jax.numpy', 'sparse'}:
22+
if use_compat and library not in wrapped_libraries:
2323
pytest.raises(ValueError, lambda: array_namespace(array, use_compat=use_compat))
2424
return
25+
if library == "ndonnx" and api_version in ("2021.12", "2022.12"):
26+
pytest.skip("Unsupported API version")
27+
2528
namespace = array_api_compat.array_namespace(array, api_version=api_version, use_compat=use_compat)
2629

2730
if use_compat is False or use_compat is None and library not in wrapped_libraries:

tests/test_common.py

Lines changed: 22 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
11
from array_api_compat import ( # noqa: F401
22
is_numpy_array, is_cupy_array, is_torch_array,
33
is_dask_array, is_jax_array, is_pydata_sparse_array,
4+
is_ndonnx_array,
45
is_numpy_namespace, is_cupy_namespace, is_torch_namespace,
56
is_dask_namespace, is_jax_namespace, is_pydata_sparse_namespace,
6-
is_array_api_strict_namespace,
7+
is_array_api_strict_namespace, is_ndonnx_namespace,
78
)
89

910
from array_api_compat import device, is_array_api_obj, is_writeable_array, to_device
@@ -22,6 +23,7 @@
2223
'dask.array': 'is_dask_array',
2324
'jax.numpy': 'is_jax_array',
2425
'sparse': 'is_pydata_sparse_array',
26+
'ndonnx': 'is_ndonnx_array',
2527
}
2628

2729
is_namespace_functions = {
@@ -32,6 +34,7 @@
3234
'jax.numpy': 'is_jax_namespace',
3335
'sparse': 'is_pydata_sparse_namespace',
3436
'array_api_strict': 'is_array_api_strict_namespace',
37+
'ndonnx': 'is_ndonnx_namespace',
3538
}
3639

3740

@@ -135,26 +138,40 @@ def test_to_device_host(library):
135138
@pytest.mark.parametrize("target_library", is_array_functions.keys())
136139
@pytest.mark.parametrize("source_library", is_array_functions.keys())
137140
def test_asarray_cross_library(source_library, target_library, request):
138-
if source_library == "dask.array" and target_library == "torch":
141+
def _xfail(reason: str) -> None:
139142
# Allow rest of test to execute instead of immediately xfailing
140143
# xref https://github.com/pandas-dev/pandas/issues/38902
144+
request.node.add_marker(pytest.mark.xfail(reason=reason))
141145

146+
if source_library == "dask.array" and target_library == "torch":
142147
# TODO: remove xfail once
143148
# https://github.com/dask/dask/issues/8260 is resolved
144-
request.node.add_marker(pytest.mark.xfail(reason="Bug in dask raising error on conversion"))
145-
if source_library == "cupy" and target_library != "cupy":
149+
_xfail(reason="Bug in dask raising error on conversion")
150+
elif (
151+
source_library == "ndonnx"
152+
and target_library not in ("array_api_strict", "ndonnx", "numpy")
153+
):
154+
_xfail(reason="The truth value of lazy Array Array(dtype=Boolean) is unknown")
155+
elif source_library == "ndonnx" and target_library == "numpy":
156+
_xfail(reason="produces numpy array of ndonnx scalar arrays")
157+
elif source_library == "jax.numpy" and target_library == "torch":
158+
_xfail(reason="casts int to float")
159+
elif source_library == "cupy" and target_library != "cupy":
146160
# cupy explicitly disallows implicit conversions to CPU
147161
pytest.skip(reason="cupy does not support implicit conversion to CPU")
148162
elif source_library == "sparse" and target_library != "sparse":
149163
pytest.skip(reason="`sparse` does not allow implicit densification")
164+
150165
src_lib = import_(source_library, wrapper=True)
151166
tgt_lib = import_(target_library, wrapper=True)
152167
is_tgt_type = globals()[is_array_functions[target_library]]
153168

154-
a = src_lib.asarray([1, 2, 3])
169+
a = src_lib.asarray([1, 2, 3], dtype=src_lib.int32)
155170
b = tgt_lib.asarray(a)
156171

157172
assert is_tgt_type(b), f"Expected {b} to be a {tgt_lib.ndarray}, but was {type(b)}"
173+
assert b.dtype == tgt_lib.int32
174+
158175

159176
@pytest.mark.parametrize("library", wrapped_libraries)
160177
def test_asarray_copy(library):

0 commit comments

Comments
 (0)