From 88abd2be9f3c7dca1efb87a3061424b4c257fe65 Mon Sep 17 00:00:00 2001 From: Hameer Abbasi <2190658+hameerabbasi@users.noreply.github.com> Date: Fri, 3 May 2024 08:38:41 +0200 Subject: [PATCH 01/12] WIP: Add `sparse` compatibility layer. --- array_api_compat/common/_helpers.py | 63 +++++++++++++++++++++++++++++ array_api_compat/sparse/__init__.py | 5 +++ requirements-dev.txt | 1 + 3 files changed, 69 insertions(+) create mode 100644 array_api_compat/sparse/__init__.py diff --git a/array_api_compat/common/_helpers.py b/array_api_compat/common/_helpers.py index 982b284a..9ac7784f 100644 --- a/array_api_compat/common/_helpers.py +++ b/array_api_compat/common/_helpers.py @@ -50,6 +50,7 @@ def is_numpy_array(x): is_torch_array is_dask_array is_jax_array + is_sparse_array """ # Avoid importing NumPy if it isn't already if 'numpy' not in sys.modules: @@ -79,6 +80,7 @@ def is_cupy_array(x): is_torch_array is_dask_array is_jax_array + is_sparse_array """ # Avoid importing NumPy if it isn't already if 'cupy' not in sys.modules: @@ -105,6 +107,7 @@ def is_torch_array(x): is_cupy_array is_dask_array is_jax_array + is_sparse_array """ # Avoid importing torch if it isn't already if 'torch' not in sys.modules: @@ -131,6 +134,7 @@ def is_dask_array(x): is_cupy_array is_torch_array is_jax_array + is_sparse_array """ # Avoid importing dask if it isn't already if 'dask.array' not in sys.modules: @@ -157,6 +161,7 @@ def is_jax_array(x): is_cupy_array is_torch_array is_dask_array + is_sparse_array """ # Avoid importing jax if it isn't already if 'jax' not in sys.modules: @@ -166,6 +171,35 @@ def is_jax_array(x): return isinstance(x, jax.Array) or _is_jax_zero_gradient_array(x) + +def is_sparse_array(x) -> bool: + """ + Return True if `x` is a `sparse` array. + + This function does not import `sparse` if it has not already been imported + and is therefore cheap to use. + + + See Also + -------- + + array_namespace + is_array_api_obj + is_numpy_array + is_cupy_array + is_torch_array + is_dask_array + is_jax_array + """ + # Avoid importing jax if it isn't already + if 'sparse' not in sys.modules: + return False + + import sparse + + # TODO: Account for other backends. + return isinstance(x, sparse.SparseArray) + def is_array_api_obj(x): """ Return True if `x` is an array API compatible array object. @@ -185,6 +219,7 @@ def is_array_api_obj(x): or is_torch_array(x) \ or is_dask_array(x) \ or is_jax_array(x) \ + or is_sparse_array(x) \ or hasattr(x, '__array_namespace__') def _check_api_version(api_version): @@ -253,6 +288,7 @@ def your_function(x, y): is_torch_array is_dask_array is_jax_array + is_sparse_array """ if use_compat not in [None, True, False]: @@ -312,6 +348,13 @@ def your_function(x, y): # not have a wrapper submodule for it. import jax.experimental.array_api as jnp namespaces.add(jnp) + elif is_sparse_array(x): + if use_compat is True: + _check_api_version(api_version) + raise ValueError("`sparse` does not have an array-api-compat wrapper") + else: + import sparse + namespaces.add(sparse) elif hasattr(x, '__array_namespace__'): if use_compat is True: raise ValueError("The given array does not have an array-api-compat wrapper") @@ -406,8 +449,23 @@ def device(x: Array, /) -> Device: return x.device() else: return x.device + elif is_sparse_array(x): + # `sparse` will gain `.device`, so check for this first. + x_device = getattr(x, 'device', None) + if x_device is not None: + return x_device + # Everything but DOK has this attr. + try: + inner = x.data + except AttributeError: + return "cpu" + # Return the device of the constituent array + return device(inner) return x.device +# Prevent shadowing, used below +_device = device + # Based on cupy.array_api.Array.to_device def _cupy_to_device(x, device, /, stream=None): import cupy as cp @@ -523,6 +581,10 @@ def to_device(x: Array, device: Device, /, *, stream: Optional[Union[int, Any]] # This import adds to_device to x import jax.experimental.array_api # noqa: F401 return x.to_device(device, stream=stream) + elif is_sparse_array(x) and device == _device(x): + # Perform trivial check to return the same array if + # device is same instead of err-ing. + return x return x.to_device(device, stream=stream) def size(x): @@ -549,6 +611,7 @@ def size(x): "is_jax_array", "is_numpy_array", "is_torch_array", + "is_sparse_array", "size", "to_device", ] diff --git a/array_api_compat/sparse/__init__.py b/array_api_compat/sparse/__init__.py new file mode 100644 index 00000000..7f3483d2 --- /dev/null +++ b/array_api_compat/sparse/__init__.py @@ -0,0 +1,5 @@ +from sparse import * # noqa: F403 +from ..common._aliases import * # noqa: F403 +from ..common._helpers import * # noqa: F401,F403 + +__array_api_version__ = '2022.12' diff --git a/requirements-dev.txt b/requirements-dev.txt index 13143d6c..d06de300 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -4,3 +4,4 @@ jax[cpu] numpy pytest torch +sparse >=0.15.1 From a4966e7f64309581e3003985f72bef6ce2787a8b Mon Sep 17 00:00:00 2001 From: Hameer Abbasi <2190658+hameerabbasi@users.noreply.github.com> Date: Mon, 13 May 2024 07:43:56 +0200 Subject: [PATCH 02/12] Remove incorrect aliases. --- array_api_compat/sparse/__init__.py | 1 - 1 file changed, 1 deletion(-) diff --git a/array_api_compat/sparse/__init__.py b/array_api_compat/sparse/__init__.py index 7f3483d2..5102fef4 100644 --- a/array_api_compat/sparse/__init__.py +++ b/array_api_compat/sparse/__init__.py @@ -1,5 +1,4 @@ from sparse import * # noqa: F403 -from ..common._aliases import * # noqa: F403 from ..common._helpers import * # noqa: F401,F403 __array_api_version__ = '2022.12' From fb14cc815b9cb0e4b7d2ae8180b685537443c827 Mon Sep 17 00:00:00 2001 From: Hameer Abbasi <2190658+hameerabbasi@users.noreply.github.com> Date: Mon, 13 May 2024 07:50:25 +0200 Subject: [PATCH 03/12] Reword docstring for `is_sparse_array`. --- array_api_compat/common/_helpers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/array_api_compat/common/_helpers.py b/array_api_compat/common/_helpers.py index 9ac7784f..b673b7f7 100644 --- a/array_api_compat/common/_helpers.py +++ b/array_api_compat/common/_helpers.py @@ -174,7 +174,7 @@ def is_jax_array(x): def is_sparse_array(x) -> bool: """ - Return True if `x` is a `sparse` array. + Return True if `x` is an array from the `sparse` package. This function does not import `sparse` if it has not already been imported and is therefore cheap to use. From b92a35c474db416bfd27278ae7614745483e1161 Mon Sep 17 00:00:00 2001 From: Hameer Abbasi <2190658+hameerabbasi@users.noreply.github.com> Date: Mon, 13 May 2024 07:59:33 +0200 Subject: [PATCH 04/12] Remove wrapper submodule. --- array_api_compat/common/_helpers.py | 2 ++ array_api_compat/sparse/__init__.py | 4 ---- 2 files changed, 2 insertions(+), 4 deletions(-) delete mode 100644 array_api_compat/sparse/__init__.py diff --git a/array_api_compat/common/_helpers.py b/array_api_compat/common/_helpers.py index b673b7f7..22fc1274 100644 --- a/array_api_compat/common/_helpers.py +++ b/array_api_compat/common/_helpers.py @@ -354,6 +354,8 @@ def your_function(x, y): raise ValueError("`sparse` does not have an array-api-compat wrapper") else: import sparse + # `sparse` is already an array namespace. We do not have a wrapper + # submodule for it. namespaces.add(sparse) elif hasattr(x, '__array_namespace__'): if use_compat is True: diff --git a/array_api_compat/sparse/__init__.py b/array_api_compat/sparse/__init__.py deleted file mode 100644 index 5102fef4..00000000 --- a/array_api_compat/sparse/__init__.py +++ /dev/null @@ -1,4 +0,0 @@ -from sparse import * # noqa: F403 -from ..common._helpers import * # noqa: F401,F403 - -__array_api_version__ = '2022.12' From 7ebc3c096ad022f1b3448ca233377a3ec72b3af8 Mon Sep 17 00:00:00 2001 From: Hameer Abbasi <2190658+hameerabbasi@users.noreply.github.com> Date: Wed, 15 May 2024 06:56:05 +0200 Subject: [PATCH 05/12] Rename `is_sparse_array` -> `is_pydata_sparse`. --- array_api_compat/common/_helpers.py | 24 ++++++++++++------------ 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/array_api_compat/common/_helpers.py b/array_api_compat/common/_helpers.py index 22fc1274..79354487 100644 --- a/array_api_compat/common/_helpers.py +++ b/array_api_compat/common/_helpers.py @@ -50,7 +50,7 @@ def is_numpy_array(x): is_torch_array is_dask_array is_jax_array - is_sparse_array + is_pydata_sparse """ # Avoid importing NumPy if it isn't already if 'numpy' not in sys.modules: @@ -80,7 +80,7 @@ def is_cupy_array(x): is_torch_array is_dask_array is_jax_array - is_sparse_array + is_pydata_sparse """ # Avoid importing NumPy if it isn't already if 'cupy' not in sys.modules: @@ -107,7 +107,7 @@ def is_torch_array(x): is_cupy_array is_dask_array is_jax_array - is_sparse_array + is_pydata_sparse """ # Avoid importing torch if it isn't already if 'torch' not in sys.modules: @@ -134,7 +134,7 @@ def is_dask_array(x): is_cupy_array is_torch_array is_jax_array - is_sparse_array + is_pydata_sparse """ # Avoid importing dask if it isn't already if 'dask.array' not in sys.modules: @@ -161,7 +161,7 @@ def is_jax_array(x): is_cupy_array is_torch_array is_dask_array - is_sparse_array + is_pydata_sparse """ # Avoid importing jax if it isn't already if 'jax' not in sys.modules: @@ -172,7 +172,7 @@ def is_jax_array(x): return isinstance(x, jax.Array) or _is_jax_zero_gradient_array(x) -def is_sparse_array(x) -> bool: +def is_pydata_sparse(x) -> bool: """ Return True if `x` is an array from the `sparse` package. @@ -219,7 +219,7 @@ def is_array_api_obj(x): or is_torch_array(x) \ or is_dask_array(x) \ or is_jax_array(x) \ - or is_sparse_array(x) \ + or is_pydata_sparse(x) \ or hasattr(x, '__array_namespace__') def _check_api_version(api_version): @@ -288,7 +288,7 @@ def your_function(x, y): is_torch_array is_dask_array is_jax_array - is_sparse_array + is_pydata_sparse """ if use_compat not in [None, True, False]: @@ -348,7 +348,7 @@ def your_function(x, y): # not have a wrapper submodule for it. import jax.experimental.array_api as jnp namespaces.add(jnp) - elif is_sparse_array(x): + elif is_pydata_sparse(x): if use_compat is True: _check_api_version(api_version) raise ValueError("`sparse` does not have an array-api-compat wrapper") @@ -451,7 +451,7 @@ def device(x: Array, /) -> Device: return x.device() else: return x.device - elif is_sparse_array(x): + elif is_pydata_sparse(x): # `sparse` will gain `.device`, so check for this first. x_device = getattr(x, 'device', None) if x_device is not None: @@ -583,7 +583,7 @@ def to_device(x: Array, device: Device, /, *, stream: Optional[Union[int, Any]] # This import adds to_device to x import jax.experimental.array_api # noqa: F401 return x.to_device(device, stream=stream) - elif is_sparse_array(x) and device == _device(x): + elif is_pydata_sparse(x) and device == _device(x): # Perform trivial check to return the same array if # device is same instead of err-ing. return x @@ -613,7 +613,7 @@ def size(x): "is_jax_array", "is_numpy_array", "is_torch_array", - "is_sparse_array", + "is_pydata_sparse", "size", "to_device", ] From d3c6636bdcf878b98711373c9fab99b74b924f6d Mon Sep 17 00:00:00 2001 From: Hameer Abbasi <2190658+hameerabbasi@users.noreply.github.com> Date: Wed, 15 May 2024 07:02:40 +0200 Subject: [PATCH 06/12] Add `sparse` to the documentation. --- README.md | 4 ++-- docs/changelog.md | 8 ++++++++ 2 files changed, 10 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index 8f567606..7e96c4df 100644 --- a/README.md +++ b/README.md @@ -2,8 +2,8 @@ This is a small wrapper around common array libraries that is compatible with the [Array API standard](https://data-apis.org/array-api/latest/). Currently, -NumPy, CuPy, PyTorch, Dask, and JAX are supported. If you want support for other array -libraries, or if you encounter any issues, please [open an +NumPy, CuPy, PyTorch, Dask, JAX and `sparse` are supported. If you want support +for other array libraries, or if you encounter any issues, please [open an issue](https://github.com/data-apis/array-api-compat/issues). See the documentation for more details https://data-apis.org/array-api-compat/ diff --git a/docs/changelog.md b/docs/changelog.md index f17eb23d..545a9aa8 100644 --- a/docs/changelog.md +++ b/docs/changelog.md @@ -4,6 +4,14 @@ ## Major Changes +- Add support for `sparse`. Note that unlike other array libraries, + array-api-compat does not contain any wrappers for `sparse` functions. All + `sparse` array API support is in `sparse` itself. Thus, there is no + `array_api_compat.sparse` submodule, and + `array_namespace()` returns the `sparse` module. + +- Added the function `is_pydata_sparse(x)`. + - Drop support for Python 3.8. - NumPy 2.0 is now left completely unwrapped. From 855756d17727bb2998c846a7c9a230131a525e46 Mon Sep 17 00:00:00 2001 From: Hameer Abbasi <2190658+hameerabbasi@users.noreply.github.com> Date: Wed, 15 May 2024 07:26:38 +0200 Subject: [PATCH 07/12] Modify testsuite to add `sparse` support. --- tests/_helpers.py | 4 +++- tests/test_array_namespace.py | 2 +- tests/test_common.py | 5 ++++- tests/test_no_dependencies.py | 4 ++-- 4 files changed, 10 insertions(+), 5 deletions(-) diff --git a/tests/_helpers.py b/tests/_helpers.py index ffa2171e..bf9b1504 100644 --- a/tests/_helpers.py +++ b/tests/_helpers.py @@ -3,7 +3,7 @@ import pytest wrapped_libraries = ["cupy", "torch", "dask.array"] -all_libraries = wrapped_libraries + ["numpy", "jax.numpy"] +all_libraries = wrapped_libraries + ["numpy", "jax.numpy", "sparse"] import numpy as np if np.__version__[0] == '1': wrapped_libraries.append("numpy") @@ -14,6 +14,8 @@ def import_(library, wrapper=False): if wrapper: if 'jax' in library: library = 'jax.experimental.array_api' + elif library.startswith('sparse'): + library = 'sparse' else: library = 'array_api_compat.' + library diff --git a/tests/test_array_namespace.py b/tests/test_array_namespace.py index f5454bff..1f83a473 100644 --- a/tests/test_array_namespace.py +++ b/tests/test_array_namespace.py @@ -19,7 +19,7 @@ def test_array_namespace(library, api_version, use_compat): xp = import_(library) array = xp.asarray([1.0, 2.0, 3.0]) - if use_compat is True and library in ['array_api_strict', 'jax.numpy']: + if use_compat is True and library in {'array_api_strict', 'jax.numpy', 'sparse'}: pytest.raises(ValueError, lambda: array_namespace(array, use_compat=use_compat)) return namespace = array_api_compat.array_namespace(array, api_version=api_version, use_compat=use_compat) diff --git a/tests/test_common.py b/tests/test_common.py index 1cd396f1..798dc114 100644 --- a/tests/test_common.py +++ b/tests/test_common.py @@ -1,5 +1,5 @@ from array_api_compat import (is_numpy_array, is_cupy_array, is_torch_array, # noqa: F401 - is_dask_array, is_jax_array) + is_dask_array, is_jax_array, is_pydata_sparse) from array_api_compat import is_array_api_obj, device, to_device @@ -16,6 +16,7 @@ 'torch': 'is_torch_array', 'dask.array': 'is_dask_array', 'jax.numpy': 'is_jax_array', + 'sparse': 'is_pydata_sparse', } @pytest.mark.parametrize('library', is_functions.keys()) @@ -76,6 +77,8 @@ def test_asarray_cross_library(source_library, target_library, request): if source_library == "cupy" and target_library != "cupy": # cupy explicitly disallows implicit conversions to CPU pytest.skip(reason="cupy does not support implicit conversion to CPU") + elif source_library == "sparse" and target_library != "sparse": + pytest.skip(reason="`sparse` does not allow implicit densification") src_lib = import_(source_library, wrapper=True) tgt_lib = import_(target_library, wrapper=True) is_tgt_type = globals()[is_functions[target_library]] diff --git a/tests/test_no_dependencies.py b/tests/test_no_dependencies.py index 8ad71a3c..a1fdf731 100644 --- a/tests/test_no_dependencies.py +++ b/tests/test_no_dependencies.py @@ -33,7 +33,7 @@ def _test_dependency(mod): # array-api-strict is an example of an array API library that isn't # wrapped by array-api-compat. - if "strict" not in mod: + if "strict" not in mod and mod != "sparse": is_mod_array = getattr(array_api_compat, f"is_{mod.split('.')[0]}_array") assert not is_mod_array(a) assert mod not in sys.modules @@ -50,7 +50,7 @@ def _test_dependency(mod): # Y (except most array libraries actually do themselves depend on numpy). @pytest.mark.parametrize("library", ["cupy", "numpy", "torch", "dask.array", - "jax.numpy", "array_api_strict"]) + "jax.numpy", "sparse", "array_api_strict"]) def test_numpy_dependency(library): # This import is here because it imports numpy from ._helpers import import_ From 9c408674eb9cefc387d3fab96fa6cfdeedf79767 Mon Sep 17 00:00:00 2001 From: Hameer Abbasi <2190658+hameerabbasi@users.noreply.github.com> Date: Wed, 15 May 2024 07:46:04 +0200 Subject: [PATCH 08/12] Only test `sparse` for Py3.10+. --- tests/_helpers.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/tests/_helpers.py b/tests/_helpers.py index bf9b1504..e52c205d 100644 --- a/tests/_helpers.py +++ b/tests/_helpers.py @@ -1,4 +1,5 @@ from importlib import import_module +import sys import pytest @@ -8,6 +9,10 @@ if np.__version__[0] == '1': wrapped_libraries.append("numpy") +# `sparse` added array API support as of Python 3.10. +if sys.version_info >= (3, 10): + all_libraries.append('sparse') + def import_(library, wrapper=False): if library == 'cupy': pytest.importorskip(library) From 44681b6b95c36f2428f2c0fc0ce4f1bd707ff1f8 Mon Sep 17 00:00:00 2001 From: Hameer Abbasi <2190658+hameerabbasi@users.noreply.github.com> Date: Wed, 15 May 2024 08:31:54 +0200 Subject: [PATCH 09/12] Only test `sparse` when it can actually be imported. --- tests/_helpers.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/tests/_helpers.py b/tests/_helpers.py index e52c205d..cd9044de 100644 --- a/tests/_helpers.py +++ b/tests/_helpers.py @@ -14,7 +14,9 @@ all_libraries.append('sparse') def import_(library, wrapper=False): - if library == 'cupy': + # CuPy requires a GPU + # `sparse` has a dependency conflict with NumPy 1.21 + if library in {'cupy', 'sparse'}: pytest.importorskip(library) if wrapper: if 'jax' in library: From efbd400f4f3a60a116ea5625d0ec1c2b9aa7b4d1 Mon Sep 17 00:00:00 2001 From: Hameer Abbasi <2190658+hameerabbasi@users.noreply.github.com> Date: Wed, 15 May 2024 10:20:22 +0200 Subject: [PATCH 10/12] Revert "Only test `sparse` when it can actually be imported." This reverts commit 44681b6b95c36f2428f2c0fc0ce4f1bd707ff1f8. --- tests/_helpers.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/tests/_helpers.py b/tests/_helpers.py index cd9044de..e52c205d 100644 --- a/tests/_helpers.py +++ b/tests/_helpers.py @@ -14,9 +14,7 @@ all_libraries.append('sparse') def import_(library, wrapper=False): - # CuPy requires a GPU - # `sparse` has a dependency conflict with NumPy 1.21 - if library in {'cupy', 'sparse'}: + if library == 'cupy': pytest.importorskip(library) if wrapper: if 'jax' in library: From b96d0a6af94eded445bb6aa9e63aeaa0133cc67e Mon Sep 17 00:00:00 2001 From: Hameer Abbasi <2190658+hameerabbasi@users.noreply.github.com> Date: Wed, 15 May 2024 10:21:35 +0200 Subject: [PATCH 11/12] Don't test `sparse` for NumPy 1.21. --- .github/workflows/tests.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 8b3f4e64..e04f7447 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -33,7 +33,7 @@ jobs: - name: Run Tests run: | if [[ "${{ matrix.numpy-version }}" == "1.21" || "${{ matrix.numpy-version }}" == "dev" ]]; then - PYTEST_EXTRA=(-k "numpy and not jax and not torch and not dask") + PYTEST_EXTRA=(-k "numpy and not jax and not torch and not dask and not sparse") fi pytest -v "${PYTEST_EXTRA[@]}" From b77472d1c9ec9c32b6932c7f2153da2f096071e0 Mon Sep 17 00:00:00 2001 From: Hameer Abbasi <2190658+hameerabbasi@users.noreply.github.com> Date: Fri, 17 May 2024 10:16:51 +0200 Subject: [PATCH 12/12] Add `sparse` to the list of supported libraries. --- docs/supported-array-libraries.md | 3 +++ 1 file changed, 3 insertions(+) diff --git a/docs/supported-array-libraries.md b/docs/supported-array-libraries.md index 861b74bd..88b9edce 100644 --- a/docs/supported-array-libraries.md +++ b/docs/supported-array-libraries.md @@ -132,3 +132,6 @@ For `linalg`, several methods are missing, for example: Other methods may only be partially implemented or return incorrect results at times. The minimum supported Dask version is 2023.12.0. + +## [`sparse`](https://sparse.pydata.org/en/stable/) +Similar to JAX, `sparse` Array API support is contained directly in `sparse`.