Skip to content

Add sparse compatibility layer. #134

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 12 commits into from
May 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ jobs:
- name: Run Tests
run: |
if [[ "${{ matrix.numpy-version }}" == "1.21" || "${{ matrix.numpy-version }}" == "dev" ]]; then
PYTEST_EXTRA=(-k "numpy and not jax and not torch and not dask")
PYTEST_EXTRA=(-k "numpy and not jax and not torch and not dask and not sparse")
fi
pytest -v "${PYTEST_EXTRA[@]}"

Expand Down
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@

This is a small wrapper around common array libraries that is compatible with
the [Array API standard](https://data-apis.org/array-api/latest/). Currently,
NumPy, CuPy, PyTorch, Dask, and JAX are supported. If you want support for other array
libraries, or if you encounter any issues, please [open an
NumPy, CuPy, PyTorch, Dask, JAX and `sparse` are supported. If you want support
for other array libraries, or if you encounter any issues, please [open an
issue](https://github.com/data-apis/array-api-compat/issues).

See the documentation for more details https://data-apis.org/array-api-compat/
65 changes: 65 additions & 0 deletions array_api_compat/common/_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ def is_numpy_array(x):
is_torch_array
is_dask_array
is_jax_array
is_pydata_sparse
"""
# Avoid importing NumPy if it isn't already
if 'numpy' not in sys.modules:
Expand Down Expand Up @@ -79,6 +80,7 @@ def is_cupy_array(x):
is_torch_array
is_dask_array
is_jax_array
is_pydata_sparse
"""
# Avoid importing NumPy if it isn't already
if 'cupy' not in sys.modules:
Expand All @@ -105,6 +107,7 @@ def is_torch_array(x):
is_cupy_array
is_dask_array
is_jax_array
is_pydata_sparse
"""
# Avoid importing torch if it isn't already
if 'torch' not in sys.modules:
Expand All @@ -131,6 +134,7 @@ def is_dask_array(x):
is_cupy_array
is_torch_array
is_jax_array
is_pydata_sparse
"""
# Avoid importing dask if it isn't already
if 'dask.array' not in sys.modules:
Expand All @@ -157,6 +161,7 @@ def is_jax_array(x):
is_cupy_array
is_torch_array
is_dask_array
is_pydata_sparse
"""
# Avoid importing jax if it isn't already
if 'jax' not in sys.modules:
Expand All @@ -166,6 +171,35 @@ def is_jax_array(x):

return isinstance(x, jax.Array) or _is_jax_zero_gradient_array(x)


def is_pydata_sparse(x) -> bool:
"""
Return True if `x` is an array from the `sparse` package.

This function does not import `sparse` if it has not already been imported
and is therefore cheap to use.


See Also
--------

array_namespace
is_array_api_obj
is_numpy_array
is_cupy_array
is_torch_array
is_dask_array
is_jax_array
"""
# Avoid importing jax if it isn't already
if 'sparse' not in sys.modules:
return False

import sparse

# TODO: Account for other backends.
return isinstance(x, sparse.SparseArray)

def is_array_api_obj(x):
"""
Return True if `x` is an array API compatible array object.
Expand All @@ -185,6 +219,7 @@ def is_array_api_obj(x):
or is_torch_array(x) \
or is_dask_array(x) \
or is_jax_array(x) \
or is_pydata_sparse(x) \
or hasattr(x, '__array_namespace__')

def _check_api_version(api_version):
Expand Down Expand Up @@ -253,6 +288,7 @@ def your_function(x, y):
is_torch_array
is_dask_array
is_jax_array
is_pydata_sparse

"""
if use_compat not in [None, True, False]:
Expand Down Expand Up @@ -312,6 +348,15 @@ def your_function(x, y):
# not have a wrapper submodule for it.
import jax.experimental.array_api as jnp
namespaces.add(jnp)
elif is_pydata_sparse(x):
if use_compat is True:
_check_api_version(api_version)
raise ValueError("`sparse` does not have an array-api-compat wrapper")
else:
import sparse
# `sparse` is already an array namespace. We do not have a wrapper
# submodule for it.
namespaces.add(sparse)
elif hasattr(x, '__array_namespace__'):
if use_compat is True:
raise ValueError("The given array does not have an array-api-compat wrapper")
Expand Down Expand Up @@ -406,8 +451,23 @@ def device(x: Array, /) -> Device:
return x.device()
else:
return x.device
elif is_pydata_sparse(x):
# `sparse` will gain `.device`, so check for this first.
x_device = getattr(x, 'device', None)
if x_device is not None:
return x_device
# Everything but DOK has this attr.
try:
inner = x.data
except AttributeError:
return "cpu"
# Return the device of the constituent array
return device(inner)
return x.device

# Prevent shadowing, used below
_device = device

# Based on cupy.array_api.Array.to_device
def _cupy_to_device(x, device, /, stream=None):
import cupy as cp
Expand Down Expand Up @@ -523,6 +583,10 @@ def to_device(x: Array, device: Device, /, *, stream: Optional[Union[int, Any]]
# This import adds to_device to x
import jax.experimental.array_api # noqa: F401
return x.to_device(device, stream=stream)
elif is_pydata_sparse(x) and device == _device(x):
# Perform trivial check to return the same array if
# device is same instead of err-ing.
return x
return x.to_device(device, stream=stream)

def size(x):
Expand All @@ -549,6 +613,7 @@ def size(x):
"is_jax_array",
"is_numpy_array",
"is_torch_array",
"is_pydata_sparse",
"size",
"to_device",
]
Expand Down
8 changes: 8 additions & 0 deletions docs/changelog.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,14 @@

## Major Changes

- Add support for `sparse`. Note that unlike other array libraries,
array-api-compat does not contain any wrappers for `sparse` functions. All
`sparse` array API support is in `sparse` itself. Thus, there is no
`array_api_compat.sparse` submodule, and
`array_namespace(<pydata/sparse array>)` returns the `sparse` module.

- Added the function `is_pydata_sparse(x)`.

- Drop support for Python 3.8.

- NumPy 2.0 is now left completely unwrapped.
Expand Down
3 changes: 3 additions & 0 deletions docs/supported-array-libraries.md
Original file line number Diff line number Diff line change
Expand Up @@ -132,3 +132,6 @@ For `linalg`, several methods are missing, for example:
Other methods may only be partially implemented or return incorrect results at times.

The minimum supported Dask version is 2023.12.0.

## [`sparse`](https://sparse.pydata.org/en/stable/)
Similar to JAX, `sparse` Array API support is contained directly in `sparse`.
1 change: 1 addition & 0 deletions requirements-dev.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,4 @@ jax[cpu]
numpy
pytest
torch
sparse >=0.15.1
9 changes: 8 additions & 1 deletion tests/_helpers.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,26 @@
from importlib import import_module
import sys

import pytest

wrapped_libraries = ["cupy", "torch", "dask.array"]
all_libraries = wrapped_libraries + ["numpy", "jax.numpy"]
all_libraries = wrapped_libraries + ["numpy", "jax.numpy", "sparse"]
import numpy as np
if np.__version__[0] == '1':
wrapped_libraries.append("numpy")

# `sparse` added array API support as of Python 3.10.
if sys.version_info >= (3, 10):
all_libraries.append('sparse')

def import_(library, wrapper=False):
if library == 'cupy':
pytest.importorskip(library)
if wrapper:
if 'jax' in library:
library = 'jax.experimental.array_api'
elif library.startswith('sparse'):
library = 'sparse'
else:
library = 'array_api_compat.' + library

Expand Down
2 changes: 1 addition & 1 deletion tests/test_array_namespace.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ def test_array_namespace(library, api_version, use_compat):
xp = import_(library)

array = xp.asarray([1.0, 2.0, 3.0])
if use_compat is True and library in ['array_api_strict', 'jax.numpy']:
if use_compat is True and library in {'array_api_strict', 'jax.numpy', 'sparse'}:
pytest.raises(ValueError, lambda: array_namespace(array, use_compat=use_compat))
return
namespace = array_api_compat.array_namespace(array, api_version=api_version, use_compat=use_compat)
Expand Down
5 changes: 4 additions & 1 deletion tests/test_common.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from array_api_compat import (is_numpy_array, is_cupy_array, is_torch_array, # noqa: F401
is_dask_array, is_jax_array)
is_dask_array, is_jax_array, is_pydata_sparse)

from array_api_compat import is_array_api_obj, device, to_device

Expand All @@ -16,6 +16,7 @@
'torch': 'is_torch_array',
'dask.array': 'is_dask_array',
'jax.numpy': 'is_jax_array',
'sparse': 'is_pydata_sparse',
}

@pytest.mark.parametrize('library', is_functions.keys())
Expand Down Expand Up @@ -76,6 +77,8 @@ def test_asarray_cross_library(source_library, target_library, request):
if source_library == "cupy" and target_library != "cupy":
# cupy explicitly disallows implicit conversions to CPU
pytest.skip(reason="cupy does not support implicit conversion to CPU")
elif source_library == "sparse" and target_library != "sparse":
pytest.skip(reason="`sparse` does not allow implicit densification")
src_lib = import_(source_library, wrapper=True)
tgt_lib = import_(target_library, wrapper=True)
is_tgt_type = globals()[is_functions[target_library]]
Expand Down
4 changes: 2 additions & 2 deletions tests/test_no_dependencies.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def _test_dependency(mod):

# array-api-strict is an example of an array API library that isn't
# wrapped by array-api-compat.
if "strict" not in mod:
if "strict" not in mod and mod != "sparse":
is_mod_array = getattr(array_api_compat, f"is_{mod.split('.')[0]}_array")
assert not is_mod_array(a)
assert mod not in sys.modules
Expand All @@ -50,7 +50,7 @@ def _test_dependency(mod):
# Y (except most array libraries actually do themselves depend on numpy).

@pytest.mark.parametrize("library", ["cupy", "numpy", "torch", "dask.array",
"jax.numpy", "array_api_strict"])
"jax.numpy", "sparse", "array_api_strict"])
def test_numpy_dependency(library):
# This import is here because it imports numpy
from ._helpers import import_
Expand Down
Loading