diff --git a/array_api_compat/common/_aliases.py b/array_api_compat/common/_aliases.py index ec8bd6e1..95885080 100644 --- a/array_api_compat/common/_aliases.py +++ b/array_api_compat/common/_aliases.py @@ -14,43 +14,6 @@ from ._helpers import _check_device, _is_numpy_array, get_namespace -# Basic renames -def acos(x, /, xp): - return xp.arccos(x) - -def acosh(x, /, xp): - return xp.arccosh(x) - -def asin(x, /, xp): - return xp.arcsin(x) - -def asinh(x, /, xp): - return xp.arcsinh(x) - -def atan(x, /, xp): - return xp.arctan(x) - -def atan2(x1, x2, /, xp): - return xp.arctan2(x1, x2) - -def atanh(x, /, xp): - return xp.arctanh(x) - -def bitwise_left_shift(x1, x2, /, xp): - return xp.left_shift(x1, x2) - -def bitwise_invert(x, /, xp): - return xp.invert(x) - -def bitwise_right_shift(x1, x2, /, xp): - return xp.right_shift(x1, x2) - -def concat(arrays: Union[Tuple[ndarray, ...], List[ndarray]], /, xp, *, axis: Optional[int] = 0) -> ndarray: - return xp.concatenate(arrays, axis=axis) - -def pow(x1, x2, /, xp): - return xp.power(x1, x2) - # These functions are modified from the NumPy versions. def arange( @@ -62,9 +25,10 @@ def arange( xp, dtype: Optional[Dtype] = None, device: Optional[Device] = None, + **kwargs ) -> ndarray: _check_device(xp, device) - return xp.arange(start, stop=stop, step=step, dtype=dtype) + return xp.arange(start, stop=stop, step=step, dtype=dtype, **kwargs) def empty( shape: Union[int, Tuple[int, ...]], @@ -72,15 +36,17 @@ def empty( *, dtype: Optional[Dtype] = None, device: Optional[Device] = None, + **kwargs ) -> ndarray: _check_device(xp, device) - return xp.empty(shape, dtype=dtype) + return xp.empty(shape, dtype=dtype, **kwargs) def empty_like( - x: ndarray, /, xp, *, dtype: Optional[Dtype] = None, device: Optional[Device] = None + x: ndarray, /, xp, *, dtype: Optional[Dtype] = None, device: Optional[Device] = None, + **kwargs ) -> ndarray: _check_device(xp, device) - return xp.empty_like(x, dtype=dtype) + return xp.empty_like(x, dtype=dtype, **kwargs) def eye( n_rows: int, @@ -91,9 +57,10 @@ def eye( k: int = 0, dtype: Optional[Dtype] = None, device: Optional[Device] = None, + **kwargs, ) -> ndarray: _check_device(xp, device) - return xp.eye(n_rows, M=n_cols, k=k, dtype=dtype) + return xp.eye(n_rows, M=n_cols, k=k, dtype=dtype, **kwargs) def full( shape: Union[int, Tuple[int, ...]], @@ -102,9 +69,10 @@ def full( *, dtype: Optional[Dtype] = None, device: Optional[Device] = None, + **kwargs, ) -> ndarray: _check_device(xp, device) - return xp.full(shape, fill_value, dtype=dtype) + return xp.full(shape, fill_value, dtype=dtype, **kwargs) def full_like( x: ndarray, @@ -114,9 +82,10 @@ def full_like( xp, dtype: Optional[Dtype] = None, device: Optional[Device] = None, + **kwargs, ) -> ndarray: _check_device(xp, device) - return xp.full_like(x, fill_value, dtype=dtype) + return xp.full_like(x, fill_value, dtype=dtype, **kwargs) def linspace( start: Union[int, float], @@ -128,9 +97,10 @@ def linspace( dtype: Optional[Dtype] = None, device: Optional[Device] = None, endpoint: bool = True, + **kwargs, ) -> ndarray: _check_device(xp, device) - return xp.linspace(start, stop, num, dtype=dtype, endpoint=endpoint) + return xp.linspace(start, stop, num, dtype=dtype, endpoint=endpoint, **kwargs) def ones( shape: Union[int, Tuple[int, ...]], @@ -138,15 +108,17 @@ def ones( *, dtype: Optional[Dtype] = None, device: Optional[Device] = None, + **kwargs, ) -> ndarray: _check_device(xp, device) - return xp.ones(shape, dtype=dtype) + return xp.ones(shape, dtype=dtype, **kwargs) def ones_like( - x: ndarray, /, xp, *, dtype: Optional[Dtype] = None, device: Optional[Device] = None + x: ndarray, /, xp, *, dtype: Optional[Dtype] = None, device: Optional[Device] = None, + **kwargs, ) -> ndarray: _check_device(xp, device) - return xp.ones_like(x, dtype=dtype) + return xp.ones_like(x, dtype=dtype, **kwargs) def zeros( shape: Union[int, Tuple[int, ...]], @@ -154,15 +126,17 @@ def zeros( *, dtype: Optional[Dtype] = None, device: Optional[Device] = None, + **kwargs, ) -> ndarray: _check_device(xp, device) - return xp.zeros(shape, dtype=dtype) + return xp.zeros(shape, dtype=dtype, **kwargs) def zeros_like( - x: ndarray, /, xp, *, dtype: Optional[Dtype] = None, device: Optional[Device] = None + x: ndarray, /, xp, *, dtype: Optional[Dtype] = None, device: Optional[Device] = None, + **kwargs, ) -> ndarray: _check_device(xp, device) - return xp.zeros_like(x, dtype=dtype) + return xp.zeros_like(x, dtype=dtype, **kwargs) # np.unique() is split into four functions in the array API: # unique_all, unique_counts, unique_inverse, and unique_values (this is done @@ -256,8 +230,9 @@ def std( axis: Optional[Union[int, Tuple[int, ...]]] = None, correction: Union[int, float] = 0.0, # correction instead of ddof keepdims: bool = False, + **kwargs, ) -> ndarray: - return xp.std(x, axis=axis, ddof=correction, keepdims=keepdims) + return xp.std(x, axis=axis, ddof=correction, keepdims=keepdims, **kwargs) def var( x: ndarray, @@ -267,8 +242,9 @@ def var( axis: Optional[Union[int, Tuple[int, ...]]] = None, correction: Union[int, float] = 0.0, # correction instead of ddof keepdims: bool = False, + **kwargs, ) -> ndarray: - return xp.var(x, axis=axis, ddof=correction, keepdims=keepdims) + return xp.var(x, axis=axis, ddof=correction, keepdims=keepdims, **kwargs) # Unlike transpose(), the axes argument to permute_dims() is required. def permute_dims(x: ndarray, /, axes: Tuple[int, ...], xp) -> ndarray: @@ -292,6 +268,7 @@ def _asarray( device: Optional[Device] = None, copy: "Optional[Union[bool, np._CopyMode]]" = None, namespace = None, + **kwargs, ) -> ndarray: """ Array API compatibility wrapper for asarray(). @@ -333,26 +310,31 @@ def _asarray( return xp.array(obj, copy=True, dtype=dtype) return obj - return xp.asarray(obj, dtype=dtype) + return xp.asarray(obj, dtype=dtype, **kwargs) # xp.reshape calls the keyword argument 'newshape' instead of 'shape' -def reshape(x: ndarray, /, shape: Tuple[int, ...], xp, copy: Optional[bool] = None) -> ndarray: +def reshape(x: ndarray, + /, + shape: Tuple[int, ...], + xp, copy: Optional[bool] = None, + **kwargs) -> ndarray: if copy is True: x = x.copy() elif copy is False: x.shape = shape return x - return xp.reshape(x, shape) + return xp.reshape(x, shape, **kwargs) # The descending keyword is new in sort and argsort, and 'kind' replaced with # 'stable' def argsort( - x: ndarray, /, xp, *, axis: int = -1, descending: bool = False, stable: bool = True + x: ndarray, /, xp, *, axis: int = -1, descending: bool = False, stable: bool = True, + **kwargs, ) -> ndarray: # Note: this keyword argument is different, and the default is different. kind = "stable" if stable else "quicksort" if not descending: - res = xp.argsort(x, axis=axis, kind=kind) + res = xp.argsort(x, axis=axis, kind=kind, **kwargs) else: # As NumPy has no native descending sort, we imitate it here. Note that # simply flipping the results of xp.argsort(x, ...) would not @@ -360,6 +342,7 @@ def argsort( res = xp.flip( xp.argsort(xp.flip(x, axis=axis), axis=axis, kind=kind), axis=axis, + **kwargs, ) # Rely on flip()/argsort() to validate axis normalised_axis = axis if axis >= 0 else x.ndim + axis @@ -368,11 +351,12 @@ def argsort( return res def sort( - x: ndarray, /, xp, *, axis: int = -1, descending: bool = False, stable: bool = True + x: ndarray, /, xp, *, axis: int = -1, descending: bool = False, stable: bool = True, + **kwargs, ) -> ndarray: # Note: this keyword argument is different, and the default is different. kind = "stable" if stable else "quicksort" - res = xp.sort(x, axis=axis, kind=kind) + res = xp.sort(x, axis=axis, kind=kind, **kwargs) if descending: res = xp.flip(res, axis=axis) return res @@ -386,11 +370,12 @@ def sum( axis: Optional[Union[int, Tuple[int, ...]]] = None, dtype: Optional[Dtype] = None, keepdims: bool = False, + **kwargs, ) -> ndarray: # `xp.sum` already upcasts integers, but not floats if dtype is None and x.dtype == xp.float32: dtype = xp.float64 - return xp.sum(x, axis=axis, dtype=dtype, keepdims=keepdims) + return xp.sum(x, axis=axis, dtype=dtype, keepdims=keepdims, **kwargs) def prod( x: ndarray, @@ -400,32 +385,30 @@ def prod( axis: Optional[Union[int, Tuple[int, ...]]] = None, dtype: Optional[Dtype] = None, keepdims: bool = False, + **kwargs, ) -> ndarray: if dtype is None and x.dtype == xp.float32: dtype = xp.float64 - return xp.prod(x, dtype=dtype, axis=axis, keepdims=keepdims) + return xp.prod(x, dtype=dtype, axis=axis, keepdims=keepdims, **kwargs) # ceil, floor, and trunc return integers for integer inputs -def ceil(x: ndarray, /, xp) -> ndarray: +def ceil(x: ndarray, /, xp, **kwargs) -> ndarray: if xp.issubdtype(x.dtype, xp.integer): return x - return xp.ceil(x) + return xp.ceil(x, **kwargs) -def floor(x: ndarray, /, xp) -> ndarray: +def floor(x: ndarray, /, xp, **kwargs) -> ndarray: if xp.issubdtype(x.dtype, xp.integer): return x - return xp.floor(x) + return xp.floor(x, **kwargs) -def trunc(x: ndarray, /, xp) -> ndarray: +def trunc(x: ndarray, /, xp, **kwargs) -> ndarray: if xp.issubdtype(x.dtype, xp.integer): return x - return xp.trunc(x) - -__all__ = ['acos', 'acosh', 'asin', 'asinh', 'atan', 'atan2', 'atanh', - 'bitwise_left_shift', 'bitwise_invert', 'bitwise_right_shift', - 'concat', 'pow', 'UniqueAllResult', 'UniqueCountsResult', - 'UniqueInverseResult', 'unique_all', 'unique_counts', - 'unique_inverse', 'unique_values', 'astype', 'std', 'var', - 'permute_dims', 'reshape', 'argsort', 'sort', 'sum', 'prod', - 'ceil', 'floor', 'trunc'] + return xp.trunc(x, **kwargs) + +__all__ = ['UniqueAllResult', 'UniqueCountsResult', 'UniqueInverseResult', + 'unique_all', 'unique_counts', 'unique_inverse', 'unique_values', + 'astype', 'std', 'var', 'permute_dims', 'reshape', 'argsort', + 'sort', 'sum', 'prod', 'ceil', 'floor', 'trunc'] diff --git a/array_api_compat/common/_linalg.py b/array_api_compat/common/_linalg.py index f3bec324..c42879d6 100644 --- a/array_api_compat/common/_linalg.py +++ b/array_api_compat/common/_linalg.py @@ -10,17 +10,24 @@ from .._internal import get_xp # These are in the main NumPy namespace but not in numpy.linalg -def cross(x1: ndarray, x2: ndarray, /, xp, *, axis: int = -1) -> ndarray: - return xp.cross(x1, x2, axis=axis) +def cross(x1: ndarray, x2: ndarray, /, xp, *, axis: int = -1, **kwargs) -> ndarray: + return xp.cross(x1, x2, axis=axis, **kwargs) -def matmul(x1: ndarray, x2: ndarray, /, xp) -> ndarray: - return xp.matmul(x1, x2) +def matmul(x1: ndarray, x2: ndarray, /, xp, **kwargs) -> ndarray: + return xp.matmul(x1, x2, **kwargs) -def outer(x1: ndarray, x2: ndarray, /, xp) -> ndarray: - return xp.outer(x1, x2) +def outer(x1: ndarray, x2: ndarray, /, xp, **kwargs) -> ndarray: + return xp.outer(x1, x2, **kwargs) -def tensordot(x1: ndarray, x2: ndarray, /, xp, *, axes: Union[int, Tuple[Sequence[int], Sequence[int]]] = 2) -> ndarray: - return xp.tensordot(x1, x2, axes=axes) +def tensordot(x1: ndarray, + x2: ndarray, + /, + xp, + *, + axes: Union[int, Tuple[Sequence[int], Sequence[int]]] = 2, + **kwargs, +) -> ndarray: + return xp.tensordot(x1, x2, axes=axes, **kwargs) class EighResult(NamedTuple): eigenvalues: ndarray @@ -41,35 +48,41 @@ class SVDResult(NamedTuple): # These functions are the same as their NumPy counterparts except they return # a namedtuple. -def eigh(x: ndarray, /, xp) -> EighResult: - return EighResult(*xp.linalg.eigh(x)) +def eigh(x: ndarray, /, xp, **kwargs) -> EighResult: + return EighResult(*xp.linalg.eigh(x, **kwargs)) -def qr(x: ndarray, /, xp, *, mode: Literal['reduced', 'complete'] = 'reduced') -> QRResult: - return QRResult(*xp.linalg.qr(x, mode=mode)) +def qr(x: ndarray, /, xp, *, mode: Literal['reduced', 'complete'] = 'reduced', + **kwargs) -> QRResult: + return QRResult(*xp.linalg.qr(x, mode=mode, **kwargs)) -def slogdet(x: ndarray, /, xp) -> SlogdetResult: - return SlogdetResult(*xp.linalg.slogdet(x)) +def slogdet(x: ndarray, /, xp, **kwargs) -> SlogdetResult: + return SlogdetResult(*xp.linalg.slogdet(x, **kwargs)) -def svd(x: ndarray, /, xp, *, full_matrices: bool = True) -> SVDResult: - return SVDResult(*xp.linalg.svd(x, full_matrices=full_matrices)) +def svd(x: ndarray, /, xp, *, full_matrices: bool = True, **kwargs) -> SVDResult: + return SVDResult(*xp.linalg.svd(x, full_matrices=full_matrices, **kwargs)) # These functions have additional keyword arguments # The upper keyword argument is new from NumPy -def cholesky(x: ndarray, /, xp, *, upper: bool = False) -> ndarray: - L = xp.linalg.cholesky(x) +def cholesky(x: ndarray, /, xp, *, upper: bool = False, **kwargs) -> ndarray: + L = xp.linalg.cholesky(x, **kwargs) if upper: return get_xp(xp)(matrix_transpose)(L) return L # The rtol keyword argument of matrix_rank() and pinv() is new from NumPy. # Note that it has a different semantic meaning from tol and rcond. -def matrix_rank(x: ndarray, /, xp, *, rtol: Optional[Union[float, ndarray]] = None) -> ndarray: +def matrix_rank(x: ndarray, + /, + xp, + *, + rtol: Optional[Union[float, ndarray]] = None, + **kwargs) -> ndarray: # this is different from xp.linalg.matrix_rank, which supports 1 # dimensional arrays. if x.ndim < 2: raise xp.linalg.LinAlgError("1-dimensional array given. Array must be at least two-dimensional") - S = xp.linalg.svd(x, compute_uv=False) + S = xp.linalg.svd(x, compute_uv=False, **kwargs) if rtol is None: tol = S.max(axis=-1, keepdims=True) * max(x.shape[-2:]) * xp.finfo(S.dtype).eps else: @@ -78,12 +91,12 @@ def matrix_rank(x: ndarray, /, xp, *, rtol: Optional[Union[float, ndarray]] = No tol = S.max(axis=-1, keepdims=True)*xp.asarray(rtol)[..., xp.newaxis] return xp.count_nonzero(S > tol, axis=-1) -def pinv(x: ndarray, /, xp, *, rtol: Optional[Union[float, ndarray]] = None) -> ndarray: +def pinv(x: ndarray, /, xp, *, rtol: Optional[Union[float, ndarray]] = None, **kwargs) -> ndarray: # this is different from xp.linalg.pinv, which does not multiply the # default tolerance by max(M, N). if rtol is None: rtol = max(x.shape[-2:]) * xp.finfo(x.dtype).eps - return xp.linalg.pinv(x, rcond=rtol) + return xp.linalg.pinv(x, rcond=rtol, **kwargs) # These functions are new in the array API spec @@ -152,11 +165,11 @@ def vector_norm(x: ndarray, /, xp, *, axis: Optional[Union[int, Tuple[int, ...]] # xp.diagonal and xp.trace operate on the first two axes whereas these # operates on the last two -def diagonal(x: ndarray, /, xp, *, offset: int = 0) -> ndarray: - return xp.diagonal(x, offset=offset, axis1=-2, axis2=-1) +def diagonal(x: ndarray, /, xp, *, offset: int = 0, **kwargs) -> ndarray: + return xp.diagonal(x, offset=offset, axis1=-2, axis2=-1, **kwargs) -def trace(x: ndarray, /, xp, *, offset: int = 0) -> ndarray: - return xp.asarray(xp.trace(x, offset=offset, axis1=-2, axis2=-1)) +def trace(x: ndarray, /, xp, *, offset: int = 0, **kwargs) -> ndarray: + return xp.asarray(xp.trace(x, offset=offset, axis1=-2, axis2=-1, **kwargs)) __all__ = ['cross', 'matmul', 'outer', 'tensordot', 'EighResult', 'QRResult', 'SlogdetResult', 'SVDResult', 'eigh', 'qr', 'slogdet', diff --git a/array_api_compat/cupy/_aliases.py b/array_api_compat/cupy/_aliases.py index cbc89381..939656ae 100644 --- a/array_api_compat/cupy/_aliases.py +++ b/array_api_compat/cupy/_aliases.py @@ -13,18 +13,20 @@ import cupy as cp bool = cp.bool_ -acos = get_xp(cp)(_aliases.acos) -acosh = get_xp(cp)(_aliases.acosh) -asin = get_xp(cp)(_aliases.asin) -asinh = get_xp(cp)(_aliases.asinh) -atan = get_xp(cp)(_aliases.atan) -atan2 = get_xp(cp)(_aliases.atan2) -atanh = get_xp(cp)(_aliases.atanh) -bitwise_left_shift = get_xp(cp)(_aliases.bitwise_left_shift) -bitwise_invert = get_xp(cp)(_aliases.bitwise_invert) -bitwise_right_shift = get_xp(cp)(_aliases.bitwise_right_shift) -concat = get_xp(cp)(_aliases.concat) -pow = get_xp(cp)(_aliases.pow) +# Basic renames +acos = cp.arccos +acosh = cp.arccosh +asin = cp.arcsin +asinh = cp.arcsinh +atan = cp.arctan +atan2 = cp.arctan2 +atanh = cp.arctanh +bitwise_left_shift = cp.left_shift +bitwise_invert = cp.invert +bitwise_right_shift = cp.right_shift +concat = cp.concatenate +pow = cp.power + arange = get_xp(cp)(_aliases.arange) empty = get_xp(cp)(_aliases.empty) empty_like = get_xp(cp)(_aliases.empty_like) diff --git a/array_api_compat/numpy/_aliases.py b/array_api_compat/numpy/_aliases.py index e6ff2ee2..31d28eeb 100644 --- a/array_api_compat/numpy/_aliases.py +++ b/array_api_compat/numpy/_aliases.py @@ -13,18 +13,20 @@ import numpy as np bool = np.bool_ -acos = get_xp(np)(_aliases.acos) -acosh = get_xp(np)(_aliases.acosh) -asin = get_xp(np)(_aliases.asin) -asinh = get_xp(np)(_aliases.asinh) -atan = get_xp(np)(_aliases.atan) -atan2 = get_xp(np)(_aliases.atan2) -atanh = get_xp(np)(_aliases.atanh) -bitwise_left_shift = get_xp(np)(_aliases.bitwise_left_shift) -bitwise_invert = get_xp(np)(_aliases.bitwise_invert) -bitwise_right_shift = get_xp(np)(_aliases.bitwise_right_shift) -concat = get_xp(np)(_aliases.concat) -pow = get_xp(np)(_aliases.pow) +# Basic renames +acos = np.arccos +acosh = np.arccosh +asin = np.arcsin +asinh = np.arcsinh +atan = np.arctan +atan2 = np.arctan2 +atanh = np.arctanh +bitwise_left_shift = np.left_shift +bitwise_invert = np.invert +bitwise_right_shift = np.right_shift +concat = np.concatenate +pow = np.power + arange = get_xp(np)(_aliases.arange) empty = get_xp(np)(_aliases.empty) empty_like = get_xp(np)(_aliases.empty_like)