Skip to content

Commit 8b9e0c0

Browse files
authored
Merge pull request #134 from hameerabbasi/sparse-compat
Add `sparse` compatibility layer.
2 parents 376038e + b77472d commit 8b9e0c0

File tree

10 files changed

+95
-8
lines changed

10 files changed

+95
-8
lines changed

.github/workflows/tests.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ jobs:
3333
- name: Run Tests
3434
run: |
3535
if [[ "${{ matrix.numpy-version }}" == "1.21" || "${{ matrix.numpy-version }}" == "dev" ]]; then
36-
PYTEST_EXTRA=(-k "numpy and not jax and not torch and not dask")
36+
PYTEST_EXTRA=(-k "numpy and not jax and not torch and not dask and not sparse")
3737
fi
3838
pytest -v "${PYTEST_EXTRA[@]}"
3939

README.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,8 @@
22

33
This is a small wrapper around common array libraries that is compatible with
44
the [Array API standard](https://data-apis.org/array-api/latest/). Currently,
5-
NumPy, CuPy, PyTorch, Dask, and JAX are supported. If you want support for other array
6-
libraries, or if you encounter any issues, please [open an
5+
NumPy, CuPy, PyTorch, Dask, JAX and `sparse` are supported. If you want support
6+
for other array libraries, or if you encounter any issues, please [open an
77
issue](https://github.com/data-apis/array-api-compat/issues).
88

99
See the documentation for more details https://data-apis.org/array-api-compat/

array_api_compat/common/_helpers.py

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@ def is_numpy_array(x):
5050
is_torch_array
5151
is_dask_array
5252
is_jax_array
53+
is_pydata_sparse
5354
"""
5455
# Avoid importing NumPy if it isn't already
5556
if 'numpy' not in sys.modules:
@@ -79,6 +80,7 @@ def is_cupy_array(x):
7980
is_torch_array
8081
is_dask_array
8182
is_jax_array
83+
is_pydata_sparse
8284
"""
8385
# Avoid importing NumPy if it isn't already
8486
if 'cupy' not in sys.modules:
@@ -105,6 +107,7 @@ def is_torch_array(x):
105107
is_cupy_array
106108
is_dask_array
107109
is_jax_array
110+
is_pydata_sparse
108111
"""
109112
# Avoid importing torch if it isn't already
110113
if 'torch' not in sys.modules:
@@ -131,6 +134,7 @@ def is_dask_array(x):
131134
is_cupy_array
132135
is_torch_array
133136
is_jax_array
137+
is_pydata_sparse
134138
"""
135139
# Avoid importing dask if it isn't already
136140
if 'dask.array' not in sys.modules:
@@ -157,6 +161,7 @@ def is_jax_array(x):
157161
is_cupy_array
158162
is_torch_array
159163
is_dask_array
164+
is_pydata_sparse
160165
"""
161166
# Avoid importing jax if it isn't already
162167
if 'jax' not in sys.modules:
@@ -166,6 +171,35 @@ def is_jax_array(x):
166171

167172
return isinstance(x, jax.Array) or _is_jax_zero_gradient_array(x)
168173

174+
175+
def is_pydata_sparse(x) -> bool:
176+
"""
177+
Return True if `x` is an array from the `sparse` package.
178+
179+
This function does not import `sparse` if it has not already been imported
180+
and is therefore cheap to use.
181+
182+
183+
See Also
184+
--------
185+
186+
array_namespace
187+
is_array_api_obj
188+
is_numpy_array
189+
is_cupy_array
190+
is_torch_array
191+
is_dask_array
192+
is_jax_array
193+
"""
194+
# Avoid importing jax if it isn't already
195+
if 'sparse' not in sys.modules:
196+
return False
197+
198+
import sparse
199+
200+
# TODO: Account for other backends.
201+
return isinstance(x, sparse.SparseArray)
202+
169203
def is_array_api_obj(x):
170204
"""
171205
Return True if `x` is an array API compatible array object.
@@ -185,6 +219,7 @@ def is_array_api_obj(x):
185219
or is_torch_array(x) \
186220
or is_dask_array(x) \
187221
or is_jax_array(x) \
222+
or is_pydata_sparse(x) \
188223
or hasattr(x, '__array_namespace__')
189224

190225
def _check_api_version(api_version):
@@ -253,6 +288,7 @@ def your_function(x, y):
253288
is_torch_array
254289
is_dask_array
255290
is_jax_array
291+
is_pydata_sparse
256292
257293
"""
258294
if use_compat not in [None, True, False]:
@@ -312,6 +348,15 @@ def your_function(x, y):
312348
# not have a wrapper submodule for it.
313349
import jax.experimental.array_api as jnp
314350
namespaces.add(jnp)
351+
elif is_pydata_sparse(x):
352+
if use_compat is True:
353+
_check_api_version(api_version)
354+
raise ValueError("`sparse` does not have an array-api-compat wrapper")
355+
else:
356+
import sparse
357+
# `sparse` is already an array namespace. We do not have a wrapper
358+
# submodule for it.
359+
namespaces.add(sparse)
315360
elif hasattr(x, '__array_namespace__'):
316361
if use_compat is True:
317362
raise ValueError("The given array does not have an array-api-compat wrapper")
@@ -406,8 +451,23 @@ def device(x: Array, /) -> Device:
406451
return x.device()
407452
else:
408453
return x.device
454+
elif is_pydata_sparse(x):
455+
# `sparse` will gain `.device`, so check for this first.
456+
x_device = getattr(x, 'device', None)
457+
if x_device is not None:
458+
return x_device
459+
# Everything but DOK has this attr.
460+
try:
461+
inner = x.data
462+
except AttributeError:
463+
return "cpu"
464+
# Return the device of the constituent array
465+
return device(inner)
409466
return x.device
410467

468+
# Prevent shadowing, used below
469+
_device = device
470+
411471
# Based on cupy.array_api.Array.to_device
412472
def _cupy_to_device(x, device, /, stream=None):
413473
import cupy as cp
@@ -523,6 +583,10 @@ def to_device(x: Array, device: Device, /, *, stream: Optional[Union[int, Any]]
523583
# This import adds to_device to x
524584
import jax.experimental.array_api # noqa: F401
525585
return x.to_device(device, stream=stream)
586+
elif is_pydata_sparse(x) and device == _device(x):
587+
# Perform trivial check to return the same array if
588+
# device is same instead of err-ing.
589+
return x
526590
return x.to_device(device, stream=stream)
527591

528592
def size(x):
@@ -549,6 +613,7 @@ def size(x):
549613
"is_jax_array",
550614
"is_numpy_array",
551615
"is_torch_array",
616+
"is_pydata_sparse",
552617
"size",
553618
"to_device",
554619
]

docs/changelog.md

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,14 @@
44

55
## Major Changes
66

7+
- Add support for `sparse`. Note that unlike other array libraries,
8+
array-api-compat does not contain any wrappers for `sparse` functions. All
9+
`sparse` array API support is in `sparse` itself. Thus, there is no
10+
`array_api_compat.sparse` submodule, and
11+
`array_namespace(<pydata/sparse array>)` returns the `sparse` module.
12+
13+
- Added the function `is_pydata_sparse(x)`.
14+
715
- Drop support for Python 3.8.
816

917
- NumPy 2.0 is now left completely unwrapped.

docs/supported-array-libraries.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -132,3 +132,6 @@ For `linalg`, several methods are missing, for example:
132132
Other methods may only be partially implemented or return incorrect results at times.
133133

134134
The minimum supported Dask version is 2023.12.0.
135+
136+
## [`sparse`](https://sparse.pydata.org/en/stable/)
137+
Similar to JAX, `sparse` Array API support is contained directly in `sparse`.

requirements-dev.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,3 +4,4 @@ jax[cpu]
44
numpy
55
pytest
66
torch
7+
sparse >=0.15.1

tests/_helpers.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,26 @@
11
from importlib import import_module
2+
import sys
23

34
import pytest
45

56
wrapped_libraries = ["cupy", "torch", "dask.array"]
6-
all_libraries = wrapped_libraries + ["numpy", "jax.numpy"]
7+
all_libraries = wrapped_libraries + ["numpy", "jax.numpy", "sparse"]
78
import numpy as np
89
if np.__version__[0] == '1':
910
wrapped_libraries.append("numpy")
1011

12+
# `sparse` added array API support as of Python 3.10.
13+
if sys.version_info >= (3, 10):
14+
all_libraries.append('sparse')
15+
1116
def import_(library, wrapper=False):
1217
if library == 'cupy':
1318
pytest.importorskip(library)
1419
if wrapper:
1520
if 'jax' in library:
1621
library = 'jax.experimental.array_api'
22+
elif library.startswith('sparse'):
23+
library = 'sparse'
1724
else:
1825
library = 'array_api_compat.' + library
1926

tests/test_array_namespace.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ def test_array_namespace(library, api_version, use_compat):
1919
xp = import_(library)
2020

2121
array = xp.asarray([1.0, 2.0, 3.0])
22-
if use_compat is True and library in ['array_api_strict', 'jax.numpy']:
22+
if use_compat is True and library in {'array_api_strict', 'jax.numpy', 'sparse'}:
2323
pytest.raises(ValueError, lambda: array_namespace(array, use_compat=use_compat))
2424
return
2525
namespace = array_api_compat.array_namespace(array, api_version=api_version, use_compat=use_compat)

tests/test_common.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from array_api_compat import (is_numpy_array, is_cupy_array, is_torch_array, # noqa: F401
2-
is_dask_array, is_jax_array)
2+
is_dask_array, is_jax_array, is_pydata_sparse)
33

44
from array_api_compat import is_array_api_obj, device, to_device
55

@@ -16,6 +16,7 @@
1616
'torch': 'is_torch_array',
1717
'dask.array': 'is_dask_array',
1818
'jax.numpy': 'is_jax_array',
19+
'sparse': 'is_pydata_sparse',
1920
}
2021

2122
@pytest.mark.parametrize('library', is_functions.keys())
@@ -76,6 +77,8 @@ def test_asarray_cross_library(source_library, target_library, request):
7677
if source_library == "cupy" and target_library != "cupy":
7778
# cupy explicitly disallows implicit conversions to CPU
7879
pytest.skip(reason="cupy does not support implicit conversion to CPU")
80+
elif source_library == "sparse" and target_library != "sparse":
81+
pytest.skip(reason="`sparse` does not allow implicit densification")
7982
src_lib = import_(source_library, wrapper=True)
8083
tgt_lib = import_(target_library, wrapper=True)
8184
is_tgt_type = globals()[is_functions[target_library]]

tests/test_no_dependencies.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ def _test_dependency(mod):
3333

3434
# array-api-strict is an example of an array API library that isn't
3535
# wrapped by array-api-compat.
36-
if "strict" not in mod:
36+
if "strict" not in mod and mod != "sparse":
3737
is_mod_array = getattr(array_api_compat, f"is_{mod.split('.')[0]}_array")
3838
assert not is_mod_array(a)
3939
assert mod not in sys.modules
@@ -50,7 +50,7 @@ def _test_dependency(mod):
5050
# Y (except most array libraries actually do themselves depend on numpy).
5151

5252
@pytest.mark.parametrize("library", ["cupy", "numpy", "torch", "dask.array",
53-
"jax.numpy", "array_api_strict"])
53+
"jax.numpy", "sparse", "array_api_strict"])
5454
def test_numpy_dependency(library):
5555
# This import is here because it imports numpy
5656
from ._helpers import import_

0 commit comments

Comments
 (0)