Skip to content

Commit 57a38f9

Browse files
committed
TST/BUG: run all tests on all backends; fix backend-specific bugs
1 parent 1708482 commit 57a38f9

File tree

15 files changed

+458
-201
lines changed

15 files changed

+458
-201
lines changed

.pre-commit-config.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,9 +44,9 @@ repos:
4444
- repo: https://github.com/astral-sh/ruff-pre-commit
4545
rev: "v0.8.2"
4646
hooks:
47+
- id: ruff-format
4748
- id: ruff
4849
args: ["--fix", "--show-fixes"]
49-
- id: ruff-format
5050

5151
- repo: https://github.com/codespell-project/codespell
5252
rev: "v2.3.0"

docs/api-reference.md

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,3 +15,15 @@
1515
setdiff1d
1616
sinc
1717
```
18+
19+
## Test tools
20+
21+
```{eval-rst}
22+
.. currentmodule:: array_api_extra.testing
23+
.. autosummary::
24+
:nosignatures:
25+
:toctree: generated
26+
27+
xp_assert_equal
28+
xp_assert_close
29+
```

pixi.lock

Lines changed: 1 addition & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -180,7 +180,7 @@ xfail_strict = true
180180
filterwarnings = ["error"]
181181
log_cli_level = "INFO"
182182
testpaths = ["tests"]
183-
183+
markers = ["skip_xp_backend(library, *, reason=None): Skip test for a specific backend"]
184184

185185
# Coverage
186186

src/array_api_extra/_funcs.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -214,8 +214,12 @@ def create_diagonal(
214214
raise ValueError(err_msg)
215215
n = x.shape[0] + abs(offset)
216216
diag = xp.zeros(n**2, dtype=x.dtype, device=_compat.device(x))
217-
i = offset if offset >= 0 else abs(offset) * n
218-
diag[i : min(n * (n - offset), diag.shape[0]) : n + 1] = x
217+
218+
start = offset if offset >= 0 else abs(offset) * n
219+
stop = min(n * (n - offset), diag.shape[0])
220+
step = n + 1
221+
diag = at(diag)[start:stop:step].set(x)
222+
219223
return xp.reshape(diag, (n, n))
220224

221225

@@ -407,9 +411,8 @@ def kron(a: Array, b: Array, /, *, xp: ModuleType | None = None) -> Array:
407411
result = xp.multiply(a_arr, b_arr)
408412

409413
# Reshape back and return
410-
a_shape = xp.asarray(a_shape)
411-
b_shape = xp.asarray(b_shape)
412-
return xp.reshape(result, tuple(xp.multiply(a_shape, b_shape)))
414+
res_shape = tuple(a_s * b_s for a_s, b_s in zip(a_shape, b_shape, strict=True))
415+
return xp.reshape(result, res_shape)
413416

414417

415418
def setdiff1d(
@@ -632,8 +635,7 @@ def pad(
632635
dtype=x.dtype,
633636
device=_compat.device(x),
634637
)
635-
padded[tuple(slices)] = x
636-
return padded
638+
return at(padded, tuple(slices)).set(x)
637639

638640

639641
class _AtOp(Enum):

src/array_api_extra/_lib/_compat.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,20 +6,35 @@
66
from ..._array_api_compat_vendor import ( # pyright: ignore[reportMissingImports]
77
array_namespace,
88
device,
9+
is_cupy_namespace,
910
is_jax_array,
11+
is_jax_namespace,
12+
is_pydata_sparse_namespace,
13+
is_torch_namespace,
1014
is_writeable_array,
15+
size,
1116
)
1217
except ImportError:
1318
from array_api_compat import ( # pyright: ignore[reportMissingTypeStubs]
1419
array_namespace,
1520
device,
21+
is_cupy_namespace,
1622
is_jax_array,
23+
is_jax_namespace,
24+
is_pydata_sparse_namespace,
25+
is_torch_namespace,
1726
is_writeable_array,
27+
size,
1828
)
1929

2030
__all__ = [
2131
"array_namespace",
2232
"device",
33+
"is_cupy_namespace",
2334
"is_jax_array",
35+
"is_jax_namespace",
36+
"is_pydata_sparse_namespace",
37+
"is_torch_namespace",
2438
"is_writeable_array",
39+
"size",
2540
]

src/array_api_extra/_lib/_compat.pyi

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,5 +18,10 @@ def array_namespace(
1818
use_compat: bool | None = None,
1919
) -> ArrayModule: ...
2020
def device(x: Array, /) -> Device: ...
21+
def is_cupy_namespace(x: object, /) -> bool: ...
2122
def is_jax_array(x: object, /) -> bool: ...
23+
def is_jax_namespace(x: object, /) -> bool: ...
24+
def is_pydata_sparse_namespace(x: object, /) -> bool: ...
25+
def is_torch_namespace(x: object, /) -> bool: ...
2226
def is_writeable_array(x: object, /) -> bool: ...
27+
def size(x: Array, /) -> int | None: ...

src/array_api_extra/_lib/_utils.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,9 @@ def in1d(
5454
order = xp.argsort(ar, stable=True)
5555
reverse_order = xp.argsort(order, stable=True)
5656
sar = xp.take(ar, order, axis=0)
57-
if sar.size >= 1:
57+
ar_size = _compat.size(sar)
58+
assert ar_size is not None, "xp.unique*() on lazy backends raises"
59+
if ar_size >= 1:
5860
bool_ar = sar[1:] != sar[:-1] if invert else sar[1:] == sar[:-1]
5961
else:
6062
bool_ar = xp.asarray([False]) if invert else xp.asarray([True])

src/array_api_extra/testing.py

Lines changed: 130 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,130 @@
1+
"""Testing utilities."""
2+
3+
from ._lib._compat import (
4+
array_namespace,
5+
is_cupy_namespace,
6+
is_pydata_sparse_namespace,
7+
is_torch_namespace,
8+
)
9+
from ._lib._typing import Array
10+
11+
__all__ = ["xp_assert_close", "xp_assert_equal"]
12+
13+
14+
def _check_shape_dtype(actual: Array, desired: Array) -> None:
15+
"""
16+
Assert that shape and dtype of the two arrays match.
17+
18+
Parameters
19+
----------
20+
actual : Array
21+
The array produced by the tested function.
22+
desired : Array
23+
The expected array (typically hardcoded).
24+
"""
25+
msg = f"shapes do not match: {actual.shape} != f{desired.shape}"
26+
assert actual.shape == desired.shape, msg
27+
28+
msg = f"dtypes do not match: {actual.dtype} != {desired.dtype}".format(
29+
actual.dtype, desired.dtype
30+
)
31+
assert actual.dtype == desired.dtype, msg
32+
33+
34+
def xp_assert_equal(actual: Array, desired: Array, err_msg: str = "") -> None:
35+
"""
36+
Array-API compatible version of `np.testing.assert_array_equal`.
37+
38+
Parameters
39+
----------
40+
actual : Array
41+
The array produced by the tested function.
42+
desired : Array
43+
The expected array (typically hardcoded).
44+
err_msg : str, optional
45+
Error message to display on failure.
46+
"""
47+
xp = array_namespace(actual, desired)
48+
_check_shape_dtype(actual, desired)
49+
50+
if is_cupy_namespace(xp):
51+
xp.testing.assert_array_equal(actual, desired, err_msg=err_msg)
52+
elif is_torch_namespace(xp):
53+
# PyTorch recommends using `rtol=0, atol=0` like this
54+
# to test for exact equality
55+
xp.testing.assert_close(
56+
actual,
57+
desired,
58+
rtol=0,
59+
atol=0,
60+
equal_nan=True,
61+
check_dtype=False,
62+
msg=err_msg or None,
63+
)
64+
else:
65+
import numpy as np # pylint: disable=import-outside-toplevel
66+
67+
if is_pydata_sparse_namespace(xp):
68+
actual = actual.todense()
69+
desired = desired.todense()
70+
71+
# JAX uses `np.testing`
72+
np.testing.assert_array_equal(actual, desired, err_msg=err_msg)
73+
74+
75+
def xp_assert_close(
76+
actual: Array,
77+
desired: Array,
78+
*,
79+
rtol: float | None = None,
80+
atol: float = 0,
81+
err_msg: str = "",
82+
) -> None:
83+
"""
84+
Array-API compatible version of `np.testing.assert_allclose`.
85+
86+
Parameters
87+
----------
88+
actual : Array
89+
The array produced by the tested function.
90+
desired : Array
91+
The expected array (typically hardcoded).
92+
rtol : float, optional
93+
Relative tolerance. Default: dtype-dependent.
94+
atol : float, optional
95+
Absolute tolerance. Default: 0.
96+
err_msg : str, optional
97+
Error message to display on failure.
98+
"""
99+
xp = array_namespace(actual, desired)
100+
_check_shape_dtype(actual, desired)
101+
102+
floating = xp.isdtype(actual.dtype, ("real floating", "complex floating"))
103+
if rtol is None and floating:
104+
# multiplier of 4 is used as for `np.float64` this puts the default `rtol`
105+
# roughly half way between sqrt(eps) and the default for
106+
# `numpy.testing.assert_allclose`, 1e-7
107+
rtol = xp.finfo(actual.dtype).eps ** 0.5 * 4
108+
elif rtol is None:
109+
rtol = 1e-7
110+
111+
if is_cupy_namespace(xp):
112+
xp.testing.assert_allclose(
113+
actual, desired, rtol=rtol, atol=atol, err_msg=err_msg
114+
)
115+
elif is_torch_namespace(xp):
116+
xp.testing.assert_close(
117+
actual, desired, rtol=rtol, atol=atol, equal_nan=True, msg=err_msg or None
118+
)
119+
else:
120+
import numpy as np # pylint: disable=import-outside-toplevel
121+
122+
if is_pydata_sparse_namespace(xp):
123+
actual = actual.to_dense()
124+
desired = desired.to_dense()
125+
126+
# JAX uses `np.testing`
127+
assert isinstance(rtol, float)
128+
np.testing.assert_allclose(
129+
actual, desired, rtol=rtol, atol=atol, err_msg=err_msg
130+
)

tests/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
"""Needed to import .conftest from the test modules."""

tests/conftest.py

Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
"""Pytest fixtures."""
2+
3+
from enum import Enum
4+
from typing import cast
5+
6+
import pytest
7+
8+
from array_api_extra._lib._compat import array_namespace
9+
from array_api_extra._lib._compat import device as get_device
10+
from array_api_extra._lib._typing import Device, ModuleType
11+
12+
13+
class Library(Enum):
14+
"""All array libraries explicitly tested by array-api-extra."""
15+
16+
ARRAY_API_STRICT = "array_api_strict"
17+
NUMPY = "numpy"
18+
NUMPY_READONLY = "numpy_readonly"
19+
CUPY = "cupy"
20+
TORCH = "torch"
21+
DASK_ARRAY = "dask.array"
22+
SPARSE = "sparse"
23+
JAX_NUMPY = "jax.numpy"
24+
25+
def __str__(self) -> str: # type: ignore[explicit-override] # pyright: ignore[reportImplicitOverride] # numpydoc ignore=RT01
26+
"""Pretty-print parameterized test names."""
27+
return self.value
28+
29+
30+
@pytest.fixture(params=tuple(Library))
31+
def library(request: pytest.FixtureRequest) -> Library: # numpydoc ignore=PR01,RT03
32+
"""
33+
Parameterized fixture that iterates on all libraries.
34+
35+
Returns
36+
-------
37+
The current Library enum.
38+
"""
39+
elem = cast(Library, request.param)
40+
41+
for marker in request.node.iter_markers("skip_xp_backend"):
42+
skip_library = marker.kwargs.get("library") or marker.args[0] # type: ignore[no-untyped-usage]
43+
if not isinstance(skip_library, Library):
44+
msg = "argument of skip_xp_backend must be a Library enum"
45+
raise TypeError(msg)
46+
if skip_library == elem:
47+
reason = cast(str, marker.kwargs.get("reason", "skip_xp_backend"))
48+
pytest.skip(reason=reason)
49+
50+
return elem
51+
52+
53+
@pytest.fixture
54+
def xp(library: Library) -> ModuleType: # numpydoc ignore=PR01,RT03
55+
"""
56+
Parameterized fixture that iterates on all libraries.
57+
58+
Returns
59+
-------
60+
The current array namespace.
61+
"""
62+
name = "numpy" if library == Library.NUMPY_READONLY else library.value
63+
xp = pytest.importorskip(name)
64+
if library == Library.JAX_NUMPY:
65+
import jax # type: ignore[import-not-found] # pyright: ignore[reportMissingImports]
66+
67+
jax.config.update("jax_enable_x64", True)
68+
69+
# Possibly wrap module with array_api_compat
70+
return array_namespace(xp.empty(0))
71+
72+
73+
@pytest.fixture
74+
def device(
75+
library: Library, xp: ModuleType
76+
) -> Device: # numpydoc ignore=PR01,RT01,RT03
77+
"""
78+
Return a valid device for the backend.
79+
80+
Where possible, return a device that is not the default one.
81+
"""
82+
if library == Library.ARRAY_API_STRICT:
83+
d = xp.Device("device1")
84+
assert get_device(xp.empty(0)) != d
85+
return d
86+
return get_device(xp.empty(0))

0 commit comments

Comments
 (0)