Skip to content

Commit 078c823

Browse files
committed
Add device() and to_device() helper functions
1 parent 547f007 commit 078c823

File tree

2 files changed

+73
-2
lines changed

2 files changed

+73
-2
lines changed

numpy_array_api_compat/__init__.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,18 +16,26 @@
1616
to ensure they are not using functionality outside of the standard, but prefer
1717
this implementation for the default when working with NumPy arrays.
1818
19+
In addition, several helper functions are provided in this library which are
20+
not part of the array API specification but which are useful for libraries
21+
writing against the array API specification who wish to support NumPy and
22+
other array API compatible libraries.
23+
1924
Known differences from the Array API spec:
2025
2126
- The array methods __array_namespace__, device, to_device, and mT are not
2227
defined. This reuses np.ndarray and we don't want to monkeypatch or wrap it.
28+
The helper functions device() and to_device() are provided to work around
29+
these missing methods. x.mT can be replaced with
30+
xp.linalg.matrix_transpose(x).
2331
2432
- NumPy value-based casting for scalars will be in effect unless explicitly
2533
disabled with the environment variable NPY_PROMOTION_STATE=weak or
2634
np._set_promotion_state('weak') (requires NumPy 1.24 or newer, see NEP 50
2735
and https://github.com/numpy/numpy/issues/22341)
2836
29-
- NumPy functions which are not wrapped may not have the same type aliases as
30-
the spec.
37+
- NumPy functions which are not wrapped may not have the same type annotations
38+
as the spec.
3139
3240
- NumPy functions which are not wrapped may not use positional-only arguments.
3341
@@ -46,3 +54,5 @@
4654
import numpy_array_api_compat.linalg
4755

4856
from .linalg import matrix_transpose, vecdot
57+
58+
from .helpers import *

numpy_array_api_compat/_helpers.py

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
"""
2+
Various helper functions which are not part of the spec.
3+
"""
4+
5+
from __future__ import annotations
6+
7+
import numpy as np
8+
9+
# device and to_device are not included in array object of this library
10+
# because this library just reuses ndarray without wrapping or subclassing it.
11+
# These helper functions can be used instead of the wrapper functions for
12+
# libraries that need to support both NumPy and other libraries that use devices.
13+
def device(x: "Array", /) -> "Device":
14+
"""
15+
Hardware device the array data resides on.
16+
17+
Parameters
18+
----------
19+
x: array
20+
array instance from NumPy or an array API compatible library.
21+
22+
Returns
23+
-------
24+
out: device
25+
a ``device`` object (see the "Device Support" section of the array API specification).
26+
"""
27+
if isinstance(x, np.ndarray):
28+
return "cpu"
29+
return x.device
30+
31+
def to_device(x: "Array", device: "Device", /, *, stream: Optional[Union[int, Any]] = None) -> "Array":
32+
"""
33+
Copy the array from the device on which it currently resides to the specified ``device``.
34+
35+
Parameters
36+
----------
37+
x: array
38+
array instance from NumPy or an array API compatible library.
39+
device: device
40+
a ``device`` object (see the "Device Support" section of the array API specification).
41+
stream: Optional[Union[int, Any]]
42+
stream object to use during copy. In addition to the types supported in ``array.__dlpack__``, implementations may choose to support any library-specific stream object with the caveat that any code using such an object would not be portable.
43+
44+
Returns
45+
-------
46+
out: array
47+
an array with the same data and data type as ``x`` and located on the specified ``device``.
48+
49+
.. note::
50+
If ``stream`` is given, the copy operation should be enqueued on the provided ``stream``; otherwise, the copy operation should be enqueued on the default stream/queue. Whether the copy is performed synchronously or asynchronously is implementation-dependent. Accordingly, if synchronization is required to guarantee data safety, this must be clearly explained in a conforming library's documentation.
51+
"""
52+
if isinstance(x, np.ndarray):
53+
if stream is not None:
54+
raise ValueError("The stream argument to to_device() is not supported")
55+
if device == 'cpu':
56+
return x
57+
raise ValueError(f"Unsupported device {device!r}")
58+
59+
return x.to_device(device, stream=stream)
60+
61+
__all__ = ['device', 'to_device']

0 commit comments

Comments
 (0)