Skip to content

CuPy support #6

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 24 commits into from
Dec 5, 2022
Merged
Show file tree
Hide file tree
Changes from 22 commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
9f2afe8
Add a get_xp decorator to support multiple array namespaces
asmeurer Nov 15, 2022
e9c52c4
Add get_xp decorator to genericize the namespace for the aliases
asmeurer Nov 17, 2022
1751356
Require the namespace argument to asarray() when it's ambiguous
asmeurer Nov 17, 2022
b3a12d9
Move all the NumPy functionality into a numpy submodule
asmeurer Nov 18, 2022
cf4083a
Move the wrapper code into common/, and make linalg use @get_xp
asmeurer Nov 29, 2022
7333696
Move _typing.py from numpy/ to common/
asmeurer Nov 29, 2022
076848e
Add a cupy submodule
asmeurer Nov 29, 2022
ed5705f
Rename numpy_array_api_compat/ to array_api_compat/
asmeurer Nov 29, 2022
420c0da
Add __array_api_version__
asmeurer Nov 29, 2022
005852f
Remove library-specific stuff from common/_typing.py
asmeurer Nov 29, 2022
2912c9e
Refactor how get_xp works
asmeurer Nov 29, 2022
5775b11
Rename common.linalg to common._linalg
asmeurer Nov 29, 2022
936a8ad
Fix arange()
asmeurer Nov 29, 2022
d88f709
Fix cupy asarray to create cupy arrays instead of numpy arrays
asmeurer Nov 29, 2022
360ea18
Re-enable the signature fix in get_xp
asmeurer Nov 29, 2022
fece5e0
Fix some issues with the linalg wrapping
asmeurer Nov 29, 2022
c91360b
Export helpers to the top-level namespace
asmeurer Nov 30, 2022
d19c1a2
Fix full_like and linspace
asmeurer Nov 30, 2022
6c54b6b
Fix permute_dims
asmeurer Nov 30, 2022
e996d22
Add more information to the README
asmeurer Dec 1, 2022
82365eb
Add a test that vendoring works
asmeurer Dec 1, 2022
a83b15c
Fixes to the README
asmeurer Dec 1, 2022
8d2d37a
Fix missing sentence in the README
asmeurer Dec 5, 2022
732b493
Move vendor_test to the top-level
asmeurer Dec 5, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
163 changes: 156 additions & 7 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,18 +1,28 @@
# NumPy Array API compatibility library
# Array API compatibility library

