Skip to content

Commit f9caeaf

Browse files
committed
TST: Free-threaded tests
1 parent a127376 commit f9caeaf

File tree

12 files changed

+827
-90
lines changed

12 files changed

+827
-90
lines changed

.github/workflows/ci.yml

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -45,13 +45,18 @@ jobs:
4545
pixi run -e lint pyright
4646
4747
checks:
48-
name: Check ${{ matrix.environment }}
48+
name: Test ${{ matrix.environment }}
4949
runs-on: ${{ matrix.runs-on }}
5050
needs: [pre-commit-and-lint]
5151
strategy:
5252
fail-fast: false
5353
matrix:
54-
environment: [tests-py310, tests-py313, tests-numpy1, tests-backends]
54+
environment:
55+
- tests-py310
56+
- tests-py313
57+
- tests-numpy1
58+
- tests-backends
59+
- tests-nogil
5560
runs-on: [ubuntu-latest]
5661

5762
steps:
@@ -66,9 +71,16 @@ jobs:
6671
environments: ${{ matrix.environment }}
6772

6873
- name: Test package
74+
# Save some time; also at the moment of writing coverage crashes on python 3.13t
75+
if: ${{ matrix.environment != 'tests-nogil' }}
6976
run: pixi run -e ${{ matrix.environment }} tests-ci
7077

78+
- name: Test free-threading
79+
if: ${{ matrix.environment == 'tests-nogil' }}
80+
run: pixi run -e tests-nogil tests --parallel-threads=4
81+
7182
- name: Upload coverage report
83+
if: ${{ matrix.environment != 'tests-nogil' }}
7284
uses: codecov/codecov-action@ad3126e916f78f00edff4ed0317cf185271ccc2d # v5.4.2
7385
with:
7486
token: ${{ secrets.CODECOV_TOKEN }}

pixi.lock

Lines changed: 663 additions & 39 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pyproject.toml

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,6 @@ channels = ["https://prefix.dev/conda-forge"]
4747
platforms = ["linux-64", "osx-64", "osx-arm64", "win-64"]
4848

4949
[tool.pixi.dependencies]
50-
python = ">=3.10,<3.14"
5150
array-api-compat = ">=1.12.0,<2"
5251

5352
[tool.pixi.pypi-dependencies]
@@ -179,6 +178,15 @@ cupy = ">=13.4.1"
179178
# jaxlib = { version = "*", build = "cuda12*" } # unavailable
180179
pytorch = { version = ">=2.7.0", build = "cuda12*" }
181180

