Skip to content

Pass keyword arguments through in the aliases #11

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Jan 7, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
135 changes: 59 additions & 76 deletions array_api_compat/common/_aliases.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,43 +14,6 @@

from ._helpers import _check_device, _is_numpy_array, get_namespace

# Basic renames
def acos(x, /, xp):
return xp.arccos(x)

def acosh(x, /, xp):
return xp.arccosh(x)

def asin(x, /, xp):
return xp.arcsin(x)

def asinh(x, /, xp):
return xp.arcsinh(x)

def atan(x, /, xp):
return xp.arctan(x)

def atan2(x1, x2, /, xp):
return xp.arctan2(x1, x2)

def atanh(x, /, xp):
return xp.arctanh(x)

def bitwise_left_shift(x1, x2, /, xp):
return xp.left_shift(x1, x2)

def bitwise_invert(x, /, xp):
return xp.invert(x)

def bitwise_right_shift(x1, x2, /, xp):
return xp.right_shift(x1, x2)

def concat(arrays: Union[Tuple[ndarray, ...], List[ndarray]], /, xp, *, axis: Optional[int] = 0) -> ndarray:
return xp.concatenate(arrays, axis=axis)

def pow(x1, x2, /, xp):
return xp.power(x1, x2)

# These functions are modified from the NumPy versions.

