Skip to content

Commit 2ee6902

Browse files
committed
Make is_numpy_array, is_cupy_array, is_torch_array, and is_dask_array public
1 parent 9cb5a13 commit 2ee6902

File tree

4 files changed

+52
-20
lines changed

4 files changed

+52
-20
lines changed

README.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,11 @@ part of the specification but which are useful for using the array API:
9999
- `is_array_api_obj(x)`: Return `True` if `x` is an array API compatible array
100100
object.
101101

102+
- `is_numpy_array(x)`, `is_cupy_array(x)`, `is_torch_array(x)`,
103+
`is_dask_array(x)`: return `True` if `x` is an array from the corresponding
104+
library. These functions do not import the underlying library if it has not
105+
already been imported, so they are cheap to use.
106+
102107
- `array_namespace(*xs)`: Get the corresponding array API namespace for the
103108
arrays `xs`. For example, if the arrays are NumPy arrays, the returned
104109
namespace will be `array_api_compat.numpy`. Note that this function will

array_api_compat/common/_aliases.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
from types import ModuleType
1414
import inspect
1515

16-
from ._helpers import _check_device, _is_numpy_array, array_namespace
16+
from ._helpers import _check_device, is_numpy_array, array_namespace
1717

1818
# These functions are modified from the NumPy versions.
1919

