Skip to content

ENH: Free-threading support #330

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 14 additions & 2 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -45,13 +45,18 @@ jobs:
pixi run -e lint pyright

checks:
name: Check ${{ matrix.environment }}
name: Test ${{ matrix.environment }}
runs-on: ${{ matrix.runs-on }}
needs: [pre-commit-and-lint]
strategy:
fail-fast: false
matrix:
environment: [tests-py310, tests-py313, tests-numpy1, tests-backends]
environment:
- tests-py310
- tests-py313
- tests-numpy1
- tests-backends
- tests-nogil
runs-on: [ubuntu-latest]

steps:
Expand All @@ -66,9 +71,16 @@ jobs:
environments: ${{ matrix.environment }}

- name: Test package
# Save some time; also at the moment of writing coverage crashes on python 3.13t
if: ${{ matrix.environment != 'tests-nogil' }}
run: pixi run -e ${{ matrix.environment }} tests-ci

- name: Test free-threading
if: ${{ matrix.environment == 'tests-nogil' }}
run: pixi run -e tests-nogil tests --parallel-threads=4

- name: Upload coverage report
if: ${{ matrix.environment != 'tests-nogil' }}
uses: codecov/codecov-action@ad3126e916f78f00edff4ed0317cf185271ccc2d # v5.4.2
with:
token: ${{ secrets.CODECOV_TOKEN }}
702 changes: 663 additions & 39 deletions pixi.lock

Large diffs are not rendered by default.

13 changes: 11 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ classifiers = [
"Programming Language :: Python :: 3.11",
"Programming Language :: Python :: 3.12",
"Programming Language :: Python :: 3.13",
"Programming Language :: Python :: Free Threading :: 3 - Stable",
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The whole library has 1 stateful class, at, which has trivial state. So no point adding tests beyond those of pytest-run-parallel.
Where we are thread-unsafe, in lazy_xp_function, it is clearly documented, so I believe this should qualify for the Stable badge?

"Typing :: Typed",
]
dynamic = ["version"]
Expand All @@ -47,7 +48,6 @@ channels = ["https://prefix.dev/conda-forge"]
platforms = ["linux-64", "osx-64", "osx-arm64", "win-64"]

[tool.pixi.dependencies]
python = ">=3.10,<3.14"
array-api-compat = ">=1.12.0,<2"

[tool.pixi.pypi-dependencies]
Expand Down Expand Up @@ -179,6 +179,15 @@ cupy = ">=13.4.1"
# jaxlib = { version = "*", build = "cuda12*" } # unavailable
pytorch = { version = ">=2.7.0", build = "cuda12*" }

[tool.pixi.feature.nogil.dependencies]
python-freethreading = "~=3.13.0"
pytest-run-parallel = ">=0.4.3"
numpy = ">=2.3.0"
# pytorch = "*" # Not available on Python 3.13t yet
dask-core = ">=2025.5.1" # No distributed, tornado, etc.
# sparse = "*" # numba not available on Python 3.13t yet
# jax = "*" # ml_dtypes not available on Python 3.13t yet