def arange(
Expand All @@ -62,25 +25,28 @@ def arange(
xp,
dtype: Optional[Dtype] = None,
device: Optional[Device] = None,
**kwargs
) -> ndarray:
_check_device(xp, device)
return xp.arange(start, stop=stop, step=step, dtype=dtype)
return xp.arange(start, stop=stop, step=step, dtype=dtype, **kwargs)

def empty(
shape: Union[int, Tuple[int, ...]],
xp,
*,
dtype: Optional[Dtype] = None,
device: Optional[Device] = None,
**kwargs
) -> ndarray:
_check_device(xp, device)
return xp.empty(shape, dtype=dtype)
return xp.empty(shape, dtype=dtype, **kwargs)

def empty_like(
x: ndarray, /, xp, *, dtype: Optional[Dtype] = None, device: Optional[Device] = None
x: ndarray, /, xp, *, dtype: Optional[Dtype] = None, device: Optional[Device] = None,
**kwargs
) -> ndarray:
_check_device(xp, device)
return xp.empty_like(x, dtype=dtype)
return xp.empty_like(x, dtype=dtype, **kwargs)

def eye(
n_rows: int,
Expand All @@ -91,9 +57,10 @@ def eye(
k: int = 0,
dtype: Optional[Dtype] = None,
device: Optional[Device] = None,
**kwargs,
) -> ndarray:
_check_device(xp, device)
return xp.eye(n_rows, M=n_cols, k=k, dtype=dtype)
return xp.eye(n_rows, M=n_cols, k=k, dtype=dtype, **kwargs)

def full(
shape: Union[int, Tuple[int, ...]],
Expand All @@ -102,9 +69,10 @@ def full(
*,
dtype: Optional[Dtype] = None,
device: Optional[Device] = None,
**kwargs,
) -> ndarray:
_check_device(xp, device)
return xp.full(shape, fill_value, dtype=dtype)
return xp.full(shape, fill_value, dtype=dtype, **kwargs)

def full_like(
x: ndarray,
Expand All @@ -114,9 +82,10 @@ def full_like(
xp,
dtype: Optional[Dtype] = None,
device: Optional[Device] = None,
**kwargs,
) -> ndarray:
_check_device(xp, device)
return xp.full_like(x, fill_value, dtype=dtype)
return xp.full_like(x, fill_value, dtype=dtype, **kwargs)

def linspace(
start: Union[int, float],
Expand All @@ -128,41 +97,46 @@ def linspace(
dtype: Optional[Dtype] = None,
device: Optional[Device] = None,
endpoint: bool = True,
**kwargs,
) -> ndarray:
_check_device(xp, device)
return xp.linspace(start, stop, num, dtype=dtype, endpoint=endpoint)
return xp.linspace(start, stop, num, dtype=dtype, endpoint=endpoint, **kwargs)

def ones(
shape: Union[int, Tuple[int, ...]],
xp,
*,
dtype: Optional[Dtype] = None,
device: Optional[Device] = None,
**kwargs,
) -> ndarray:
_check_device(xp, device)
return xp.ones(shape, dtype=dtype)
return xp.ones(shape, dtype=dtype, **kwargs)

def ones_like(
x: ndarray, /, xp, *, dtype: Optional[Dtype] = None, device: Optional[Device] = None
x: ndarray, /, xp, *, dtype: Optional[Dtype] = None, device: Optional[Device] = None,
**kwargs,
) -> ndarray:
_check_device(xp, device)
return xp.ones_like(x, dtype=dtype)
return xp.ones_like(x, dtype=dtype, **kwargs)

def zeros(
shape: Union[int, Tuple[int, ...]],
xp,
*,
dtype: Optional[Dtype] = None,
device: Optional[Device] = None,
**kwargs,
) -> ndarray:
_check_device(xp, device)
return xp.zeros(shape, dtype=dtype)
return xp.zeros(shape, dtype=dtype, **kwargs)

def zeros_like(
x: ndarray, /, xp, *, dtype: Optional[Dtype] = None, device: Optional[Device] = None
x: ndarray, /, xp, *, dtype: Optional[Dtype] = None, device: Optional[Device] = None,
**kwargs,
) -> ndarray:
_check_device(xp, device)
return xp.zeros_like(x, dtype=dtype)
return xp.zeros_like(x, dtype=dtype, **kwargs)

# np.unique() is split into four functions in the array API:
# unique_all, unique_counts, unique_inverse, and unique_values (this is done
Expand Down Expand Up @@ -256,8 +230,9 @@ def std(
axis: Optional[Union[int, Tuple[int, ...]]] = None,
correction: Union[int, float] = 0.0, # correction instead of ddof
keepdims: bool = False,
**kwargs,
) -> ndarray:
return xp.std(x, axis=axis, ddof=correction, keepdims=keepdims)
return xp.std(x, axis=axis, ddof=correction, keepdims=keepdims, **kwargs)

def var(
x: ndarray,
Expand All @@ -267,8 +242,9 @@ def var(
axis: Optional[Union[int, Tuple[int, ...]]] = None,
correction: Union[int, float] = 0.0, # correction instead of ddof
keepdims: bool = False,
**kwargs,
) -> ndarray:
return xp.var(x, axis=axis, ddof=correction, keepdims=keepdims)
return xp.var(x, axis=axis, ddof=correction, keepdims=keepdims, **kwargs)

# Unlike transpose(), the axes argument to permute_dims() is required.
def permute_dims(x: ndarray, /, axes: Tuple[int, ...], xp) -> ndarray:
Expand All @@ -292,6 +268,7 @@ def _asarray(
device: Optional[Device] = None,
copy: "Optional[Union[bool, np._CopyMode]]" = None,
namespace = None,
**kwargs,
) -> ndarray:
"""
Array API compatibility wrapper for asarray().
Expand Down Expand Up @@ -333,33 +310,39 @@ def _asarray(
return xp.array(obj, copy=True, dtype=dtype)
return obj

return xp.asarray(obj, dtype=dtype)
return xp.asarray(obj, dtype=dtype, **kwargs)

# xp.reshape calls the keyword argument 'newshape' instead of 'shape'
def reshape(x: ndarray, /, shape: Tuple[int, ...], xp, copy: Optional[bool] = None) -> ndarray:
def reshape(x: ndarray,
/,
shape: Tuple[int, ...],
xp, copy: Optional[bool] = None,
**kwargs) -> ndarray:
if copy is True:
x = x.copy()
elif copy is False:
x.shape = shape
return x
return xp.reshape(x, shape)
return xp.reshape(x, shape, **kwargs)

# The descending keyword is new in sort and argsort, and 'kind' replaced with
# 'stable'
def argsort(
x: ndarray, /, xp, *, axis: int = -1, descending: bool = False, stable: bool = True
x: ndarray, /, xp, *, axis: int = -1, descending: bool = False, stable: bool = True,
**kwargs,
) -> ndarray:
# Note: this keyword argument is different, and the default is different.
kind = "stable" if stable else "quicksort"
if not descending:
res = xp.argsort(x, axis=axis, kind=kind)
res = xp.argsort(x, axis=axis, kind=kind, **kwargs)
else:
# As NumPy has no native descending sort, we imitate it here. Note that
# simply flipping the results of xp.argsort(x, ...) would not
# respect the relative order like it would in native descending sorts.
res = xp.flip(
xp.argsort(xp.flip(x, axis=axis), axis=axis, kind=kind),
axis=axis,
**kwargs,
)
# Rely on flip()/argsort() to validate axis
normalised_axis = axis if axis >= 0 else x.ndim + axis
Expand All @@ -368,11 +351,12 @@ def argsort(
return res

def sort(
x: ndarray, /, xp, *, axis: int = -1, descending: bool = False, stable: bool = True
x: ndarray, /, xp, *, axis: int = -1, descending: bool = False, stable: bool = True,
**kwargs,
) -> ndarray:
# Note: this keyword argument is different, and the default is different.
kind = "stable" if stable else "quicksort"
res = xp.sort(x, axis=axis, kind=kind)
res = xp.sort(x, axis=axis, kind=kind, **kwargs)
if descending:
res = xp.flip(res, axis=axis)
return res
Expand All @@ -386,11 +370,12 @@ def sum(
axis: Optional[Union[int, Tuple[int, ...]]] = None,
dtype: Optional[Dtype] = None,
keepdims: bool = False,
**kwargs,
) -> ndarray:
# `xp.sum` already upcasts integers, but not floats
if dtype is None and x.dtype == xp.float32:
dtype = xp.float64
return xp.sum(x, axis=axis, dtype=dtype, keepdims=keepdims)
return xp.sum(x, axis=axis, dtype=dtype, keepdims=keepdims, **kwargs)

def prod(
x: ndarray,
Expand All @@ -400,32 +385,30 @@ def prod(
axis: Optional[Union[int, Tuple[int, ...]]] = None,
dtype: Optional[Dtype] = None,
keepdims: bool = False,
**kwargs,
) -> ndarray:
if dtype is None and x.dtype == xp.float32:
dtype = xp.float64
return xp.prod(x, dtype=dtype, axis=axis, keepdims=keepdims)
return xp.prod(x, dtype=dtype, axis=axis, keepdims=keepdims, **kwargs)

# ceil, floor, and trunc return integers for integer inputs

def ceil(x: ndarray, /, xp) -> ndarray:
def ceil(x: ndarray, /, xp, **kwargs) -> ndarray:
if xp.issubdtype(x.dtype, xp.integer):
return x
return xp.ceil(x)
return xp.ceil(x, **kwargs)

def floor(x: ndarray, /, xp) -> ndarray:
def floor(x: ndarray, /, xp, **kwargs) -> ndarray:
if xp.issubdtype(x.dtype, xp.integer):
return x
return xp.floor(x)
return xp.floor(x, **kwargs)

def trunc(x: ndarray, /, xp) -> ndarray:
def trunc(x: ndarray, /, xp, **kwargs) -> ndarray:
if xp.issubdtype(x.dtype, xp.integer):
return x
return xp.trunc(x)

__all__ = ['acos', 'acosh', 'asin', 'asinh', 'atan', 'atan2', 'atanh',
'bitwise_left_shift', 'bitwise_invert', 'bitwise_right_shift',
'concat', 'pow', 'UniqueAllResult', 'UniqueCountsResult',
'UniqueInverseResult', 'unique_all', 'unique_counts',
'unique_inverse', 'unique_values', 'astype', 'std', 'var',
'permute_dims', 'reshape', 'argsort', 'sort', 'sum', 'prod',
'ceil', 'floor', 'trunc']
return xp.trunc(x, **kwargs)

__all__ = ['UniqueAllResult', 'UniqueCountsResult', 'UniqueInverseResult',
'unique_all', 'unique_counts', 'unique_inverse', 'unique_values',
'astype', 'std', 'var', 'permute_dims', 'reshape', 'argsort',
'sort', 'sum', 'prod', 'ceil', 'floor', 'trunc']
Loading