From c03daa36c09d51162d240b77e223a49cc8a6076e Mon Sep 17 00:00:00 2001 From: Evgeni Burovski Date: Sun, 1 Jun 2025 13:36:11 +0200 Subject: [PATCH 1/3] CI: install jax/sparse/torch in more jobs Also, `ndonnx` has wheels for all python versions now; And we do not bother with jax or dask numpy < 1. --- .github/workflows/tests.yml | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 81a05b3f..c995b370 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -32,20 +32,20 @@ jobs: python -m pip install --upgrade pip python -m pip install pytest + # Don't `pip install .[dev]` as it would pull in the whole torch cuda stack + python -m pip install array-api-strict + python -m pip install torch --index-url https://download.pytorch.org/whl/cpu + if [ "${{ matrix.numpy-version }}" == "dev" ]; then python -m pip install numpy --pre --extra-index-url https://pypi.anaconda.org/scientific-python-nightly-wheels/simple + python -m pip install dask[array] jax[cpu] sparse ndonnx elif [ "${{ matrix.numpy-version }}" == "1.22" ]; then python -m pip install 'numpy==1.22.*' elif [ "${{ matrix.numpy-version }}" == "1.26" ]; then python -m pip install 'numpy==1.26.*' else - # Don't `pip install .[dev]` as it would pull in the whole torch cuda stack - python -m pip install array-api-strict dask[array] jax[cpu] numpy sparse - python -m pip install torch --index-url https://download.pytorch.org/whl/cpu - if [ "${{ matrix.python-version }}" != "3.13" ]; then - # onnx wheels are not available on Python 3.13 at the moment of writing - python -m pip install ndonnx - fi + python -m pip install numpy + python -m pip install dask[array] jax[cpu] sparse ndonnx fi - name: Dump pip environment From a8e19835092335ab8e1846f1e3dda335d8eb4c4a Mon Sep 17 00:00:00 2001 From: Evgeni Burovski Date: Sun, 1 Jun 2025 13:38:31 +0200 Subject: [PATCH 2/3] TST: xfail test_device_to_device with numpy < 2 It assumes that asarray has the copy kwarg, and this is not true in NumPy < 2. --- tests/test_common.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/tests/test_common.py b/tests/test_common.py index 54b5ed69..85ed032e 100644 --- a/tests/test_common.py +++ b/tests/test_common.py @@ -195,6 +195,9 @@ def test_device_to_device(library, request): xfail(request, reason="Stub raises ValueError") if library == "sparse": xfail(request, reason="No __array_namespace_info__()") + if library == "array_api_strict": + if np.__version__ < "2": + xfail(request, reason="no copy argument of np.asarray") xp = import_(library, wrapper=True) devices = xp.__array_namespace_info__().devices() From 8e3ab3e7c5c6794f66196ec435d2f6bdd1492404 Mon Sep 17 00:00:00 2001 From: Evgeni Burovski Date: Sun, 1 Jun 2025 13:39:28 +0200 Subject: [PATCH 3/3] MAINT: filter out some warning noise --- tests/test_array_namespace.py | 19 +++++++++++++------ 1 file changed, 13 insertions(+), 6 deletions(-) diff --git a/tests/test_array_namespace.py b/tests/test_array_namespace.py index cdb80007..2fbb0339 100644 --- a/tests/test_array_namespace.py +++ b/tests/test_array_namespace.py @@ -23,7 +23,9 @@ def test_array_namespace(library, api_version, use_compat): if library == "ndonnx" and api_version in ("2021.12", "2022.12"): pytest.skip("Unsupported API version") - namespace = array_namespace(array, api_version=api_version, use_compat=use_compat) + with warnings.catch_warnings(): + warnings.simplefilter('ignore', UserWarning) + namespace = array_namespace(array, api_version=api_version, use_compat=use_compat) if use_compat is False or use_compat is None and library not in wrapped_libraries: if library == "jax.numpy" and use_compat is None: @@ -45,10 +47,13 @@ def test_array_namespace(library, api_version, use_compat): if library == "numpy": # check that the same namespace is returned for NumPy scalars - scalar_namespace = array_namespace( - xp.float64(0.0), api_version=api_version, use_compat=use_compat - ) - assert scalar_namespace == namespace + with warnings.catch_warnings(): + warnings.simplefilter('ignore', UserWarning) + + scalar_namespace = array_namespace( + xp.float64(0.0), api_version=api_version, use_compat=use_compat + ) + assert scalar_namespace == namespace # Check that array_namespace works even if jax.experimental.array_api # hasn't been imported yet (it monkeypatches __array_namespace__ @@ -97,7 +102,9 @@ def test_api_version_torch(): torch = import_("torch") x = torch.asarray([1, 2]) torch_ = import_("torch", wrapper=True) - assert array_namespace(x, api_version="2023.12") == torch_ + with warnings.catch_warnings(): + warnings.simplefilter('ignore', UserWarning) + assert array_namespace(x, api_version="2023.12") == torch_ assert array_namespace(x, api_version=None) == torch_ assert array_namespace(x) == torch_ # Should issue a warning