[tool.pixi.environments]
default = { features = ["py313"], solve-group = "py313" }
lint = { features = ["py313", "lint"], solve-group = "py313" }
Expand All @@ -197,7 +206,7 @@ tests-cuda = { features = ["py310", "tests", "backends", "cuda-backends"], solve
# Ungrouped environments
tests-numpy1 = ["py310", "tests", "numpy1"]
tests-py310 = ["py310", "tests"]

tests-nogil = ["nogil", "tests"]

# pytest

Expand Down
40 changes: 33 additions & 7 deletions src/array_api_extra/_lib/_backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,14 @@
from __future__ import annotations

from enum import Enum
from typing import Any

__all__ = ["Backend"]
import numpy as np
import pytest

__all__ = ["NUMPY_VERSION", "Backend"]

NUMPY_VERSION = tuple(int(v) for v in np.__version__.split(".")[:3]) # pyright: ignore[reportUnknownArgumentType]


class Backend(Enum): # numpydoc ignore=PR02
Expand All @@ -30,12 +36,6 @@ class Backend(Enum): # numpydoc ignore=PR02
JAX = "jax.numpy"
JAX_GPU = "jax.numpy:gpu"

def __str__(self) -> str: # pyright: ignore[reportImplicitOverride] # numpydoc ignore=RT01
"""Pretty-print parameterized test names."""
return (
self.name.lower().replace("_gpu", ":gpu").replace("_readonly", ":readonly")
)

@property
def modname(self) -> str: # numpydoc ignore=RT01
"""Module name to be imported."""
Expand All @@ -44,3 +44,29 @@ def modname(self) -> str: # numpydoc ignore=RT01
def like(self, *others: Backend) -> bool: # numpydoc ignore=PR01,RT01
"""Check if this backend uses the same module as others."""
return any(self.modname == other.modname for other in others)

def pytest_param(self) -> Any:
"""
Backend as a pytest parameter

Returns
-------
pytest.mark.ParameterSet
"""
id_ = (
self.name.lower().replace("_gpu", ":gpu").replace("_readonly", ":readonly")
)

marks = []
if self.like(Backend.ARRAY_API_STRICT):
marks.append(
pytest.mark.skipif(
NUMPY_VERSION < (1, 26),
reason="array_api_strict is untested on NumPy <1.26",
)
)
if self.like(Backend.DASK, Backend.JAX):
# Monkey-patched by lazy_xp_function
marks.append(pytest.mark.thread_unsafe)

return pytest.param(self, id=id_, marks=marks) # pyright: ignore[reportUnknownArgumentType]
2 changes: 1 addition & 1 deletion src/array_api_extra/_lib/_lazy.py
Original file line number Diff line number Diff line change
Expand Up @@ -343,7 +343,7 @@ def wrapper(
if as_numpy:
import numpy as np

arg = cast(Array, np.asarray(arg)) # noqa: PLW2901
arg = cast(Array, np.asarray(arg)) # pyright: ignore[reportInvalidCast] # noqa: PLW2901
args_list.append(arg)
assert device is not None

Expand Down
73 changes: 63 additions & 10 deletions src/array_api_extra/testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import contextlib
import enum
import warnings
from collections.abc import Callable, Iterator, Sequence
from collections.abc import Callable, Generator, Iterator, Sequence
from functools import wraps
from types import ModuleType
from typing import TYPE_CHECKING, Any, ParamSpec, TypeVar, cast
Expand Down Expand Up @@ -216,8 +216,11 @@ def test_myfunc(xp):


def patch_lazy_xp_functions(
request: pytest.FixtureRequest, monkeypatch: pytest.MonkeyPatch, *, xp: ModuleType
) -> None:
request: pytest.FixtureRequest,
monkeypatch: pytest.MonkeyPatch | None = None,
*,
xp: ModuleType,
) -> contextlib.AbstractContextManager[None]:
"""
Test lazy execution of functions tagged with :func:`lazy_xp_function`.

Expand All @@ -233,10 +236,15 @@ def patch_lazy_xp_functions(
This function should be typically called by your library's `xp` fixture that runs
tests on multiple backends::

@pytest.fixture(params=[numpy, array_api_strict, jax.numpy, dask.array])
def xp(request, monkeypatch):
patch_lazy_xp_functions(request, monkeypatch, xp=request.param)
return request.param
@pytest.fixture(params=[
numpy,
array_api_strict,
pytest.param(jax.numpy, marks=pytest.mark.thread_unsafe),
pytest.param(dask.array, marks=pytest.mark.thread_unsafe),
])
def xp(request):
with patch_lazy_xp_functions(request, xp=request.param):
yield request.param

but it can be otherwise be called by the test itself too.

Expand All @@ -245,18 +253,50 @@ def xp(request, monkeypatch):
request : pytest.FixtureRequest
Pytest fixture, as acquired by the test itself or by one of its fixtures.
monkeypatch : pytest.MonkeyPatch
Pytest fixture, as acquired by the test itself or by one of its fixtures.
Deprecated
xp : array_namespace
Array namespace to be tested.

See Also
--------
lazy_xp_function : Tag a function to be tested on lazy backends.
pytest.FixtureRequest : `request` test function parameter.

Notes
-----
This context manager monkey-patches modules and as such is thread unsafe
on Dask and JAX. If you run your test suite with
`pytest-run-parallel <https://github.com/Quansight-Labs/pytest-run-parallel/>`_,
you should mark these backends with ``@pytest.mark.thread_unsafe``, as shown in
the example above.
"""
mod = cast(ModuleType, request.module)
mods = [mod, *cast(list[ModuleType], getattr(mod, "lazy_xp_modules", []))]

to_revert: list[tuple[ModuleType, str, object]] = []

def temp_setattr(mod: ModuleType, name: str, func: object) -> None:
"""
Variant of monkeypatch.setattr, which allows monkey-patching only selected
parameters of a test so that pytest-run-parallel can run on the remainder.
"""
assert hasattr(mod, name)
to_revert.append((mod, name, getattr(mod, name)))
setattr(mod, name, func)

if monkeypatch is not None:
warnings.warn(
(
"The `monkeypatch` parameter is deprecated and will be removed in a "
"future version. "
"Use `patch_lazy_xp_function` as a context manager instead."
),
DeprecationWarning,
stacklevel=2,
)
# Enable using patch_lazy_xp_function not as a context manager
temp_setattr = monkeypatch.setattr # type: ignore[assignment] # pyright: ignore[reportAssignmentType]

def iter_tagged() -> (
Iterator[tuple[ModuleType, str, Callable[..., Any], dict[str, Any]]]
):
Expand All @@ -279,13 +319,26 @@ def iter_tagged() -> (
elif n is False:
n = 0
wrapped = _dask_wrap(func, n)
monkeypatch.setattr(mod, name, wrapped)
temp_setattr(mod, name, wrapped)

elif is_jax_namespace(xp):
for mod, name, func, tags in iter_tagged():
if tags["jax_jit"]:
wrapped = jax_autojit(func)
monkeypatch.setattr(mod, name, wrapped)
temp_setattr(mod, name, wrapped)

# We can't just decorate patch_lazy_xp_functions with
# @contextlib.contextmanager because it would not work with the
# deprecated monkeypatch when not used as a context manager.
@contextlib.contextmanager
def revert_on_exit() -> Generator[None]:
try:
yield
finally:
for mod, name, orig_func in to_revert:
setattr(mod, name, orig_func)

return revert_on_exit()


class CountingDaskScheduler(SchedulerGetCallable):
Expand Down
38 changes: 17 additions & 21 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,10 @@
T = TypeVar("T")
P = ParamSpec("P")

NUMPY_VERSION = tuple(int(v) for v in np.__version__.split(".")[2])
np_compat = array_namespace(np.empty(0)) # type: ignore[arg-type] # pyright: ignore[reportArgumentType]


@pytest.fixture(params=tuple(Backend))
@pytest.fixture(params=[b.pytest_param() for b in Backend])
def library(request: pytest.FixtureRequest) -> Backend: # numpydoc ignore=PR01,RT03
"""
Parameterized fixture that iterates on all libraries.
Expand Down Expand Up @@ -112,7 +111,7 @@ def wrapper(*args: P.args, **kwargs: P.kwargs) -> T: # numpydoc ignore=GL08

@pytest.fixture
def xp(
library: Backend, request: pytest.FixtureRequest, monkeypatch: pytest.MonkeyPatch
library: Backend, request: pytest.FixtureRequest
) -> Generator[ModuleType]: # numpydoc ignore=PR01,RT03
"""
Parameterized fixture that iterates on all libraries.
Expand All @@ -125,9 +124,6 @@ def xp(
yield NumPyReadOnly() # type: ignore[misc] # pyright: ignore[reportReturnType]
return

if library.like(Backend.ARRAY_API_STRICT) and NUMPY_VERSION < (1, 26):
pytest.skip("array_api_strict is untested on NumPy <1.26")

xp = pytest.importorskip(library.modname)
# Possibly wrap module with array_api_compat
xp = array_namespace(xp.empty(0))
Expand All @@ -143,16 +139,15 @@ def xp(
yield xp
return

# On Dask and JAX, monkey-patch all functions tagged by `lazy_xp_function`
# in the global scope of the module containing the test function.
patch_lazy_xp_functions(request, monkeypatch, xp=xp)

if library.like(Backend.JAX):
_setup_jax(library)
elif library.like(Backend.TORCH):
_setup_torch(library)

yield xp
# On Dask and JAX, monkey-patch all functions tagged by `lazy_xp_function`
# in the global scope of the module containing the test function.
with patch_lazy_xp_functions(request, xp=xp):
yield xp


def _setup_jax(library: Backend) -> None:
Expand Down Expand Up @@ -189,26 +184,27 @@ def _setup_torch(library: Backend) -> None:
torch.set_default_device("cpu")


@pytest.fixture(params=[Backend.DASK]) # Can select the test with `pytest -k dask`
# Can select the test with `pytest -k dask`
@pytest.fixture(params=[Backend.DASK.pytest_param()])
def da(
request: pytest.FixtureRequest, monkeypatch: pytest.MonkeyPatch
) -> ModuleType: # numpydoc ignore=PR01,RT01
request: pytest.FixtureRequest,
) -> Generator[ModuleType]: # numpydoc ignore=PR01,RT01
"""Variant of the `xp` fixture that only yields dask.array."""
xp = pytest.importorskip("dask.array")
xp = array_namespace(xp.empty(0))
patch_lazy_xp_functions(request, monkeypatch, xp=xp)
return xp
with patch_lazy_xp_functions(request, xp=xp):
yield xp


@pytest.fixture(params=[Backend.JAX, Backend.JAX_GPU])
@pytest.fixture(params=[Backend.JAX.pytest_param(), Backend.JAX_GPU.pytest_param()])
def jnp(
request: pytest.FixtureRequest, monkeypatch: pytest.MonkeyPatch
) -> ModuleType: # numpydoc ignore=PR01,RT01
request: pytest.FixtureRequest,
) -> Generator[ModuleType]: # numpydoc ignore=PR01,RT01
"""Variant of the `xp` fixture that only yields jax.numpy."""
xp = pytest.importorskip("jax.numpy")
_setup_jax(request.param)
patch_lazy_xp_functions(request, monkeypatch, xp=xp)
return xp
with patch_lazy_xp_functions(request, xp=xp):
yield xp


@pytest.fixture(params=[Backend.TORCH, Backend.TORCH_GPU])
Expand Down
2 changes: 1 addition & 1 deletion tests/test_at.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,7 @@ def test_copy_invalid():


def test_xp():
a = cast(Array, np.asarray([1, 2, 3]))
a = cast(Array, np.asarray([1, 2, 3])) # pyright: ignore[reportInvalidCast]
_ = at(a, 0).set(4, xp=np)
_ = at(a, 0).add(4, xp=np)
_ = at(a, 0).subtract(4, xp=np)
Expand Down
4 changes: 1 addition & 3 deletions tests/test_funcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,15 +27,13 @@
setdiff1d,
sinc,
)
from array_api_extra._lib._backends import Backend
from array_api_extra._lib._backends import NUMPY_VERSION, Backend
from array_api_extra._lib._testing import xp_assert_close, xp_assert_equal
from array_api_extra._lib._utils._compat import device as get_device
from array_api_extra._lib._utils._helpers import eager_shape, ndindex
from array_api_extra._lib._utils._typing import Array, Device
from array_api_extra.testing import lazy_xp_function

from .conftest import NUMPY_VERSION

lazy_xp_function(apply_where)
lazy_xp_function(atleast_nd)
lazy_xp_function(cov)
Expand Down
2 changes: 1 addition & 1 deletion tests/test_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@

from .conftest import np_compat

if TYPE_CHECKING:
if TYPE_CHECKING: # pragma: no cover
# TODO import from typing (requires Python >=3.12)
from typing_extensions import override
else:
Expand Down
Loading