diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 8b3f4e64..e04f7447 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -33,7 +33,7 @@ jobs: - name: Run Tests run: | if [[ "${{ matrix.numpy-version }}" == "1.21" || "${{ matrix.numpy-version }}" == "dev" ]]; then - PYTEST_EXTRA=(-k "numpy and not jax and not torch and not dask") + PYTEST_EXTRA=(-k "numpy and not jax and not torch and not dask and not sparse") fi pytest -v "${PYTEST_EXTRA[@]}" diff --git a/README.md b/README.md index 8f567606..7e96c4df 100644 --- a/README.md +++ b/README.md @@ -2,8 +2,8 @@ This is a small wrapper around common array libraries that is compatible with the [Array API standard](https://data-apis.org/array-api/latest/). Currently, -NumPy, CuPy, PyTorch, Dask, and JAX are supported. If you want support for other array -libraries, or if you encounter any issues, please [open an +NumPy, CuPy, PyTorch, Dask, JAX and `sparse` are supported. If you want support +for other array libraries, or if you encounter any issues, please [open an issue](https://github.com/data-apis/array-api-compat/issues). See the documentation for more details https://data-apis.org/array-api-compat/ diff --git a/array_api_compat/common/_helpers.py b/array_api_compat/common/_helpers.py index 982b284a..79354487 100644 --- a/array_api_compat/common/_helpers.py +++ b/array_api_compat/common/_helpers.py @@ -50,6 +50,7 @@ def is_numpy_array(x): is_torch_array is_dask_array is_jax_array + is_pydata_sparse """ # Avoid importing NumPy if it isn't already if 'numpy' not in sys.modules: @@ -79,6 +80,7 @@ def is_cupy_array(x): is_torch_array is_dask_array is_jax_array + is_pydata_sparse """ # Avoid importing NumPy if it isn't already if 'cupy' not in sys.modules: @@ -105,6 +107,7 @@ def is_torch_array(x): is_cupy_array is_dask_array is_jax_array + is_pydata_sparse """ # Avoid importing torch if it isn't already if 'torch' not in sys.modules: @@ -131,6 +134,7 @@ def is_dask_array(x): is_cupy_array is_torch_array is_jax_array + is_pydata_sparse """ # Avoid importing dask if it isn't already if 'dask.array' not in sys.modules: @@ -157,6 +161,7 @@ def is_jax_array(x): is_cupy_array is_torch_array is_dask_array + is_pydata_sparse """ # Avoid importing jax if it isn't already if 'jax' not in sys.modules: @@ -166,6 +171,35 @@ def is_jax_array(x): return isinstance(x, jax.Array) or _is_jax_zero_gradient_array(x) + +def is_pydata_sparse(x) -> bool: + """ + Return True if `x` is an array from the `sparse` package. + + This function does not import `sparse` if it has not already been imported + and is therefore cheap to use. + + + See Also + -------- + + array_namespace + is_array_api_obj + is_numpy_array + is_cupy_array + is_torch_array + is_dask_array + is_jax_array + """ + # Avoid importing jax if it isn't already + if 'sparse' not in sys.modules: + return False + + import sparse + + # TODO: Account for other backends. + return isinstance(x, sparse.SparseArray) + def is_array_api_obj(x): """ Return True if `x` is an array API compatible array object. @@ -185,6 +219,7 @@ def is_array_api_obj(x): or is_torch_array(x) \ or is_dask_array(x) \ or is_jax_array(x) \ + or is_pydata_sparse(x) \ or hasattr(x, '__array_namespace__') def _check_api_version(api_version): @@ -253,6 +288,7 @@ def your_function(x, y): is_torch_array is_dask_array is_jax_array + is_pydata_sparse """ if use_compat not in [None, True, False]: @@ -312,6 +348,15 @@ def your_function(x, y): # not have a wrapper submodule for it. import jax.experimental.array_api as jnp namespaces.add(jnp) + elif is_pydata_sparse(x): + if use_compat is True: + _check_api_version(api_version) + raise ValueError("`sparse` does not have an array-api-compat wrapper") + else: + import sparse + # `sparse` is already an array namespace. We do not have a wrapper + # submodule for it. + namespaces.add(sparse) elif hasattr(x, '__array_namespace__'): if use_compat is True: raise ValueError("The given array does not have an array-api-compat wrapper") @@ -406,8 +451,23 @@ def device(x: Array, /) -> Device: return x.device() else: return x.device + elif is_pydata_sparse(x): + # `sparse` will gain `.device`, so check for this first. + x_device = getattr(x, 'device', None) + if x_device is not None: + return x_device + # Everything but DOK has this attr. + try: + inner = x.data + except AttributeError: + return "cpu" + # Return the device of the constituent array + return device(inner) return x.device +# Prevent shadowing, used below +_device = device + # Based on cupy.array_api.Array.to_device def _cupy_to_device(x, device, /, stream=None): import cupy as cp @@ -523,6 +583,10 @@ def to_device(x: Array, device: Device, /, *, stream: Optional[Union[int, Any]] # This import adds to_device to x import jax.experimental.array_api # noqa: F401 return x.to_device(device, stream=stream) + elif is_pydata_sparse(x) and device == _device(x): + # Perform trivial check to return the same array if + # device is same instead of err-ing. + return x return x.to_device(device, stream=stream) def size(x): @@ -549,6 +613,7 @@ def size(x): "is_jax_array", "is_numpy_array", "is_torch_array", + "is_pydata_sparse", "size", "to_device", ] diff --git a/docs/changelog.md b/docs/changelog.md index f17eb23d..545a9aa8 100644 --- a/docs/changelog.md +++ b/docs/changelog.md @@ -4,6 +4,14 @@ ## Major Changes +- Add support for `sparse`. Note that unlike other array libraries, + array-api-compat does not contain any wrappers for `sparse` functions. All + `sparse` array API support is in `sparse` itself. Thus, there is no + `array_api_compat.sparse` submodule, and + `array_namespace()` returns the `sparse` module. + +- Added the function `is_pydata_sparse(x)`. + - Drop support for Python 3.8. - NumPy 2.0 is now left completely unwrapped. diff --git a/docs/supported-array-libraries.md b/docs/supported-array-libraries.md index 861b74bd..88b9edce 100644 --- a/docs/supported-array-libraries.md +++ b/docs/supported-array-libraries.md @@ -132,3 +132,6 @@ For `linalg`, several methods are missing, for example: Other methods may only be partially implemented or return incorrect results at times. The minimum supported Dask version is 2023.12.0. + +## [`sparse`](https://sparse.pydata.org/en/stable/) +Similar to JAX, `sparse` Array API support is contained directly in `sparse`. diff --git a/requirements-dev.txt b/requirements-dev.txt index 13143d6c..d06de300 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -4,3 +4,4 @@ jax[cpu] numpy pytest torch +sparse >=0.15.1 diff --git a/tests/_helpers.py b/tests/_helpers.py index ffa2171e..e52c205d 100644 --- a/tests/_helpers.py +++ b/tests/_helpers.py @@ -1,19 +1,26 @@ from importlib import import_module +import sys import pytest wrapped_libraries = ["cupy", "torch", "dask.array"] -all_libraries = wrapped_libraries + ["numpy", "jax.numpy"] +all_libraries = wrapped_libraries + ["numpy", "jax.numpy", "sparse"] import numpy as np if np.__version__[0] == '1': wrapped_libraries.append("numpy") +# `sparse` added array API support as of Python 3.10. +if sys.version_info >= (3, 10): + all_libraries.append('sparse') + def import_(library, wrapper=False): if library == 'cupy': pytest.importorskip(library) if wrapper: if 'jax' in library: library = 'jax.experimental.array_api' + elif library.startswith('sparse'): + library = 'sparse' else: library = 'array_api_compat.' + library diff --git a/tests/test_array_namespace.py b/tests/test_array_namespace.py index f5454bff..1f83a473 100644 --- a/tests/test_array_namespace.py +++ b/tests/test_array_namespace.py @@ -19,7 +19,7 @@ def test_array_namespace(library, api_version, use_compat): xp = import_(library) array = xp.asarray([1.0, 2.0, 3.0]) - if use_compat is True and library in ['array_api_strict', 'jax.numpy']: + if use_compat is True and library in {'array_api_strict', 'jax.numpy', 'sparse'}: pytest.raises(ValueError, lambda: array_namespace(array, use_compat=use_compat)) return namespace = array_api_compat.array_namespace(array, api_version=api_version, use_compat=use_compat) diff --git a/tests/test_common.py b/tests/test_common.py index 1cd396f1..798dc114 100644 --- a/tests/test_common.py +++ b/tests/test_common.py @@ -1,5 +1,5 @@ from array_api_compat import (is_numpy_array, is_cupy_array, is_torch_array, # noqa: F401 - is_dask_array, is_jax_array) + is_dask_array, is_jax_array, is_pydata_sparse) from array_api_compat import is_array_api_obj, device, to_device @@ -16,6 +16,7 @@ 'torch': 'is_torch_array', 'dask.array': 'is_dask_array', 'jax.numpy': 'is_jax_array', + 'sparse': 'is_pydata_sparse', } @pytest.mark.parametrize('library', is_functions.keys()) @@ -76,6 +77,8 @@ def test_asarray_cross_library(source_library, target_library, request): if source_library == "cupy" and target_library != "cupy": # cupy explicitly disallows implicit conversions to CPU pytest.skip(reason="cupy does not support implicit conversion to CPU") + elif source_library == "sparse" and target_library != "sparse": + pytest.skip(reason="`sparse` does not allow implicit densification") src_lib = import_(source_library, wrapper=True) tgt_lib = import_(target_library, wrapper=True) is_tgt_type = globals()[is_functions[target_library]] diff --git a/tests/test_no_dependencies.py b/tests/test_no_dependencies.py index 8ad71a3c..a1fdf731 100644 --- a/tests/test_no_dependencies.py +++ b/tests/test_no_dependencies.py @@ -33,7 +33,7 @@ def _test_dependency(mod): # array-api-strict is an example of an array API library that isn't # wrapped by array-api-compat. - if "strict" not in mod: + if "strict" not in mod and mod != "sparse": is_mod_array = getattr(array_api_compat, f"is_{mod.split('.')[0]}_array") assert not is_mod_array(a) assert mod not in sys.modules @@ -50,7 +50,7 @@ def _test_dependency(mod): # Y (except most array libraries actually do themselves depend on numpy). @pytest.mark.parametrize("library", ["cupy", "numpy", "torch", "dask.array", - "jax.numpy", "array_api_strict"]) + "jax.numpy", "sparse", "array_api_strict"]) def test_numpy_dependency(library): # This import is here because it imports numpy from ._helpers import import_