181+
[tool.pixi.feature.nogil.dependencies]
182+
python-freethreading = "~=3.13.0"
183+
pytest-run-parallel = ">=0.4.3"
184+
numpy = ">=2.3.0"
185+
# pytorch = "*" # Not available on Python 3.13t yet
186+
dask-core = ">=2025.5.1" # No distributed, tornado, etc.
187+
# sparse = "*" # numba not available on Python 3.13t yet
188+
# jax = "*" # ml_dtypes not available on Python 3.13t yet
189+
182190
[tool.pixi.environments]
183191
default = { features = ["py313"], solve-group = "py313" }
184192
lint = { features = ["py313", "lint"], solve-group = "py313" }
@@ -197,7 +205,7 @@ tests-cuda = { features = ["py310", "tests", "backends", "cuda-backends"], solve
197205
# Ungrouped environments
198206
tests-numpy1 = ["py310", "tests", "numpy1"]
199207
tests-py310 = ["py310", "tests"]
200-
208+
tests-nogil = ["nogil", "tests"]
201209

202210
# pytest
203211

src/array_api_extra/_lib/_backends.py

Lines changed: 33 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,14 @@
33
from __future__ import annotations
44

55
from enum import Enum
6+
from typing import Any
67

7-
__all__ = ["Backend"]
8+
import numpy as np
9+
import pytest
10+
11+
__all__ = ["NUMPY_VERSION", "Backend"]
12+
13+
NUMPY_VERSION = tuple(int(v) for v in np.__version__.split(".")[:3]) # pyright: ignore[reportUnknownArgumentType]
814

915

1016
class Backend(Enum): # numpydoc ignore=PR02
@@ -30,12 +36,6 @@ class Backend(Enum): # numpydoc ignore=PR02
3036
JAX = "jax.numpy"
3137
JAX_GPU = "jax.numpy:gpu"
3238

33-
def __str__(self) -> str: # type: ignore[explicit-override] # pyright: ignore[reportImplicitOverride] # numpydoc ignore=RT01
34-
"""Pretty-print parameterized test names."""
35-
return (
36-
self.name.lower().replace("_gpu", ":gpu").replace("_readonly", ":readonly")
37-
)
38-
3939
@property
4040
def modname(self) -> str: # numpydoc ignore=RT01
4141
"""Module name to be imported."""
@@ -44,3 +44,29 @@ def modname(self) -> str: # numpydoc ignore=RT01
4444
def like(self, *others: Backend) -> bool: # numpydoc ignore=PR01,RT01
4545
"""Check if this backend uses the same module as others."""
4646
return any(self.modname == other.modname for other in others)
47+
48+
def pytest_param(self) -> Any: # type: ignore[explicit-any]
49+
"""
50+
Backend as a pytest parameter
51+
52+
Returns
53+
-------
54+
pytest.mark.ParameterSet
55+
"""
56+
id_ = (
57+
self.name.lower().replace("_gpu", ":gpu").replace("_readonly", ":readonly")
58+
)
59+
60+
marks = []
61+
if self.like(Backend.ARRAY_API_STRICT):
62+
marks.append(
63+
pytest.mark.skipif(
64+
NUMPY_VERSION < (1, 26),
65+
reason="array_api_strict is untested on NumPy <1.26",
66+
)
67+
)
68+
if self.like(Backend.DASK, Backend.JAX):
69+
# Monkey-patched by lazy_xp_function
70+
marks.append(pytest.mark.thread_unsafe)
71+
72+
return pytest.param(self, id=id_, marks=marks) # pyright: ignore[reportUnknownArgumentType]

src/array_api_extra/_lib/_lazy.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -343,7 +343,7 @@ def wrapper( # type: ignore[decorated-any,explicit-any]
343343
if as_numpy:
344344
import numpy as np
345345

346-
arg = cast(Array, np.asarray(arg)) # type: ignore[bad-cast] # noqa: PLW2901
346+
arg = cast(Array, np.asarray(arg)) # type: ignore[bad-cast] # pyright: ignore[reportInvalidCast] # noqa: PLW2901
347347
args_list.append(arg)
348348
assert device is not None
349349

src/array_api_extra/testing.py

Lines changed: 63 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
import contextlib
1010
import enum
1111
import warnings
12-
from collections.abc import Callable, Iterator, Sequence
12+
from collections.abc import Callable, Generator, Iterator, Sequence
1313
from functools import wraps
1414
from types import ModuleType
1515
from typing import TYPE_CHECKING, Any, ParamSpec, TypeVar, cast
@@ -216,8 +216,11 @@ def test_myfunc(xp):
216216

217217

218218
def patch_lazy_xp_functions(
219-
request: pytest.FixtureRequest, monkeypatch: pytest.MonkeyPatch, *, xp: ModuleType
220-
) -> None:
219+
request: pytest.FixtureRequest,
220+
monkeypatch: pytest.MonkeyPatch | None = None,
221+
*,
222+
xp: ModuleType,
223+
) -> contextlib.AbstractContextManager[None]:
221224
"""
222225
Test lazy execution of functions tagged with :func:`lazy_xp_function`.
223226
@@ -233,10 +236,15 @@ def patch_lazy_xp_functions(
233236
This function should be typically called by your library's `xp` fixture that runs
234237
tests on multiple backends::
235238
236-
@pytest.fixture(params=[numpy, array_api_strict, jax.numpy, dask.array])
237-
def xp(request, monkeypatch):
238-
patch_lazy_xp_functions(request, monkeypatch, xp=request.param)
239-
return request.param
239+
@pytest.fixture(params=[
240+
numpy,
241+
array_api_strict,
242+
pytest.param(jax.numpy, marks=pytest.mark.thread_unsafe),
243+
pytest.param(dask.array, marks=pytest.mark.thread_unsafe),
244+
])
245+
def xp(request):
246+
with patch_lazy_xp_functions(request, xp=request.param):
247+
yield request.param
240248
241249
but it can be otherwise be called by the test itself too.
242250
@@ -245,18 +253,50 @@ def xp(request, monkeypatch):
245253
request : pytest.FixtureRequest
246254
Pytest fixture, as acquired by the test itself or by one of its fixtures.
247255
monkeypatch : pytest.MonkeyPatch
248-
Pytest fixture, as acquired by the test itself or by one of its fixtures.
256+
Deprecated
249257
xp : array_namespace
250258
Array namespace to be tested.
251259
252260
See Also
253261
--------
254262
lazy_xp_function : Tag a function to be tested on lazy backends.
255263
pytest.FixtureRequest : `request` test function parameter.
264+
265+
Notes
266+
-----
267+
This context manager monkey-patches modules and as such is thread unsafe
268+
on Dask and JAX. If you run your test suite with
269+
`pytest-run-parallel <https://github.com/Quansight-Labs/pytest-run-parallel/>`_,
270+
you should mark these backends with ``@pytest.mark.thread_unsafe``, as shown in
271+
the example above.
256272
"""
257273
mod = cast(ModuleType, request.module)
258274
mods = [mod, *cast(list[ModuleType], getattr(mod, "lazy_xp_modules", []))]
259275

276+
to_revert: list[tuple[ModuleType, str, object]] = []
277+
278+
def temp_setattr(mod: ModuleType, name: str, func: object) -> None:
279+
"""
280+
Variant of monkeypatch.setattr, which allows monkey-patching only selected
281+
parameters of a test so that pytest-run-parallel can run on the remainder.
282+
"""
283+
assert hasattr(mod, name)
284+
to_revert.append((mod, name, getattr(mod, name)))
285+
setattr(mod, name, func)
286+
287+
if monkeypatch is not None:
288+
warnings.warn(
289+
(
290+
"The `monkeypatch` parameter is deprecated and will be removed in a "
291+
"future version. "
292+
"Use `patch_lazy_xp_function` as a context manager instead."
293+
),
294+
DeprecationWarning,
295+
stacklevel=2,
296+
)
297+
# Enable using patch_lazy_xp_function not as a context manager
298+
temp_setattr = monkeypatch.setattr # type: ignore[assignment] # pyright: ignore[reportAssignmentType]
299+
260300
def iter_tagged() -> ( # type: ignore[explicit-any]
261301
Iterator[tuple[ModuleType, str, Callable[..., Any], dict[str, Any]]]
262302
):
@@ -279,13 +319,26 @@ def iter_tagged() -> ( # type: ignore[explicit-any]
279319
elif n is False:
280320
n = 0
281321
wrapped = _dask_wrap(func, n)
282-
monkeypatch.setattr(mod, name, wrapped)
322+
temp_setattr(mod, name, wrapped)
283323

284324
elif is_jax_namespace(xp):
285325
for mod, name, func, tags in iter_tagged():
286326
if tags["jax_jit"]:
287327
wrapped = jax_autojit(func)
288-
monkeypatch.setattr(mod, name, wrapped)
328+
temp_setattr(mod, name, wrapped)
329+
330+
# We can't just decorate patch_lazy_xp_functions with
331+
# @contextlib.contextmanager because it would not work with the
332+
# deprecated monkeypatch when not used as a context manager.
333+
@contextlib.contextmanager
334+
def revert_on_exit() -> Generator[None]:
335+
try:
336+
yield
337+
finally:
338+
for mod, name, orig_func in to_revert:
339+
setattr(mod, name, orig_func)
340+
341+
return revert_on_exit()
289342

290343

291344
class CountingDaskScheduler(SchedulerGetCallable):

tests/conftest.py

Lines changed: 17 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -18,11 +18,10 @@
1818
T = TypeVar("T")
1919
P = ParamSpec("P")
2020

21-
NUMPY_VERSION = tuple(int(v) for v in np.__version__.split(".")[2])
2221
np_compat = array_namespace(np.empty(0)) # type: ignore[arg-type] # pyright: ignore[reportArgumentType]
2322

2423

25-
@pytest.fixture(params=tuple(Backend))
24+
@pytest.fixture(params=[b.pytest_param() for b in Backend])
2625
def library(request: pytest.FixtureRequest) -> Backend: # numpydoc ignore=PR01,RT03
2726
"""
2827
Parameterized fixture that iterates on all libraries.
@@ -112,7 +111,7 @@ def wrapper(*args: P.args, **kwargs: P.kwargs) -> T: # numpydoc ignore=GL08
112111

113112
@pytest.fixture
114113
def xp(
115-
library: Backend, request: pytest.FixtureRequest, monkeypatch: pytest.MonkeyPatch
114+
library: Backend, request: pytest.FixtureRequest
116115
) -> Generator[ModuleType]: # numpydoc ignore=PR01,RT03
117116
"""
118117
Parameterized fixture that iterates on all libraries.
@@ -125,9 +124,6 @@ def xp(
125124
yield NumPyReadOnly() # type: ignore[misc] # pyright: ignore[reportReturnType]
126125
return
127126

128-
if library.like(Backend.ARRAY_API_STRICT) and NUMPY_VERSION < (1, 26):
129-
pytest.skip("array_api_strict is untested on NumPy <1.26")
130-
131127
xp = pytest.importorskip(library.modname)
132128
# Possibly wrap module with array_api_compat
133129
xp = array_namespace(xp.empty(0))
@@ -143,16 +139,15 @@ def xp(
143139
yield xp
144140
return
145141

146-
# On Dask and JAX, monkey-patch all functions tagged by `lazy_xp_function`
147-
# in the global scope of the module containing the test function.
148-
patch_lazy_xp_functions(request, monkeypatch, xp=xp)
149-
150142
if library.like(Backend.JAX):
151143
_setup_jax(library)
152144
elif library.like(Backend.TORCH):
153145
_setup_torch(library)
154146

155-
yield xp
147+
# On Dask and JAX, monkey-patch all functions tagged by `lazy_xp_function`
148+
# in the global scope of the module containing the test function.
149+
with patch_lazy_xp_functions(request, xp=xp):
150+
yield xp
156151

157152

158153
def _setup_jax(library: Backend) -> None:
@@ -189,26 +184,27 @@ def _setup_torch(library: Backend) -> None:
189184
torch.set_default_device("cpu")
190185

191186

192-
@pytest.fixture(params=[Backend.DASK]) # Can select the test with `pytest -k dask`
187+
# Can select the test with `pytest -k dask`
188+
@pytest.fixture(params=[Backend.DASK.pytest_param()])
193189
def da(
194-
request: pytest.FixtureRequest, monkeypatch: pytest.MonkeyPatch
195-
) -> ModuleType: # numpydoc ignore=PR01,RT01
190+
request: pytest.FixtureRequest,
191+
) -> Generator[ModuleType]: # numpydoc ignore=PR01,RT01
196192
"""Variant of the `xp` fixture that only yields dask.array."""
197193
xp = pytest.importorskip("dask.array")
198194
xp = array_namespace(xp.empty(0))
199-
patch_lazy_xp_functions(request, monkeypatch, xp=xp)
200-
return xp
195+
with patch_lazy_xp_functions(request, xp=xp):
196+
yield xp
201197

202198

203-
@pytest.fixture(params=[Backend.JAX, Backend.JAX_GPU])
199+
@pytest.fixture(params=[Backend.JAX.pytest_param(), Backend.JAX_GPU.pytest_param()])
204200
def jnp(
205-
request: pytest.FixtureRequest, monkeypatch: pytest.MonkeyPatch
206-
) -> ModuleType: # numpydoc ignore=PR01,RT01
201+
request: pytest.FixtureRequest,
202+
) -> Generator[ModuleType]: # numpydoc ignore=PR01,RT01
207203
"""Variant of the `xp` fixture that only yields jax.numpy."""
208204
xp = pytest.importorskip("jax.numpy")
209205
_setup_jax(request.param)
210-
patch_lazy_xp_functions(request, monkeypatch, xp=xp)
211-
return xp
206+
with patch_lazy_xp_functions(request, xp=xp):
207+
yield xp
212208

213209

214210
@pytest.fixture(params=[Backend.TORCH, Backend.TORCH_GPU])

tests/test_at.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -178,7 +178,7 @@ def test_copy_invalid():
178178

179179

180180
def test_xp():
181-
a = cast(Array, np.asarray([1, 2, 3])) # type: ignore[bad-cast]
181+
a = cast(Array, np.asarray([1, 2, 3])) # type: ignore[bad-cast] # pyright: ignore[reportInvalidCast]
182182
_ = at(a, 0).set(4, xp=np)
183183
_ = at(a, 0).add(4, xp=np)
184184
_ = at(a, 0).subtract(4, xp=np)

tests/test_funcs.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -27,15 +27,13 @@
2727
setdiff1d,
2828
sinc,
2929
)
30-
from array_api_extra._lib._backends import Backend
30+
from array_api_extra._lib._backends import NUMPY_VERSION, Backend
3131
from array_api_extra._lib._testing import xp_assert_close, xp_assert_equal
3232
from array_api_extra._lib._utils._compat import device as get_device
3333
from array_api_extra._lib._utils._helpers import eager_shape, ndindex
3434
from array_api_extra._lib._utils._typing import Array, Device
3535
from array_api_extra.testing import lazy_xp_function
3636

37-
from .conftest import NUMPY_VERSION
38-
3937
# some xp backends are untyped
4038
# mypy: disable-error-code=no-untyped-def
4139

tests/test_helpers.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424

2525
from .conftest import np_compat
2626

27-
if TYPE_CHECKING:
27+
if TYPE_CHECKING: # pragma: no cover
2828
# TODO import from typing (requires Python >=3.12)
2929
from typing_extensions import override
3030
else:

0 commit comments

Comments
 (0)