This is a small wrapper around NumPy that is compatible with the [Array API
standard](https://data-apis.org/array-api/latest/). See also [NEP 47](https://numpy.org/neps/nep-0047-array-api-standard.html).
This is a small wrapper around NumPy and CuPy that is compatible with the
[Array API standard](https://data-apis.org/array-api/latest/). See also [NEP
47](https://numpy.org/neps/nep-0047-array-api-standard.html).

Unlike `numpy.array_api`, this is not a strict minimal implementation of the
Array API, but rather just an extension of the main NumPy namespace with
changes needed to be compliant with the Array API. See
https://numpy.org/doc/stable/reference/array_api.html for a full list of
Array API, but rather just an extension of the main NumPy and CuPy namespaces
with changes needed to be compliant with the Array API.

Library authors using the Array API may wish to test against `numpy.array_api`
to ensure they are not using functionality outside of the standard, but prefer
this implementation for the default when working with NumPy or CuPy arrays.

See https://numpy.org/doc/stable/reference/array_api.html for a full list of
changes. In particular, unlike `numpy.array_api`, this package does not use a
separate Array object, but rather just uses `numpy.ndarray` directly.

Note that some of the functionality in this library is backwards incompatible
with NumPy.

This library also supports CuPy in addition to NumPy. If you want support for
other array libraries, please [open an
issue](https://github.com/data-apis/array-api-compat/issues).

Library authors using the Array API may wish to test against `numpy.array_api`
to ensure they are not using functionality outside of the standard, but prefer
this implementation for end users who use NumPy arrays.
Expand All @@ -28,5 +38,144 @@ import numpy as np
with

```py
import numpy_array_api_compat as np
import array_api_compat.numpy as np
```

and replace

```py
import cupy as cp
```

with

```py
import array_api_compat.cupy as cp
```

Each will include all the functions from the normal NumPy/CuPy namespace,
except that functions that are part of the array API are wrapped so that they
have the correct array API behavior. In each case, the array object
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

incomplete sentence here - and it seems to start saying something that's fairly important.



## Helper Functions

In addition to the default NumPy/CuPy namespace and functions in the array API
specification, there are several helper functions
included that aren't part of the specification but which are useful for using
the array API:

- `is_array_api_obj(x)`: Return `True` if `x` is an array API compatible array
object.

- `get_namespace(*xs)`: Get the corresponding array API namespace for the
arrays `xs`. If the arrays are NumPy or CuPy arrays, the returned namespace
will be `array_api_compat.numpy` or `array_api_compat.cupy` so that it is
array API compatible.

- `device(x)`: Equivalent to
[`x.device`](https://data-apis.org/array-api/latest/API_specification/generated/signatures.array_object.array.device.html)
in the array API specification. Included because `numpy.ndarray` does not
include the `device` attribute and this library does not wrap or extend the
array object. Note that for NumPy, `device` is always `"cpu"`.

- `to_device(x, device, /, *, stream=None)`: Equivalent to
[`x.to_device`](https://data-apis.org/array-api/latest/API_specification/generated/signatures.array_object.array.to_device.html).
Included because neither NumPy's nor CuPy's ndarray objects include this
method. For NumPy, this function effectively does nothing since the only
supported device is the CPU, but for CuPy, this method supports CuPy CUDA
[Device](https://docs.cupy.dev/en/stable/reference/generated/cupy.cuda.Device.html)
and
[Stream](https://docs.cupy.dev/en/stable/reference/generated/cupy.cuda.Stream.html)
objects.

## Known Differences from the Array API Specification

There are some known differences between this library and the array API
specification:

- The array methods `__array_namespace__`, `device` (for NumPy), `to_device`,
and `mT` are not defined. This reuses `np.ndarray` and `cp.ndarray` and we
don't want to monkeypatch or wrap it. The helper functions `device()` and
`to_device()` are provided to work around these missing methods (see above).
`x.mT` can be replaced with `xp.linalg.matrix_transpose(x)`.
`get_namespace(x)` should be used instead of `x.__array_namespace__`.

- NumPy value-based casting for scalars will be in effect unless explicitly
disabled with the environment variable NPY_PROMOTION_STATE=weak or
np._set_promotion_state('weak') (requires NumPy 1.24 or newer, see NEP 50
and https://github.com/numpy/numpy/issues/22341)

- Functions which are not wrapped may not have the same type annotations
as the spec.

- Functions which are not wrapped may not use positional-only arguments.

## Vendoring

This library supports vendoring as an installation method. To vendor the
library, simply copy `array_api_compat` into the appropriate place in the
library, like

```
cp -R array_api_compat/ mylib/vendored/array_api_compat
```

You may also rename it to something else if you like (nowhere in the code
references the name "array_api_compat").

Alternatively, the library may be installed as dependency on PyPI.

## Implementation

As noted before, the goal of this library is to reuse the NumPy and CuPy array
objects, rather than wrapping or extending them. This means that the functions
need to accept and return `np.ndarray` for NumPy and `cp.ndarray` for CuPy.

Each namespace (`array_api_compat.numpy` and `array_api_compat.cupy`) is
populated with the normal library namespace (like `from numpy import *`). Then
specific functions are replaced with wrapped variants. Wrapped functions that
have the same logic between NumPy and CuPy (which is most functions) are in
`array_api_compat/common/`. These functions are defined like

```py
# In array_api_compat/common/_aliases.py

def acos(x, /, xp):
return xp.arccos(x)
```

The `xp` argument refers to the original array namespace (either `numpy` or
`cupy`). Then in the specific `array_api_compat/numpy` and
`array_api_compat/cupy` namespace, the `get_xp` decorator is applied to these
functions, which automatically removes the `xp` argument from the function
signature and replaces it with the corresponding array library, like

```py
# In array_api_compat/numpy/_aliases.py

from ..common import _aliases

import numpy as np

acos = get_xp(np)(_aliases.acos)
```

This `acos` now has the signature `acos(x, /)` and calls `numpy.arccos`.

Similarly, for CuPy:

```py
# In array_api_compat/cupy/_aliases.py

from ..common import _aliases

import cupy as cp

acos = get_xp(cp)(_aliases.acos)
```

Since NumPy and CuPy are nearly identical in their behaviors, this allows
writing the wrapping logic for both libraries only once. If support is added
for other libraries which differ significantly from NumPy, their wrapper code
should go in their specific sub-namespace instead of `common/`.
20 changes: 20 additions & 0 deletions array_api_compat/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
"""
NumPy Array API compatibility library

This is a small wrapper around NumPy and CuPy that is compatible with the
Array API standard https://data-apis.org/array-api/latest/. See also NEP 47
https://numpy.org/neps/nep-0047-array-api-standard.html.

Unlike numpy.array_api, this is not a strict minimal implementation of the
Array API, but rather just an extension of the main NumPy namespace with
changes needed to be compliant with the Array API. See
https://numpy.org/doc/stable/reference/array_api.html for a full list of
changes. In particular, unlike numpy.array_api, this package does not use a
separate Array object, but rather just uses numpy.ndarray directly.

Library authors using the Array API may wish to test against numpy.array_api
to ensure they are not using functionality outside of the standard, but prefer
this implementation for the default when working with NumPy arrays.

"""
from .common import *
43 changes: 43 additions & 0 deletions array_api_compat/_internal.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
"""
Internal helpers
"""

from functools import wraps
from inspect import signature

def get_xp(xp):
"""
Decorator to automatically replace xp with the corresponding array module.

Use like

import numpy as np

@get_xp(np)
def func(x, /, xp, kwarg=None):
return xp.func(x, kwarg=kwarg)

Note that xp must be a keyword argument and come after all non-keyword
arguments.

"""
def inner(f):
@wraps(f)
def wrapped_f(*args, **kwargs):
return f(*args, xp=xp, **kwargs)

sig = signature(f)
new_sig = sig.replace(parameters=[sig.parameters[i] for i in sig.parameters if i != 'xp'])

if wrapped_f.__doc__ is None:
wrapped_f.__doc__ = f"""\
Array API compatibility wrapper for {f.__name__}.

See the corresponding documentation in NumPy/CuPy and/or the array API
specification for more details.

"""
wrapped_f.__signature__ = new_sig
return wrapped_f

return inner
1 change: 1 addition & 0 deletions array_api_compat/common/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from ._helpers import *
Loading