@@ -309,7 +309,7 @@ def _asarray(
309309
raise ValueError("Unrecognized namespace argument to asarray()")
310310

311311
_check_device(xp, device)
312-
if _is_numpy_array(obj):
312+
if is_numpy_array(obj):
313313
import numpy as np
314314
if hasattr(np, '_CopyMode'):
315315
# Not present in older NumPys

array_api_compat/common/_helpers.py

Lines changed: 20 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
import sys
1111
import math
1212

13-
def _is_numpy_array(x):
13+
def is_numpy_array(x):
1414
# Avoid importing NumPy if it isn't already
1515
if 'numpy' not in sys.modules:
1616
return False
@@ -20,7 +20,7 @@ def _is_numpy_array(x):
2020
# TODO: Should we reject ndarray subclasses?
2121
return isinstance(x, (np.ndarray, np.generic))
2222

23-
def _is_cupy_array(x):
23+
def is_cupy_array(x):
2424
# Avoid importing NumPy if it isn't already
2525
if 'cupy' not in sys.modules:
2626
return False
@@ -30,7 +30,7 @@ def _is_cupy_array(x):
3030
# TODO: Should we reject ndarray subclasses?
3131
return isinstance(x, (cp.ndarray, cp.generic))
3232

33-
def _is_torch_array(x):
33+
def is_torch_array(x):
3434
# Avoid importing torch if it isn't already
3535
if 'torch' not in sys.modules:
3636
return False
@@ -40,7 +40,7 @@ def _is_torch_array(x):
4040
# TODO: Should we reject ndarray subclasses?
4141
return isinstance(x, torch.Tensor)
4242

43-
def _is_dask_array(x):
43+
def is_dask_array(x):
4444
# Avoid importing dask if it isn't already
4545
if 'dask.array' not in sys.modules:
4646
return False
@@ -53,10 +53,10 @@ def is_array_api_obj(x):
5353
"""
5454
Check if x is an array API compatible array object.
5555
"""
56-
return _is_numpy_array(x) \
57-
or _is_cupy_array(x) \
58-
or _is_torch_array(x) \
59-
or _is_dask_array(x) \
56+
return is_numpy_array(x) \
57+
or is_cupy_array(x) \
58+
or is_torch_array(x) \
59+
or is_dask_array(x) \
6060
or hasattr(x, '__array_namespace__')
6161

6262
def _check_api_version(api_version):
@@ -81,31 +81,31 @@ def your_function(x, y):
8181
"""
8282
namespaces = set()
8383
for x in xs:
84-
if _is_numpy_array(x):
84+
if is_numpy_array(x):
8585
_check_api_version(api_version)
8686
if _use_compat:
8787
from .. import numpy as numpy_namespace
8888
namespaces.add(numpy_namespace)
8989
else:
9090
import numpy as np
9191
namespaces.add(np)
92-
elif _is_cupy_array(x):
92+
elif is_cupy_array(x):
9393
_check_api_version(api_version)
9494
if _use_compat:
9595
from .. import cupy as cupy_namespace
9696
namespaces.add(cupy_namespace)
9797
else:
9898
import cupy as cp
9999
namespaces.add(cp)
100-
elif _is_torch_array(x):
100+
elif is_torch_array(x):
101101
_check_api_version(api_version)
102102
if _use_compat:
103103
from .. import torch as torch_namespace
104104
namespaces.add(torch_namespace)
105105
else:
106106
import torch
107107
namespaces.add(torch)
108-
elif _is_dask_array(x):
108+
elif is_dask_array(x):
109109
_check_api_version(api_version)
110110
if _use_compat:
111111
from ..dask import array as dask_namespace
@@ -156,7 +156,7 @@ def device(x: "Array", /) -> "Device":
156156
out: device
157157
a ``device`` object (see the "Device Support" section of the array API specification).
158158
"""
159-
if _is_numpy_array(x):
159+
if is_numpy_array(x):
160160
return "cpu"
161161
return x.device
162162

@@ -225,18 +225,18 @@ def to_device(x: "Array", device: "Device", /, *, stream: "Optional[Union[int, A
225225
.. note::
226226
If ``stream`` is given, the copy operation should be enqueued on the provided ``stream``; otherwise, the copy operation should be enqueued on the default stream/queue. Whether the copy is performed synchronously or asynchronously is implementation-dependent. Accordingly, if synchronization is required to guarantee data safety, this must be clearly explained in a conforming library's documentation.
227227
"""
228-
if _is_numpy_array(x):
228+
if is_numpy_array(x):
229229
if stream is not None:
230230
raise ValueError("The stream argument to to_device() is not supported")
231231
if device == 'cpu':
232232
return x
233233
raise ValueError(f"Unsupported device {device!r}")
234-
elif _is_cupy_array(x):
234+
elif is_cupy_array(x):
235235
# cupy does not yet have to_device
236236
return _cupy_to_device(x, device, stream=stream)
237-
elif _is_torch_array(x):
237+
elif is_torch_array(x):
238238
return _torch_to_device(x, device, stream=stream)
239-
elif _is_dask_array(x):
239+
elif is_dask_array(x):
240240
if stream is not None:
241241
raise ValueError("The stream argument to to_device() is not supported")
242242
# TODO: What if our array is on the GPU already?
@@ -253,4 +253,6 @@ def size(x):
253253
return None
254254
return math.prod(x.shape)
255255

256-
__all__ = ['is_array_api_obj', 'array_namespace', 'get_namespace', 'device', 'to_device', 'size']
256+
__all__ = ['is_array_api_obj', 'array_namespace', 'get_namespace', 'device',
257+
'to_device', 'size', 'is_numpy_array', 'is_cupy_array',
258+
'is_torch_array', 'is_dask_array']

tests/test_helpers.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
from array_api_compat import (is_numpy_array, is_cupy_array, is_torch_array,
2+
is_dask_array, is_array_api_obj)
3+
4+
from ._helpers import import_
5+
6+
import pytest
7+
8+
is_functions = {
9+
'numpy': 'is_numpy_array',
10+
'cupy': 'is_cupy_array',
11+
'torch': 'is_torch_array',
12+
'dask.array': 'is_dask_array',
13+
}
14+
15+
@pytest.mark.parametrize('library', is_functions.keys())
16+
@pytest.mark.parametrize('func', is_functions.values())
17+
def test_is_xp_array(library, func):
18+
lib = import_(library)
19+
is_func = globals()[func]
20+
21+
x = lib.asarray([1, 2, 3])
22+
23+
assert is_func(x) == (func == is_functions[library])
24+
25+
assert is_array_api_obj(x)

0 commit comments

Comments
 (0)