From 9f2afe84eb1e88f01c53bbd8027a12e19510a17c Mon Sep 17 00:00:00 2001 From: Aaron Meurer Date: Mon, 14 Nov 2022 17:49:10 -0700 Subject: [PATCH 01/24] Add a get_xp decorator to support multiple array namespaces --- numpy_array_api_compat/_helpers.py | 7 +++++-- numpy_array_api_compat/_internal.py | 32 +++++++++++++++++++++++++++++ 2 files changed, 37 insertions(+), 2 deletions(-) create mode 100644 numpy_array_api_compat/_internal.py diff --git a/numpy_array_api_compat/_helpers.py b/numpy_array_api_compat/_helpers.py index 7b789bb6..b499bc91 100644 --- a/numpy_array_api_compat/_helpers.py +++ b/numpy_array_api_compat/_helpers.py @@ -19,7 +19,7 @@ def is_array_api_obj(x): """ return _is_numpy_array(x) or hasattr(x, '__array_namespace__') -def get_namespace(*xs): +def get_namespace(*xs, _use_compat=True): """ Get the array API compatible namespace for the arrays `xs`. @@ -30,7 +30,10 @@ def get_namespace(*xs): if hasattr(x, '__array_namespace__'): namespaces.add(x.__array_namespace__) elif _is_numpy_array(x): - namespaces.add(compat_namespace) + if _use_compat: + namespaces.add(compat_namespace) + else: + namespaces.add(np) else: # TODO: Support Python scalars? raise ValueError("The input is not a supported array type") diff --git a/numpy_array_api_compat/_internal.py b/numpy_array_api_compat/_internal.py new file mode 100644 index 00000000..eb78ded4 --- /dev/null +++ b/numpy_array_api_compat/_internal.py @@ -0,0 +1,32 @@ +""" +Internal helpers +""" + +from functools import wraps +from inspect import signature + +from ._helpers import get_namespace + +def get_xp(f): + """ + Decorator to automatically replace xp with the corresponding array module + + Use like + + @get_xp + def func(x, /, xp, kwarg=None): + return xp.func(x, kwarg=kwarg) + + Note that xp must be able to be passed as a keyword argument. + """ + @wraps(f) + def inner(*args, **kwargs): + xp = get_namespace(*args, _use_compat=False) + return f(*args, xp=xp, **kwargs) + + sig = signature(f) + new_sig = sig.replace(parameters=[sig.parameters[i] for i in sig.parameters if i != 'xp']) + + inner.__signature__ = new_sig + + return inner From e9c52c46934ebf582e938bd5d53ccba255393233 Mon Sep 17 00:00:00 2001 From: Aaron Meurer Date: Wed, 16 Nov 2022 22:55:12 -0700 Subject: [PATCH 02/24] Add get_xp decorator to genericize the namespace for the aliases I still need to factor out uses of 'np' in other places in the code. I'm also not completely sure about the solution for asarray(). asarray() can't just infer the namespace from the input like every other function can. For now, I have added a namespace keyword argument to it, which defaults to numpy. --- numpy_array_api_compat/_aliases.py | 248 +++++++++++++++++++--------- numpy_array_api_compat/_helpers.py | 4 +- numpy_array_api_compat/_internal.py | 8 + 3 files changed, 183 insertions(+), 77 deletions(-) diff --git a/numpy_array_api_compat/_aliases.py b/numpy_array_api_compat/_aliases.py index 0223c068..0e0810c4 100644 --- a/numpy_array_api_compat/_aliases.py +++ b/numpy_array_api_compat/_aliases.py @@ -6,27 +6,67 @@ from typing import TYPE_CHECKING if TYPE_CHECKING: - from typing import Optional, Tuple, Union + from typing import Optional, Tuple, Union, List from ._typing import ndarray, Device, Dtype, NestedSequence, SupportsBufferProtocol from typing import NamedTuple +from types import ModuleType -import numpy as np +from ._helpers import _is_numpy_array +from ._internal import get_xp # Basic renames -acos = np.arccos -acosh = np.arccosh -asin = np.arcsin -asinh = np.arcsinh -atan = np.arctan -atan2 = np.arctan2 -atanh = np.arctanh -bitwise_left_shift = np.left_shift -bitwise_invert = np.invert -bitwise_right_shift = np.right_shift -bool = np.bool_ -concat = np.concatenate -pow = np.power +@get_xp +def acos(x, /, xp): + return xp.arccos(x) + +@get_xp +def acosh(x, /, xp): + return xp.arccosh(x) + +@get_xp +def asin(x, /, xp): + return xp.arcsin(x) + +@get_xp +def asinh(x, /, xp): + return xp.arcsinh(x) + +@get_xp +def atan(x, /, xp): + return xp.arctan(x) + +@get_xp +def atan2(x1, x2, /, xp): + return xp.arctan2(x1, x2) + +@get_xp +def atanh(x, /, xp): + return xp.arctanh(x) + +@get_xp +def bitwise_left_shift(x1, x2, /, xp): + return xp.left_shift(x1, x2) + +@get_xp +def bitwise_invert(x, /, xp): + return xp.invert(x) + +@get_xp +def bitwise_right_shift(x1, x2, /, xp): + return xp.right_shift(x1, x2) + +@get_xp +def bool(x, /, xp): + return xp.bool_(x) + +@get_xp +def concat(arrays: Union[Tuple[ndarray, ...], List[ndarray]], /, xp, *, axis: Optional[int] = 0) -> ndarray: + return xp.concatenate(arrays, axis=axis) + +@get_xp +def pow(x1, x2, /, xp): + return xp.power(x1, x2) # These functions are modified from the NumPy versions. @@ -53,8 +93,9 @@ class UniqueInverseResult(NamedTuple): inverse_indices: ndarray -def unique_all(x: ndarray, /) -> UniqueAllResult: - values, indices, inverse_indices, counts = np.unique( +@get_xp +def unique_all(x: ndarray, /, xp) -> UniqueAllResult: + values, indices, inverse_indices, counts = xp.unique( x, return_counts=True, return_index=True, @@ -72,8 +113,9 @@ def unique_all(x: ndarray, /) -> UniqueAllResult: ) -def unique_counts(x: ndarray, /) -> UniqueCountsResult: - res = np.unique( +@get_xp +def unique_counts(x: ndarray, /, xp) -> UniqueCountsResult: + res = xp.unique( x, return_counts=True, return_index=False, @@ -84,22 +126,24 @@ def unique_counts(x: ndarray, /) -> UniqueCountsResult: return UniqueCountsResult(*res) -def unique_inverse(x: ndarray, /) -> UniqueInverseResult: - values, inverse_indices = np.unique( +@get_xp +def unique_inverse(x: ndarray, /, xp) -> UniqueInverseResult: + values, inverse_indices = xp.unique( x, return_counts=False, return_index=False, return_inverse=True, equal_nan=False, ) - # np.unique() flattens inverse indices, but they need to share x's shape + # xp.unique() flattens inverse indices, but they need to share x's shape # See https://github.com/numpy/numpy/issues/20638 inverse_indices = inverse_indices.reshape(x.shape) return UniqueInverseResult(values, inverse_indices) -def unique_values(x: ndarray, /) -> ndarray: - return np.unique( +@get_xp +def unique_values(x: ndarray, /, xp) -> ndarray: + return xp.unique( x, return_counts=False, return_index=False, @@ -114,29 +158,34 @@ def astype(x: ndarray, dtype: Dtype, /, *, copy: bool = True) -> ndarray: # These functions have different keyword argument names +@get_xp def std( x: ndarray, /, + xp, *, axis: Optional[Union[int, Tuple[int, ...]]] = None, correction: Union[int, float] = 0.0, # correction instead of ddof keepdims: bool = False, ) -> ndarray: - return np.std(x, axis=axis, ddof=correction, keepdims=keepdims) + return xp.std(x, axis=axis, ddof=correction, keepdims=keepdims) +@get_xp def var( x: ndarray, /, + xp, *, axis: Optional[Union[int, Tuple[int, ...]]] = None, correction: Union[int, float] = 0.0, # correction instead of ddof keepdims: bool = False, ) -> ndarray: - return np.var(x, axis=axis, ddof=correction, keepdims=keepdims) + return xp.var(x, axis=axis, ddof=correction, keepdims=keepdims) # Unlike transpose(), the axes argument to permute_dims() is required. -def permute_dims(x: ndarray, /, axes: Tuple[int, ...]) -> ndarray: - return np.transpose(x, axes) +@get_xp +def permute_dims(x: ndarray, /, xp, axes: Tuple[int, ...]) -> ndarray: + return xp.transpose(x, axes) # Creation functions add the device keyword (which does nothing for NumPy) @@ -158,24 +207,44 @@ def asarray( *, dtype: Optional[Dtype] = None, device: Optional[Device] = None, - copy: Optional[Union[bool, np._CopyMode]] = None, + copy: "Optional[Union[bool, np._CopyMode]]" = None, + namespace = 'numpy', ) -> ndarray: + + if isinstance(namespace, ModuleType): + xp = namespace + elif namespace == 'numpy': + import numpy as xp + elif namespace == 'cupy': + import cupy as xp + else: + raise ValueError("Unrecognized namespace argument to asarray()") + _check_device(device) - if copy in (False, np._CopyMode.IF_NEEDED): - # copy=False is not yet implemented in np.asarray + if _is_numpy_array(obj): + import numpy as np + COPY_FALSE = (False, np._CopyMode.IF_NEEDED) + COPY_TRUE = (True, np._CopyMode.ALWAYS) + else: + COPY_FALSE = (False,) + COPY_TRUE = (True,) + if copy in COPY_FALSE: + # copy=False is not yet implemented in xp.asarray raise NotImplementedError("copy=False is not yet implemented") - if isinstance(obj, np.ndarray): + if isinstance(obj, xp.ndarray): if dtype is not None and obj.dtype != dtype: copy = True - if copy in (True, np._CopyMode.ALWAYS): - return np.array(obj, copy=True, dtype=dtype) + if copy in COPY_TRUE: + return xp.array(obj, copy=True, dtype=dtype) return obj - return np.asarray(obj, dtype=dtype) + return xp.asarray(obj, dtype=dtype) +@get_xp def arange( start: Union[int, float], /, + xp, stop: Optional[Union[int, float]] = None, step: Union[int, float] = 1, *, @@ -183,60 +252,71 @@ def arange( device: Optional[Device] = None, ) -> ndarray: _check_device(device) - return np.arange(start, stop=stop, step=step, dtype=dtype) + return xp.arange(start, stop=stop, step=step, dtype=dtype) +@get_xp def empty( shape: Union[int, Tuple[int, ...]], + xp, *, dtype: Optional[Dtype] = None, device: Optional[Device] = None, ) -> ndarray: _check_device(device) - return np.empty(shape, dtype=dtype) + return xp.empty(shape, dtype=dtype) +@get_xp def empty_like( - x: ndarray, /, *, dtype: Optional[Dtype] = None, device: Optional[Device] = None + x: ndarray, /, xp, *, dtype: Optional[Dtype] = None, device: Optional[Device] = None ) -> ndarray: _check_device(device) - return np.empty_like(x, dtype=dtype) + return xp.empty_like(x, dtype=dtype) +@get_xp def eye( n_rows: int, n_cols: Optional[int] = None, /, *, + xp, k: int = 0, dtype: Optional[Dtype] = None, device: Optional[Device] = None, ) -> ndarray: _check_device(device) - return np.eye(n_rows, M=n_cols, k=k, dtype=dtype) + return xp.eye(n_rows, M=n_cols, k=k, dtype=dtype) +@get_xp def full( shape: Union[int, Tuple[int, ...]], fill_value: Union[int, float], + xp, *, dtype: Optional[Dtype] = None, device: Optional[Device] = None, ) -> ndarray: _check_device(device) - return np.full(shape, fill_value, dtype=dtype) + return xp.full(shape, fill_value, dtype=dtype) +@get_xp def full_like( x: ndarray, /, + xp, fill_value: Union[int, float], *, dtype: Optional[Dtype] = None, device: Optional[Device] = None, ) -> ndarray: _check_device(device) - return np.full_like(x, fill_value, dtype=dtype) + return xp.full_like(x, fill_value, dtype=dtype) +@get_xp def linspace( start: Union[int, float], stop: Union[int, float], /, + xp, num: int, *, dtype: Optional[Dtype] = None, @@ -244,62 +324,70 @@ def linspace( endpoint: bool = True, ) -> ndarray: _check_device(device) - return np.linspace(start, stop, num, dtype=dtype, endpoint=endpoint) + return xp.linspace(start, stop, num, dtype=dtype, endpoint=endpoint) +@get_xp def ones( shape: Union[int, Tuple[int, ...]], + xp, *, dtype: Optional[Dtype] = None, device: Optional[Device] = None, ) -> ndarray: _check_device(device) - return np.ones(shape, dtype=dtype) + return xp.ones(shape, dtype=dtype) +@get_xp def ones_like( - x: ndarray, /, *, dtype: Optional[Dtype] = None, device: Optional[Device] = None + x: ndarray, /, xp, *, dtype: Optional[Dtype] = None, device: Optional[Device] = None ) -> ndarray: _check_device(device) - return np.ones_like(x, dtype=dtype) + return xp.ones_like(x, dtype=dtype) +@get_xp def zeros( shape: Union[int, Tuple[int, ...]], + xp, *, dtype: Optional[Dtype] = None, device: Optional[Device] = None, ) -> ndarray: _check_device(device) - return np.zeros(shape, dtype=dtype) + return xp.zeros(shape, dtype=dtype) +@get_xp def zeros_like( - x: ndarray, /, *, dtype: Optional[Dtype] = None, device: Optional[Device] = None + x: ndarray, /, xp, *, dtype: Optional[Dtype] = None, device: Optional[Device] = None ) -> ndarray: _check_device(device) - return np.zeros_like(x, dtype=dtype) + return xp.zeros_like(x, dtype=dtype) -# np.reshape calls the keyword argument 'newshape' instead of 'shape' -def reshape(x: ndarray, /, shape: Tuple[int, ...], copy: Optional[bool] = None) -> ndarray: +# xp.reshape calls the keyword argument 'newshape' instead of 'shape' +@get_xp +def reshape(x: ndarray, /, xp, shape: Tuple[int, ...], copy: Optional[bool] = None) -> ndarray: if copy is True: x = x.copy() elif copy is False: x.shape = shape return x - return np.reshape(x, shape) + return xp.reshape(x, shape) # The descending keyword is new in sort and argsort, and 'kind' replaced with # 'stable' +@get_xp def argsort( - x: ndarray, /, *, axis: int = -1, descending: bool = False, stable: bool = True + x: ndarray, /, xp, *, axis: int = -1, descending: bool = False, stable: bool = True ) -> ndarray: # Note: this keyword argument is different, and the default is different. kind = "stable" if stable else "quicksort" if not descending: - res = np.argsort(x, axis=axis, kind=kind) + res = xp.argsort(x, axis=axis, kind=kind) else: # As NumPy has no native descending sort, we imitate it here. Note that - # simply flipping the results of np.argsort(x, ...) would not + # simply flipping the results of xp.argsort(x, ...) would not # respect the relative order like it would in native descending sorts. - res = np.flip( - np.argsort(np.flip(x, axis=axis), axis=axis, kind=kind), + res = xp.flip( + xp.argsort(xp.flip(x, axis=axis), axis=axis, kind=kind), axis=axis, ) # Rely on flip()/argsort() to validate axis @@ -308,58 +396,66 @@ def argsort( res = max_i - res return res +@get_xp def sort( - x: ndarray, /, *, axis: int = -1, descending: bool = False, stable: bool = True + x: ndarray, /, xp, *, axis: int = -1, descending: bool = False, stable: bool = True ) -> ndarray: # Note: this keyword argument is different, and the default is different. kind = "stable" if stable else "quicksort" - res = np.sort(x, axis=axis, kind=kind) + res = xp.sort(x, axis=axis, kind=kind) if descending: - res = np.flip(res, axis=axis) + res = xp.flip(res, axis=axis) return res # sum() and prod() should always upcast when dtype=None +@get_xp def sum( x: ndarray, /, + xp, *, axis: Optional[Union[int, Tuple[int, ...]]] = None, dtype: Optional[Dtype] = None, keepdims: bool = False, ) -> ndarray: - # `np.sum` already upcasts integers, but not floats - if dtype is None and x.dtype == np.float32: - dtype = np.float64 - return np.sum(x, axis=axis, dtype=dtype, keepdims=keepdims) + # `xp.sum` already upcasts integers, but not floats + if dtype is None and x.dtype == xp.float32: + dtype = xp.float64 + return xp.sum(x, axis=axis, dtype=dtype, keepdims=keepdims) +@get_xp def prod( x: ndarray, /, + xp, *, axis: Optional[Union[int, Tuple[int, ...]]] = None, dtype: Optional[Dtype] = None, keepdims: bool = False, ) -> ndarray: - if dtype is None and x.dtype == np.float32: - dtype = np.float64 - return np.prod(x, dtype=dtype, axis=axis, keepdims=keepdims) + if dtype is None and x.dtype == xp.float32: + dtype = xp.float64 + return xp.prod(x, dtype=dtype, axis=axis, keepdims=keepdims) # ceil, floor, and trunc return integers for integer inputs -def ceil(x: ndarray, /) -> ndarray: - if np.issubdtype(x.dtype, np.integer): +@get_xp +def ceil(x: ndarray, /, xp) -> ndarray: + if xp.issubdtype(x.dtype, xp.integer): return x - return np.ceil(x) + return xp.ceil(x) -def floor(x: ndarray, /) -> ndarray: - if np.issubdtype(x.dtype, np.integer): +@get_xp +def floor(x: ndarray, /, xp) -> ndarray: + if xp.issubdtype(x.dtype, xp.integer): return x - return np.floor(x) + return xp.floor(x) -def trunc(x: ndarray, /) -> ndarray: - if np.issubdtype(x.dtype, np.integer): +@get_xp +def trunc(x: ndarray, /, xp) -> ndarray: + if xp.issubdtype(x.dtype, xp.integer): return x - return np.trunc(x) + return xp.trunc(x) # from numpy import * doesn't overwrite these builtin names from numpy import abs, max, min, round diff --git a/numpy_array_api_compat/_helpers.py b/numpy_array_api_compat/_helpers.py index b499bc91..ae3edadd 100644 --- a/numpy_array_api_compat/_helpers.py +++ b/numpy_array_api_compat/_helpers.py @@ -27,7 +27,9 @@ def get_namespace(*xs, _use_compat=True): """ namespaces = set() for x in xs: - if hasattr(x, '__array_namespace__'): + if isinstance(x, (tuple, list)): + namespaces.add(get_namespace(*x, _use_compat=_use_compat)) + elif hasattr(x, '__array_namespace__'): namespaces.add(x.__array_namespace__) elif _is_numpy_array(x): if _use_compat: diff --git a/numpy_array_api_compat/_internal.py b/numpy_array_api_compat/_internal.py index eb78ded4..6a63703a 100644 --- a/numpy_array_api_compat/_internal.py +++ b/numpy_array_api_compat/_internal.py @@ -27,6 +27,14 @@ def inner(*args, **kwargs): sig = signature(f) new_sig = sig.replace(parameters=[sig.parameters[i] for i in sig.parameters if i != 'xp']) + if inner.__doc__ is None: + inner.__doc__ = f"""\ +Array API compatibility wrapper for {f.__name__}. + +See the corresponding documentation in NumPy/CuPy and/or the array API +specification for more details. + +""" inner.__signature__ = new_sig return inner From 17513568dc0199364a33c46acaacea0ab1d1cdb0 Mon Sep 17 00:00:00 2001 From: Aaron Meurer Date: Wed, 16 Nov 2022 22:58:35 -0700 Subject: [PATCH 03/24] Require the namespace argument to asarray() when it's ambiguous --- numpy_array_api_compat/_aliases.py | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/numpy_array_api_compat/_aliases.py b/numpy_array_api_compat/_aliases.py index 0e0810c4..5c3b9fe6 100644 --- a/numpy_array_api_compat/_aliases.py +++ b/numpy_array_api_compat/_aliases.py @@ -12,7 +12,7 @@ from typing import NamedTuple from types import ModuleType -from ._helpers import _is_numpy_array +from ._helpers import _is_numpy_array, get_namespace from ._internal import get_xp # Basic renames @@ -208,10 +208,15 @@ def asarray( dtype: Optional[Dtype] = None, device: Optional[Device] = None, copy: "Optional[Union[bool, np._CopyMode]]" = None, - namespace = 'numpy', + namespace = None, ) -> ndarray: - - if isinstance(namespace, ModuleType): + if namespace is None: + try: + xp = get_namespace(obj, _use_compat=False) + except ValueError: + # TODO: What about lists of arrays? + raise ValueError("A namespace must be specified for asarray() with non-array input") + elif isinstance(namespace, ModuleType): xp = namespace elif namespace == 'numpy': import numpy as xp From b3a12d92e840be6ef0f87f568d68a4f58b55c7e7 Mon Sep 17 00:00:00 2001 From: Aaron Meurer Date: Thu, 17 Nov 2022 22:12:40 -0700 Subject: [PATCH 04/24] Move all the NumPy functionality into a numpy submodule There will also be a 'cupy' submodule. Common code will be factored out into the 'common' submodule. --- numpy_array_api_compat/__init__.py | 18 --------- numpy_array_api_compat/_internal.py | 2 +- numpy_array_api_compat/common/__init__.py | 0 numpy_array_api_compat/common/_helpers.py | 37 +++++++++++++++++++ numpy_array_api_compat/numpy/__init__.py | 20 ++++++++++ .../{ => numpy}/_aliases.py | 20 +++++----- .../{ => numpy}/_helpers.py | 36 +----------------- numpy_array_api_compat/{ => numpy}/_typing.py | 0 numpy_array_api_compat/{ => numpy}/linalg.py | 0 9 files changed, 70 insertions(+), 63 deletions(-) create mode 100644 numpy_array_api_compat/common/__init__.py create mode 100644 numpy_array_api_compat/common/_helpers.py create mode 100644 numpy_array_api_compat/numpy/__init__.py rename numpy_array_api_compat/{ => numpy}/_aliases.py (96%) rename numpy_array_api_compat/{ => numpy}/_helpers.py (73%) rename numpy_array_api_compat/{ => numpy}/_typing.py (100%) rename numpy_array_api_compat/{ => numpy}/linalg.py (100%) diff --git a/numpy_array_api_compat/__init__.py b/numpy_array_api_compat/__init__.py index 82f36e98..babe1d72 100644 --- a/numpy_array_api_compat/__init__.py +++ b/numpy_array_api_compat/__init__.py @@ -40,21 +40,3 @@ - NumPy functions which are not wrapped may not use positional-only arguments. """ - -from numpy import * - -# These imports may overwrite names from the import * above. -from ._aliases import * - -# Don't know why, but we have to do an absolute import to import linalg. If we -# instead do -# -# from . import linalg -# -# It doesn't overwrite np.linalg from above. The import is generated -# dynamically so that the library can be vendored. -__import__(__package__ + '.linalg') - -from .linalg import matrix_transpose, vecdot - -from ._helpers import * diff --git a/numpy_array_api_compat/_internal.py b/numpy_array_api_compat/_internal.py index 6a63703a..0448cd53 100644 --- a/numpy_array_api_compat/_internal.py +++ b/numpy_array_api_compat/_internal.py @@ -5,7 +5,7 @@ from functools import wraps from inspect import signature -from ._helpers import get_namespace +from .common._helpers import get_namespace def get_xp(f): """ diff --git a/numpy_array_api_compat/common/__init__.py b/numpy_array_api_compat/common/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/numpy_array_api_compat/common/_helpers.py b/numpy_array_api_compat/common/_helpers.py new file mode 100644 index 00000000..5090fd5d --- /dev/null +++ b/numpy_array_api_compat/common/_helpers.py @@ -0,0 +1,37 @@ +""" +Various helper functions which are not part of the spec. +""" +def get_namespace(*xs, _use_compat=True): + """ + Get the array API compatible namespace for the arrays `xs`. + + `xs` should contain one or more arrays. + """ + from ..numpy._helpers import _is_numpy_array + + namespaces = set() + for x in xs: + if isinstance(x, (tuple, list)): + namespaces.add(get_namespace(*x, _use_compat=_use_compat)) + elif hasattr(x, '__array_namespace__'): + namespaces.add(x.__array_namespace__) + elif _is_numpy_array(x): + if _use_compat: + from .. import numpy as numpy_namespace + namespaces.add(numpy_namespace) + else: + import numpy as np + namespaces.add(np) + else: + # TODO: Support Python scalars? + raise ValueError("The input is not a supported array type") + + if not namespaces: + raise ValueError("Unrecognized array input") + + if len(namespaces) != 1: + raise ValueError(f"Multiple namespaces for array inputs: {namespaces}") + + xp, = namespaces + + return xp diff --git a/numpy_array_api_compat/numpy/__init__.py b/numpy_array_api_compat/numpy/__init__.py new file mode 100644 index 00000000..46c000f9 --- /dev/null +++ b/numpy_array_api_compat/numpy/__init__.py @@ -0,0 +1,20 @@ +from numpy import * + +# from numpy import * doesn't overwrite these builtin names +from numpy import abs, max, min, round + +# These imports may overwrite names from the import * above. +from ._aliases import * + +# Don't know why, but we have to do an absolute import to import linalg. If we +# instead do +# +# from . import linalg +# +# It doesn't overwrite np.linalg from above. The import is generated +# dynamically so that the library can be vendored. +__import__(__package__ + '.linalg') + +from .linalg import matrix_transpose, vecdot + +from ._helpers import * diff --git a/numpy_array_api_compat/_aliases.py b/numpy_array_api_compat/numpy/_aliases.py similarity index 96% rename from numpy_array_api_compat/_aliases.py rename to numpy_array_api_compat/numpy/_aliases.py index 5c3b9fe6..aba23aad 100644 --- a/numpy_array_api_compat/_aliases.py +++ b/numpy_array_api_compat/numpy/_aliases.py @@ -9,11 +9,12 @@ from typing import Optional, Tuple, Union, List from ._typing import ndarray, Device, Dtype, NestedSequence, SupportsBufferProtocol +from functools import partial from typing import NamedTuple from types import ModuleType from ._helpers import _is_numpy_array, get_namespace -from ._internal import get_xp +from .._internal import get_xp # Basic renames @get_xp @@ -194,7 +195,7 @@ def _check_device(device): raise ValueError(f"Unsupported device {device!r}") # asarray also adds the copy keyword -def asarray( +def _asarray( obj: Union[ ndarray, bool, @@ -245,6 +246,8 @@ def asarray( return xp.asarray(obj, dtype=dtype) +asarray_numpy = partial(_asarray, namespace='numpy') + @get_xp def arange( start: Union[int, float], @@ -462,15 +465,12 @@ def trunc(x: ndarray, /, xp) -> ndarray: return x return xp.trunc(x) -# from numpy import * doesn't overwrite these builtin names -from numpy import abs, max, min, round - __all__ = ['acos', 'acosh', 'asin', 'asinh', 'atan', 'atan2', 'atanh', 'bitwise_left_shift', 'bitwise_invert', 'bitwise_right_shift', 'bool', 'concat', 'pow', 'UniqueAllResult', 'UniqueCountsResult', 'UniqueInverseResult', 'unique_all', 'unique_counts', - 'unique_inverse', 'unique_values', 'astype', 'abs', 'max', 'min', - 'round', 'std', 'var', 'permute_dims', 'asarray', 'arange', - 'empty', 'empty_like', 'eye', 'full', 'full_like', 'linspace', - 'ones', 'ones_like', 'zeros', 'zeros_like', 'reshape', 'argsort', - 'sort', 'sum', 'prod', 'ceil', 'floor', 'trunc'] + 'unique_inverse', 'unique_values', 'astype', 'std', 'var', + 'permute_dims', 'asarray_numpy', 'arange', 'empty', 'empty_like', + 'eye', 'full', 'full_like', 'linspace', 'ones', 'ones_like', + 'zeros', 'zeros_like', 'reshape', 'argsort', 'sort', 'sum', 'prod', + 'ceil', 'floor', 'trunc'] diff --git a/numpy_array_api_compat/_helpers.py b/numpy_array_api_compat/numpy/_helpers.py similarity index 73% rename from numpy_array_api_compat/_helpers.py rename to numpy_array_api_compat/numpy/_helpers.py index ae3edadd..b658f5b6 100644 --- a/numpy_array_api_compat/_helpers.py +++ b/numpy_array_api_compat/numpy/_helpers.py @@ -4,11 +4,10 @@ from __future__ import annotations -import importlib -compat_namespace = importlib.import_module(__package__) - import numpy as np +from ..common._helpers import get_namespace + def _is_numpy_array(x): # TODO: Should we reject ndarray subclasses? return isinstance(x, (np.ndarray, np.generic)) @@ -19,37 +18,6 @@ def is_array_api_obj(x): """ return _is_numpy_array(x) or hasattr(x, '__array_namespace__') -def get_namespace(*xs, _use_compat=True): - """ - Get the array API compatible namespace for the arrays `xs`. - - `xs` should contain one or more arrays. - """ - namespaces = set() - for x in xs: - if isinstance(x, (tuple, list)): - namespaces.add(get_namespace(*x, _use_compat=_use_compat)) - elif hasattr(x, '__array_namespace__'): - namespaces.add(x.__array_namespace__) - elif _is_numpy_array(x): - if _use_compat: - namespaces.add(compat_namespace) - else: - namespaces.add(np) - else: - # TODO: Support Python scalars? - raise ValueError("The input is not a supported array type") - - if not namespaces: - raise ValueError("Unrecognized array input") - - if len(namespaces) != 1: - raise ValueError(f"Multiple namespaces for array inputs: {namespaces}") - - xp, = namespaces - - return xp - # device and to_device are not included in array object of this library # because this library just reuses ndarray without wrapping or subclassing it. # These helper functions can be used instead of the wrapper functions for diff --git a/numpy_array_api_compat/_typing.py b/numpy_array_api_compat/numpy/_typing.py similarity index 100% rename from numpy_array_api_compat/_typing.py rename to numpy_array_api_compat/numpy/_typing.py diff --git a/numpy_array_api_compat/linalg.py b/numpy_array_api_compat/numpy/linalg.py similarity index 100% rename from numpy_array_api_compat/linalg.py rename to numpy_array_api_compat/numpy/linalg.py From cf4083a9ce365f6247c84ec63eac1f63252105b2 Mon Sep 17 00:00:00 2001 From: Aaron Meurer Date: Mon, 28 Nov 2022 17:21:42 -0700 Subject: [PATCH 05/24] Move the wrapper code into common/, and make linalg use @get_xp --- numpy_array_api_compat/common/_aliases.py | 480 ++++++++++++++++++++++ numpy_array_api_compat/common/linalg.py | 182 ++++++++ numpy_array_api_compat/numpy/_aliases.py | 479 +-------------------- numpy_array_api_compat/numpy/_helpers.py | 8 +- numpy_array_api_compat/numpy/linalg.py | 156 +------ 5 files changed, 679 insertions(+), 626 deletions(-) create mode 100644 numpy_array_api_compat/common/_aliases.py create mode 100644 numpy_array_api_compat/common/linalg.py diff --git a/numpy_array_api_compat/common/_aliases.py b/numpy_array_api_compat/common/_aliases.py new file mode 100644 index 00000000..50cb62cf --- /dev/null +++ b/numpy_array_api_compat/common/_aliases.py @@ -0,0 +1,480 @@ +""" +These are functions that are just aliases of existing functions in NumPy. +""" + +from __future__ import annotations + +from typing import TYPE_CHECKING +if TYPE_CHECKING: + from typing import Optional, Tuple, Union, List + from ._typing import ndarray, Device, Dtype, NestedSequence, SupportsBufferProtocol + +from typing import NamedTuple +from types import ModuleType + +from ..numpy._helpers import _is_numpy_array, get_namespace +from .._internal import get_xp + +# Basic renames +@get_xp +def acos(x, /, xp): + return xp.arccos(x) + +@get_xp +def acosh(x, /, xp): + return xp.arccosh(x) + +@get_xp +def asin(x, /, xp): + return xp.arcsin(x) + +@get_xp +def asinh(x, /, xp): + return xp.arcsinh(x) + +@get_xp +def atan(x, /, xp): + return xp.arctan(x) + +@get_xp +def atan2(x1, x2, /, xp): + return xp.arctan2(x1, x2) + +@get_xp +def atanh(x, /, xp): + return xp.arctanh(x) + +@get_xp +def bitwise_left_shift(x1, x2, /, xp): + return xp.left_shift(x1, x2) + +@get_xp +def bitwise_invert(x, /, xp): + return xp.invert(x) + +@get_xp +def bitwise_right_shift(x1, x2, /, xp): + return xp.right_shift(x1, x2) + +@get_xp +def bool(x, /, xp): + return xp.bool_(x) + +@get_xp +def concat(arrays: Union[Tuple[ndarray, ...], List[ndarray]], /, xp, *, axis: Optional[int] = 0) -> ndarray: + return xp.concatenate(arrays, axis=axis) + +@get_xp +def pow(x1, x2, /, xp): + return xp.power(x1, x2) + +# These functions are modified from the NumPy versions. + +# np.unique() is split into four functions in the array API: +# unique_all, unique_counts, unique_inverse, and unique_values (this is done +# to remove polymorphic return types). + +# The functions here return namedtuples (np.unique() returns a normal +# tuple). +class UniqueAllResult(NamedTuple): + values: ndarray + indices: ndarray + inverse_indices: ndarray + counts: ndarray + + +class UniqueCountsResult(NamedTuple): + values: ndarray + counts: ndarray + + +class UniqueInverseResult(NamedTuple): + values: ndarray + inverse_indices: ndarray + + +@get_xp +def unique_all(x: ndarray, /, xp) -> UniqueAllResult: + values, indices, inverse_indices, counts = xp.unique( + x, + return_counts=True, + return_index=True, + return_inverse=True, + equal_nan=False, + ) + # np.unique() flattens inverse indices, but they need to share x's shape + # See https://github.com/numpy/numpy/issues/20638 + inverse_indices = inverse_indices.reshape(x.shape) + return UniqueAllResult( + values, + indices, + inverse_indices, + counts, + ) + + +@get_xp +def unique_counts(x: ndarray, /, xp) -> UniqueCountsResult: + res = xp.unique( + x, + return_counts=True, + return_index=False, + return_inverse=False, + equal_nan=False, + ) + + return UniqueCountsResult(*res) + + +@get_xp +def unique_inverse(x: ndarray, /, xp) -> UniqueInverseResult: + values, inverse_indices = xp.unique( + x, + return_counts=False, + return_index=False, + return_inverse=True, + equal_nan=False, + ) + # xp.unique() flattens inverse indices, but they need to share x's shape + # See https://github.com/numpy/numpy/issues/20638 + inverse_indices = inverse_indices.reshape(x.shape) + return UniqueInverseResult(values, inverse_indices) + + +@get_xp +def unique_values(x: ndarray, /, xp) -> ndarray: + return xp.unique( + x, + return_counts=False, + return_index=False, + return_inverse=False, + equal_nan=False, + ) + +def astype(x: ndarray, dtype: Dtype, /, *, copy: bool = True) -> ndarray: + if not copy and dtype == x.dtype: + return x + return x.astype(dtype=dtype, copy=copy) + +# These functions have different keyword argument names + +@get_xp +def std( + x: ndarray, + /, + xp, + *, + axis: Optional[Union[int, Tuple[int, ...]]] = None, + correction: Union[int, float] = 0.0, # correction instead of ddof + keepdims: bool = False, +) -> ndarray: + return xp.std(x, axis=axis, ddof=correction, keepdims=keepdims) + +@get_xp +def var( + x: ndarray, + /, + xp, + *, + axis: Optional[Union[int, Tuple[int, ...]]] = None, + correction: Union[int, float] = 0.0, # correction instead of ddof + keepdims: bool = False, +) -> ndarray: + return xp.var(x, axis=axis, ddof=correction, keepdims=keepdims) + +# Unlike transpose(), the axes argument to permute_dims() is required. +@get_xp +def permute_dims(x: ndarray, /, xp, axes: Tuple[int, ...]) -> ndarray: + return xp.transpose(x, axes) + +# Creation functions add the device keyword (which does nothing for NumPy) + +def _check_device(device): + if device not in ["cpu", None]: + raise ValueError(f"Unsupported device {device!r}") + +# asarray also adds the copy keyword +def _asarray( + obj: Union[ + ndarray, + bool, + int, + float, + NestedSequence[bool | int | float], + SupportsBufferProtocol, + ], + /, + *, + dtype: Optional[Dtype] = None, + device: Optional[Device] = None, + copy: "Optional[Union[bool, np._CopyMode]]" = None, + namespace = None, +) -> ndarray: + """ + Array API compatibility wrapper for asarray(). + + See the corresponding documentation in NumPy/CuPy and/or the array API + specification for more details. + + """ + if namespace is None: + try: + xp = get_namespace(obj, _use_compat=False) + except ValueError: + # TODO: What about lists of arrays? + raise ValueError("A namespace must be specified for asarray() with non-array input") + elif isinstance(namespace, ModuleType): + xp = namespace + elif namespace == 'numpy': + import numpy as xp + elif namespace == 'cupy': + import cupy as xp + else: + raise ValueError("Unrecognized namespace argument to asarray()") + + _check_device(device) + if _is_numpy_array(obj): + import numpy as np + COPY_FALSE = (False, np._CopyMode.IF_NEEDED) + COPY_TRUE = (True, np._CopyMode.ALWAYS) + else: + COPY_FALSE = (False,) + COPY_TRUE = (True,) + if copy in COPY_FALSE: + # copy=False is not yet implemented in xp.asarray + raise NotImplementedError("copy=False is not yet implemented") + if isinstance(obj, xp.ndarray): + if dtype is not None and obj.dtype != dtype: + copy = True + if copy in COPY_TRUE: + return xp.array(obj, copy=True, dtype=dtype) + return obj + + return xp.asarray(obj, dtype=dtype) + +@get_xp +def arange( + start: Union[int, float], + /, + xp, + stop: Optional[Union[int, float]] = None, + step: Union[int, float] = 1, + *, + dtype: Optional[Dtype] = None, + device: Optional[Device] = None, +) -> ndarray: + _check_device(device) + return xp.arange(start, stop=stop, step=step, dtype=dtype) + +@get_xp +def empty( + shape: Union[int, Tuple[int, ...]], + xp, + *, + dtype: Optional[Dtype] = None, + device: Optional[Device] = None, +) -> ndarray: + _check_device(device) + return xp.empty(shape, dtype=dtype) + +@get_xp +def empty_like( + x: ndarray, /, xp, *, dtype: Optional[Dtype] = None, device: Optional[Device] = None +) -> ndarray: + _check_device(device) + return xp.empty_like(x, dtype=dtype) + +@get_xp +def eye( + n_rows: int, + n_cols: Optional[int] = None, + /, + *, + xp, + k: int = 0, + dtype: Optional[Dtype] = None, + device: Optional[Device] = None, +) -> ndarray: + _check_device(device) + return xp.eye(n_rows, M=n_cols, k=k, dtype=dtype) + +@get_xp +def full( + shape: Union[int, Tuple[int, ...]], + fill_value: Union[int, float], + xp, + *, + dtype: Optional[Dtype] = None, + device: Optional[Device] = None, +) -> ndarray: + _check_device(device) + return xp.full(shape, fill_value, dtype=dtype) + +@get_xp +def full_like( + x: ndarray, + /, + xp, + fill_value: Union[int, float], + *, + dtype: Optional[Dtype] = None, + device: Optional[Device] = None, +) -> ndarray: + _check_device(device) + return xp.full_like(x, fill_value, dtype=dtype) + +@get_xp +def linspace( + start: Union[int, float], + stop: Union[int, float], + /, + xp, + num: int, + *, + dtype: Optional[Dtype] = None, + device: Optional[Device] = None, + endpoint: bool = True, +) -> ndarray: + _check_device(device) + return xp.linspace(start, stop, num, dtype=dtype, endpoint=endpoint) + +@get_xp +def ones( + shape: Union[int, Tuple[int, ...]], + xp, + *, + dtype: Optional[Dtype] = None, + device: Optional[Device] = None, +) -> ndarray: + _check_device(device) + return xp.ones(shape, dtype=dtype) + +@get_xp +def ones_like( + x: ndarray, /, xp, *, dtype: Optional[Dtype] = None, device: Optional[Device] = None +) -> ndarray: + _check_device(device) + return xp.ones_like(x, dtype=dtype) + +@get_xp +def zeros( + shape: Union[int, Tuple[int, ...]], + xp, + *, + dtype: Optional[Dtype] = None, + device: Optional[Device] = None, +) -> ndarray: + _check_device(device) + return xp.zeros(shape, dtype=dtype) + +@get_xp +def zeros_like( + x: ndarray, /, xp, *, dtype: Optional[Dtype] = None, device: Optional[Device] = None +) -> ndarray: + _check_device(device) + return xp.zeros_like(x, dtype=dtype) + +# xp.reshape calls the keyword argument 'newshape' instead of 'shape' +@get_xp +def reshape(x: ndarray, /, xp, shape: Tuple[int, ...], copy: Optional[bool] = None) -> ndarray: + if copy is True: + x = x.copy() + elif copy is False: + x.shape = shape + return x + return xp.reshape(x, shape) + +# The descending keyword is new in sort and argsort, and 'kind' replaced with +# 'stable' +@get_xp +def argsort( + x: ndarray, /, xp, *, axis: int = -1, descending: bool = False, stable: bool = True +) -> ndarray: + # Note: this keyword argument is different, and the default is different. + kind = "stable" if stable else "quicksort" + if not descending: + res = xp.argsort(x, axis=axis, kind=kind) + else: + # As NumPy has no native descending sort, we imitate it here. Note that + # simply flipping the results of xp.argsort(x, ...) would not + # respect the relative order like it would in native descending sorts. + res = xp.flip( + xp.argsort(xp.flip(x, axis=axis), axis=axis, kind=kind), + axis=axis, + ) + # Rely on flip()/argsort() to validate axis + normalised_axis = axis if axis >= 0 else x.ndim + axis + max_i = x.shape[normalised_axis] - 1 + res = max_i - res + return res + +@get_xp +def sort( + x: ndarray, /, xp, *, axis: int = -1, descending: bool = False, stable: bool = True +) -> ndarray: + # Note: this keyword argument is different, and the default is different. + kind = "stable" if stable else "quicksort" + res = xp.sort(x, axis=axis, kind=kind) + if descending: + res = xp.flip(res, axis=axis) + return res + +# sum() and prod() should always upcast when dtype=None +@get_xp +def sum( + x: ndarray, + /, + xp, + *, + axis: Optional[Union[int, Tuple[int, ...]]] = None, + dtype: Optional[Dtype] = None, + keepdims: bool = False, +) -> ndarray: + # `xp.sum` already upcasts integers, but not floats + if dtype is None and x.dtype == xp.float32: + dtype = xp.float64 + return xp.sum(x, axis=axis, dtype=dtype, keepdims=keepdims) + +@get_xp +def prod( + x: ndarray, + /, + xp, + *, + axis: Optional[Union[int, Tuple[int, ...]]] = None, + dtype: Optional[Dtype] = None, + keepdims: bool = False, +) -> ndarray: + if dtype is None and x.dtype == xp.float32: + dtype = xp.float64 + return xp.prod(x, dtype=dtype, axis=axis, keepdims=keepdims) + +# ceil, floor, and trunc return integers for integer inputs + +@get_xp +def ceil(x: ndarray, /, xp) -> ndarray: + if xp.issubdtype(x.dtype, xp.integer): + return x + return xp.ceil(x) + +@get_xp +def floor(x: ndarray, /, xp) -> ndarray: + if xp.issubdtype(x.dtype, xp.integer): + return x + return xp.floor(x) + +@get_xp +def trunc(x: ndarray, /, xp) -> ndarray: + if xp.issubdtype(x.dtype, xp.integer): + return x + return xp.trunc(x) + +__all__ = ['acos', 'acosh', 'asin', 'asinh', 'atan', 'atan2', 'atanh', + 'bitwise_left_shift', 'bitwise_invert', 'bitwise_right_shift', + 'bool', 'concat', 'pow', 'UniqueAllResult', 'UniqueCountsResult', + 'UniqueInverseResult', 'unique_all', 'unique_counts', + 'unique_inverse', 'unique_values', 'astype', 'std', 'var', + 'permute_dims', 'arange', 'empty', 'empty_like', 'eye', 'full', + 'full_like', 'linspace', 'ones', 'ones_like', 'zeros', + 'zeros_like', 'reshape', 'argsort', 'sort', 'sum', 'prod', 'ceil', + 'floor', 'trunc'] diff --git a/numpy_array_api_compat/common/linalg.py b/numpy_array_api_compat/common/linalg.py new file mode 100644 index 00000000..e2a2c478 --- /dev/null +++ b/numpy_array_api_compat/common/linalg.py @@ -0,0 +1,182 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, NamedTuple +if TYPE_CHECKING: + from typing import Literal, Optional, Sequence, Tuple, Union + from ._typing import ndarray + +from numpy.core.numeric import normalize_axis_tuple + +from .._internal import get_xp + +# These are in the main NumPy namespace but not in numpy.linalg +@get_xp +def cross(x1: ndarray, x2: ndarray, /, xp, *, axis: int = -1) -> ndarray: + return xp.cross(x1, x2, axis=axis) + +@get_xp +def matmul(x1: ndarray, x2: ndarray, /, xp) -> ndarray: + return xp.matmul(x1, x2) + +@get_xp +def outer(x1: ndarray, x2: ndarray, /, xp) -> ndarray: + return xp.outer(x1, x2) + +@get_xp +def tensordot(x1: ndarray, x2: ndarray, /, xp, *, axes: Union[int, Tuple[Sequence[int], Sequence[int]]] = 2) -> ndarray: + return xp.tensordot(x1, x2, axes=axes) + +class EighResult(NamedTuple): + eigenvalues: ndarray + eigenvectors: ndarray + +class QRResult(NamedTuple): + Q: ndarray + R: ndarray + +class SlogdetResult(NamedTuple): + sign: ndarray + logabsdet: ndarray + +class SVDResult(NamedTuple): + U: ndarray + S: ndarray + Vh: ndarray + +# These functions are the same as their NumPy counterparts except they return +# a namedtuple. +@get_xp +def eigh(x: ndarray, /, xp) -> EighResult: + return EighResult(*xp.linalg.eigh(x)) + +@get_xp +def qr(x: ndarray, /, xp, *, mode: Literal['reduced', 'complete'] = 'reduced') -> QRResult: + return QRResult(*xp.linalg.qr(x, mode=mode)) + +@get_xp +def slogdet(x: ndarray, /, xp) -> SlogdetResult: + return SlogdetResult(*xp.linalg.slogdet(x)) + +@get_xp +def svd(x: ndarray, /, xp, *, full_matrices: bool = True) -> SVDResult: + return SVDResult(*xp.linalg.svd(x, full_matrices=full_matrices)) + +# These functions have additional keyword arguments + +# The upper keyword argument is new from NumPy +@get_xp +def cholesky(x: ndarray, /, xp, *, upper: bool = False) -> ndarray: + L = xp.linalg.cholesky(x) + if upper: + return matrix_transpose(L) + return L + +# The rtol keyword argument of matrix_rank() and pinv() is new from NumPy. +# Note that it has a different semantic meaning from tol and rcond. +@get_xp +def matrix_rank(x: ndarray, /, xp, *, rtol: Optional[Union[float, ndarray]] = None) -> ndarray: + # this is different from xp.linalg.matrix_rank, which supports 1 + # dimensional arrays. + if x.ndim < 2: + raise xp.linalg.LinAlgError("1-dimensional array given. Array must be at least two-dimensional") + S = xp.linalg.svd(x, compute_uv=False) + if rtol is None: + tol = S.max(axis=-1, keepdims=True) * max(x.shape[-2:]) * xp.finfo(S.dtype).eps + else: + # this is different from xp.linalg.matrix_rank, which does not + # multiply the tolerance by the largest singular value. + tol = S.max(axis=-1, keepdims=True)*xp.asarray(rtol)[..., xp.newaxis] + return xp.count_nonzero(S > tol, axis=-1) + +@get_xp +def pinv(x: ndarray, /, xp, *, rtol: Optional[Union[float, ndarray]] = None) -> ndarray: + # this is different from xp.linalg.pinv, which does not multiply the + # default tolerance by max(M, N). + if rtol is None: + rtol = max(x.shape[-2:]) * xp.finfo(x.dtype).eps + return xp.linalg.pinv(x, rcond=rtol) + +# These functions are new in the array API spec + +@get_xp +def matrix_norm(x: ndarray, /, xp, *, keepdims: bool = False, ord: Optional[Union[int, float, Literal['fro', 'nuc']]] = 'fro') -> ndarray: + return xp.linalg.norm(x, axis=(-2, -1), keepdims=keepdims, ord=ord) + +# Unlike transpose, matrix_transpose only transposes the last two axes. +@get_xp +def matrix_transpose(x: ndarray, /, xp) -> ndarray: + if x.ndim < 2: + raise ValueError("x must be at least 2-dimensional for matrix_transpose") + return xp.swapaxes(x, -1, -2) + +# svdvals is not in NumPy (but it is in SciPy). It is equivalent to +# xp.linalg.svd(compute_uv=False). +@get_xp +def svdvals(x: ndarray, /, xp) -> Union[ndarray, Tuple[ndarray, ...]]: + return xp.linalg.svd(x, compute_uv=False) + +@get_xp +def vecdot(x1: ndarray, x2: ndarray, /, xp, *, axis: int = -1) -> ndarray: + ndim = max(x1.ndim, x2.ndim) + x1_shape = (1,)*(ndim - x1.ndim) + tuple(x1.shape) + x2_shape = (1,)*(ndim - x2.ndim) + tuple(x2.shape) + if x1_shape[axis] != x2_shape[axis]: + raise ValueError("x1 and x2 must have the same size along the given axis") + + x1_, x2_ = xp.broadcast_arrays(x1, x2) + x1_ = xp.moveaxis(x1_, axis, -1) + x2_ = xp.moveaxis(x2_, axis, -1) + + res = x1_[..., None, :] @ x2_[..., None] + return res[..., 0, 0] + +@get_xp +def vector_norm(x: ndarray, /, xp, *, axis: Optional[Union[int, Tuple[int, ...]]] = None, keepdims: bool = False, ord: Optional[Union[int, float]] = 2) -> ndarray: + # xp.linalg.norm tries to do a matrix norm whenever axis is a 2-tuple or + # when axis=None and the input is 2-D, so to force a vector norm, we make + # it so the input is 1-D (for axis=None), or reshape so that norm is done + # on a single dimension. + if axis is None: + # Note: xp.linalg.norm() doesn't handle 0-D arrays + x = x.ravel() + _axis = 0 + elif isinstance(axis, tuple): + # Note: The axis argument supports any number of axes, whereas + # xp.linalg.norm() only supports a single axis for vector norm. + normalized_axis = normalize_axis_tuple(axis, x.ndim) + rest = tuple(i for i in range(x.ndim) if i not in normalized_axis) + newshape = axis + rest + x = xp.transpose(x, newshape).reshape( + (xp.prod([x.shape[i] for i in axis], dtype=int), *[x.shape[i] for i in rest])) + _axis = 0 + else: + _axis = axis + + res = xp.linalg.norm(x, axis=_axis, ord=ord) + + if keepdims: + # We can't reuse xp.linalg.norm(keepdims) because of the reshape hacks + # above to avoid matrix norm logic. + shape = list(x.shape) + _axis = normalize_axis_tuple(range(x.ndim) if axis is None else axis, x.ndim) + for i in _axis: + shape[i] = 1 + res = xp.reshape(res, tuple(shape)) + + return res + +# xp.diagonal and xp.trace operate on the first two axes whereas these +# operates on the last two + +@get_xp +def diagonal(x: ndarray, /, xp, *, offset: int = 0) -> ndarray: + return xp.diagonal(x, offset=offset, axis1=-2, axis2=-1) + +@get_xp +def trace(x: ndarray, /, xp, *, offset: int = 0) -> ndarray: + return xp.asarray(xp.trace(x, offset=offset, axis1=-2, axis2=-1)) + +__all__ = ['cross', 'diagonal', 'matmul', 'cholesky', 'matrix_rank', 'pinv', + 'matrix_norm', 'matrix_transpose', 'outer', 'svdvals', + 'tensordot', 'trace', 'vecdot', 'vector_norm', 'EighResult', + 'QRResult', 'SlogdetResult', 'SVDResult'] diff --git a/numpy_array_api_compat/numpy/_aliases.py b/numpy_array_api_compat/numpy/_aliases.py index aba23aad..ae0172c7 100644 --- a/numpy_array_api_compat/numpy/_aliases.py +++ b/numpy_array_api_compat/numpy/_aliases.py @@ -1,476 +1,11 @@ -""" -These are functions that are just aliases of existing functions in NumPy. -""" - -from __future__ import annotations - -from typing import TYPE_CHECKING -if TYPE_CHECKING: - from typing import Optional, Tuple, Union, List - from ._typing import ndarray, Device, Dtype, NestedSequence, SupportsBufferProtocol - from functools import partial -from typing import NamedTuple -from types import ModuleType - -from ._helpers import _is_numpy_array, get_namespace -from .._internal import get_xp - -# Basic renames -@get_xp -def acos(x, /, xp): - return xp.arccos(x) - -@get_xp -def acosh(x, /, xp): - return xp.arccosh(x) - -@get_xp -def asin(x, /, xp): - return xp.arcsin(x) - -@get_xp -def asinh(x, /, xp): - return xp.arcsinh(x) - -@get_xp -def atan(x, /, xp): - return xp.arctan(x) - -@get_xp -def atan2(x1, x2, /, xp): - return xp.arctan2(x1, x2) - -@get_xp -def atanh(x, /, xp): - return xp.arctanh(x) - -@get_xp -def bitwise_left_shift(x1, x2, /, xp): - return xp.left_shift(x1, x2) - -@get_xp -def bitwise_invert(x, /, xp): - return xp.invert(x) - -@get_xp -def bitwise_right_shift(x1, x2, /, xp): - return xp.right_shift(x1, x2) - -@get_xp -def bool(x, /, xp): - return xp.bool_(x) - -@get_xp -def concat(arrays: Union[Tuple[ndarray, ...], List[ndarray]], /, xp, *, axis: Optional[int] = 0) -> ndarray: - return xp.concatenate(arrays, axis=axis) - -@get_xp -def pow(x1, x2, /, xp): - return xp.power(x1, x2) - -# These functions are modified from the NumPy versions. - -# np.unique() is split into four functions in the array API: -# unique_all, unique_counts, unique_inverse, and unique_values (this is done -# to remove polymorphic return types). - -# The functions here return namedtuples (np.unique() returns a normal -# tuple). -class UniqueAllResult(NamedTuple): - values: ndarray - indices: ndarray - inverse_indices: ndarray - counts: ndarray - - -class UniqueCountsResult(NamedTuple): - values: ndarray - counts: ndarray - - -class UniqueInverseResult(NamedTuple): - values: ndarray - inverse_indices: ndarray - - -@get_xp -def unique_all(x: ndarray, /, xp) -> UniqueAllResult: - values, indices, inverse_indices, counts = xp.unique( - x, - return_counts=True, - return_index=True, - return_inverse=True, - equal_nan=False, - ) - # np.unique() flattens inverse indices, but they need to share x's shape - # See https://github.com/numpy/numpy/issues/20638 - inverse_indices = inverse_indices.reshape(x.shape) - return UniqueAllResult( - values, - indices, - inverse_indices, - counts, - ) - - -@get_xp -def unique_counts(x: ndarray, /, xp) -> UniqueCountsResult: - res = xp.unique( - x, - return_counts=True, - return_index=False, - return_inverse=False, - equal_nan=False, - ) - - return UniqueCountsResult(*res) - - -@get_xp -def unique_inverse(x: ndarray, /, xp) -> UniqueInverseResult: - values, inverse_indices = xp.unique( - x, - return_counts=False, - return_index=False, - return_inverse=True, - equal_nan=False, - ) - # xp.unique() flattens inverse indices, but they need to share x's shape - # See https://github.com/numpy/numpy/issues/20638 - inverse_indices = inverse_indices.reshape(x.shape) - return UniqueInverseResult(values, inverse_indices) - - -@get_xp -def unique_values(x: ndarray, /, xp) -> ndarray: - return xp.unique( - x, - return_counts=False, - return_index=False, - return_inverse=False, - equal_nan=False, - ) - -def astype(x: ndarray, dtype: Dtype, /, *, copy: bool = True) -> ndarray: - if not copy and dtype == x.dtype: - return x - return x.astype(dtype=dtype, copy=copy) - -# These functions have different keyword argument names - -@get_xp -def std( - x: ndarray, - /, - xp, - *, - axis: Optional[Union[int, Tuple[int, ...]]] = None, - correction: Union[int, float] = 0.0, # correction instead of ddof - keepdims: bool = False, -) -> ndarray: - return xp.std(x, axis=axis, ddof=correction, keepdims=keepdims) - -@get_xp -def var( - x: ndarray, - /, - xp, - *, - axis: Optional[Union[int, Tuple[int, ...]]] = None, - correction: Union[int, float] = 0.0, # correction instead of ddof - keepdims: bool = False, -) -> ndarray: - return xp.var(x, axis=axis, ddof=correction, keepdims=keepdims) - -# Unlike transpose(), the axes argument to permute_dims() is required. -@get_xp -def permute_dims(x: ndarray, /, xp, axes: Tuple[int, ...]) -> ndarray: - return xp.transpose(x, axes) - -# Creation functions add the device keyword (which does nothing for NumPy) - -def _check_device(device): - if device not in ["cpu", None]: - raise ValueError(f"Unsupported device {device!r}") - -# asarray also adds the copy keyword -def _asarray( - obj: Union[ - ndarray, - bool, - int, - float, - NestedSequence[bool | int | float], - SupportsBufferProtocol, - ], - /, - *, - dtype: Optional[Dtype] = None, - device: Optional[Device] = None, - copy: "Optional[Union[bool, np._CopyMode]]" = None, - namespace = None, -) -> ndarray: - if namespace is None: - try: - xp = get_namespace(obj, _use_compat=False) - except ValueError: - # TODO: What about lists of arrays? - raise ValueError("A namespace must be specified for asarray() with non-array input") - elif isinstance(namespace, ModuleType): - xp = namespace - elif namespace == 'numpy': - import numpy as xp - elif namespace == 'cupy': - import cupy as xp - else: - raise ValueError("Unrecognized namespace argument to asarray()") - - _check_device(device) - if _is_numpy_array(obj): - import numpy as np - COPY_FALSE = (False, np._CopyMode.IF_NEEDED) - COPY_TRUE = (True, np._CopyMode.ALWAYS) - else: - COPY_FALSE = (False,) - COPY_TRUE = (True,) - if copy in COPY_FALSE: - # copy=False is not yet implemented in xp.asarray - raise NotImplementedError("copy=False is not yet implemented") - if isinstance(obj, xp.ndarray): - if dtype is not None and obj.dtype != dtype: - copy = True - if copy in COPY_TRUE: - return xp.array(obj, copy=True, dtype=dtype) - return obj - - return xp.asarray(obj, dtype=dtype) - -asarray_numpy = partial(_asarray, namespace='numpy') - -@get_xp -def arange( - start: Union[int, float], - /, - xp, - stop: Optional[Union[int, float]] = None, - step: Union[int, float] = 1, - *, - dtype: Optional[Dtype] = None, - device: Optional[Device] = None, -) -> ndarray: - _check_device(device) - return xp.arange(start, stop=stop, step=step, dtype=dtype) - -@get_xp -def empty( - shape: Union[int, Tuple[int, ...]], - xp, - *, - dtype: Optional[Dtype] = None, - device: Optional[Device] = None, -) -> ndarray: - _check_device(device) - return xp.empty(shape, dtype=dtype) - -@get_xp -def empty_like( - x: ndarray, /, xp, *, dtype: Optional[Dtype] = None, device: Optional[Device] = None -) -> ndarray: - _check_device(device) - return xp.empty_like(x, dtype=dtype) - -@get_xp -def eye( - n_rows: int, - n_cols: Optional[int] = None, - /, - *, - xp, - k: int = 0, - dtype: Optional[Dtype] = None, - device: Optional[Device] = None, -) -> ndarray: - _check_device(device) - return xp.eye(n_rows, M=n_cols, k=k, dtype=dtype) - -@get_xp -def full( - shape: Union[int, Tuple[int, ...]], - fill_value: Union[int, float], - xp, - *, - dtype: Optional[Dtype] = None, - device: Optional[Device] = None, -) -> ndarray: - _check_device(device) - return xp.full(shape, fill_value, dtype=dtype) - -@get_xp -def full_like( - x: ndarray, - /, - xp, - fill_value: Union[int, float], - *, - dtype: Optional[Dtype] = None, - device: Optional[Device] = None, -) -> ndarray: - _check_device(device) - return xp.full_like(x, fill_value, dtype=dtype) - -@get_xp -def linspace( - start: Union[int, float], - stop: Union[int, float], - /, - xp, - num: int, - *, - dtype: Optional[Dtype] = None, - device: Optional[Device] = None, - endpoint: bool = True, -) -> ndarray: - _check_device(device) - return xp.linspace(start, stop, num, dtype=dtype, endpoint=endpoint) - -@get_xp -def ones( - shape: Union[int, Tuple[int, ...]], - xp, - *, - dtype: Optional[Dtype] = None, - device: Optional[Device] = None, -) -> ndarray: - _check_device(device) - return xp.ones(shape, dtype=dtype) - -@get_xp -def ones_like( - x: ndarray, /, xp, *, dtype: Optional[Dtype] = None, device: Optional[Device] = None -) -> ndarray: - _check_device(device) - return xp.ones_like(x, dtype=dtype) - -@get_xp -def zeros( - shape: Union[int, Tuple[int, ...]], - xp, - *, - dtype: Optional[Dtype] = None, - device: Optional[Device] = None, -) -> ndarray: - _check_device(device) - return xp.zeros(shape, dtype=dtype) - -@get_xp -def zeros_like( - x: ndarray, /, xp, *, dtype: Optional[Dtype] = None, device: Optional[Device] = None -) -> ndarray: - _check_device(device) - return xp.zeros_like(x, dtype=dtype) - -# xp.reshape calls the keyword argument 'newshape' instead of 'shape' -@get_xp -def reshape(x: ndarray, /, xp, shape: Tuple[int, ...], copy: Optional[bool] = None) -> ndarray: - if copy is True: - x = x.copy() - elif copy is False: - x.shape = shape - return x - return xp.reshape(x, shape) - -# The descending keyword is new in sort and argsort, and 'kind' replaced with -# 'stable' -@get_xp -def argsort( - x: ndarray, /, xp, *, axis: int = -1, descending: bool = False, stable: bool = True -) -> ndarray: - # Note: this keyword argument is different, and the default is different. - kind = "stable" if stable else "quicksort" - if not descending: - res = xp.argsort(x, axis=axis, kind=kind) - else: - # As NumPy has no native descending sort, we imitate it here. Note that - # simply flipping the results of xp.argsort(x, ...) would not - # respect the relative order like it would in native descending sorts. - res = xp.flip( - xp.argsort(xp.flip(x, axis=axis), axis=axis, kind=kind), - axis=axis, - ) - # Rely on flip()/argsort() to validate axis - normalised_axis = axis if axis >= 0 else x.ndim + axis - max_i = x.shape[normalised_axis] - 1 - res = max_i - res - return res - -@get_xp -def sort( - x: ndarray, /, xp, *, axis: int = -1, descending: bool = False, stable: bool = True -) -> ndarray: - # Note: this keyword argument is different, and the default is different. - kind = "stable" if stable else "quicksort" - res = xp.sort(x, axis=axis, kind=kind) - if descending: - res = xp.flip(res, axis=axis) - return res - -# sum() and prod() should always upcast when dtype=None -@get_xp -def sum( - x: ndarray, - /, - xp, - *, - axis: Optional[Union[int, Tuple[int, ...]]] = None, - dtype: Optional[Dtype] = None, - keepdims: bool = False, -) -> ndarray: - # `xp.sum` already upcasts integers, but not floats - if dtype is None and x.dtype == xp.float32: - dtype = xp.float64 - return xp.sum(x, axis=axis, dtype=dtype, keepdims=keepdims) - -@get_xp -def prod( - x: ndarray, - /, - xp, - *, - axis: Optional[Union[int, Tuple[int, ...]]] = None, - dtype: Optional[Dtype] = None, - keepdims: bool = False, -) -> ndarray: - if dtype is None and x.dtype == xp.float32: - dtype = xp.float64 - return xp.prod(x, dtype=dtype, axis=axis, keepdims=keepdims) - -# ceil, floor, and trunc return integers for integer inputs - -@get_xp -def ceil(x: ndarray, /, xp) -> ndarray: - if xp.issubdtype(x.dtype, xp.integer): - return x - return xp.ceil(x) -@get_xp -def floor(x: ndarray, /, xp) -> ndarray: - if xp.issubdtype(x.dtype, xp.integer): - return x - return xp.floor(x) +from ..common._aliases import * +from ..common._aliases import _asarray +from ..common._aliases import __all__ -@get_xp -def trunc(x: ndarray, /, xp) -> ndarray: - if xp.issubdtype(x.dtype, xp.integer): - return x - return xp.trunc(x) +asarray = asarray_numpy = partial(_asarray, namespace='numpy') +asarray.__doc__ = _asarray.__doc__ +del partial -__all__ = ['acos', 'acosh', 'asin', 'asinh', 'atan', 'atan2', 'atanh', - 'bitwise_left_shift', 'bitwise_invert', 'bitwise_right_shift', - 'bool', 'concat', 'pow', 'UniqueAllResult', 'UniqueCountsResult', - 'UniqueInverseResult', 'unique_all', 'unique_counts', - 'unique_inverse', 'unique_values', 'astype', 'std', 'var', - 'permute_dims', 'asarray_numpy', 'arange', 'empty', 'empty_like', - 'eye', 'full', 'full_like', 'linspace', 'ones', 'ones_like', - 'zeros', 'zeros_like', 'reshape', 'argsort', 'sort', 'sum', 'prod', - 'ceil', 'floor', 'trunc'] +__all__ = __all__ + ['asarray', 'asarray_numpy'] diff --git a/numpy_array_api_compat/numpy/_helpers.py b/numpy_array_api_compat/numpy/_helpers.py index b658f5b6..037e0f7a 100644 --- a/numpy_array_api_compat/numpy/_helpers.py +++ b/numpy_array_api_compat/numpy/_helpers.py @@ -4,11 +4,17 @@ from __future__ import annotations -import numpy as np +import sys from ..common._helpers import get_namespace def _is_numpy_array(x): + # Avoid importing NumPy if it isn't already + if 'numpy' not in sys.modules: + return False + + import numpy as np + # TODO: Should we reject ndarray subclasses? return isinstance(x, (np.ndarray, np.generic)) diff --git a/numpy_array_api_compat/numpy/linalg.py b/numpy_array_api_compat/numpy/linalg.py index ec3bc11d..1f6c95c0 100644 --- a/numpy_array_api_compat/numpy/linalg.py +++ b/numpy_array_api_compat/numpy/linalg.py @@ -1,157 +1,7 @@ -from __future__ import annotations - -from typing import TYPE_CHECKING, NamedTuple -if TYPE_CHECKING: - from typing import Literal, Optional, Tuple, Union - from numpy import ndarray - -import numpy as np -from numpy.core.numeric import normalize_axis_tuple - from numpy.linalg import * from numpy.linalg import __all__ as linalg_all -# These are in the main NumPy namespace but not in numpy.linalg -from numpy import cross, matmul, outer, tensordot - -class EighResult(NamedTuple): - eigenvalues: ndarray - eigenvectors: ndarray - -class QRResult(NamedTuple): - Q: ndarray - R: ndarray - -class SlogdetResult(NamedTuple): - sign: ndarray - logabsdet: ndarray - -class SVDResult(NamedTuple): - U: ndarray - S: ndarray - Vh: ndarray - -# These functions are the same as their NumPy counterparts except they return -# a namedtuple. -def eigh(x: ndarray, /) -> EighResult: - return EighResult(*np.linalg.eigh(x)) - -def qr(x: ndarray, /, *, mode: Literal['reduced', 'complete'] = 'reduced') -> QRResult: - return QRResult(*np.linalg.qr(x, mode=mode)) - -def slogdet(x: ndarray, /) -> SlogdetResult: - return SlogdetResult(*np.linalg.slogdet(x)) - -def svd(x: ndarray, /, *, full_matrices: bool = True) -> SVDResult: - return SVDResult(*np.linalg.svd(x, full_matrices=full_matrices)) - -# These functions have additional keyword arguments - -# The upper keyword argument is new from NumPy -def cholesky(x: ndarray, /, *, upper: bool = False) -> ndarray: - L = np.linalg.cholesky(x) - if upper: - return matrix_transpose(L) - return L - -# The rtol keyword argument of matrix_rank() and pinv() is new from NumPy. -# Note that it has a different semantic meaning from tol and rcond. -def matrix_rank(x: ndarray, /, *, rtol: Optional[Union[float, ndarray]] = None) -> ndarray: - # this is different from np.linalg.matrix_rank, which supports 1 - # dimensional arrays. - if x.ndim < 2: - raise np.linalg.LinAlgError("1-dimensional array given. Array must be at least two-dimensional") - S = np.linalg.svd(x, compute_uv=False) - if rtol is None: - tol = S.max(axis=-1, keepdims=True) * max(x.shape[-2:]) * np.finfo(S.dtype).eps - else: - # this is different from np.linalg.matrix_rank, which does not - # multiply the tolerance by the largest singular value. - tol = S.max(axis=-1, keepdims=True)*np.asarray(rtol)[..., np.newaxis] - return np.count_nonzero(S > tol, axis=-1) - -def pinv(x: ndarray, /, *, rtol: Optional[Union[float, ndarray]] = None) -> ndarray: - # this is different from np.linalg.pinv, which does not multiply the - # default tolerance by max(M, N). - if rtol is None: - rtol = max(x.shape[-2:]) * np.finfo(x.dtype).eps - return np.linalg.pinv(x, rcond=rtol) - -# These functions are new in the array API spec - -def matrix_norm(x: ndarray, /, *, keepdims: bool = False, ord: Optional[Union[int, float, Literal['fro', 'nuc']]] = 'fro') -> ndarray: - return np.linalg.norm(x, axis=(-2, -1), keepdims=keepdims, ord=ord) - -# Unlike transpose, matrix_transpose only transposes the last two axes. -def matrix_transpose(x: ndarray, /) -> ndarray: - if x.ndim < 2: - raise ValueError("x must be at least 2-dimensional for matrix_transpose") - return np.swapaxes(x, -1, -2) - -# svdvals is not in NumPy (but it is in SciPy). It is equivalent to -# np.linalg.svd(compute_uv=False). -def svdvals(x: ndarray, /) -> Union[ndarray, Tuple[ndarray, ...]]: - return np.linalg.svd(x, compute_uv=False) - -def vecdot(x1: ndarray, x2: ndarray, /, *, axis: int = -1) -> ndarray: - ndim = max(x1.ndim, x2.ndim) - x1_shape = (1,)*(ndim - x1.ndim) + tuple(x1.shape) - x2_shape = (1,)*(ndim - x2.ndim) + tuple(x2.shape) - if x1_shape[axis] != x2_shape[axis]: - raise ValueError("x1 and x2 must have the same size along the given axis") - - x1_, x2_ = np.broadcast_arrays(x1, x2) - x1_ = np.moveaxis(x1_, axis, -1) - x2_ = np.moveaxis(x2_, axis, -1) - - res = x1_[..., None, :] @ x2_[..., None] - return res[..., 0, 0] - -def vector_norm(x: ndarray, /, *, axis: Optional[Union[int, Tuple[int, ...]]] = None, keepdims: bool = False, ord: Optional[Union[int, float]] = 2) -> ndarray: - # np.linalg.norm tries to do a matrix norm whenever axis is a 2-tuple or - # when axis=None and the input is 2-D, so to force a vector norm, we make - # it so the input is 1-D (for axis=None), or reshape so that norm is done - # on a single dimension. - if axis is None: - # Note: np.linalg.norm() doesn't handle 0-D arrays - x = x.ravel() - _axis = 0 - elif isinstance(axis, tuple): - # Note: The axis argument supports any number of axes, whereas - # np.linalg.norm() only supports a single axis for vector norm. - normalized_axis = normalize_axis_tuple(axis, x.ndim) - rest = tuple(i for i in range(x.ndim) if i not in normalized_axis) - newshape = axis + rest - x = np.transpose(x, newshape).reshape( - (np.prod([x.shape[i] for i in axis], dtype=int), *[x.shape[i] for i in rest])) - _axis = 0 - else: - _axis = axis - - res = np.linalg.norm(x, axis=_axis, ord=ord) - - if keepdims: - # We can't reuse np.linalg.norm(keepdims) because of the reshape hacks - # above to avoid matrix norm logic. - shape = list(x.shape) - _axis = normalize_axis_tuple(range(x.ndim) if axis is None else axis, x.ndim) - for i in _axis: - shape[i] = 1 - res = np.reshape(res, tuple(shape)) - - return res - -# np.diagonal and np.trace operate on the first two axes whereas these -# operates on the last two - -def diagonal(x: ndarray, /, *, offset: int = 0) -> ndarray: - return np.diagonal(x, offset=offset, axis1=-2, axis2=-1) - -def trace(x: ndarray, /, *, offset: int = 0) -> ndarray: - return np.asarray(np.trace(x, offset=offset, axis1=-2, axis2=-1)) +from ..common.linalg import * +from ..common.linalg import __all__ as common_linalg_all -__all__ = linalg_all.copy() -__all__ += ['cross', 'diagonal', 'matmul', 'cholesky', 'matrix_rank', 'pinv', - 'matrix_norm', 'matrix_transpose', 'outer', 'svdvals', - 'tensordot', 'trace', 'vecdot', 'vector_norm', 'EighResult', - 'QRResult', 'SlogdetResult', 'SVDResult'] +__all__ = linalg_all + common_linalg_all From 7333696f188a97236e485b2cef2ed26047b6d5eb Mon Sep 17 00:00:00 2001 From: Aaron Meurer Date: Mon, 28 Nov 2022 17:22:58 -0700 Subject: [PATCH 06/24] Move _typing.py from numpy/ to common/ --- .../{numpy => common}/_typing.py | 0 numpy_array_api_compat/cupy/__init__.py | 20 +++++++++++++++++++ 2 files changed, 20 insertions(+) rename numpy_array_api_compat/{numpy => common}/_typing.py (100%) create mode 100644 numpy_array_api_compat/cupy/__init__.py diff --git a/numpy_array_api_compat/numpy/_typing.py b/numpy_array_api_compat/common/_typing.py similarity index 100% rename from numpy_array_api_compat/numpy/_typing.py rename to numpy_array_api_compat/common/_typing.py diff --git a/numpy_array_api_compat/cupy/__init__.py b/numpy_array_api_compat/cupy/__init__.py new file mode 100644 index 00000000..a3897fb7 --- /dev/null +++ b/numpy_array_api_compat/cupy/__init__.py @@ -0,0 +1,20 @@ +from cupy import * + +# from cupy import * doesn't overwrite these builtin names +from cupy import abs, max, min, round + +# These imports may overwrite names from the import * above. +from ._aliases import * + +# Don't know why, but we have to do an absolute import to import linalg. If we +# instead do +# +# from . import linalg +# +# It doesn't overwrite np.linalg from above. The import is generated +# dynamically so that the library can be vendored. +__import__(__package__ + '.linalg') + +from .linalg import matrix_transpose, vecdot + +from ._helpers import * From 076848ed261b6dcd89cec5809a523dc35def7500 Mon Sep 17 00:00:00 2001 From: Aaron Meurer Date: Mon, 28 Nov 2022 17:49:24 -0700 Subject: [PATCH 07/24] Add a cupy submodule Full support still needs to be tested, and also we need to double check if any of the aliases are unnecessary for cupy. --- numpy_array_api_compat/common/_aliases.py | 30 +++-- numpy_array_api_compat/common/_helpers.py | 133 ++++++++++++++++++++++ numpy_array_api_compat/cupy/__init__.py | 4 +- numpy_array_api_compat/cupy/_aliases.py | 11 ++ numpy_array_api_compat/cupy/linalg.py | 14 +++ numpy_array_api_compat/numpy/__init__.py | 2 +- numpy_array_api_compat/numpy/_helpers.py | 79 ------------- 7 files changed, 174 insertions(+), 99 deletions(-) create mode 100644 numpy_array_api_compat/cupy/_aliases.py create mode 100644 numpy_array_api_compat/cupy/linalg.py delete mode 100644 numpy_array_api_compat/numpy/_helpers.py diff --git a/numpy_array_api_compat/common/_aliases.py b/numpy_array_api_compat/common/_aliases.py index 50cb62cf..198d2d3b 100644 --- a/numpy_array_api_compat/common/_aliases.py +++ b/numpy_array_api_compat/common/_aliases.py @@ -12,7 +12,7 @@ from typing import NamedTuple from types import ModuleType -from ..numpy._helpers import _is_numpy_array, get_namespace +from ._helpers import _check_device, _is_numpy_array, get_namespace from .._internal import get_xp # Basic renames @@ -189,10 +189,6 @@ def permute_dims(x: ndarray, /, xp, axes: Tuple[int, ...]) -> ndarray: # Creation functions add the device keyword (which does nothing for NumPy) -def _check_device(device): - if device not in ["cpu", None]: - raise ValueError(f"Unsupported device {device!r}") - # asarray also adds the copy keyword def _asarray( obj: Union[ @@ -232,7 +228,7 @@ def _asarray( else: raise ValueError("Unrecognized namespace argument to asarray()") - _check_device(device) + _check_device(xp, device) if _is_numpy_array(obj): import numpy as np COPY_FALSE = (False, np._CopyMode.IF_NEEDED) @@ -263,7 +259,7 @@ def arange( dtype: Optional[Dtype] = None, device: Optional[Device] = None, ) -> ndarray: - _check_device(device) + _check_device(xp, device) return xp.arange(start, stop=stop, step=step, dtype=dtype) @get_xp @@ -274,14 +270,14 @@ def empty( dtype: Optional[Dtype] = None, device: Optional[Device] = None, ) -> ndarray: - _check_device(device) + _check_device(xp, device) return xp.empty(shape, dtype=dtype) @get_xp def empty_like( x: ndarray, /, xp, *, dtype: Optional[Dtype] = None, device: Optional[Device] = None ) -> ndarray: - _check_device(device) + _check_device(xp, device) return xp.empty_like(x, dtype=dtype) @get_xp @@ -295,7 +291,7 @@ def eye( dtype: Optional[Dtype] = None, device: Optional[Device] = None, ) -> ndarray: - _check_device(device) + _check_device(xp, device) return xp.eye(n_rows, M=n_cols, k=k, dtype=dtype) @get_xp @@ -307,7 +303,7 @@ def full( dtype: Optional[Dtype] = None, device: Optional[Device] = None, ) -> ndarray: - _check_device(device) + _check_device(xp, device) return xp.full(shape, fill_value, dtype=dtype) @get_xp @@ -320,7 +316,7 @@ def full_like( dtype: Optional[Dtype] = None, device: Optional[Device] = None, ) -> ndarray: - _check_device(device) + _check_device(xp, device) return xp.full_like(x, fill_value, dtype=dtype) @get_xp @@ -335,7 +331,7 @@ def linspace( device: Optional[Device] = None, endpoint: bool = True, ) -> ndarray: - _check_device(device) + _check_device(xp, device) return xp.linspace(start, stop, num, dtype=dtype, endpoint=endpoint) @get_xp @@ -346,14 +342,14 @@ def ones( dtype: Optional[Dtype] = None, device: Optional[Device] = None, ) -> ndarray: - _check_device(device) + _check_device(xp, device) return xp.ones(shape, dtype=dtype) @get_xp def ones_like( x: ndarray, /, xp, *, dtype: Optional[Dtype] = None, device: Optional[Device] = None ) -> ndarray: - _check_device(device) + _check_device(xp, device) return xp.ones_like(x, dtype=dtype) @get_xp @@ -364,14 +360,14 @@ def zeros( dtype: Optional[Dtype] = None, device: Optional[Device] = None, ) -> ndarray: - _check_device(device) + _check_device(xp, device) return xp.zeros(shape, dtype=dtype) @get_xp def zeros_like( x: ndarray, /, xp, *, dtype: Optional[Dtype] = None, device: Optional[Device] = None ) -> ndarray: - _check_device(device) + _check_device(xp, device) return xp.zeros_like(x, dtype=dtype) # xp.reshape calls the keyword argument 'newshape' instead of 'shape' diff --git a/numpy_array_api_compat/common/_helpers.py b/numpy_array_api_compat/common/_helpers.py index 5090fd5d..3863c0cb 100644 --- a/numpy_array_api_compat/common/_helpers.py +++ b/numpy_array_api_compat/common/_helpers.py @@ -1,6 +1,40 @@ """ Various helper functions which are not part of the spec. + +Functions which start with an underscore are for internal use only but helpers +that are in __all__ are intended as additional helper functions for use by end +users of the compat library. """ +from __future__ import annotations + +import sys + +def _is_numpy_array(x): + # Avoid importing NumPy if it isn't already + if 'numpy' not in sys.modules: + return False + + import numpy as np + + # TODO: Should we reject ndarray subclasses? + return isinstance(x, (np.ndarray, np.generic)) + +def _is_cupy_array(x): + # Avoid importing NumPy if it isn't already + if 'cupy' not in sys.modules: + return False + + import cupy as cp + + # TODO: Should we reject ndarray subclasses? + return isinstance(x, (cp.ndarray, cp.generic)) + +def is_array_api_obj(x): + """ + Check if x is an array API compatible array object. + """ + return _is_numpy_array(x) or _is_cupy_array(x) or hasattr(x, '__array_namespace__') + def get_namespace(*xs, _use_compat=True): """ Get the array API compatible namespace for the arrays `xs`. @@ -35,3 +69,102 @@ def get_namespace(*xs, _use_compat=True): xp, = namespaces return xp + + +def _check_device(xp, device): + if xp == sys.modules.get('numpy'): + if device not in ["cpu", None]: + raise ValueError(f"Unsupported device for NumPy: {device!r}") + +# device() is not on numpy.ndarray and and to_device() is not on numpy.ndarray +# or cupy.ndarray. They are not included in array objects of this library +# because this library just reuses the respective ndarray classes without +# wrapping or subclassing them. These helper functions can be used instead of +# the wrapper functions for libraries that need to support both NumPy/CuPy and +# other libraries that use devices. +def device(x: "Array", /) -> "Device": + """ + Hardware device the array data resides on. + + Parameters + ---------- + x: array + array instance from NumPy or an array API compatible library. + + Returns + ------- + out: device + a ``device`` object (see the "Device Support" section of the array API specification). + """ + if _is_numpy_array(x): + return "cpu" + return x.device + +# Based on cupy.array_api.Array.to_device +def _cupy_to_device(x, device, /, stream=None): + import cupy as cp + from cupy.cuda import Device as _Device + from cupy.cuda import stream as stream_module + from cupy_backends.cuda.api import runtime + + if device == x.device: + return x + elif not isinstance(device, _Device): + raise ValueError(f"Unsupported device {device!r}") + else: + # see cupy/cupy#5985 for the reason how we handle device/stream here + prev_device = runtime.getDevice() + prev_stream: stream_module.Stream = None + if stream is not None: + prev_stream = stream_module.get_current_stream() + # stream can be an int as specified in __dlpack__, or a CuPy stream + if isinstance(stream, int): + stream = cp.cuda.ExternalStream(stream) + elif isinstance(stream, cp.cuda.Stream): + pass + else: + raise ValueError('the input stream is not recognized') + stream.use() + try: + runtime.setDevice(device.id) + arr = x.copy() + finally: + runtime.setDevice(prev_device) + if stream is not None: + prev_stream.use() + return arr + +def to_device(x: "Array", device: "Device", /, *, stream: Optional[Union[int, Any]] = None) -> "Array": + """ + Copy the array from the device on which it currently resides to the specified ``device``. + + Parameters + ---------- + x: array + array instance from NumPy or an array API compatible library. + device: device + a ``device`` object (see the "Device Support" section of the array API specification). + stream: Optional[Union[int, Any]] + stream object to use during copy. In addition to the types supported in ``array.__dlpack__``, implementations may choose to support any library-specific stream object with the caveat that any code using such an object would not be portable. + + Returns + ------- + out: array + an array with the same data and data type as ``x`` and located on the specified ``device``. + + .. note:: + If ``stream`` is given, the copy operation should be enqueued on the provided ``stream``; otherwise, the copy operation should be enqueued on the default stream/queue. Whether the copy is performed synchronously or asynchronously is implementation-dependent. Accordingly, if synchronization is required to guarantee data safety, this must be clearly explained in a conforming library's documentation. + """ + if _is_numpy_array(x): + if stream is not None: + raise ValueError("The stream argument to to_device() is not supported") + if device == 'cpu': + return x + raise ValueError(f"Unsupported device {device!r}") + elif _is_cupy_array(x): + # cupy does not yet have to_device + return _cupy_to_device(x, device, stream=stream) + + return x.to_device(device, stream=stream) + +__all__ = ['is_array_api_obj', 'get_namespace', 'device', 'to_device'] diff --git a/numpy_array_api_compat/cupy/__init__.py b/numpy_array_api_compat/cupy/__init__.py index a3897fb7..bc760b2b 100644 --- a/numpy_array_api_compat/cupy/__init__.py +++ b/numpy_array_api_compat/cupy/__init__.py @@ -11,10 +11,10 @@ # # from . import linalg # -# It doesn't overwrite np.linalg from above. The import is generated +# It doesn't overwrite cupy.linalg from above. The import is generated # dynamically so that the library can be vendored. __import__(__package__ + '.linalg') from .linalg import matrix_transpose, vecdot -from ._helpers import * +from ..common._helpers import * diff --git a/numpy_array_api_compat/cupy/_aliases.py b/numpy_array_api_compat/cupy/_aliases.py new file mode 100644 index 00000000..49e662e3 --- /dev/null +++ b/numpy_array_api_compat/cupy/_aliases.py @@ -0,0 +1,11 @@ +from functools import partial + +from ..common._aliases import * +from ..common._aliases import _asarray +from ..common._aliases import __all__ + +asarray = asarray_cupy = partial(_asarray, namespace='cupy') +asarray.__doc__ = _asarray.__doc__ +del partial + +__all__ = __all__ + ['asarray', 'asarray_cupy'] diff --git a/numpy_array_api_compat/cupy/linalg.py b/numpy_array_api_compat/cupy/linalg.py new file mode 100644 index 00000000..83c16b12 --- /dev/null +++ b/numpy_array_api_compat/cupy/linalg.py @@ -0,0 +1,14 @@ +from cupy.linalg import * +# cupy.linalg doesn't have __all__. If it is added, replace this with +# +# from cupy.linalg import __all__ as linalg_all +_n = {} +exec('from cupy.linalg import *', _n) +del _n['__builtins__'] +linalg_all = list(_n) +del _n + +from ..common.linalg import * +from ..common.linalg import __all__ as common_linalg_all + +__all__ = linalg_all + common_linalg_all diff --git a/numpy_array_api_compat/numpy/__init__.py b/numpy_array_api_compat/numpy/__init__.py index 46c000f9..32a1ea2b 100644 --- a/numpy_array_api_compat/numpy/__init__.py +++ b/numpy_array_api_compat/numpy/__init__.py @@ -17,4 +17,4 @@ from .linalg import matrix_transpose, vecdot -from ._helpers import * +from ..common._helpers import * diff --git a/numpy_array_api_compat/numpy/_helpers.py b/numpy_array_api_compat/numpy/_helpers.py deleted file mode 100644 index 037e0f7a..00000000 --- a/numpy_array_api_compat/numpy/_helpers.py +++ /dev/null @@ -1,79 +0,0 @@ -""" -Various helper functions which are not part of the spec. -""" - -from __future__ import annotations - -import sys - -from ..common._helpers import get_namespace - -def _is_numpy_array(x): - # Avoid importing NumPy if it isn't already - if 'numpy' not in sys.modules: - return False - - import numpy as np - - # TODO: Should we reject ndarray subclasses? - return isinstance(x, (np.ndarray, np.generic)) - -def is_array_api_obj(x): - """ - Check if x is an array API compatible array object. - """ - return _is_numpy_array(x) or hasattr(x, '__array_namespace__') - -# device and to_device are not included in array object of this library -# because this library just reuses ndarray without wrapping or subclassing it. -# These helper functions can be used instead of the wrapper functions for -# libraries that need to support both NumPy and other libraries that use devices. -def device(x: "Array", /) -> "Device": - """ - Hardware device the array data resides on. - - Parameters - ---------- - x: array - array instance from NumPy or an array API compatible library. - - Returns - ------- - out: device - a ``device`` object (see the "Device Support" section of the array API specification). - """ - if _is_numpy_array(x): - return "cpu" - return x.device - -def to_device(x: "Array", device: "Device", /, *, stream: Optional[Union[int, Any]] = None) -> "Array": - """ - Copy the array from the device on which it currently resides to the specified ``device``. - - Parameters - ---------- - x: array - array instance from NumPy or an array API compatible library. - device: device - a ``device`` object (see the "Device Support" section of the array API specification). - stream: Optional[Union[int, Any]] - stream object to use during copy. In addition to the types supported in ``array.__dlpack__``, implementations may choose to support any library-specific stream object with the caveat that any code using such an object would not be portable. - - Returns - ------- - out: array - an array with the same data and data type as ``x`` and located on the specified ``device``. - - .. note:: - If ``stream`` is given, the copy operation should be enqueued on the provided ``stream``; otherwise, the copy operation should be enqueued on the default stream/queue. Whether the copy is performed synchronously or asynchronously is implementation-dependent. Accordingly, if synchronization is required to guarantee data safety, this must be clearly explained in a conforming library's documentation. - """ - if _is_numpy_array(x): - if stream is not None: - raise ValueError("The stream argument to to_device() is not supported") - if device == 'cpu': - return x - raise ValueError(f"Unsupported device {device!r}") - - return x.to_device(device, stream=stream) - -__all__ = ['is_array_api_obj', 'get_namespace', 'device', 'to_device'] From ed5705fbc86df007e562c6f62cd43d116d43b30a Mon Sep 17 00:00:00 2001 From: Aaron Meurer Date: Mon, 28 Nov 2022 17:52:23 -0700 Subject: [PATCH 08/24] Rename numpy_array_api_compat/ to array_api_compat/ --- README.md | 18 ++++++++++++++++-- .../__init__.py | 0 .../_internal.py | 0 .../common/__init__.py | 0 .../common/_aliases.py | 0 .../common/_helpers.py | 0 .../common/_typing.py | 0 .../common/linalg.py | 0 .../cupy/__init__.py | 0 .../cupy/_aliases.py | 0 .../cupy/linalg.py | 0 .../numpy/__init__.py | 0 .../numpy/_aliases.py | 0 .../numpy/linalg.py | 0 14 files changed, 16 insertions(+), 2 deletions(-) rename {numpy_array_api_compat => array_api_compat}/__init__.py (100%) rename {numpy_array_api_compat => array_api_compat}/_internal.py (100%) rename {numpy_array_api_compat => array_api_compat}/common/__init__.py (100%) rename {numpy_array_api_compat => array_api_compat}/common/_aliases.py (100%) rename {numpy_array_api_compat => array_api_compat}/common/_helpers.py (100%) rename {numpy_array_api_compat => array_api_compat}/common/_typing.py (100%) rename {numpy_array_api_compat => array_api_compat}/common/linalg.py (100%) rename {numpy_array_api_compat => array_api_compat}/cupy/__init__.py (100%) rename {numpy_array_api_compat => array_api_compat}/cupy/_aliases.py (100%) rename {numpy_array_api_compat => array_api_compat}/cupy/linalg.py (100%) rename {numpy_array_api_compat => array_api_compat}/numpy/__init__.py (100%) rename {numpy_array_api_compat => array_api_compat}/numpy/_aliases.py (100%) rename {numpy_array_api_compat => array_api_compat}/numpy/linalg.py (100%) diff --git a/README.md b/README.md index 9d888663..9ccf4afb 100644 --- a/README.md +++ b/README.md @@ -1,4 +1,4 @@ -# NumPy Array API compatibility library +# Array API compatibility library This is a small wrapper around NumPy that is compatible with the [Array API standard](https://data-apis.org/array-api/latest/). See also [NEP 47](https://numpy.org/neps/nep-0047-array-api-standard.html). @@ -13,6 +13,8 @@ separate Array object, but rather just uses `numpy.ndarray` directly. Note that some of the functionality in this library is backwards incompatible with NumPy. +This library also supports CuPy in addition to NumPy. + Library authors using the Array API may wish to test against `numpy.array_api` to ensure they are not using functionality outside of the standard, but prefer this implementation for end users who use NumPy arrays. @@ -28,5 +30,17 @@ import numpy as np with ```py -import numpy_array_api_compat as np +import array_api_compat.numpy as np +``` + +and replace + +```py +import cupy as cp +``` + +with + +```py +import array_api_compat.cupy as cp ``` diff --git a/numpy_array_api_compat/__init__.py b/array_api_compat/__init__.py similarity index 100% rename from numpy_array_api_compat/__init__.py rename to array_api_compat/__init__.py diff --git a/numpy_array_api_compat/_internal.py b/array_api_compat/_internal.py similarity index 100% rename from numpy_array_api_compat/_internal.py rename to array_api_compat/_internal.py diff --git a/numpy_array_api_compat/common/__init__.py b/array_api_compat/common/__init__.py similarity index 100% rename from numpy_array_api_compat/common/__init__.py rename to array_api_compat/common/__init__.py diff --git a/numpy_array_api_compat/common/_aliases.py b/array_api_compat/common/_aliases.py similarity index 100% rename from numpy_array_api_compat/common/_aliases.py rename to array_api_compat/common/_aliases.py diff --git a/numpy_array_api_compat/common/_helpers.py b/array_api_compat/common/_helpers.py similarity index 100% rename from numpy_array_api_compat/common/_helpers.py rename to array_api_compat/common/_helpers.py diff --git a/numpy_array_api_compat/common/_typing.py b/array_api_compat/common/_typing.py similarity index 100% rename from numpy_array_api_compat/common/_typing.py rename to array_api_compat/common/_typing.py diff --git a/numpy_array_api_compat/common/linalg.py b/array_api_compat/common/linalg.py similarity index 100% rename from numpy_array_api_compat/common/linalg.py rename to array_api_compat/common/linalg.py diff --git a/numpy_array_api_compat/cupy/__init__.py b/array_api_compat/cupy/__init__.py similarity index 100% rename from numpy_array_api_compat/cupy/__init__.py rename to array_api_compat/cupy/__init__.py diff --git a/numpy_array_api_compat/cupy/_aliases.py b/array_api_compat/cupy/_aliases.py similarity index 100% rename from numpy_array_api_compat/cupy/_aliases.py rename to array_api_compat/cupy/_aliases.py diff --git a/numpy_array_api_compat/cupy/linalg.py b/array_api_compat/cupy/linalg.py similarity index 100% rename from numpy_array_api_compat/cupy/linalg.py rename to array_api_compat/cupy/linalg.py diff --git a/numpy_array_api_compat/numpy/__init__.py b/array_api_compat/numpy/__init__.py similarity index 100% rename from numpy_array_api_compat/numpy/__init__.py rename to array_api_compat/numpy/__init__.py diff --git a/numpy_array_api_compat/numpy/_aliases.py b/array_api_compat/numpy/_aliases.py similarity index 100% rename from numpy_array_api_compat/numpy/_aliases.py rename to array_api_compat/numpy/_aliases.py diff --git a/numpy_array_api_compat/numpy/linalg.py b/array_api_compat/numpy/linalg.py similarity index 100% rename from numpy_array_api_compat/numpy/linalg.py rename to array_api_compat/numpy/linalg.py From 420c0daf678f1e8607bda1c778f00407847519ba Mon Sep 17 00:00:00 2001 From: Aaron Meurer Date: Mon, 28 Nov 2022 17:57:22 -0700 Subject: [PATCH 09/24] Add __array_api_version__ --- array_api_compat/cupy/__init__.py | 2 ++ array_api_compat/numpy/__init__.py | 2 ++ 2 files changed, 4 insertions(+) diff --git a/array_api_compat/cupy/__init__.py b/array_api_compat/cupy/__init__.py index bc760b2b..e4f43c13 100644 --- a/array_api_compat/cupy/__init__.py +++ b/array_api_compat/cupy/__init__.py @@ -18,3 +18,5 @@ from .linalg import matrix_transpose, vecdot from ..common._helpers import * + +__array_api_version__ = '2021.12' diff --git a/array_api_compat/numpy/__init__.py b/array_api_compat/numpy/__init__.py index 32a1ea2b..745367bc 100644 --- a/array_api_compat/numpy/__init__.py +++ b/array_api_compat/numpy/__init__.py @@ -18,3 +18,5 @@ from .linalg import matrix_transpose, vecdot from ..common._helpers import * + +__array_api_version__ = '2021.12' From 005852f719dcf3e6b1f6eb7f72b3bb7d1501da9c Mon Sep 17 00:00:00 2001 From: Aaron Meurer Date: Mon, 28 Nov 2022 22:32:58 -0700 Subject: [PATCH 10/24] Remove library-specific stuff from common/_typing.py --- array_api_compat/common/_typing.py | 39 ------------------------- array_api_compat/cupy/_typing.py | 46 ++++++++++++++++++++++++++++++ array_api_compat/numpy/_typing.py | 46 ++++++++++++++++++++++++++++++ 3 files changed, 92 insertions(+), 39 deletions(-) create mode 100644 array_api_compat/cupy/_typing.py create mode 100644 array_api_compat/numpy/_typing.py diff --git a/array_api_compat/common/_typing.py b/array_api_compat/common/_typing.py index f49868e1..3f178060 100644 --- a/array_api_compat/common/_typing.py +++ b/array_api_compat/common/_typing.py @@ -1,59 +1,20 @@ from __future__ import annotations __all__ = [ - "ndarray", - "Device", - "Dtype", "NestedSequence", "SupportsBufferProtocol", ] -import sys from typing import ( Any, - Literal, - Union, - TYPE_CHECKING, TypeVar, Protocol, ) -from numpy import ( - ndarray, - dtype, - int8, - int16, - int32, - int64, - uint8, - uint16, - uint32, - uint64, - float32, - float64, -) - _T_co = TypeVar("_T_co", covariant=True) class NestedSequence(Protocol[_T_co]): def __getitem__(self, key: int, /) -> _T_co | NestedSequence[_T_co]: ... def __len__(self, /) -> int: ... -Device = Literal["cpu"] -if TYPE_CHECKING or sys.version_info >= (3, 9): - Dtype = dtype[Union[ - int8, - int16, - int32, - int64, - uint8, - uint16, - uint32, - uint64, - float32, - float64, - ]] -else: - Dtype = dtype - SupportsBufferProtocol = Any diff --git a/array_api_compat/cupy/_typing.py b/array_api_compat/cupy/_typing.py new file mode 100644 index 00000000..f3d9aab6 --- /dev/null +++ b/array_api_compat/cupy/_typing.py @@ -0,0 +1,46 @@ +from __future__ import annotations + +__all__ = [ + "ndarray", + "Device", + "Dtype", +] + +import sys +from typing import ( + Union, + TYPE_CHECKING, +) + +from cupy import ( + ndarray, + dtype, + int8, + int16, + int32, + int64, + uint8, + uint16, + uint32, + uint64, + float32, + float64, +) + +from cupy.cuda.device import Device + +if TYPE_CHECKING or sys.version_info >= (3, 9): + Dtype = dtype[Union[ + int8, + int16, + int32, + int64, + uint8, + uint16, + uint32, + uint64, + float32, + float64, + ]] +else: + Dtype = dtype diff --git a/array_api_compat/numpy/_typing.py b/array_api_compat/numpy/_typing.py new file mode 100644 index 00000000..c5ebb5ab --- /dev/null +++ b/array_api_compat/numpy/_typing.py @@ -0,0 +1,46 @@ +from __future__ import annotations + +__all__ = [ + "ndarray", + "Device", + "Dtype", +] + +import sys +from typing import ( + Literal, + Union, + TYPE_CHECKING, +) + +from numpy import ( + ndarray, + dtype, + int8, + int16, + int32, + int64, + uint8, + uint16, + uint32, + uint64, + float32, + float64, +) + +Device = Literal["cpu"] +if TYPE_CHECKING or sys.version_info >= (3, 9): + Dtype = dtype[Union[ + int8, + int16, + int32, + int64, + uint8, + uint16, + uint32, + uint64, + float32, + float64, + ]] +else: + Dtype = dtype From 2912c9e3291423ce2713c99c01ccee65f175371f Mon Sep 17 00:00:00 2001 From: Aaron Meurer Date: Mon, 28 Nov 2022 22:39:16 -0700 Subject: [PATCH 11/24] Refactor how get_xp works Now instead of guessing the array library from the input with get_namespace, the array library is hard-coded into the wrapped function based on which subnamespace it is imported from. --- array_api_compat/_internal.py | 34 ++-- array_api_compat/common/_aliases.py | 275 ++++++++++++---------------- array_api_compat/common/_helpers.py | 9 +- array_api_compat/common/linalg.py | 18 -- array_api_compat/cupy/_aliases.py | 62 ++++++- array_api_compat/cupy/linalg.py | 27 ++- array_api_compat/numpy/_aliases.py | 62 ++++++- array_api_compat/numpy/linalg.py | 27 ++- 8 files changed, 301 insertions(+), 213 deletions(-) diff --git a/array_api_compat/_internal.py b/array_api_compat/_internal.py index 0448cd53..85826a6a 100644 --- a/array_api_compat/_internal.py +++ b/array_api_compat/_internal.py @@ -5,36 +5,40 @@ from functools import wraps from inspect import signature -from .common._helpers import get_namespace - -def get_xp(f): +def get_xp(xp): """ - Decorator to automatically replace xp with the corresponding array module + Decorator to automatically replace xp with the corresponding array module. Use like - @get_xp + import numpy as np + + @get_xp(np) def func(x, /, xp, kwarg=None): return xp.func(x, kwarg=kwarg) - Note that xp must be able to be passed as a keyword argument. + Note that xp must be a keyword argument and come after all non-keyword + arguments. + """ - @wraps(f) - def inner(*args, **kwargs): - xp = get_namespace(*args, _use_compat=False) - return f(*args, xp=xp, **kwargs) + def inner(f): + sig = signature(f) + + @wraps(f) + def wrapped_f(*args, **kwargs): + return f(*args, xp=xp, **kwargs) - sig = signature(f) - new_sig = sig.replace(parameters=[sig.parameters[i] for i in sig.parameters if i != 'xp']) + new_sig = sig.replace(parameters=[sig.parameters[i] for i in sig.parameters if i != 'xp']) - if inner.__doc__ is None: - inner.__doc__ = f"""\ + if wrapped_f.__doc__ is None: + wrapped_f.__doc__ = f"""\ Array API compatibility wrapper for {f.__name__}. See the corresponding documentation in NumPy/CuPy and/or the array API specification for more details. """ - inner.__signature__ = new_sig + # wrapped_f.__signature__ = new_sig + return wrapped_f return inner diff --git a/array_api_compat/common/_aliases.py b/array_api_compat/common/_aliases.py index 198d2d3b..1551c671 100644 --- a/array_api_compat/common/_aliases.py +++ b/array_api_compat/common/_aliases.py @@ -13,63 +13,157 @@ from types import ModuleType from ._helpers import _check_device, _is_numpy_array, get_namespace -from .._internal import get_xp # Basic renames -@get_xp def acos(x, /, xp): return xp.arccos(x) -@get_xp def acosh(x, /, xp): return xp.arccosh(x) -@get_xp def asin(x, /, xp): return xp.arcsin(x) -@get_xp def asinh(x, /, xp): return xp.arcsinh(x) -@get_xp def atan(x, /, xp): return xp.arctan(x) -@get_xp def atan2(x1, x2, /, xp): return xp.arctan2(x1, x2) -@get_xp def atanh(x, /, xp): return xp.arctanh(x) -@get_xp def bitwise_left_shift(x1, x2, /, xp): return xp.left_shift(x1, x2) -@get_xp def bitwise_invert(x, /, xp): return xp.invert(x) -@get_xp def bitwise_right_shift(x1, x2, /, xp): return xp.right_shift(x1, x2) -@get_xp -def bool(x, /, xp): - return xp.bool_(x) - -@get_xp def concat(arrays: Union[Tuple[ndarray, ...], List[ndarray]], /, xp, *, axis: Optional[int] = 0) -> ndarray: return xp.concatenate(arrays, axis=axis) -@get_xp def pow(x1, x2, /, xp): return xp.power(x1, x2) # These functions are modified from the NumPy versions. +def arange( + start: Union[int, float], + /, + xp, + stop: Optional[Union[int, float]] = None, + step: Union[int, float] = 1, + *, + dtype: Optional[Dtype] = None, + device: Optional[Device] = None, +) -> ndarray: + _check_device(xp, device) + return xp.arange(start, stop=stop, step=step, dtype=dtype) + +def empty( + shape: Union[int, Tuple[int, ...]], + xp, + *, + dtype: Optional[Dtype] = None, + device: Optional[Device] = None, +) -> ndarray: + _check_device(xp, device) + return xp.empty(shape, dtype=dtype) + +def empty_like( + x: ndarray, /, xp, *, dtype: Optional[Dtype] = None, device: Optional[Device] = None +) -> ndarray: + _check_device(xp, device) + return xp.empty_like(x, dtype=dtype) + +def eye( + n_rows: int, + n_cols: Optional[int] = None, + /, + *, + xp, + k: int = 0, + dtype: Optional[Dtype] = None, + device: Optional[Device] = None, +) -> ndarray: + _check_device(xp, device) + return xp.eye(n_rows, M=n_cols, k=k, dtype=dtype) + +def full( + shape: Union[int, Tuple[int, ...]], + fill_value: Union[int, float], + xp, + *, + dtype: Optional[Dtype] = None, + device: Optional[Device] = None, +) -> ndarray: + _check_device(xp, device) + return xp.full(shape, fill_value, dtype=dtype) + +def full_like( + x: ndarray, + /, + xp, + fill_value: Union[int, float], + *, + dtype: Optional[Dtype] = None, + device: Optional[Device] = None, +) -> ndarray: + _check_device(xp, device) + return xp.full_like(x, fill_value, dtype=dtype) + +def linspace( + start: Union[int, float], + stop: Union[int, float], + /, + xp, + num: int, + *, + dtype: Optional[Dtype] = None, + device: Optional[Device] = None, + endpoint: bool = True, +) -> ndarray: + _check_device(xp, device) + return xp.linspace(start, stop, num, dtype=dtype, endpoint=endpoint) + +def ones( + shape: Union[int, Tuple[int, ...]], + xp, + *, + dtype: Optional[Dtype] = None, + device: Optional[Device] = None, +) -> ndarray: + _check_device(xp, device) + return xp.ones(shape, dtype=dtype) + +def ones_like( + x: ndarray, /, xp, *, dtype: Optional[Dtype] = None, device: Optional[Device] = None +) -> ndarray: + _check_device(xp, device) + return xp.ones_like(x, dtype=dtype) + +def zeros( + shape: Union[int, Tuple[int, ...]], + xp, + *, + dtype: Optional[Dtype] = None, + device: Optional[Device] = None, +) -> ndarray: + _check_device(xp, device) + return xp.zeros(shape, dtype=dtype) + +def zeros_like( + x: ndarray, /, xp, *, dtype: Optional[Dtype] = None, device: Optional[Device] = None +) -> ndarray: + _check_device(xp, device) + return xp.zeros_like(x, dtype=dtype) + # np.unique() is split into four functions in the array API: # unique_all, unique_counts, unique_inverse, and unique_values (this is done # to remove polymorphic return types). @@ -93,7 +187,6 @@ class UniqueInverseResult(NamedTuple): inverse_indices: ndarray -@get_xp def unique_all(x: ndarray, /, xp) -> UniqueAllResult: values, indices, inverse_indices, counts = xp.unique( x, @@ -113,7 +206,6 @@ def unique_all(x: ndarray, /, xp) -> UniqueAllResult: ) -@get_xp def unique_counts(x: ndarray, /, xp) -> UniqueCountsResult: res = xp.unique( x, @@ -126,7 +218,6 @@ def unique_counts(x: ndarray, /, xp) -> UniqueCountsResult: return UniqueCountsResult(*res) -@get_xp def unique_inverse(x: ndarray, /, xp) -> UniqueInverseResult: values, inverse_indices = xp.unique( x, @@ -141,7 +232,6 @@ def unique_inverse(x: ndarray, /, xp) -> UniqueInverseResult: return UniqueInverseResult(values, inverse_indices) -@get_xp def unique_values(x: ndarray, /, xp) -> ndarray: return xp.unique( x, @@ -158,7 +248,6 @@ def astype(x: ndarray, dtype: Dtype, /, *, copy: bool = True) -> ndarray: # These functions have different keyword argument names -@get_xp def std( x: ndarray, /, @@ -170,7 +259,6 @@ def std( ) -> ndarray: return xp.std(x, axis=axis, ddof=correction, keepdims=keepdims) -@get_xp def var( x: ndarray, /, @@ -183,7 +271,6 @@ def var( return xp.var(x, axis=axis, ddof=correction, keepdims=keepdims) # Unlike transpose(), the axes argument to permute_dims() is required. -@get_xp def permute_dims(x: ndarray, /, xp, axes: Tuple[int, ...]) -> ndarray: return xp.transpose(x, axes) @@ -248,131 +335,8 @@ def _asarray( return xp.asarray(obj, dtype=dtype) -@get_xp -def arange( - start: Union[int, float], - /, - xp, - stop: Optional[Union[int, float]] = None, - step: Union[int, float] = 1, - *, - dtype: Optional[Dtype] = None, - device: Optional[Device] = None, -) -> ndarray: - _check_device(xp, device) - return xp.arange(start, stop=stop, step=step, dtype=dtype) - -@get_xp -def empty( - shape: Union[int, Tuple[int, ...]], - xp, - *, - dtype: Optional[Dtype] = None, - device: Optional[Device] = None, -) -> ndarray: - _check_device(xp, device) - return xp.empty(shape, dtype=dtype) - -@get_xp -def empty_like( - x: ndarray, /, xp, *, dtype: Optional[Dtype] = None, device: Optional[Device] = None -) -> ndarray: - _check_device(xp, device) - return xp.empty_like(x, dtype=dtype) - -@get_xp -def eye( - n_rows: int, - n_cols: Optional[int] = None, - /, - *, - xp, - k: int = 0, - dtype: Optional[Dtype] = None, - device: Optional[Device] = None, -) -> ndarray: - _check_device(xp, device) - return xp.eye(n_rows, M=n_cols, k=k, dtype=dtype) - -@get_xp -def full( - shape: Union[int, Tuple[int, ...]], - fill_value: Union[int, float], - xp, - *, - dtype: Optional[Dtype] = None, - device: Optional[Device] = None, -) -> ndarray: - _check_device(xp, device) - return xp.full(shape, fill_value, dtype=dtype) - -@get_xp -def full_like( - x: ndarray, - /, - xp, - fill_value: Union[int, float], - *, - dtype: Optional[Dtype] = None, - device: Optional[Device] = None, -) -> ndarray: - _check_device(xp, device) - return xp.full_like(x, fill_value, dtype=dtype) - -@get_xp -def linspace( - start: Union[int, float], - stop: Union[int, float], - /, - xp, - num: int, - *, - dtype: Optional[Dtype] = None, - device: Optional[Device] = None, - endpoint: bool = True, -) -> ndarray: - _check_device(xp, device) - return xp.linspace(start, stop, num, dtype=dtype, endpoint=endpoint) - -@get_xp -def ones( - shape: Union[int, Tuple[int, ...]], - xp, - *, - dtype: Optional[Dtype] = None, - device: Optional[Device] = None, -) -> ndarray: - _check_device(xp, device) - return xp.ones(shape, dtype=dtype) - -@get_xp -def ones_like( - x: ndarray, /, xp, *, dtype: Optional[Dtype] = None, device: Optional[Device] = None -) -> ndarray: - _check_device(xp, device) - return xp.ones_like(x, dtype=dtype) - -@get_xp -def zeros( - shape: Union[int, Tuple[int, ...]], - xp, - *, - dtype: Optional[Dtype] = None, - device: Optional[Device] = None, -) -> ndarray: - _check_device(xp, device) - return xp.zeros(shape, dtype=dtype) - -@get_xp -def zeros_like( - x: ndarray, /, xp, *, dtype: Optional[Dtype] = None, device: Optional[Device] = None -) -> ndarray: - _check_device(xp, device) - return xp.zeros_like(x, dtype=dtype) - # xp.reshape calls the keyword argument 'newshape' instead of 'shape' -@get_xp -def reshape(x: ndarray, /, xp, shape: Tuple[int, ...], copy: Optional[bool] = None) -> ndarray: +def reshape(x: ndarray, /, shape: Tuple[int, ...], xp, copy: Optional[bool] = None) -> ndarray: if copy is True: x = x.copy() elif copy is False: @@ -382,7 +346,6 @@ def reshape(x: ndarray, /, xp, shape: Tuple[int, ...], copy: Optional[bool] = No # The descending keyword is new in sort and argsort, and 'kind' replaced with # 'stable' -@get_xp def argsort( x: ndarray, /, xp, *, axis: int = -1, descending: bool = False, stable: bool = True ) -> ndarray: @@ -404,7 +367,6 @@ def argsort( res = max_i - res return res -@get_xp def sort( x: ndarray, /, xp, *, axis: int = -1, descending: bool = False, stable: bool = True ) -> ndarray: @@ -416,7 +378,6 @@ def sort( return res # sum() and prod() should always upcast when dtype=None -@get_xp def sum( x: ndarray, /, @@ -431,7 +392,6 @@ def sum( dtype = xp.float64 return xp.sum(x, axis=axis, dtype=dtype, keepdims=keepdims) -@get_xp def prod( x: ndarray, /, @@ -447,19 +407,16 @@ def prod( # ceil, floor, and trunc return integers for integer inputs -@get_xp def ceil(x: ndarray, /, xp) -> ndarray: if xp.issubdtype(x.dtype, xp.integer): return x return xp.ceil(x) -@get_xp def floor(x: ndarray, /, xp) -> ndarray: if xp.issubdtype(x.dtype, xp.integer): return x return xp.floor(x) -@get_xp def trunc(x: ndarray, /, xp) -> ndarray: if xp.issubdtype(x.dtype, xp.integer): return x @@ -467,10 +424,8 @@ def trunc(x: ndarray, /, xp) -> ndarray: __all__ = ['acos', 'acosh', 'asin', 'asinh', 'atan', 'atan2', 'atanh', 'bitwise_left_shift', 'bitwise_invert', 'bitwise_right_shift', - 'bool', 'concat', 'pow', 'UniqueAllResult', 'UniqueCountsResult', + 'concat', 'pow', 'UniqueAllResult', 'UniqueCountsResult', 'UniqueInverseResult', 'unique_all', 'unique_counts', 'unique_inverse', 'unique_values', 'astype', 'std', 'var', - 'permute_dims', 'arange', 'empty', 'empty_like', 'eye', 'full', - 'full_like', 'linspace', 'ones', 'ones_like', 'zeros', - 'zeros_like', 'reshape', 'argsort', 'sort', 'sum', 'prod', 'ceil', - 'floor', 'trunc'] + 'permute_dims', 'reshape', 'argsort', 'sort', 'sum', 'prod', + 'ceil', 'floor', 'trunc'] diff --git a/array_api_compat/common/_helpers.py b/array_api_compat/common/_helpers.py index 3863c0cb..a1310b1c 100644 --- a/array_api_compat/common/_helpers.py +++ b/array_api_compat/common/_helpers.py @@ -41,8 +41,6 @@ def get_namespace(*xs, _use_compat=True): `xs` should contain one or more arrays. """ - from ..numpy._helpers import _is_numpy_array - namespaces = set() for x in xs: if isinstance(x, (tuple, list)): @@ -56,6 +54,13 @@ def get_namespace(*xs, _use_compat=True): else: import numpy as np namespaces.add(np) + elif _is_cupy_array(x): + if _use_compat: + from .. import cupy as cupy_namespace + namespaces.add(cupy_namespace) + else: + import cupy as cp + namespaces.add(cp) else: # TODO: Support Python scalars? raise ValueError("The input is not a supported array type") diff --git a/array_api_compat/common/linalg.py b/array_api_compat/common/linalg.py index e2a2c478..df097ce5 100644 --- a/array_api_compat/common/linalg.py +++ b/array_api_compat/common/linalg.py @@ -10,19 +10,15 @@ from .._internal import get_xp # These are in the main NumPy namespace but not in numpy.linalg -@get_xp def cross(x1: ndarray, x2: ndarray, /, xp, *, axis: int = -1) -> ndarray: return xp.cross(x1, x2, axis=axis) -@get_xp def matmul(x1: ndarray, x2: ndarray, /, xp) -> ndarray: return xp.matmul(x1, x2) -@get_xp def outer(x1: ndarray, x2: ndarray, /, xp) -> ndarray: return xp.outer(x1, x2) -@get_xp def tensordot(x1: ndarray, x2: ndarray, /, xp, *, axes: Union[int, Tuple[Sequence[int], Sequence[int]]] = 2) -> ndarray: return xp.tensordot(x1, x2, axes=axes) @@ -45,26 +41,21 @@ class SVDResult(NamedTuple): # These functions are the same as their NumPy counterparts except they return # a namedtuple. -@get_xp def eigh(x: ndarray, /, xp) -> EighResult: return EighResult(*xp.linalg.eigh(x)) -@get_xp def qr(x: ndarray, /, xp, *, mode: Literal['reduced', 'complete'] = 'reduced') -> QRResult: return QRResult(*xp.linalg.qr(x, mode=mode)) -@get_xp def slogdet(x: ndarray, /, xp) -> SlogdetResult: return SlogdetResult(*xp.linalg.slogdet(x)) -@get_xp def svd(x: ndarray, /, xp, *, full_matrices: bool = True) -> SVDResult: return SVDResult(*xp.linalg.svd(x, full_matrices=full_matrices)) # These functions have additional keyword arguments # The upper keyword argument is new from NumPy -@get_xp def cholesky(x: ndarray, /, xp, *, upper: bool = False) -> ndarray: L = xp.linalg.cholesky(x) if upper: @@ -73,7 +64,6 @@ def cholesky(x: ndarray, /, xp, *, upper: bool = False) -> ndarray: # The rtol keyword argument of matrix_rank() and pinv() is new from NumPy. # Note that it has a different semantic meaning from tol and rcond. -@get_xp def matrix_rank(x: ndarray, /, xp, *, rtol: Optional[Union[float, ndarray]] = None) -> ndarray: # this is different from xp.linalg.matrix_rank, which supports 1 # dimensional arrays. @@ -88,7 +78,6 @@ def matrix_rank(x: ndarray, /, xp, *, rtol: Optional[Union[float, ndarray]] = No tol = S.max(axis=-1, keepdims=True)*xp.asarray(rtol)[..., xp.newaxis] return xp.count_nonzero(S > tol, axis=-1) -@get_xp def pinv(x: ndarray, /, xp, *, rtol: Optional[Union[float, ndarray]] = None) -> ndarray: # this is different from xp.linalg.pinv, which does not multiply the # default tolerance by max(M, N). @@ -98,12 +87,10 @@ def pinv(x: ndarray, /, xp, *, rtol: Optional[Union[float, ndarray]] = None) -> # These functions are new in the array API spec -@get_xp def matrix_norm(x: ndarray, /, xp, *, keepdims: bool = False, ord: Optional[Union[int, float, Literal['fro', 'nuc']]] = 'fro') -> ndarray: return xp.linalg.norm(x, axis=(-2, -1), keepdims=keepdims, ord=ord) # Unlike transpose, matrix_transpose only transposes the last two axes. -@get_xp def matrix_transpose(x: ndarray, /, xp) -> ndarray: if x.ndim < 2: raise ValueError("x must be at least 2-dimensional for matrix_transpose") @@ -111,11 +98,9 @@ def matrix_transpose(x: ndarray, /, xp) -> ndarray: # svdvals is not in NumPy (but it is in SciPy). It is equivalent to # xp.linalg.svd(compute_uv=False). -@get_xp def svdvals(x: ndarray, /, xp) -> Union[ndarray, Tuple[ndarray, ...]]: return xp.linalg.svd(x, compute_uv=False) -@get_xp def vecdot(x1: ndarray, x2: ndarray, /, xp, *, axis: int = -1) -> ndarray: ndim = max(x1.ndim, x2.ndim) x1_shape = (1,)*(ndim - x1.ndim) + tuple(x1.shape) @@ -130,7 +115,6 @@ def vecdot(x1: ndarray, x2: ndarray, /, xp, *, axis: int = -1) -> ndarray: res = x1_[..., None, :] @ x2_[..., None] return res[..., 0, 0] -@get_xp def vector_norm(x: ndarray, /, xp, *, axis: Optional[Union[int, Tuple[int, ...]]] = None, keepdims: bool = False, ord: Optional[Union[int, float]] = 2) -> ndarray: # xp.linalg.norm tries to do a matrix norm whenever axis is a 2-tuple or # when axis=None and the input is 2-D, so to force a vector norm, we make @@ -168,11 +152,9 @@ def vector_norm(x: ndarray, /, xp, *, axis: Optional[Union[int, Tuple[int, ...]] # xp.diagonal and xp.trace operate on the first two axes whereas these # operates on the last two -@get_xp def diagonal(x: ndarray, /, xp, *, offset: int = 0) -> ndarray: return xp.diagonal(x, offset=offset, axis1=-2, axis2=-1) -@get_xp def trace(x: ndarray, /, xp, *, offset: int = 0) -> ndarray: return xp.asarray(xp.trace(x, offset=offset, axis1=-2, axis2=-1)) diff --git a/array_api_compat/cupy/_aliases.py b/array_api_compat/cupy/_aliases.py index 49e662e3..e92f9f8c 100644 --- a/array_api_compat/cupy/_aliases.py +++ b/array_api_compat/cupy/_aliases.py @@ -1,11 +1,61 @@ +from __future__ import annotations + from functools import partial -from ..common._aliases import * -from ..common._aliases import _asarray -from ..common._aliases import __all__ +from ..common import _aliases + +from .._internal import get_xp -asarray = asarray_cupy = partial(_asarray, namespace='cupy') -asarray.__doc__ = _asarray.__doc__ +asarray = asarray_numpy = partial(_aliases._asarray, namespace='numpy') +asarray.__doc__ = _aliases._asarray.__doc__ del partial -__all__ = __all__ + ['asarray', 'asarray_cupy'] +import cupy as cp +bool = cp.bool_ + +acos = get_xp(cp)(_aliases.acos) +acosh = get_xp(cp)(_aliases.acosh) +asin = get_xp(cp)(_aliases.asin) +asinh = get_xp(cp)(_aliases.asinh) +atan = get_xp(cp)(_aliases.atan) +atan2 = get_xp(cp)(_aliases.atan2) +atanh = get_xp(cp)(_aliases.atanh) +bitwise_left_shift = get_xp(cp)(_aliases.bitwise_left_shift) +bitwise_invert = get_xp(cp)(_aliases.bitwise_invert) +bitwise_right_shift = get_xp(cp)(_aliases.bitwise_right_shift) +concat = get_xp(cp)(_aliases.concat) +pow = get_xp(cp)(_aliases.pow) +arange = get_xp(cp)(_aliases.arange) +empty = get_xp(cp)(_aliases.empty) +empty_like = get_xp(cp)(_aliases.empty_like) +eye = get_xp(cp)(_aliases.eye) +full = get_xp(cp)(_aliases.full) +full_like = get_xp(cp)(_aliases.full_like) +linspace = get_xp(cp)(_aliases.linspace) +ones = get_xp(cp)(_aliases.ones) +ones_like = get_xp(cp)(_aliases.ones_like) +zeros = get_xp(cp)(_aliases.zeros) +zeros_like = get_xp(cp)(_aliases.zeros_like) +UniqueAllResult = get_xp(cp)(_aliases.UniqueAllResult) +UniqueCountsResult = get_xp(cp)(_aliases.UniqueCountsResult) +UniqueInverseResult = get_xp(cp)(_aliases.UniqueInverseResult) +unique_all = get_xp(cp)(_aliases.unique_all) +unique_counts = get_xp(cp)(_aliases.unique_counts) +unique_inverse = get_xp(cp)(_aliases.unique_inverse) +unique_values = get_xp(cp)(_aliases.unique_values) +astype = _aliases.astype +std = get_xp(cp)(_aliases.std) +var = get_xp(cp)(_aliases.var) +permute_dims = get_xp(cp)(_aliases.permute_dims) +reshape = get_xp(cp)(_aliases.reshape) +argsort = get_xp(cp)(_aliases.argsort) +sort = get_xp(cp)(_aliases.sort) +sum = get_xp(cp)(_aliases.sum) +prod = get_xp(cp)(_aliases.prod) +ceil = get_xp(cp)(_aliases.ceil) +floor = get_xp(cp)(_aliases.floor) +trunc = get_xp(cp)(_aliases.trunc) + +__all__ = _aliases.__all__ + ['asarray', 'asarray_numpy', 'bool', 'arange', + 'empty', 'empty_like', 'eye', 'full', 'full_like', + 'linspace', 'ones', 'ones_like', 'zeros', 'zeros_like'] diff --git a/array_api_compat/cupy/linalg.py b/array_api_compat/cupy/linalg.py index 83c16b12..5d32fbdb 100644 --- a/array_api_compat/cupy/linalg.py +++ b/array_api_compat/cupy/linalg.py @@ -8,7 +8,28 @@ linalg_all = list(_n) del _n -from ..common.linalg import * -from ..common.linalg import __all__ as common_linalg_all +from ..common import linalg +from .._internal import get_xp -__all__ = linalg_all + common_linalg_all +import cupy as cp + +cross = get_xp(cp)(linalg.cross) +diagonal = get_xp(cp)(linalg.diagonal) +matmul = get_xp(cp)(linalg.matmul) +cholesky = get_xp(cp)(linalg.cholesky) +matrix_rank = get_xp(cp)(linalg.matrix_rank) +pinv = get_xp(cp)(linalg.pinv) +matrix_norm = get_xp(cp)(linalg.matrix_norm) +matrix_transpose = get_xp(cp)(linalg.matrix_transpose) +outer = get_xp(cp)(linalg.outer) +svdvals = get_xp(cp)(linalg.svdvals) +tensordot = get_xp(cp)(linalg.tensordot) +trace = get_xp(cp)(linalg.trace) +vecdot = get_xp(cp)(linalg.vecdot) +vector_norm = get_xp(cp)(linalg.vector_norm) + +__all__ = linalg_all + linalg.__all__ + +del get_xp +del cp +del linalg_all diff --git a/array_api_compat/numpy/_aliases.py b/array_api_compat/numpy/_aliases.py index ae0172c7..e6ff2ee2 100644 --- a/array_api_compat/numpy/_aliases.py +++ b/array_api_compat/numpy/_aliases.py @@ -1,11 +1,61 @@ +from __future__ import annotations + from functools import partial -from ..common._aliases import * -from ..common._aliases import _asarray -from ..common._aliases import __all__ +from ..common import _aliases + +from .._internal import get_xp -asarray = asarray_numpy = partial(_asarray, namespace='numpy') -asarray.__doc__ = _asarray.__doc__ +asarray = asarray_numpy = partial(_aliases._asarray, namespace='numpy') +asarray.__doc__ = _aliases._asarray.__doc__ del partial -__all__ = __all__ + ['asarray', 'asarray_numpy'] +import numpy as np +bool = np.bool_ + +acos = get_xp(np)(_aliases.acos) +acosh = get_xp(np)(_aliases.acosh) +asin = get_xp(np)(_aliases.asin) +asinh = get_xp(np)(_aliases.asinh) +atan = get_xp(np)(_aliases.atan) +atan2 = get_xp(np)(_aliases.atan2) +atanh = get_xp(np)(_aliases.atanh) +bitwise_left_shift = get_xp(np)(_aliases.bitwise_left_shift) +bitwise_invert = get_xp(np)(_aliases.bitwise_invert) +bitwise_right_shift = get_xp(np)(_aliases.bitwise_right_shift) +concat = get_xp(np)(_aliases.concat) +pow = get_xp(np)(_aliases.pow) +arange = get_xp(np)(_aliases.arange) +empty = get_xp(np)(_aliases.empty) +empty_like = get_xp(np)(_aliases.empty_like) +eye = get_xp(np)(_aliases.eye) +full = get_xp(np)(_aliases.full) +full_like = get_xp(np)(_aliases.full_like) +linspace = get_xp(np)(_aliases.linspace) +ones = get_xp(np)(_aliases.ones) +ones_like = get_xp(np)(_aliases.ones_like) +zeros = get_xp(np)(_aliases.zeros) +zeros_like = get_xp(np)(_aliases.zeros_like) +UniqueAllResult = get_xp(np)(_aliases.UniqueAllResult) +UniqueCountsResult = get_xp(np)(_aliases.UniqueCountsResult) +UniqueInverseResult = get_xp(np)(_aliases.UniqueInverseResult) +unique_all = get_xp(np)(_aliases.unique_all) +unique_counts = get_xp(np)(_aliases.unique_counts) +unique_inverse = get_xp(np)(_aliases.unique_inverse) +unique_values = get_xp(np)(_aliases.unique_values) +astype = _aliases.astype +std = get_xp(np)(_aliases.std) +var = get_xp(np)(_aliases.var) +permute_dims = get_xp(np)(_aliases.permute_dims) +reshape = get_xp(np)(_aliases.reshape) +argsort = get_xp(np)(_aliases.argsort) +sort = get_xp(np)(_aliases.sort) +sum = get_xp(np)(_aliases.sum) +prod = get_xp(np)(_aliases.prod) +ceil = get_xp(np)(_aliases.ceil) +floor = get_xp(np)(_aliases.floor) +trunc = get_xp(np)(_aliases.trunc) + +__all__ = _aliases.__all__ + ['asarray', 'asarray_numpy', 'bool', 'arange', + 'empty', 'empty_like', 'eye', 'full', 'full_like', + 'linspace', 'ones', 'ones_like', 'zeros', 'zeros_like'] diff --git a/array_api_compat/numpy/linalg.py b/array_api_compat/numpy/linalg.py index 1f6c95c0..b138de8e 100644 --- a/array_api_compat/numpy/linalg.py +++ b/array_api_compat/numpy/linalg.py @@ -1,7 +1,28 @@ from numpy.linalg import * from numpy.linalg import __all__ as linalg_all -from ..common.linalg import * -from ..common.linalg import __all__ as common_linalg_all +from ..common import linalg +from .._internal import get_xp -__all__ = linalg_all + common_linalg_all +import numpy as np + +cross = get_xp(np)(linalg.cross) +diagonal = get_xp(np)(linalg.diagonal) +matmul = get_xp(np)(linalg.matmul) +cholesky = get_xp(np)(linalg.cholesky) +matrix_rank = get_xp(np)(linalg.matrix_rank) +pinv = get_xp(np)(linalg.pinv) +matrix_norm = get_xp(np)(linalg.matrix_norm) +matrix_transpose = get_xp(np)(linalg.matrix_transpose) +outer = get_xp(np)(linalg.outer) +svdvals = get_xp(np)(linalg.svdvals) +tensordot = get_xp(np)(linalg.tensordot) +trace = get_xp(np)(linalg.trace) +vecdot = get_xp(np)(linalg.vecdot) +vector_norm = get_xp(np)(linalg.vector_norm) + +__all__ = linalg_all + linalg.__all__ + +del get_xp +del np +del linalg_all From 5775b11dd011dd835569699a812580d842a46d47 Mon Sep 17 00:00:00 2001 From: Aaron Meurer Date: Mon, 28 Nov 2022 22:42:48 -0700 Subject: [PATCH 12/24] Rename common.linalg to common._linalg --- .../common/{linalg.py => _linalg.py} | 0 array_api_compat/cupy/linalg.py | 33 ++++++++++--------- array_api_compat/numpy/linalg.py | 33 ++++++++++--------- 3 files changed, 34 insertions(+), 32 deletions(-) rename array_api_compat/common/{linalg.py => _linalg.py} (100%) diff --git a/array_api_compat/common/linalg.py b/array_api_compat/common/_linalg.py similarity index 100% rename from array_api_compat/common/linalg.py rename to array_api_compat/common/_linalg.py diff --git a/array_api_compat/cupy/linalg.py b/array_api_compat/cupy/linalg.py index 5d32fbdb..b11e14b6 100644 --- a/array_api_compat/cupy/linalg.py +++ b/array_api_compat/cupy/linalg.py @@ -8,28 +8,29 @@ linalg_all = list(_n) del _n -from ..common import linalg +from ..common import _linalg from .._internal import get_xp import cupy as cp -cross = get_xp(cp)(linalg.cross) -diagonal = get_xp(cp)(linalg.diagonal) -matmul = get_xp(cp)(linalg.matmul) -cholesky = get_xp(cp)(linalg.cholesky) -matrix_rank = get_xp(cp)(linalg.matrix_rank) -pinv = get_xp(cp)(linalg.pinv) -matrix_norm = get_xp(cp)(linalg.matrix_norm) -matrix_transpose = get_xp(cp)(linalg.matrix_transpose) -outer = get_xp(cp)(linalg.outer) -svdvals = get_xp(cp)(linalg.svdvals) -tensordot = get_xp(cp)(linalg.tensordot) -trace = get_xp(cp)(linalg.trace) -vecdot = get_xp(cp)(linalg.vecdot) -vector_norm = get_xp(cp)(linalg.vector_norm) +cross = get_xp(cp)(_linalg.cross) +diagonal = get_xp(cp)(_linalg.diagonal) +matmul = get_xp(cp)(_linalg.matmul) +cholesky = get_xp(cp)(_linalg.cholesky) +matrix_rank = get_xp(cp)(_linalg.matrix_rank) +pinv = get_xp(cp)(_linalg.pinv) +matrix_norm = get_xp(cp)(_linalg.matrix_norm) +matrix_transpose = get_xp(cp)(_linalg.matrix_transpose) +outer = get_xp(cp)(_linalg.outer) +svdvals = get_xp(cp)(_linalg.svdvals) +tensordot = get_xp(cp)(_linalg.tensordot) +trace = get_xp(cp)(_linalg.trace) +vecdot = get_xp(cp)(_linalg.vecdot) +vector_norm = get_xp(cp)(_linalg.vector_norm) -__all__ = linalg_all + linalg.__all__ +__all__ = linalg_all + _linalg.__all__ del get_xp del cp del linalg_all +del _linalg diff --git a/array_api_compat/numpy/linalg.py b/array_api_compat/numpy/linalg.py index b138de8e..4b373267 100644 --- a/array_api_compat/numpy/linalg.py +++ b/array_api_compat/numpy/linalg.py @@ -1,28 +1,29 @@ from numpy.linalg import * from numpy.linalg import __all__ as linalg_all -from ..common import linalg +from ..common import _linalg from .._internal import get_xp import numpy as np -cross = get_xp(np)(linalg.cross) -diagonal = get_xp(np)(linalg.diagonal) -matmul = get_xp(np)(linalg.matmul) -cholesky = get_xp(np)(linalg.cholesky) -matrix_rank = get_xp(np)(linalg.matrix_rank) -pinv = get_xp(np)(linalg.pinv) -matrix_norm = get_xp(np)(linalg.matrix_norm) -matrix_transpose = get_xp(np)(linalg.matrix_transpose) -outer = get_xp(np)(linalg.outer) -svdvals = get_xp(np)(linalg.svdvals) -tensordot = get_xp(np)(linalg.tensordot) -trace = get_xp(np)(linalg.trace) -vecdot = get_xp(np)(linalg.vecdot) -vector_norm = get_xp(np)(linalg.vector_norm) +cross = get_xp(np)(_linalg.cross) +diagonal = get_xp(np)(_linalg.diagonal) +matmul = get_xp(np)(_linalg.matmul) +cholesky = get_xp(np)(_linalg.cholesky) +matrix_rank = get_xp(np)(_linalg.matrix_rank) +pinv = get_xp(np)(_linalg.pinv) +matrix_norm = get_xp(np)(_linalg.matrix_norm) +matrix_transpose = get_xp(np)(_linalg.matrix_transpose) +outer = get_xp(np)(_linalg.outer) +svdvals = get_xp(np)(_linalg.svdvals) +tensordot = get_xp(np)(_linalg.tensordot) +trace = get_xp(np)(_linalg.trace) +vecdot = get_xp(np)(_linalg.vecdot) +vector_norm = get_xp(np)(_linalg.vector_norm) -__all__ = linalg_all + linalg.__all__ +__all__ = linalg_all + _linalg.__all__ del get_xp del np del linalg_all +del _linalg From 936a8ad8d89b501fab2d38044c9bfcd8ef153e08 Mon Sep 17 00:00:00 2001 From: Aaron Meurer Date: Mon, 28 Nov 2022 22:43:37 -0700 Subject: [PATCH 13/24] Fix arange() --- array_api_compat/common/_aliases.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/array_api_compat/common/_aliases.py b/array_api_compat/common/_aliases.py index 1551c671..da0a1a5c 100644 --- a/array_api_compat/common/_aliases.py +++ b/array_api_compat/common/_aliases.py @@ -56,10 +56,10 @@ def pow(x1, x2, /, xp): def arange( start: Union[int, float], /, - xp, stop: Optional[Union[int, float]] = None, step: Union[int, float] = 1, *, + xp, dtype: Optional[Dtype] = None, device: Optional[Device] = None, ) -> ndarray: From d88f709f061459fa748db5af8728dff7da075720 Mon Sep 17 00:00:00 2001 From: Aaron Meurer Date: Mon, 28 Nov 2022 22:44:42 -0700 Subject: [PATCH 14/24] Fix cupy asarray to create cupy arrays instead of numpy arrays --- array_api_compat/cupy/_aliases.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/array_api_compat/cupy/_aliases.py b/array_api_compat/cupy/_aliases.py index e92f9f8c..cbc89381 100644 --- a/array_api_compat/cupy/_aliases.py +++ b/array_api_compat/cupy/_aliases.py @@ -6,7 +6,7 @@ from .._internal import get_xp -asarray = asarray_numpy = partial(_aliases._asarray, namespace='numpy') +asarray = asarray_cupy = partial(_aliases._asarray, namespace='cupy') asarray.__doc__ = _aliases._asarray.__doc__ del partial @@ -56,6 +56,6 @@ floor = get_xp(cp)(_aliases.floor) trunc = get_xp(cp)(_aliases.trunc) -__all__ = _aliases.__all__ + ['asarray', 'asarray_numpy', 'bool', 'arange', +__all__ = _aliases.__all__ + ['asarray', 'asarray_cupy', 'bool', 'arange', 'empty', 'empty_like', 'eye', 'full', 'full_like', 'linspace', 'ones', 'ones_like', 'zeros', 'zeros_like'] From 360ea18798f9aa7d628ce0e64044120ab152def6 Mon Sep 17 00:00:00 2001 From: Aaron Meurer Date: Mon, 28 Nov 2022 22:50:34 -0700 Subject: [PATCH 15/24] Re-enable the signature fix in get_xp --- array_api_compat/_internal.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/array_api_compat/_internal.py b/array_api_compat/_internal.py index 85826a6a..553c0356 100644 --- a/array_api_compat/_internal.py +++ b/array_api_compat/_internal.py @@ -22,12 +22,11 @@ def func(x, /, xp, kwarg=None): """ def inner(f): - sig = signature(f) - @wraps(f) def wrapped_f(*args, **kwargs): return f(*args, xp=xp, **kwargs) + sig = signature(f) new_sig = sig.replace(parameters=[sig.parameters[i] for i in sig.parameters if i != 'xp']) if wrapped_f.__doc__ is None: @@ -38,7 +37,7 @@ def wrapped_f(*args, **kwargs): specification for more details. """ - # wrapped_f.__signature__ = new_sig + wrapped_f.__signature__ = new_sig return wrapped_f return inner From fece5e0609e2d01e68f2793c97a8644123a7f2da Mon Sep 17 00:00:00 2001 From: Aaron Meurer Date: Mon, 28 Nov 2022 22:59:56 -0700 Subject: [PATCH 16/24] Fix some issues with the linalg wrapping --- array_api_compat/common/_linalg.py | 11 ++++++----- array_api_compat/cupy/linalg.py | 16 ++++++++++++---- array_api_compat/numpy/linalg.py | 16 ++++++++++++---- 3 files changed, 30 insertions(+), 13 deletions(-) diff --git a/array_api_compat/common/_linalg.py b/array_api_compat/common/_linalg.py index df097ce5..f3bec324 100644 --- a/array_api_compat/common/_linalg.py +++ b/array_api_compat/common/_linalg.py @@ -59,7 +59,7 @@ def svd(x: ndarray, /, xp, *, full_matrices: bool = True) -> SVDResult: def cholesky(x: ndarray, /, xp, *, upper: bool = False) -> ndarray: L = xp.linalg.cholesky(x) if upper: - return matrix_transpose(L) + return get_xp(xp)(matrix_transpose)(L) return L # The rtol keyword argument of matrix_rank() and pinv() is new from NumPy. @@ -158,7 +158,8 @@ def diagonal(x: ndarray, /, xp, *, offset: int = 0) -> ndarray: def trace(x: ndarray, /, xp, *, offset: int = 0) -> ndarray: return xp.asarray(xp.trace(x, offset=offset, axis1=-2, axis2=-1)) -__all__ = ['cross', 'diagonal', 'matmul', 'cholesky', 'matrix_rank', 'pinv', - 'matrix_norm', 'matrix_transpose', 'outer', 'svdvals', - 'tensordot', 'trace', 'vecdot', 'vector_norm', 'EighResult', - 'QRResult', 'SlogdetResult', 'SVDResult'] +__all__ = ['cross', 'matmul', 'outer', 'tensordot', 'EighResult', + 'QRResult', 'SlogdetResult', 'SVDResult', 'eigh', 'qr', 'slogdet', + 'svd', 'cholesky', 'matrix_rank', 'pinv', 'matrix_norm', + 'matrix_transpose', 'svdvals', 'vecdot', 'vector_norm', 'diagonal', + 'trace'] diff --git a/array_api_compat/cupy/linalg.py b/array_api_compat/cupy/linalg.py index b11e14b6..04c71dec 100644 --- a/array_api_compat/cupy/linalg.py +++ b/array_api_compat/cupy/linalg.py @@ -14,19 +14,27 @@ import cupy as cp cross = get_xp(cp)(_linalg.cross) -diagonal = get_xp(cp)(_linalg.diagonal) matmul = get_xp(cp)(_linalg.matmul) +outer = get_xp(cp)(_linalg.outer) +tensordot = get_xp(cp)(_linalg.tensordot) +EighResult = _linalg.EighResult +QRResult = _linalg.QRResult +SlogdetResult = _linalg.SlogdetResult +SVDResult = _linalg.SVDResult +eigh = get_xp(cp)(_linalg.eigh) +qr = get_xp(cp)(_linalg.qr) +slogdet = get_xp(cp)(_linalg.slogdet) +svd = get_xp(cp)(_linalg.svd) cholesky = get_xp(cp)(_linalg.cholesky) matrix_rank = get_xp(cp)(_linalg.matrix_rank) pinv = get_xp(cp)(_linalg.pinv) matrix_norm = get_xp(cp)(_linalg.matrix_norm) matrix_transpose = get_xp(cp)(_linalg.matrix_transpose) -outer = get_xp(cp)(_linalg.outer) svdvals = get_xp(cp)(_linalg.svdvals) -tensordot = get_xp(cp)(_linalg.tensordot) -trace = get_xp(cp)(_linalg.trace) vecdot = get_xp(cp)(_linalg.vecdot) vector_norm = get_xp(cp)(_linalg.vector_norm) +diagonal = get_xp(cp)(_linalg.diagonal) +trace = get_xp(cp)(_linalg.trace) __all__ = linalg_all + _linalg.__all__ diff --git a/array_api_compat/numpy/linalg.py b/array_api_compat/numpy/linalg.py index 4b373267..ac04b055 100644 --- a/array_api_compat/numpy/linalg.py +++ b/array_api_compat/numpy/linalg.py @@ -7,19 +7,27 @@ import numpy as np cross = get_xp(np)(_linalg.cross) -diagonal = get_xp(np)(_linalg.diagonal) matmul = get_xp(np)(_linalg.matmul) +outer = get_xp(np)(_linalg.outer) +tensordot = get_xp(np)(_linalg.tensordot) +EighResult = _linalg.EighResult +QRResult = _linalg.QRResult +SlogdetResult = _linalg.SlogdetResult +SVDResult = _linalg.SVDResult +eigh = get_xp(np)(_linalg.eigh) +qr = get_xp(np)(_linalg.qr) +slogdet = get_xp(np)(_linalg.slogdet) +svd = get_xp(np)(_linalg.svd) cholesky = get_xp(np)(_linalg.cholesky) matrix_rank = get_xp(np)(_linalg.matrix_rank) pinv = get_xp(np)(_linalg.pinv) matrix_norm = get_xp(np)(_linalg.matrix_norm) matrix_transpose = get_xp(np)(_linalg.matrix_transpose) -outer = get_xp(np)(_linalg.outer) svdvals = get_xp(np)(_linalg.svdvals) -tensordot = get_xp(np)(_linalg.tensordot) -trace = get_xp(np)(_linalg.trace) vecdot = get_xp(np)(_linalg.vecdot) vector_norm = get_xp(np)(_linalg.vector_norm) +diagonal = get_xp(np)(_linalg.diagonal) +trace = get_xp(np)(_linalg.trace) __all__ = linalg_all + _linalg.__all__ From c91360b9228cf2df69b20562047d966af062e403 Mon Sep 17 00:00:00 2001 From: Aaron Meurer Date: Wed, 30 Nov 2022 14:25:24 -0700 Subject: [PATCH 17/24] Export helpers to the top-level namespace --- array_api_compat/__init__.py | 1 + array_api_compat/common/__init__.py | 1 + 2 files changed, 2 insertions(+) diff --git a/array_api_compat/__init__.py b/array_api_compat/__init__.py index babe1d72..565d9113 100644 --- a/array_api_compat/__init__.py +++ b/array_api_compat/__init__.py @@ -40,3 +40,4 @@ - NumPy functions which are not wrapped may not use positional-only arguments. """ +from .common import * diff --git a/array_api_compat/common/__init__.py b/array_api_compat/common/__init__.py index e69de29b..ce3f44dd 100644 --- a/array_api_compat/common/__init__.py +++ b/array_api_compat/common/__init__.py @@ -0,0 +1 @@ +from ._helpers import * From d19c1a2a9cd936973de3b74fbc2b1dcdbbe36afb Mon Sep 17 00:00:00 2001 From: Aaron Meurer Date: Wed, 30 Nov 2022 14:26:20 -0700 Subject: [PATCH 18/24] Fix full_like and linspace --- array_api_compat/common/_aliases.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/array_api_compat/common/_aliases.py b/array_api_compat/common/_aliases.py index da0a1a5c..bcbb2e2e 100644 --- a/array_api_compat/common/_aliases.py +++ b/array_api_compat/common/_aliases.py @@ -109,9 +109,9 @@ def full( def full_like( x: ndarray, /, - xp, fill_value: Union[int, float], *, + xp, dtype: Optional[Dtype] = None, device: Optional[Device] = None, ) -> ndarray: @@ -122,9 +122,9 @@ def linspace( start: Union[int, float], stop: Union[int, float], /, - xp, num: int, *, + xp, dtype: Optional[Dtype] = None, device: Optional[Device] = None, endpoint: bool = True, From 6c54b6b34c5cb957556eb86d043d71ec2072694c Mon Sep 17 00:00:00 2001 From: Aaron Meurer Date: Wed, 30 Nov 2022 14:32:18 -0700 Subject: [PATCH 19/24] Fix permute_dims --- array_api_compat/common/_aliases.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/array_api_compat/common/_aliases.py b/array_api_compat/common/_aliases.py index bcbb2e2e..ec8bd6e1 100644 --- a/array_api_compat/common/_aliases.py +++ b/array_api_compat/common/_aliases.py @@ -271,7 +271,7 @@ def var( return xp.var(x, axis=axis, ddof=correction, keepdims=keepdims) # Unlike transpose(), the axes argument to permute_dims() is required. -def permute_dims(x: ndarray, /, xp, axes: Tuple[int, ...]) -> ndarray: +def permute_dims(x: ndarray, /, axes: Tuple[int, ...], xp) -> ndarray: return xp.transpose(x, axes) # Creation functions add the device keyword (which does nothing for NumPy) From e996d22a8f867b4b324c8a9ea63d811aa2208fcf Mon Sep 17 00:00:00 2001 From: Aaron Meurer Date: Wed, 30 Nov 2022 17:22:54 -0700 Subject: [PATCH 20/24] Add more information to the README --- README.md | 145 +++++++++++++++++++++++++++++++++-- array_api_compat/__init__.py | 27 +------ 2 files changed, 142 insertions(+), 30 deletions(-) diff --git a/README.md b/README.md index 9ccf4afb..583eb8a4 100644 --- a/README.md +++ b/README.md @@ -1,19 +1,27 @@ # Array API compatibility library -This is a small wrapper around NumPy that is compatible with the [Array API -standard](https://data-apis.org/array-api/latest/). See also [NEP 47](https://numpy.org/neps/nep-0047-array-api-standard.html). +This is a small wrapper around NumPy and CuPy that is compatible with the +[Array API standard](https://data-apis.org/array-api/latest/). See also [NEP +47](https://numpy.org/neps/nep-0047-array-api-standard.html). Unlike `numpy.array_api`, this is not a strict minimal implementation of the Array API, but rather just an extension of the main NumPy namespace with -changes needed to be compliant with the Array API. See -https://numpy.org/doc/stable/reference/array_api.html for a full list of +changes needed to be compliant with the Array API. + +Library authors using the Array API may wish to test against numpy.array_api +to ensure they are not using functionality outside of the standard, but prefer +this implementation for the default when working with NumPy arrays. + +See https://numpy.org/doc/stable/reference/array_api.html for a full list of changes. In particular, unlike `numpy.array_api`, this package does not use a separate Array object, but rather just uses `numpy.ndarray` directly. Note that some of the functionality in this library is backwards incompatible with NumPy. -This library also supports CuPy in addition to NumPy. +This library also supports CuPy in addition to NumPy. If you want support for +other array libraries, please [open an +issue](https://github.com/data-apis/array-api-compat/issues). Library authors using the Array API may wish to test against `numpy.array_api` to ensure they are not using functionality outside of the standard, but prefer @@ -44,3 +52,130 @@ with ```py import array_api_compat.cupy as cp ``` + +Each will include all the functions from the normal NumPy/CuPy namespace, +except that functions that are part of the array API are wrapped so that they +have the correct array API behavior. In each case, the array object + + +## Helper Functions + +In addition to the default NumPy/CuPy namespace and functions in the array API +specification, there are several helper functions +included that aren't part of the specification but which are useful for using +the array API: + +- `is_array_api_obj(x)`: Return `True` if `x` is an array API compatible array + object. + +- `get_namespace(*xs)`: Get the corresponding array API namespace for the + arrays `xs`. If the arrays are NumPy or CuPy arrays, the returned namespace + will be `array_api_compat.numpy` or `array_api_compat.cupy` so that it is + array API compatible. + +- `device(x)`: Equivalent to + [`x.device`](https://data-apis.org/array-api/latest/API_specification/generated/signatures.array_object.array.device.html) + in the array API specification. Included because `numpy.ndarray` does not + include the `device` attribute and this library does not wrap or extend the + array object. Note that for NumPy, `device` is always `"cpu"`. + +- `to_device(x, device, /, *, stream=None)`: Equivalent to + [`x.to_device`](https://data-apis.org/array-api/latest/API_specification/generated/signatures.array_object.array.to_device.html). + Included because neither NumPy's nor CuPy's ndarray objects include this + method. For NumPy, this function effectively does nothing since the only + supported device is the CPU, but for CuPy, this method supports CuPy CUDA + [Device](https://docs.cupy.dev/en/stable/reference/generated/cupy.cuda.Device.html) + and + [Stream](https://docs.cupy.dev/en/stable/reference/generated/cupy.cuda.Stream.html) + objects. + +## Known Differences from the Array API Specification + +There are some known differences between this library and the array API +specification: + +- The array methods `__array_namespace__`, `device` (for NumPy), `to_device`, + and `mT` are not defined. This reuses `np.ndarray` and `cp.ndarray` and we + don't want to monkeypatch or wrap it. The helper functions `device()` and + `to_device()` are provided to work around these missing methods (see above). + `x.mT` can be replaced with `xp.linalg.matrix_transpose(x)`. + `get_namespace(x)` should be used instead of `x.__array_namespace__`. + +- NumPy value-based casting for scalars will be in effect unless explicitly + disabled with the environment variable NPY_PROMOTION_STATE=weak or + np._set_promotion_state('weak') (requires NumPy 1.24 or newer, see NEP 50 + and https://github.com/numpy/numpy/issues/22341) + +- Functions which are not wrapped may not have the same type annotations + as the spec. + +- Functions which are not wrapped may not use positional-only arguments. + +## Vendoring + +This library supports vendoring as an installation method. To vendor the +library, simply copy `array_api_compat` into the appropriate place in the +library, like + +``` +cp -R array_api_compat/ mylib/vendored/array_api_compat +``` + +You may also rename it to something else if you like (nowhere in the code +references the name "array_api_compat"). + +Alternatively, the library may be installed as dependency on PyPI. + +## Implementation + +As noted before, the goal of this library is to reuse the NumPy and CuPy array +objects, rather than wrapping or extending them. This means that the functions +need to accept and return `np.ndarray` for NumPy and `cp.ndarray` for CuPy. + +Each namespace (`array_api_compat.numpy` and `array_api_compat.cupy`) is +populated with the normal library namespace (like `from numpy import *`). Then +specific functions are replaced with wrapped variants. Wrapped functions that +have the same logic between NumPy and CuPy (which is most functions) are in +`array_api_compat/common/`. These functions are defined like + +```py +# In array_api_compat/common/_aliases.py + +def acos(x, /, xp): + return xp.arccos(x) +``` + +The `xp` argument refers to the original array namespace (either `numpy` or +`cupy`). Then in the specific `array_api_compat/numpy` and +`array_api_compat/cupy` namespace, the `get_xp` decorator is applied to these +functions, which automatically removes the `xp` argument from the function +signature and replaces it with the corresponding array library, like + +```py +# In array_api_compat/numpy/_aliases.py + +from ..common import _aliases + +import numpy as np + +acos = get_xp(np)(_aliases.acos) +``` + +This `acos` now has the signature `acos(x, /)` and calls `numpy.arccos`. + +Similarly, for CuPy: + +```py +# In array_api_compat/cupy/_aliases.py + +from ..common import _aliases + +import cupy as cp + +acos = get_xp(cp)(_aliases.acos) +``` + +Since NumPy and CuPy are nearly identical in their behaviors, this allows +writing the wrapping logic for both libraries only once. If support is added +for other libraries which differ significantly from NumPy, their wrapper code +should go in their specific sub-namespace instead of `common/`. diff --git a/array_api_compat/__init__.py b/array_api_compat/__init__.py index 565d9113..34bd7565 100644 --- a/array_api_compat/__init__.py +++ b/array_api_compat/__init__.py @@ -1,8 +1,8 @@ """ NumPy Array API compatibility library -This is a small wrapper around NumPy that is compatible with the Array API -standard https://data-apis.org/array-api/latest/. See also NEP 47 +This is a small wrapper around NumPy and CuPy that is compatible with the +Array API standard https://data-apis.org/array-api/latest/. See also NEP 47 https://numpy.org/neps/nep-0047-array-api-standard.html. Unlike numpy.array_api, this is not a strict minimal implementation of the @@ -16,28 +16,5 @@ to ensure they are not using functionality outside of the standard, but prefer this implementation for the default when working with NumPy arrays. -In addition, several helper functions are provided in this library which are -not part of the array API specification but which are useful for libraries -writing against the array API specification who wish to support NumPy and -other array API compatible libraries. - -Known differences from the Array API spec: - -- The array methods __array_namespace__, device, to_device, and mT are not - defined. This reuses np.ndarray and we don't want to monkeypatch or wrap it. - The helper functions device() and to_device() are provided to work around - these missing methods. x.mT can be replaced with - xp.linalg.matrix_transpose(x). - -- NumPy value-based casting for scalars will be in effect unless explicitly - disabled with the environment variable NPY_PROMOTION_STATE=weak or - np._set_promotion_state('weak') (requires NumPy 1.24 or newer, see NEP 50 - and https://github.com/numpy/numpy/issues/22341) - -- NumPy functions which are not wrapped may not have the same type annotations - as the spec. - -- NumPy functions which are not wrapped may not use positional-only arguments. - """ from .common import * From 82365ebf22c85076094a9d66516c423c36f232a7 Mon Sep 17 00:00:00 2001 From: Aaron Meurer Date: Wed, 30 Nov 2022 17:36:56 -0700 Subject: [PATCH 21/24] Add a test that vendoring works --- tests/__init__.py | 7 +++++++ tests/test_vendoring.py | 15 +++++++++++++++ tests/vendor_test/__init__.py | 0 tests/vendor_test/uses_cupy.py | 18 ++++++++++++++++++ tests/vendor_test/uses_numpy.py | 18 ++++++++++++++++++ tests/vendor_test/vendored/__init__.py | 0 tests/vendor_test/vendored/_compat | 1 + 7 files changed, 59 insertions(+) create mode 100644 tests/__init__.py create mode 100644 tests/test_vendoring.py create mode 100644 tests/vendor_test/__init__.py create mode 100644 tests/vendor_test/uses_cupy.py create mode 100644 tests/vendor_test/uses_numpy.py create mode 100644 tests/vendor_test/vendored/__init__.py create mode 120000 tests/vendor_test/vendored/_compat diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 00000000..abc8f0f1 --- /dev/null +++ b/tests/__init__.py @@ -0,0 +1,7 @@ +""" +Basic tests for the compat library + +This only tests basic things like that vendoring works. The extensive tests +are done by the array API test suite https://github.com/data-apis/array-api-tests + +""" diff --git a/tests/test_vendoring.py b/tests/test_vendoring.py new file mode 100644 index 00000000..85f68626 --- /dev/null +++ b/tests/test_vendoring.py @@ -0,0 +1,15 @@ +from pytest import skip + +def test_vendoring_numpy(): + from vendor_test import uses_numpy + uses_numpy._test_numpy() + + +def test_vendoring_cupy(): + try: + import cupy + except ImportError: + skip("CuPy is not installed") + + from vendor_test import uses_cupy + uses_cupy._test_cupy() diff --git a/tests/vendor_test/__init__.py b/tests/vendor_test/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/vendor_test/uses_cupy.py b/tests/vendor_test/uses_cupy.py new file mode 100644 index 00000000..97f710b9 --- /dev/null +++ b/tests/vendor_test/uses_cupy.py @@ -0,0 +1,18 @@ +# Basic test that vendoring works + +from .vendored._compat import cupy as cp_compat + +import cupy as cp + +def _test_cupy(): + a = cp_compat.asarray([1., 2., 3.]) + b = cp_compat.arange(3, dtype=cp_compat.float32) + + # cp.pow does not exist. Update this to use something else if it is added + res = cp_compat.pow(a, b) + assert res.dtype == cp_compat.float64 == cp.float64 + assert isinstance(a, cp.ndarray) + assert isinstance(b, cp.ndarray) + assert isinstance(res, cp.ndarray) + + cp.testing.assert_allclose(res, [1., 2., 9.]) diff --git a/tests/vendor_test/uses_numpy.py b/tests/vendor_test/uses_numpy.py new file mode 100644 index 00000000..96f2c5ff --- /dev/null +++ b/tests/vendor_test/uses_numpy.py @@ -0,0 +1,18 @@ +# Basic test that vendoring works + +from .vendored._compat import numpy as np_compat + +import numpy as np + +def _test_numpy(): + a = np_compat.asarray([1., 2., 3.]) + b = np_compat.arange(3, dtype=np_compat.float32) + + # np.pow does not exist. Update this to use something else if it is added + res = np_compat.pow(a, b) + assert res.dtype == np_compat.float64 == np.float64 + assert isinstance(a, np.ndarray) + assert isinstance(b, np.ndarray) + assert isinstance(res, np.ndarray) + + np.testing.assert_allclose(res, [1., 2., 9.]) diff --git a/tests/vendor_test/vendored/__init__.py b/tests/vendor_test/vendored/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/vendor_test/vendored/_compat b/tests/vendor_test/vendored/_compat new file mode 120000 index 00000000..4843524d --- /dev/null +++ b/tests/vendor_test/vendored/_compat @@ -0,0 +1 @@ +../../../array_api_compat/ \ No newline at end of file From a83b15c49f028d4c4481b6f446a06b86a5152c1a Mon Sep 17 00:00:00 2001 From: Aaron Meurer Date: Wed, 30 Nov 2022 17:42:26 -0700 Subject: [PATCH 22/24] Fixes to the README --- README.md | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/README.md b/README.md index 583eb8a4..5f74b1b6 100644 --- a/README.md +++ b/README.md @@ -5,12 +5,12 @@ This is a small wrapper around NumPy and CuPy that is compatible with the 47](https://numpy.org/neps/nep-0047-array-api-standard.html). Unlike `numpy.array_api`, this is not a strict minimal implementation of the -Array API, but rather just an extension of the main NumPy namespace with -changes needed to be compliant with the Array API. +Array API, but rather just an extension of the main NumPy and CuPy namespaces +with changes needed to be compliant with the Array API. -Library authors using the Array API may wish to test against numpy.array_api +Library authors using the Array API may wish to test against `numpy.array_api` to ensure they are not using functionality outside of the standard, but prefer -this implementation for the default when working with NumPy arrays. +this implementation for the default when working with NumPy or CuPy arrays. See https://numpy.org/doc/stable/reference/array_api.html for a full list of changes. In particular, unlike `numpy.array_api`, this package does not use a From 8d2d37ab6e00c5416474e9bf1699f5ec3f2254fb Mon Sep 17 00:00:00 2001 From: Aaron Meurer Date: Mon, 5 Dec 2022 15:59:27 -0700 Subject: [PATCH 23/24] Fix missing sentence in the README --- README.md | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/README.md b/README.md index 5f74b1b6..1f5b7c40 100644 --- a/README.md +++ b/README.md @@ -55,7 +55,8 @@ import array_api_compat.cupy as cp Each will include all the functions from the normal NumPy/CuPy namespace, except that functions that are part of the array API are wrapped so that they -have the correct array API behavior. In each case, the array object +have the correct array API behavior. In each case, the array object used will +be thew same array object from the wrapped library. ## Helper Functions From 732b4933fb46c0384f341a8c687e225dd0adacd2 Mon Sep 17 00:00:00 2001 From: Aaron Meurer Date: Mon, 5 Dec 2022 16:01:19 -0700 Subject: [PATCH 24/24] Move vendor_test to the top-level --- tests/vendor_test/vendored/_compat | 1 - {tests/vendor_test => vendor_test}/__init__.py | 0 {tests/vendor_test => vendor_test}/uses_cupy.py | 0 {tests/vendor_test => vendor_test}/uses_numpy.py | 0 {tests/vendor_test => vendor_test}/vendored/__init__.py | 0 vendor_test/vendored/_compat | 1 + 6 files changed, 1 insertion(+), 1 deletion(-) delete mode 120000 tests/vendor_test/vendored/_compat rename {tests/vendor_test => vendor_test}/__init__.py (100%) rename {tests/vendor_test => vendor_test}/uses_cupy.py (100%) rename {tests/vendor_test => vendor_test}/uses_numpy.py (100%) rename {tests/vendor_test => vendor_test}/vendored/__init__.py (100%) create mode 120000 vendor_test/vendored/_compat diff --git a/tests/vendor_test/vendored/_compat b/tests/vendor_test/vendored/_compat deleted file mode 120000 index 4843524d..00000000 --- a/tests/vendor_test/vendored/_compat +++ /dev/null @@ -1 +0,0 @@ -../../../array_api_compat/ \ No newline at end of file diff --git a/tests/vendor_test/__init__.py b/vendor_test/__init__.py similarity index 100% rename from tests/vendor_test/__init__.py rename to vendor_test/__init__.py diff --git a/tests/vendor_test/uses_cupy.py b/vendor_test/uses_cupy.py similarity index 100% rename from tests/vendor_test/uses_cupy.py rename to vendor_test/uses_cupy.py diff --git a/tests/vendor_test/uses_numpy.py b/vendor_test/uses_numpy.py similarity index 100% rename from tests/vendor_test/uses_numpy.py rename to vendor_test/uses_numpy.py diff --git a/tests/vendor_test/vendored/__init__.py b/vendor_test/vendored/__init__.py similarity index 100% rename from tests/vendor_test/vendored/__init__.py rename to vendor_test/vendored/__init__.py diff --git a/vendor_test/vendored/_compat b/vendor_test/vendored/_compat new file mode 120000 index 00000000..ba484f91 --- /dev/null +++ b/vendor_test/vendored/_compat @@ -0,0 +1 @@ +../../array_api_compat/ \ No newline at end of file