Skip to content

Commit fd798e1

Browse files
committed
Add is_array_obj and get_namespace helpers
1 parent 53557d4 commit fd798e1

File tree

1 file changed

+42
-3
lines changed

1 file changed

+42
-3
lines changed

numpy_array_api_compat/_helpers.py

Lines changed: 42 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,47 @@
44

55
from __future__ import annotations
66

7+
import importlib
8+
compat_namespace = importlib.import_module(__package__)
9+
710
import numpy as np
811

12+
def _is_numpy_array(x):
13+
# TODO: Should we reject ndarray subclasses?
14+
return isinstance(x, (np.ndarray, np.generic))
15+
16+
def is_array_api_obj(x):
17+
"""
18+
Check if x is an array API compatible array object.
19+
"""
20+
return _is_numpy_array(x) or hasattr(x, '__array_namespace__')
21+
22+
def get_namespace(*xs):
23+
"""
24+
Get the array API compatible namespace for the arrays `xs`.
25+
26+
`xs` should contain one or more arrays.
27+
"""
28+
namespaces = set()
29+
for x in xs:
30+
if hasattr(x, '__array_namespace__'):
31+
namespaces.add(x.__array_namespace__)
32+
elif _is_numpy_array(x):
33+
namespaces.add(compat_namespace)
34+
else:
35+
# TODO: Support Python scalars?
36+
raise ValueError("The input is not a supported array type")
37+
38+
if not namespaces:
39+
raise ValueError("Unrecognized array input")
40+
41+
if len(namespaces) != 1:
42+
raise ValueError(f"Multiple namespaces for array inputs: {namespaces}")
43+
44+
xp, = namespaces
45+
46+
return xp
47+
948
# device and to_device are not included in array object of this library
1049
# because this library just reuses ndarray without wrapping or subclassing it.
1150
# These helper functions can be used instead of the wrapper functions for
@@ -24,7 +63,7 @@ def device(x: "Array", /) -> "Device":
2463
out: device
2564
a ``device`` object (see the "Device Support" section of the array API specification).
2665
"""
27-
if isinstance(x, np.ndarray):
66+
if _is_numpy_array(x):
2867
return "cpu"
2968
return x.device
3069

@@ -49,7 +88,7 @@ def to_device(x: "Array", device: "Device", /, *, stream: Optional[Union[int, An
4988
.. note::
5089
If ``stream`` is given, the copy operation should be enqueued on the provided ``stream``; otherwise, the copy operation should be enqueued on the default stream/queue. Whether the copy is performed synchronously or asynchronously is implementation-dependent. Accordingly, if synchronization is required to guarantee data safety, this must be clearly explained in a conforming library's documentation.
5190
"""
52-
if isinstance(x, np.ndarray):
91+
if _is_numpy_array(x):
5392
if stream is not None:
5493
raise ValueError("The stream argument to to_device() is not supported")
5594
if device == 'cpu':
@@ -58,4 +97,4 @@ def to_device(x: "Array", device: "Device", /, *, stream: Optional[Union[int, An
5897

5998
return x.to_device(device, stream=stream)
6099

61-
__all__ = ['device', 'to_device']
100+
__all__ = ['is_array_api_obj', 'get_namespace', 'device', 'to_device']

0 commit comments

Comments
 (0)